Compare commits

...

18 Commits

Author SHA1 Message Date
Pepijn 72d0fc0dce refactor(annotate): drop HF Inference Providers code path
Default backend is now a local OpenAI-compatible server (vllm /
transformers) which auto_serve spawns. Removes the
use_hf_inference_providers config flag and the router.huggingface.co
routing branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 00:53:08 +02:00
Pepijn 3c6a6b39a2 feat(annotate): --vlm.push_to_hub uploads the annotated dataset
After the pipeline completes, optionally create/locate a dataset repo
and upload the dataset root (excluding .annotate_staging/). Add
push_private and push_commit_message knobs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 00:28:38 +02:00
Pepijn 39f6167fa3 feat(annotate): parallelize episodes within each module phase
Saturates parallel_servers + client_concurrency. Previously the
executor processed one episode at a time, so each Module 1 episode's
3-5 dependent VLM calls hit a single server with the others idle. Now
defaults to 16 episodes in flight; configurable via
ExecutorConfig.episode_parallelism.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 23:59:02 +02:00
Pepijn caef184c82 fix(annotate): probe /v1/models for spawn-helper readiness
vllm with --uvicorn-log-level warning suppresses the "Uvicorn running"
banner that the readiness watcher waited for, so the spawn helper hung
forever even after the API was live. Add an HTTP probe in parallel with
the log watcher and broaden the log markers to include vllm's own
"Starting vLLM API server" / "Available routes are" lines.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 23:47:52 +02:00
Pepijn 7bbf5777a2 fix(annotate): lock-protect per-line writes for parallel server streams
8 server-streaming threads writing chars unsynchronized cause UTF-8
sequences from different servers to interleave mid-byte, garbling the
terminal output. Switch to line-buffered reads with a single shared
print lock — output stays readable, ready-marker detection still works
on the line containing 'Uvicorn running' / 'Application startup
complete'.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 23:19:37 +02:00
Pepijn 545d7eb713 feat(annotate): client_concurrency for parallel in-flight requests
Adds vlm.client_concurrency (default 16) which uses a ThreadPoolExecutor
to fan out batched chat.completions calls. vllm batches them internally
on the server side, giving big throughput wins on a single TP=1 server
without needing DP/TP and the NCCL setup it requires.

Module 3 now batches all per-episode VQA calls into a single
generate_json invocation so they fire in parallel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 23:07:48 +02:00
Pepijn 47f2ea17bb feat(annotate): parallel_servers spawns N independent vllm replicas
Adds --vlm.parallel_servers=N. Spawns N independent vllm processes
(each pinned to GPU i via CUDA_VISIBLE_DEVICES, listening on
serve_port+i) and round-robins requests across them. Sidesteps DP/TP
NCCL setup failures on nodes with restricted P2P/SHM.

Default serve_command for parallel mode: vllm serve <model_id>
--tensor-parallel-size 1 --max-model-len 32768 --uvicorn-log-level
warning. Override via --vlm.serve_command (use {port} placeholder).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 23:06:20 +02:00
Pepijn 5119d22f1f feat(annotate): per-episode progress logs in executor 2026-04-28 22:56:03 +02:00
Pepijn 916b419af3 fix(annotate): don't crash pipeline on persistent JSON parse failure
Some prompts/models occasionally return pure prose with no JSON object
even on retry. Returning None (and logging a preview) lets the pipeline
skip that one VLM call cleanly instead of aborting the whole episode.
The modules already check for None / non-dict results and degrade
gracefully (no row emitted from that call).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 22:33:54 +02:00
Pepijn 7c10c4fcdd fix(annotate): robust JSON extraction (think tags + first balanced object)
Models often wrap JSON in prose or <think>...</think> blocks. Strip the
think tags first, then try direct json.loads, then fall back to scanning
for the first balanced {...} substring (ignoring braces inside strings).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 22:15:25 +02:00
Pepijn 421e84497b fix(annotate): stream child stdout char-by-char so tqdm \\r progress flushes 2026-04-28 21:58:12 +02:00
Pepijn 9d38477728 test(annotate): adjust video-block test for fps-based frame sampling 2026-04-28 19:49:08 +02:00
Pepijn b895e3b057 feat(annotate): Module 1 samples image frames at fps rate
Replace the fixed max_video_frames count with a rate (default 1 fps).
A 30 s episode now sends 30 frames; a 5 s episode sends 5; capped at
max_video_frames (default 128) to avoid blowing up the payload on long
episodes.

