From 5722d365c5f3d1772e113f71a7074d46ac2b4c7c Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 28 Apr 2026 23:07:48 +0200 Subject: [PATCH] feat(annotate): client_concurrency for parallel in-flight requests Adds vlm.client_concurrency (default 16) which uses a ThreadPoolExecutor to fan out batched chat.completions calls. vllm batches them internally on the server side, giving big throughput wins on a single TP=1 server without needing DP/TP and the NCCL setup it requires. Module 3 now batches all per-episode VQA calls into a single generate_json invocation so they fire in parallel. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../annotations/steerable_pipeline/config.py | 5 ++ .../steerable_pipeline/modules/general_vqa.py | 33 +++++++++--- .../steerable_pipeline/vlm_client.py | 52 ++++++++++++------- 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py index 0aa515174..b494ba09e 100644 --- a/src/lerobot/annotations/steerable_pipeline/config.py +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -121,6 +121,11 @@ class VlmConfig: ``serve_port + i``) and round-robin client requests across them. Useful when DP/TP NCCL setup is broken on the node — single-GPU replicas don't need cross-GPU communication.""" + client_concurrency: int = 16 + """Maximum number of in-flight chat requests the client issues in + parallel. vllm batches them internally for free, so bumping this + typically gives big throughput wins on a single TP=1 server. Set to + ``1`` for strict serial calls.""" serve_ready_timeout_s: float = 600.0 """Max seconds to wait for the server to start serving requests.""" max_new_tokens: int = 512 diff --git a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py index b9930d63a..df34a2772 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py @@ -98,11 +98,24 @@ class GeneralVqaModule: anchor_idx = _emission_anchor_indices( record.frame_timestamps, self.config.vqa_emission_hz, self.config.K ) - rows: list[dict[str, Any]] = [] + # Build all messages first, then issue them as a single batched + # generate_json call so the client can fan them out concurrently. + per_call: list[tuple[float, str, list[dict[str, Any]]]] = [] for idx in anchor_idx: ts = float(record.frame_timestamps[idx]) qtype = rng.choice(self.config.question_types) - qa = self._generate_one(record, qtype, ts) + messages = self._build_messages(record, qtype, ts) + per_call.append((ts, qtype, messages)) + + if not per_call: + staging.write("module_3", []) + return + + results = self.vlm.generate_json([m for _, _, m in per_call]) + + rows: list[dict[str, Any]] = [] + for (ts, _qtype, _messages), result in zip(per_call, results): + qa = self._postprocess(result) if qa is None: continue question, answer = qa @@ -126,17 +139,18 @@ class GeneralVqaModule: ) staging.write("module_3", rows) - def _generate_one( + def _build_messages( self, record: EpisodeRecord, question_type: str, frame_timestamp: float - ) -> tuple[str, dict[str, Any]] | None: + ) -> list[dict[str, Any]]: prompt = load_prompt("module_3_vqa").format( episode_task=record.episode_task, question_type=question_type, ) images = self.frame_provider.frames_at(record, [frame_timestamp]) content = [*to_image_blocks(images), {"type": "text", "text": prompt}] - messages = [{"role": "user", "content": content}] - result = self.vlm.generate_json([messages])[0] + return [{"role": "user", "content": content}] + + def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None: if not isinstance(result, dict): return None question = result.get("question") @@ -150,3 +164,10 @@ class GeneralVqaModule: if classify_vqa_answer(answer) is None: return None return question.strip(), answer + + def _generate_one( + self, record: EpisodeRecord, question_type: str, frame_timestamp: float + ) -> tuple[str, dict[str, Any]] | None: + messages = self._build_messages(record, question_type, frame_timestamp) + result = self.vlm.generate_json([messages])[0] + return self._postprocess(result) diff --git a/src/lerobot/annotations/steerable_pipeline/vlm_client.py b/src/lerobot/annotations/steerable_pipeline/vlm_client.py index 9151c081c..5dbbb3cb7 100644 --- a/src/lerobot/annotations/steerable_pipeline/vlm_client.py +++ b/src/lerobot/annotations/steerable_pipeline/vlm_client.py @@ -34,6 +34,7 @@ from __future__ import annotations import json import os +import threading from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any, Protocol @@ -390,27 +391,42 @@ def _make_openai_client(config: VlmConfig) -> VlmClient: "LEROBOT_OPENAI_SEND_MM_KWARGS", "" ).lower() in {"1", "true", "yes"} + rr_lock = threading.Lock() + + def _one_call( + messages: Sequence[dict[str, Any]], max_tok: int, temp: float + ) -> str: + api_messages, mm_kwargs = _to_openai_messages(messages) + kwargs: dict[str, Any] = { + "model": config.model_id, + "messages": api_messages, + "max_tokens": max_tok, + "temperature": temp, + } + if send_mm_kwargs and mm_kwargs: + kwargs["extra_body"] = { + "mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True} + } + with rr_lock: + chosen = clients[rr_counter["i"] % len(clients)] + rr_counter["i"] += 1 + response = chosen.chat.completions.create(**kwargs) + return response.choices[0].message.content or "" + def _gen( batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float ) -> list[str]: - outs: list[str] = [] - for messages in batch: - api_messages, mm_kwargs = _to_openai_messages(messages) - kwargs: dict[str, Any] = { - "model": config.model_id, - "messages": api_messages, - "max_tokens": max_tok, - "temperature": temp, - } - if send_mm_kwargs and mm_kwargs: - kwargs["extra_body"] = { - "mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True} - } - chosen = clients[rr_counter["i"] % len(clients)] - rr_counter["i"] += 1 - response = chosen.chat.completions.create(**kwargs) - outs.append(response.choices[0].message.content or "") - return outs + if len(batch) <= 1 or config.client_concurrency <= 1: + return [_one_call(messages, max_tok, temp) for messages in batch] + # Parallel fan-out — vllm batches these on the server side. + from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415 + + max_workers = min(config.client_concurrency, len(batch)) + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [ + pool.submit(_one_call, messages, max_tok, temp) for messages in batch + ] + return [f.result() for f in futures] return _GenericTextClient(_gen, config)