mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
feat(annotate): client_concurrency for parallel in-flight requests
Adds vlm.client_concurrency (default 16) which uses a ThreadPoolExecutor to fan out batched chat.completions calls. vllm batches them internally on the server side, giving big throughput wins on a single TP=1 server without needing DP/TP and the NCCL setup it requires. Module 3 now batches all per-episode VQA calls into a single generate_json invocation so they fire in parallel. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -121,6 +121,11 @@ class VlmConfig:
|
|||||||
``serve_port + i``) and round-robin client requests across them.
|
``serve_port + i``) and round-robin client requests across them.
|
||||||
Useful when DP/TP NCCL setup is broken on the node — single-GPU
|
Useful when DP/TP NCCL setup is broken on the node — single-GPU
|
||||||
replicas don't need cross-GPU communication."""
|
replicas don't need cross-GPU communication."""
|
||||||
|
client_concurrency: int = 16
|
||||||
|
"""Maximum number of in-flight chat requests the client issues in
|
||||||
|
parallel. vllm batches them internally for free, so bumping this
|
||||||
|
typically gives big throughput wins on a single TP=1 server. Set to
|
||||||
|
``1`` for strict serial calls."""
|
||||||
serve_ready_timeout_s: float = 600.0
|
serve_ready_timeout_s: float = 600.0
|
||||||
"""Max seconds to wait for the server to start serving requests."""
|
"""Max seconds to wait for the server to start serving requests."""
|
||||||
max_new_tokens: int = 512
|
max_new_tokens: int = 512
|
||||||
|
|||||||
@@ -98,11 +98,24 @@ class GeneralVqaModule:
|
|||||||
anchor_idx = _emission_anchor_indices(
|
anchor_idx = _emission_anchor_indices(
|
||||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||||
)
|
)
|
||||||
rows: list[dict[str, Any]] = []
|
# Build all messages first, then issue them as a single batched
|
||||||
|
# generate_json call so the client can fan them out concurrently.
|
||||||
|
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
|
||||||
for idx in anchor_idx:
|
for idx in anchor_idx:
|
||||||
ts = float(record.frame_timestamps[idx])
|
ts = float(record.frame_timestamps[idx])
|
||||||
qtype = rng.choice(self.config.question_types)
|
qtype = rng.choice(self.config.question_types)
|
||||||
qa = self._generate_one(record, qtype, ts)
|
messages = self._build_messages(record, qtype, ts)
|
||||||
|
per_call.append((ts, qtype, messages))
|
||||||
|
|
||||||
|
if not per_call:
|
||||||
|
staging.write("module_3", [])
|
||||||
|
return
|
||||||
|
|
||||||
|
results = self.vlm.generate_json([m for _, _, m in per_call])
|
||||||
|
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for (ts, _qtype, _messages), result in zip(per_call, results):
|
||||||
|
qa = self._postprocess(result)
|
||||||
if qa is None:
|
if qa is None:
|
||||||
continue
|
continue
|
||||||
question, answer = qa
|
question, answer = qa
|
||||||
@@ -126,17 +139,18 @@ class GeneralVqaModule:
|
|||||||
)
|
)
|
||||||
staging.write("module_3", rows)
|
staging.write("module_3", rows)
|
||||||
|
|
||||||
def _generate_one(
|
def _build_messages(
|
||||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||||
) -> tuple[str, dict[str, Any]] | None:
|
) -> list[dict[str, Any]]:
|
||||||
prompt = load_prompt("module_3_vqa").format(
|
prompt = load_prompt("module_3_vqa").format(
|
||||||
episode_task=record.episode_task,
|
episode_task=record.episode_task,
|
||||||
question_type=question_type,
|
question_type=question_type,
|
||||||
)
|
)
|
||||||
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
||||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
messages = [{"role": "user", "content": content}]
|
return [{"role": "user", "content": content}]
|
||||||
result = self.vlm.generate_json([messages])[0]
|
|
||||||
|
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
return None
|
return None
|
||||||
question = result.get("question")
|
question = result.get("question")
|
||||||
@@ -150,3 +164,10 @@ class GeneralVqaModule:
|
|||||||
if classify_vqa_answer(answer) is None:
|
if classify_vqa_answer(answer) is None:
|
||||||
return None
|
return None
|
||||||
return question.strip(), answer
|
return question.strip(), answer
|
||||||
|
|
||||||
|
def _generate_one(
|
||||||
|
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||||
|
) -> tuple[str, dict[str, Any]] | None:
|
||||||
|
messages = self._build_messages(record, question_type, frame_timestamp)
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
return self._postprocess(result)
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
@@ -390,27 +391,42 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
|
|||||||
"LEROBOT_OPENAI_SEND_MM_KWARGS", ""
|
"LEROBOT_OPENAI_SEND_MM_KWARGS", ""
|
||||||
).lower() in {"1", "true", "yes"}
|
).lower() in {"1", "true", "yes"}
|
||||||
|
|
||||||
|
rr_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _one_call(
|
||||||
|
messages: Sequence[dict[str, Any]], max_tok: int, temp: float
|
||||||
|
) -> str:
|
||||||
|
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": config.model_id,
|
||||||
|
"messages": api_messages,
|
||||||
|
"max_tokens": max_tok,
|
||||||
|
"temperature": temp,
|
||||||
|
}
|
||||||
|
if send_mm_kwargs and mm_kwargs:
|
||||||
|
kwargs["extra_body"] = {
|
||||||
|
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
|
||||||
|
}
|
||||||
|
with rr_lock:
|
||||||
|
chosen = clients[rr_counter["i"] % len(clients)]
|
||||||
|
rr_counter["i"] += 1
|
||||||
|
response = chosen.chat.completions.create(**kwargs)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
|
||||||
def _gen(
|
def _gen(
|
||||||
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
|
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
outs: list[str] = []
|
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||||
for messages in batch:
|
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
# Parallel fan-out — vllm batches these on the server side.
|
||||||
kwargs: dict[str, Any] = {
|
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
|
||||||
"model": config.model_id,
|
|
||||||
"messages": api_messages,
|
max_workers = min(config.client_concurrency, len(batch))
|
||||||
"max_tokens": max_tok,
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
"temperature": temp,
|
futures = [
|
||||||
}
|
pool.submit(_one_call, messages, max_tok, temp) for messages in batch
|
||||||
if send_mm_kwargs and mm_kwargs:
|
]
|
||||||
kwargs["extra_body"] = {
|
return [f.result() for f in futures]
|
||||||
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
|
|
||||||
}
|
|
||||||
chosen = clients[rr_counter["i"] % len(clients)]
|
|
||||||
rr_counter["i"] += 1
|
|
||||||
response = chosen.chat.completions.create(**kwargs)
|
|
||||||
outs.append(response.choices[0].message.content or "")
|
|
||||||
return outs
|
|
||||||
|
|
||||||
return _GenericTextClient(_gen, config)
|
return _GenericTextClient(_gen, config)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user