Override with --module_1.frames_per_second=2.0 for denser sampling, or
--module_1.frames_per_second=0.5 for sparser.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 19:48:25 +02:00
Pepijn a8aa6b08ba feat(annotate): use cached HF token from huggingface-cli login
Fall back to huggingface_hub.get_token() when HF_TOKEN/HUGGINGFACE_API_KEY
env vars aren't set. That picks up the token cached by
'huggingface-cli login' so users don't need to export it on every shell.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 19:36:24 +02:00
Pepijn 4ac6c58ab1 feat(annotate): default to HF Inference Providers, no local GPU needed
Flip the default backend to 'openai' with use_hf_inference_providers=True
and a Qwen3-VL-30B-A3B-Instruct:novita default model_id. The CLI now
runs end-to-end without a local model load — annotations are produced
by sending video_url + prompt to https://router.huggingface.co/v1.

Switch back to local inference with --vlm.backend=vllm or
--vlm.use_hf_inference_providers=false.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 19:33:34 +02:00
Pepijn d5559a9445 feat(annotate): one-flag HF Inference Providers backend
Setting --vlm.use_hf_inference_providers=true routes requests through
https://router.huggingface.co/v1 using HF_TOKEN as the API key, and
disables auto_serve so no local server is spawned. Combine with a
provider-pinned model id like 'Qwen/Qwen3-VL-30B-A3B-Instruct:novita'
or any plain model id to let HF route.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 19:28:13 +02:00
Pepijn 7a7b8ac111 fix(annotate): omit mm_processor_kwargs by default; transformers serve rejects it
transformers serve returns HTTP 422 'Unexpected fields' when
mm_processor_kwargs is in extra_body — that field is vllm-specific.
Drop it by default; opt in via LEROBOT_OPENAI_SEND_MM_KWARGS=1 when
talking to vllm serve.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 19:11:58 +02:00
Pepijn 504bad6342 fix(annotate): mm_processor_kwargs in extra_body; inline file URLs as data URLs
Two fixes for video_url with transformers serve:
- fps must be in extra_body.mm_processor_kwargs, not in the content
  block; otherwise the server discards it as unknown kwargs.
- file:// URLs aren't fetched by transformers serve. Read the local mp4
  and inline it as a base64 data:video/mp4 URL so the server sees the
  bytes directly.

