feat(annotate): client_concurrency for parallel in-flight requests

Adds vlm.client_concurrency (default 16) which uses a ThreadPoolExecutor
to fan out batched chat.completions calls. vllm batches them internally
on the server side, giving big throughput wins on a single TP=1 server
without needing DP/TP and the NCCL setup it requires.

Module 3 now batches all per-episode VQA calls into a single
generate_json invocation so they fire in parallel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-28 23:07:48 +02:00
parent 3d7e60cee4
commit 5722d365c5
3 changed files with 66 additions and 24 deletions
@@ -121,6 +121,11 @@ class VlmConfig:
``serve_port + i``) and round-robin client requests across them.
Useful when DP/TP NCCL setup is broken on the node — single-GPU
replicas don't need cross-GPU communication."""
client_concurrency: int = 16
"""Maximum number of in-flight chat requests the client issues in
parallel. vllm batches them internally for free, so bumping this
typically gives big throughput wins on a single TP=1 server. Set to
``1`` for strict serial calls."""
serve_ready_timeout_s: float = 600.0
"""Max seconds to wait for the server to start serving requests."""
max_new_tokens: int = 512
@@ -98,11 +98,24 @@ class GeneralVqaModule:
anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
)
rows: list[dict[str, Any]] = []
# Build all messages first, then issue them as a single batched
# generate_json call so the client can fan them out concurrently.
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
for idx in anchor_idx:
ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types)
qa = self._generate_one(record, qtype, ts)
messages = self._build_messages(record, qtype, ts)
per_call.append((ts, qtype, messages))
if not per_call:
staging.write("module_3", [])
return
results = self.vlm.generate_json([m for _, _, m in per_call])
rows: list[dict[str, Any]] = []
for (ts, _qtype, _messages), result in zip(per_call, results):
qa = self._postprocess(result)
if qa is None:
continue
question, answer = qa
@@ -126,17 +139,18 @@ class GeneralVqaModule:
)
staging.write("module_3", rows)
def _generate_one(
def _build_messages(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None:
) -> list[dict[str, Any]]:
prompt = load_prompt("module_3_vqa").format(
episode_task=record.episode_task,
question_type=question_type,
)
images = self.frame_provider.frames_at(record, [frame_timestamp])
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
return [{"role": "user", "content": content}]
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
if not isinstance(result, dict):
return None
question = result.get("question")
@@ -150,3 +164,10 @@ class GeneralVqaModule:
if classify_vqa_answer(answer) is None:
return None
return question.strip(), answer
def _generate_one(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None:
messages = self._build_messages(record, question_type, frame_timestamp)
result = self.vlm.generate_json([messages])[0]
return self._postprocess(result)
@@ -34,6 +34,7 @@ from __future__ import annotations
import json
import os
import threading
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any, Protocol
@@ -390,27 +391,42 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
"LEROBOT_OPENAI_SEND_MM_KWARGS", ""
).lower() in {"1", "true", "yes"}
rr_lock = threading.Lock()
def _one_call(
messages: Sequence[dict[str, Any]], max_tok: int, temp: float
) -> str:
api_messages, mm_kwargs = _to_openai_messages(messages)
kwargs: dict[str, Any] = {
"model": config.model_id,
"messages": api_messages,
"max_tokens": max_tok,
"temperature": temp,
}
if send_mm_kwargs and mm_kwargs:
kwargs["extra_body"] = {
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
}
with rr_lock:
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
def _gen(
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
) -> list[str]:
outs: list[str] = []
for messages in batch:
api_messages, mm_kwargs = _to_openai_messages(messages)
kwargs: dict[str, Any] = {
"model": config.model_id,
"messages": api_messages,
"max_tokens": max_tok,
"temperature": temp,
}
if send_mm_kwargs and mm_kwargs:
kwargs["extra_body"] = {
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
}
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
outs.append(response.choices[0].message.content or "")
return outs
if len(batch) <= 1 or config.client_concurrency <= 1:
return [_one_call(messages, max_tok, temp) for messages in batch]
# Parallel fan-out — vllm batches these on the server side.
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
max_workers = min(config.client_concurrency, len(batch))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [
pool.submit(_one_call, messages, max_tok, temp) for messages in batch
]
return [f.result() for f in futures]
return _GenericTextClient(_gen, config)