feat(annotate): client_concurrency for parallel in-flight requests

Adds vlm.client_concurrency (default 16) which uses a ThreadPoolExecutor
to fan out batched chat.completions calls. vllm batches them internally
on the server side, giving big throughput wins on a single TP=1 server
without needing DP/TP and the NCCL setup it requires.

Module 3 now batches all per-episode VQA calls into a single
generate_json invocation so they fire in parallel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-28 23:07:48 +02:00
parent 47f2ea17bb
commit 545d7eb713
3 changed files with 66 additions and 24 deletions
@@ -121,6 +121,11 @@ class VlmConfig:
``serve_port + i``) and round-robin client requests across them. ``serve_port + i``) and round-robin client requests across them.
Useful when DP/TP NCCL setup is broken on the node — single-GPU Useful when DP/TP NCCL setup is broken on the node — single-GPU
replicas don't need cross-GPU communication.""" replicas don't need cross-GPU communication."""
client_concurrency: int = 16
"""Maximum number of in-flight chat requests the client issues in
parallel. vllm batches them internally for free, so bumping this
typically gives big throughput wins on a single TP=1 server. Set to
``1`` for strict serial calls."""
serve_ready_timeout_s: float = 600.0 serve_ready_timeout_s: float = 600.0
"""Max seconds to wait for the server to start serving requests.""" """Max seconds to wait for the server to start serving requests."""
max_new_tokens: int = 512 max_new_tokens: int = 512
@@ -98,11 +98,24 @@ class GeneralVqaModule:
anchor_idx = _emission_anchor_indices( anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
) )
rows: list[dict[str, Any]] = [] # Build all messages first, then issue them as a single batched
# generate_json call so the client can fan them out concurrently.
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
for idx in anchor_idx: for idx in anchor_idx:
ts = float(record.frame_timestamps[idx]) ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types) qtype = rng.choice(self.config.question_types)
qa = self._generate_one(record, qtype, ts) messages = self._build_messages(record, qtype, ts)
per_call.append((ts, qtype, messages))
if not per_call:
staging.write("module_3", [])
return
results = self.vlm.generate_json([m for _, _, m in per_call])
rows: list[dict[str, Any]] = []
for (ts, _qtype, _messages), result in zip(per_call, results):
qa = self._postprocess(result)
if qa is None: if qa is None:
continue continue
question, answer = qa question, answer = qa
@@ -126,17 +139,18 @@ class GeneralVqaModule:
) )
staging.write("module_3", rows) staging.write("module_3", rows)
def _generate_one( def _build_messages(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None: ) -> list[dict[str, Any]]:
prompt = load_prompt("module_3_vqa").format( prompt = load_prompt("module_3_vqa").format(
episode_task=record.episode_task, episode_task=record.episode_task,
question_type=question_type, question_type=question_type,
) )
images = self.frame_provider.frames_at(record, [frame_timestamp]) images = self.frame_provider.frames_at(record, [frame_timestamp])
content = [*to_image_blocks(images), {"type": "text", "text": prompt}] content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}] return [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
if not isinstance(result, dict): if not isinstance(result, dict):
return None return None
question = result.get("question") question = result.get("question")
@@ -150,3 +164,10 @@ class GeneralVqaModule:
if classify_vqa_answer(answer) is None: if classify_vqa_answer(answer) is None:
return None return None
return question.strip(), answer return question.strip(), answer
def _generate_one(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None:
messages = self._build_messages(record, question_type, frame_timestamp)
result = self.vlm.generate_json([messages])[0]
return self._postprocess(result)
@@ -34,6 +34,7 @@ from __future__ import annotations
import json import json
import os import os
import threading
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Protocol from typing import Any, Protocol
@@ -390,27 +391,42 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
"LEROBOT_OPENAI_SEND_MM_KWARGS", "" "LEROBOT_OPENAI_SEND_MM_KWARGS", ""
).lower() in {"1", "true", "yes"} ).lower() in {"1", "true", "yes"}
rr_lock = threading.Lock()
def _one_call(
messages: Sequence[dict[str, Any]], max_tok: int, temp: float
) -> str:
api_messages, mm_kwargs = _to_openai_messages(messages)
kwargs: dict[str, Any] = {
"model": config.model_id,
"messages": api_messages,
"max_tokens": max_tok,
"temperature": temp,
}
if send_mm_kwargs and mm_kwargs:
kwargs["extra_body"] = {
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
}
with rr_lock:
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
def _gen( def _gen(
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
) -> list[str]: ) -> list[str]:
outs: list[str] = [] if len(batch) <= 1 or config.client_concurrency <= 1:
for messages in batch: return [_one_call(messages, max_tok, temp) for messages in batch]
api_messages, mm_kwargs = _to_openai_messages(messages) # Parallel fan-out — vllm batches these on the server side.
kwargs: dict[str, Any] = { from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
"model": config.model_id,
"messages": api_messages, max_workers = min(config.client_concurrency, len(batch))
"max_tokens": max_tok, with ThreadPoolExecutor(max_workers=max_workers) as pool:
"temperature": temp, futures = [
} pool.submit(_one_call, messages, max_tok, temp) for messages in batch
if send_mm_kwargs and mm_kwargs: ]
kwargs["extra_body"] = { return [f.result() for f in futures]
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
}
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
outs.append(response.choices[0].message.content or "")
return outs
return _GenericTextClient(_gen, config) return _GenericTextClient(_gen, config)