Both surface as std::bad_alloc on the server side when wrong, which is
unhelpful but explains what we hit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-28 18:53:43 +02:00
7 changed files with 489 additions and 71 deletions
@@ -32,7 +32,14 @@ class Module1Config:
""" """
enabled: bool = True enabled: bool = True
max_video_frames: int = 32 frames_per_second: float = 1.0
"""Sample one image-frame per ``1/fps`` seconds across the episode for
Module 1's subtask-decomposition prompt. ``1.0`` = 1 fps. Capped by
``max_video_frames`` to avoid blowing up the request payload."""
max_video_frames: int = 128
"""Hard cap on the number of frames Module 1 sends. With ``fps=1`` and
a 30 s episode this yields 30 frames. Bumped from 32 since each frame
is small (~30-100 KB PNG when base64'd)."""
min_subtask_seconds: float = 1.5 min_subtask_seconds: float = 1.5
plan_max_steps: int = 8 plan_max_steps: int = 8
use_video_url: bool = False use_video_url: bool = False
@@ -72,13 +79,12 @@ class Module3Config:
class VlmConfig: class VlmConfig:
"""Shared Qwen-VL client configuration.""" """Shared Qwen-VL client configuration."""
backend: str = "vllm" backend: str = "openai"
"""One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests only). """One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests only).
The ``openai`` backend talks to any OpenAI-compatible server — works Default ``openai`` talks to a local OpenAI-compatible server (vllm /
with ``vllm serve``, ``transformers serve``, ``ktransformers serve``, transformers) which the CLI auto-spawns when ``auto_serve=True``."""
or hosted endpoints. Set ``api_base`` and (optionally) ``api_key``.""" model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
model_id: str = "Qwen/Qwen3.6-27B-FP8"
api_base: str = "http://localhost:8000/v1" api_base: str = "http://localhost:8000/v1"
"""Base URL for the ``openai`` backend.""" """Base URL for the ``openai`` backend."""
api_key: str = "EMPTY" api_key: str = "EMPTY"
@@ -95,7 +101,21 @@ class VlmConfig:
"""Port the auto-spawned server binds to. Sets ``api_base`` automatically.""" """Port the auto-spawned server binds to. Sets ``api_base`` automatically."""
serve_command: str | None = None serve_command: str | None = None
"""Override the auto-serve command (full shell command). When ``None``, """Override the auto-serve command (full shell command). When ``None``,
we run ``transformers serve <model_id> --port <serve_port> --continuous-batching``.""" we run ``transformers serve <model_id> --port <serve_port> --continuous-batching``.
When ``parallel_servers > 1``, the literal ``{port}`` placeholder in
this command (if present) is substituted per-replica."""
parallel_servers: int = 1
"""When >1, spawn this many independent inference servers (each pinned
to one GPU via ``CUDA_VISIBLE_DEVICES`` and listening on
``serve_port + i``) and round-robin client requests across them.
Useful when DP/TP NCCL setup is broken on the node — single-GPU
replicas don't need cross-GPU communication."""
client_concurrency: int = 16
"""Maximum number of in-flight chat requests the client issues in
parallel. vllm batches them internally for free, so bumping this
typically gives big throughput wins on a single TP=1 server. Set to
``1`` for strict serial calls."""
serve_ready_timeout_s: float = 600.0 serve_ready_timeout_s: float = 600.0
"""Max seconds to wait for the server to start serving requests.""" """Max seconds to wait for the server to start serving requests."""
max_new_tokens: int = 512 max_new_tokens: int = 512
@@ -132,6 +152,13 @@ class ExecutorConfig:
slurm_gpus: int = 1 slurm_gpus: int = 1
slurm_time: str = "06:00:00" slurm_time: str = "06:00:00"
workers: int = 1 workers: int = 1
episode_parallelism: int = 16
"""Number of episodes processed concurrently within each module phase.
Each in-flight episode sends 35 dependent VLM calls; bumping this is
how you actually saturate ``parallel_servers`` and ``client_concurrency``
— without it, the executor loops one episode at a time and the
inference servers sit ~90% idle. Set to ``1`` for strict serial
execution."""
@dataclass @dataclass
@@ -164,5 +191,14 @@ class AnnotationPipelineConfig:
skip_validation: bool = False skip_validation: bool = False
only_episodes: tuple[int, ...] | None = None only_episodes: tuple[int, ...] | None = None
push_to_hub: str | None = None
"""If set, after the pipeline completes, upload the annotated dataset
root to the Hugging Face Hub as a dataset repo with this id (e.g.
``pepijn/super_poulain_steerable``). Creates the repo if missing."""
push_private: bool = False
"""When ``push_to_hub`` is set, create the repo as private."""
push_commit_message: str | None = None
"""Override the commit message used for the hub upload."""
def resolved_staging_dir(self, root: Path) -> Path: def resolved_staging_dir(self, root: Path) -> Path:
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging" return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
@@ -101,7 +101,7 @@ class Executor:
raise ValueError(f"No episodes found under {root}/data/") raise ValueError(f"No episodes found under {root}/data/")
executor_kind = select_executor_class(n, self.config.executor) executor_kind = select_executor_class(n, self.config.executor)
logger.info("annotate: %d episodes; executor=%s", n, executor_kind) print(f"[annotate] {n} episodes total; executor={executor_kind}", flush=True)
staging_dir = self.config.resolved_staging_dir(root) staging_dir = self.config.resolved_staging_dir(root)
staging_dir.mkdir(parents=True, exist_ok=True) staging_dir.mkdir(parents=True, exist_ok=True)
@@ -117,11 +117,15 @@ class Executor:
# Phase 4: Module 3 (VQA) # Phase 4: Module 3 (VQA)
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3)) phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
print("[annotate] running validator...", flush=True)
report = self.validator.validate(records, staging_dir) report = self.validator.validate(records, staging_dir)
if not report.ok and not self.config.skip_validation: if not report.ok and not self.config.skip_validation:
raise RuntimeError(f"Staging validation failed: {report.summary()}") raise RuntimeError(f"Staging validation failed: {report.summary()}")
print(f"[annotate] validator: {report.summary()}", flush=True)
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
written = self.writer.write_all(records, staging_dir, root) written = self.writer.write_all(records, staging_dir, root)
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report) return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
def _run_module_phase( def _run_module_phase(
@@ -131,16 +135,56 @@ class Executor:
staging_dir: Path, staging_dir: Path,
module: Any, module: Any,
) -> PhaseResult: ) -> PhaseResult:
import time as _time # noqa: PLC0415
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
if not module.enabled: if not module.enabled:
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records)) return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
processed = 0 n = len(records)
for record in records: parallelism = max(1, min(self.config.executor.episode_parallelism, n))
print(
f"[annotate] phase={name} starting on {n} episode(s) "
f"(parallelism={parallelism})",
flush=True,
)
t0 = _time.time()
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
i, record = idx_record
ep_start = _time.time()
staging = EpisodeStaging(staging_dir, record.episode_index) staging = EpisodeStaging(staging_dir, record.episode_index)
module.run_episode(record, staging) module.run_episode(record, staging)
return i, record.episode_index, _time.time() - ep_start
processed = 0
if parallelism == 1:
for i, record in enumerate(records, 1):
_, ep_idx, elapsed = _do((i, record))
processed += 1 processed += 1
print(
f"[annotate] {name} episode {i}/{n} "
f"(idx={ep_idx}) done in {elapsed:.1f}s",
flush=True,
)
else:
with ThreadPoolExecutor(max_workers=parallelism) as pool:
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
for fut in as_completed(futures):
i, ep_idx, elapsed = fut.result()
processed += 1
print(
f"[annotate] {name} episode {processed}/{n} "
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
flush=True,
)
total = _time.time() - t0
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0) return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
def _run_plan_update_phase(self, records: list[EpisodeRecord], staging_dir: Path) -> PhaseResult: def _run_plan_update_phase( # noqa: PLR0915
self, records: list[EpisodeRecord], staging_dir: Path
) -> PhaseResult:
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2. """Re-emit ``plan`` rows at each interjection timestamp from Module 2.
Module 1 owns the prompt; Module 2 produced the timestamps. This phase Module 1 owns the prompt; Module 2 produced the timestamps. This phase
@@ -98,11 +98,24 @@ class GeneralVqaModule:
anchor_idx = _emission_anchor_indices( anchor_idx = _emission_anchor_indices(
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
) )
rows: list[dict[str, Any]] = [] # Build all messages first, then issue them as a single batched
# generate_json call so the client can fan them out concurrently.
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
for idx in anchor_idx: for idx in anchor_idx:
ts = float(record.frame_timestamps[idx]) ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types) qtype = rng.choice(self.config.question_types)
qa = self._generate_one(record, qtype, ts) messages = self._build_messages(record, qtype, ts)
per_call.append((ts, qtype, messages))
if not per_call:
staging.write("module_3", [])
return
results = self.vlm.generate_json([m for _, _, m in per_call])
rows: list[dict[str, Any]] = []
for (ts, _qtype, _messages), result in zip(per_call, results):
qa = self._postprocess(result)
if qa is None: if qa is None:
continue continue
question, answer = qa question, answer = qa
@@ -126,17 +139,18 @@ class GeneralVqaModule:
) )
staging.write("module_3", rows) staging.write("module_3", rows)
def _generate_one( def _build_messages(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None: ) -> list[dict[str, Any]]:
prompt = load_prompt("module_3_vqa").format( prompt = load_prompt("module_3_vqa").format(
episode_task=record.episode_task, episode_task=record.episode_task,
question_type=question_type, question_type=question_type,
) )
images = self.frame_provider.frames_at(record, [frame_timestamp]) images = self.frame_provider.frames_at(record, [frame_timestamp])
content = [*to_image_blocks(images), {"type": "text", "text": prompt}] content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}] return [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
if not isinstance(result, dict): if not isinstance(result, dict):
return None return None
question = result.get("question") question = result.get("question")
@@ -150,3 +164,10 @@ class GeneralVqaModule:
if classify_vqa_answer(answer) is None: if classify_vqa_answer(answer) is None:
return None return None
return question.strip(), answer return question.strip(), answer
def _generate_one(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None:
messages = self._build_messages(record, question_type, frame_timestamp)
result = self.vlm.generate_json([messages])[0]
return self._postprocess(result)
@@ -175,9 +175,12 @@ class PlanSubtasksMemoryModule:
else [] else []
) )
else: else:
video_frames = self.frame_provider.video_for_episode( target_count = max(
record, self.config.max_video_frames 1,
int(round(episode_duration * self.config.frames_per_second)),
) )
target_count = min(target_count, self.config.max_video_frames)
video_frames = self.frame_provider.video_for_episode(record, target_count)
video_block = to_video_block(video_frames) video_block = to_video_block(video_frames)
content = [*video_block, {"type": "text", "text": prompt}] content = [*video_block, {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}] messages = [{"role": "user", "content": content}]
@@ -33,6 +33,8 @@ The client speaks one method, :meth:`VlmClient.generate_json`, which:
from __future__ import annotations from __future__ import annotations
import json import json
import os
import threading
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Protocol from typing import Any, Protocol
@@ -75,13 +77,57 @@ class StubVlmClient:
def _strip_to_json(text: str) -> Any: def _strip_to_json(text: str) -> Any:
text = text.strip() text = text.strip()
# Strip <think>...</think> blocks (Qwen3 Thinking style)
while "<think>" in text and "</think>" in text:
start = text.find("<think>")
end = text.find("</think>", start) + len("</think>")
text = (text[:start] + text[end:]).strip()
# Strip ```json ... ``` fences from chat-tuned backbones
if text.startswith("```"): if text.startswith("```"):
# tolerate ```json ... ``` fences from chat-tuned backbones
first = text.find("\n") first = text.find("\n")
last = text.rfind("```") last = text.rfind("```")
if first != -1 and last != -1 and last > first: if first != -1 and last != -1 and last > first:
text = text[first + 1 : last].strip() text = text[first + 1 : last].strip()
try:
return json.loads(text) return json.loads(text)
except (ValueError, json.JSONDecodeError):
pass
# Fall back to extracting the first balanced {...} block.
obj_text = _extract_first_json_object(text)
if obj_text is None:
raise json.JSONDecodeError("No JSON object found", text, 0)
return json.loads(obj_text)
def _extract_first_json_object(text: str) -> str | None:
"""Return the first balanced ``{...}`` substring, ignoring braces in
string literals. Returns ``None`` if no balanced block is found."""
start = text.find("{")
if start < 0:
return None
depth = 0
in_string = False
escape = False
for i in range(start, len(text)):
ch = text[i]
if escape:
escape = False
continue
if ch == "\\":
escape = True
continue
if ch == '"' and not escape:
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : i + 1]
return None
@dataclass @dataclass
@@ -119,7 +165,17 @@ class _GenericTextClient:
}, },
] ]
retry_text = self.generate_text([retry], max_tok, temp)[0] retry_text = self.generate_text([retry], max_tok, temp)[0]
try:
out.append(_strip_to_json(retry_text)) out.append(_strip_to_json(retry_text))
except (ValueError, json.JSONDecodeError):
# After retry: log preview and return None instead of crashing
# the whole pipeline. Modules treat None as "skip".
preview = retry_text.strip().replace("\n", " ")[:200]
print(
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
flush=True,
)
out.append(None)
return out return out
@@ -276,39 +332,206 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
) from exc ) from exc
api_base = config.api_base api_base = config.api_base
api_key = config.api_key
auto_serve = config.auto_serve
api_bases: list[str] = [api_base]
print( print(
f"[lerobot-annotate] backend=openai model={config.model_id} " f"[lerobot-annotate] backend=openai model={config.model_id} "
f"api_base={api_base} auto_serve={config.auto_serve}", f"api_base={api_base} auto_serve={auto_serve}",
flush=True, flush=True,
) )
if config.auto_serve: if auto_serve:
if _server_is_up(api_base): if config.parallel_servers > 1:
print(
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
flush=True,
)
api_bases = _spawn_parallel_inference_servers(config)
elif _server_is_up(api_base):
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True) print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
else: else:
print("[lerobot-annotate] no server reachable; spawning one", flush=True) print("[lerobot-annotate] no server reachable; spawning one", flush=True)
api_base = _spawn_inference_server(config) api_base = _spawn_inference_server(config)
api_bases = [api_base]
print(f"[lerobot-annotate] server ready at {api_base}", flush=True) print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
client = OpenAI(base_url=api_base, api_key=config.api_key) clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
client = clients[0]
# round-robin counter for parallel mode
rr_counter = {"i": 0}
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
# rejects it with HTTP 422. Send it only when explicitly opted in via
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
send_mm_kwargs = os.environ.get(
"LEROBOT_OPENAI_SEND_MM_KWARGS", ""
).lower() in {"1", "true", "yes"}
rr_lock = threading.Lock()
def _one_call(
messages: Sequence[dict[str, Any]], max_tok: int, temp: float
) -> str:
api_messages, mm_kwargs = _to_openai_messages(messages)
kwargs: dict[str, Any] = {
"model": config.model_id,
"messages": api_messages,
"max_tokens": max_tok,
"temperature": temp,
}
if send_mm_kwargs and mm_kwargs:
kwargs["extra_body"] = {
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
}
with rr_lock:
chosen = clients[rr_counter["i"] % len(clients)]
rr_counter["i"] += 1
response = chosen.chat.completions.create(**kwargs)
return response.choices[0].message.content or ""
def _gen( def _gen(
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
) -> list[str]: ) -> list[str]:
outs: list[str] = [] if len(batch) <= 1 or config.client_concurrency <= 1:
for messages in batch: return [_one_call(messages, max_tok, temp) for messages in batch]
api_messages = [_to_openai_message(m) for m in messages] # Parallel fan-out — vllm batches these on the server side.
response = client.chat.completions.create( from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
model=config.model_id,
messages=api_messages, max_workers = min(config.client_concurrency, len(batch))
max_tokens=max_tok, with ThreadPoolExecutor(max_workers=max_workers) as pool:
temperature=temp, futures = [
) pool.submit(_one_call, messages, max_tok, temp) for messages in batch
outs.append(response.choices[0].message.content or "") ]
return outs return [f.result() for f in futures]
return _GenericTextClient(_gen, config) return _GenericTextClient(_gen, config)
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
"""Spawn ``config.parallel_servers`` independent vllm replicas.
Each replica:
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
- listens on ``serve_port + i``
- is shut down via the same atexit hook as the single-server path
Returns the list of ``api_base`` URLs the client should round-robin
across.
"""
import atexit # noqa: PLC0415
import os as _os # noqa: PLC0415
import shlex # noqa: PLC0415
import signal # noqa: PLC0415
import subprocess # noqa: PLC0415
import sys # noqa: PLC0415
import threading # noqa: PLC0415
import time # noqa: PLC0415
n = config.parallel_servers
api_bases: list[str] = []
procs: list[subprocess.Popen] = []
ready_events: list[threading.Event] = []
# Multiple readiness signals — uvicorn's own banner is suppressed at
# ``--uvicorn-log-level warning``, so we also accept vllm's own
# "Starting vLLM API server" line and the route-listing line. The
# HTTP probe below is the ultimate fallback.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
# Single lock for all server-stream threads so multibyte chars from
# different servers don't interleave and tear UTF-8 sequences.
print_lock = threading.Lock()
base_cmd = config.serve_command or (
f"vllm serve {shlex.quote(config.model_id)} "
f"--tensor-parallel-size 1 "
f"--max-model-len {config.max_model_len or 32768} "
f"--uvicorn-log-level warning"
)
for i in range(n):
port = config.serve_port + i
env = _os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(i)
cmd = base_cmd
if "{port}" in cmd:
cmd = cmd.replace("{port}", str(port))
else:
cmd = f"{cmd} --port {port}"
api_base = f"http://localhost:{port}/v1"
api_bases.append(api_base)
print(f"[server-{i}] launching on GPU {i} port {port}: {cmd}", flush=True)
proc = subprocess.Popen(
shlex.split(cmd),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
env=env,
)
procs.append(proc)
ready = threading.Event()
ready_events.append(ready)
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
# Read whole lines and emit each line atomically under the
# shared print_lock so output from N servers stays readable.
assert p.stdout is not None
for line in iter(p.stdout.readline, ""):
with print_lock:
sys.stdout.write(f"[server-{idx}] {line}")
if not line.endswith(("\n", "\r")):
sys.stdout.write("\n")
sys.stdout.flush()
if any(m in line for m in ready_markers):
ev.set()
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
while not ev.is_set() and p.poll() is None:
if _server_is_up(base):
print(f"[server-{idx}] ready (http probe)", flush=True)
ev.set()
return
time.sleep(2)
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
def _shutdown() -> None:
for i, p in enumerate(procs):
if p.poll() is None:
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
p.send_signal(signal.SIGINT)
for p in procs:
try:
p.wait(timeout=15)
except subprocess.TimeoutExpired:
p.kill()
p.wait(timeout=5)
atexit.register(_shutdown)
deadline = time.monotonic() + config.serve_ready_timeout_s
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
for i, p in enumerate(procs):
if p.poll() is not None:
raise RuntimeError(
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
)
time.sleep(2)
if any(not ev.is_set() for ev in ready_events):
raise RuntimeError(
f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s"
)
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
return api_bases
def _server_is_up(api_base: str) -> bool: def _server_is_up(api_base: str) -> bool:
"""Return True if ``api_base/models`` answers 200 within 2 seconds.""" """Return True if ``api_base/models`` answers 200 within 2 seconds."""
import urllib.request # noqa: PLC0415 import urllib.request # noqa: PLC0415
@@ -361,15 +584,49 @@ def _spawn_inference_server(config: VlmConfig) -> str:
# rescans its cache on every model-list request, which can exceed # rescans its cache on every model-list request, which can exceed
# the urllib timeout and trigger an infinite probe loop. # the urllib timeout and trigger an infinite probe loop.
ready_event = threading.Event() ready_event = threading.Event()
ready_markers = ("Uvicorn running", "Application startup complete") # See _spawn_parallel_inference_servers for why we accept these.
ready_markers = (
"Uvicorn running",
"Application startup complete",
"Starting vLLM API server",
"Available routes are",
)
def _probe() -> None:
while not ready_event.is_set() and proc.poll() is None:
if _server_is_up(api_base):
print("[server] ready (http probe)", flush=True)
ready_event.set()
return
time.sleep(2)
threading.Thread(target=_probe, daemon=True).start()
def _stream_output() -> None: def _stream_output() -> None:
# Read raw chunks instead of iterating lines so tqdm progress
# bars (which overwrite using \r) flush in real time.
assert proc.stdout is not None assert proc.stdout is not None
for line in proc.stdout: buf = ""
sys.stdout.write(f"[server] {line}") prefix_started = False
while True:
ch = proc.stdout.read(1)
if ch == "":
# process exited; flush any tail
if buf:
sys.stdout.write(buf)
sys.stdout.flush() sys.stdout.flush()
if any(marker in line for marker in ready_markers): return
if not prefix_started:
sys.stdout.write("[server] ")
prefix_started = True
sys.stdout.write(ch)
sys.stdout.flush()
buf += ch
if ch in ("\n", "\r"):
if any(marker in buf for marker in ready_markers):
ready_event.set() ready_event.set()
buf = ""
prefix_started = False
threading.Thread(target=_stream_output, daemon=True).start() threading.Thread(target=_stream_output, daemon=True).start()
@@ -400,15 +657,25 @@ def _spawn_inference_server(config: VlmConfig) -> str:
) )
def _to_openai_message(message: dict[str, Any]) -> dict[str, Any]: def _to_openai_messages(
"""Convert an internal message dict to OpenAI chat format. messages: Sequence[dict[str, Any]],
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
"""Convert internal messages to OpenAI chat format.
Internal image/video blocks (using PIL.Image objects) become Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
OpenAI ``image_url``/``video_url`` items via base64 data URLs. (``fps`` from ``video_url`` blocks) are extracted out so the caller
can pass them via ``extra_body.mm_processor_kwargs`` rather than
inside the content blocks (which transformers serve rejects).
File-URL video blocks are inlined as base64 data URLs.
""" """
out_messages: list[dict[str, Any]] = []
mm_kwargs: dict[str, Any] = {}
for message in messages:
content = message.get("content") content = message.get("content")
if not isinstance(content, list): if not isinstance(content, list):
return {"role": message["role"], "content": content} out_messages.append({"role": message["role"], "content": content})
continue
out_blocks: list[dict[str, Any]] = [] out_blocks: list[dict[str, Any]] = []
for block in content: for block in content:
block_type = block.get("type") if isinstance(block, dict) else None block_type = block.get("type") if isinstance(block, dict) else None
@@ -425,11 +692,27 @@ def _to_openai_message(message: dict[str, Any]) -> dict[str, Any]:
{"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}} {"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}}
) )
elif block_type == "video_url": elif block_type == "video_url":
# Pass through to the OpenAI-compatible server unchanged. video_url = dict(block["video_url"])
out_blocks.append({"type": "video_url", "video_url": block["video_url"]}) url = video_url.get("url", "")
if url.startswith("file://"):
video_url["url"] = _file_to_data_url(url[len("file://") :])
out_blocks.append({"type": "video_url", "video_url": video_url})
fps = block.get("fps")
if fps is not None:
mm_kwargs["fps"] = fps
else: else:
out_blocks.append(block) out_blocks.append(block)
return {"role": message["role"], "content": out_blocks} out_messages.append({"role": message["role"], "content": out_blocks})
return out_messages, mm_kwargs
def _file_to_data_url(path: str) -> str:
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
import base64 # noqa: PLC0415
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
return f"data:video/mp4;base64,{b64}"
def _pil_to_data_url(image: Any) -> str: def _pil_to_data_url(image: Any) -> str:
+28
View File
@@ -95,6 +95,34 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
for w in summary.validation_report.warnings: for w in summary.validation_report.warnings:
logger.warning(w) logger.warning(w)
if cfg.push_to_hub:
_push_to_hub(root, cfg)
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
"""Upload the annotated dataset directory to the Hugging Face Hub."""
from huggingface_hub import HfApi # noqa: PLC0415
repo_id = cfg.push_to_hub
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
api = HfApi()
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
api.create_repo(
repo_id=repo_id,
repo_type="dataset",
private=cfg.push_private,
exist_ok=True,
)
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
api.upload_folder(
folder_path=str(root),
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
)
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
def main() -> None: def main() -> None:
annotate() annotate()
+5 -2
View File
@@ -202,7 +202,7 @@ def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Pa
provider = _StubFrameProvider() provider = _StubFrameProvider()
module = PlanSubtasksMemoryModule( module = PlanSubtasksMemoryModule(
vlm=StubVlmClient(responder=responder), vlm=StubVlmClient(responder=responder),
config=Module1Config(max_video_frames=5), config=Module1Config(max_video_frames=5, frames_per_second=10.0),
frame_provider=provider, frame_provider=provider,
) )
record = next(iter_episodes(fixture_dataset_root)) record = next(iter_episodes(fixture_dataset_root))
@@ -222,7 +222,10 @@ def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Pa
# video block must wrap a list of frames covering the episode # video block must wrap a list of frames covering the episode
assert isinstance(video_blocks[0]["video"], list) assert isinstance(video_blocks[0]["video"], list)
assert len(video_blocks[0]["video"]) <= 5 assert len(video_blocks[0]["video"]) <= 5
assert provider.video_calls == [(record.episode_index, 5)] # provider is called with target_count = min(duration * fps, max). With
# fps=10 on a ~1s episode that requests >max, so max=5 wins.
assert provider.video_calls and provider.video_calls[0][0] == record.episode_index
assert provider.video_calls[0][1] <= 5
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None: def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None: