mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
feat(annotate): client_concurrency for parallel in-flight requests
Adds vlm.client_concurrency (default 16) which uses a ThreadPoolExecutor to fan out batched chat.completions calls. vllm batches them internally on the server side, giving big throughput wins on a single TP=1 server without needing DP/TP and the NCCL setup it requires. Module 3 now batches all per-episode VQA calls into a single generate_json invocation so they fire in parallel. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -121,6 +121,11 @@ class VlmConfig:
|
||||
``serve_port + i``) and round-robin client requests across them.
|
||||
Useful when DP/TP NCCL setup is broken on the node — single-GPU
|
||||
replicas don't need cross-GPU communication."""
|
||||
client_concurrency: int = 16
|
||||
"""Maximum number of in-flight chat requests the client issues in
|
||||
parallel. vllm batches them internally for free, so bumping this
|
||||
typically gives big throughput wins on a single TP=1 server. Set to
|
||||
``1`` for strict serial calls."""
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
"""Max seconds to wait for the server to start serving requests."""
|
||||
max_new_tokens: int = 512
|
||||
|
||||
@@ -98,11 +98,24 @@ class GeneralVqaModule:
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Build all messages first, then issue them as a single batched
|
||||
# generate_json call so the client can fan them out concurrently.
|
||||
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
qa = self._generate_one(record, qtype, ts)
|
||||
messages = self._build_messages(record, qtype, ts)
|
||||
per_call.append((ts, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, _qtype, _messages), result in zip(per_call, results):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
@@ -126,17 +139,18 @@ class GeneralVqaModule:
|
||||
)
|
||||
staging.write("module_3", rows)
|
||||
|
||||
def _generate_one(
|
||||
def _build_messages(
|
||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
@@ -150,3 +164,10 @@ class GeneralVqaModule:
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
def _generate_one(
|
||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
messages = self._build_messages(record, question_type, frame_timestamp)
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
return self._postprocess(result)
|
||||
|
||||
@@ -34,6 +34,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
@@ -390,27 +391,42 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"LEROBOT_OPENAI_SEND_MM_KWARGS", ""
|
||||
).lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(
|
||||
messages: Sequence[dict[str, Any]], max_tok: int, temp: float
|
||||
) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
kwargs["extra_body"] = {
|
||||
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
|
||||
}
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(
|
||||
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
|
||||
) -> list[str]:
|
||||
outs: list[str] = []
|
||||
for messages in batch:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
kwargs["extra_body"] = {
|
||||
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
|
||||
}
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
outs.append(response.choices[0].message.content or "")
|
||||
return outs
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
|
||||
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(_one_call, messages, max_tok, temp) for messages in batch
|
||||
]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user