mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 16:09:44 +00:00
Compare commits
18 Commits
aded6214ab
...
72d0fc0dce
| Author | SHA1 | Date | |
|---|---|---|---|
| 72d0fc0dce | |||
| 3c6a6b39a2 | |||
| 39f6167fa3 | |||
| caef184c82 | |||
| 7bbf5777a2 | |||
| 545d7eb713 | |||
| 47f2ea17bb | |||
| 5119d22f1f | |||
| 916b419af3 | |||
| 7c10c4fcdd | |||
| 421e84497b | |||
| 9d38477728 | |||
| b895e3b057 | |||
| a8aa6b08ba | |||
| 4ac6c58ab1 | |||
| d5559a9445 | |||
| 7a7b8ac111 | |||
| 504bad6342 |
@@ -32,7 +32,14 @@ class Module1Config:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
max_video_frames: int = 32
|
frames_per_second: float = 1.0
|
||||||
|
"""Sample one image-frame per ``1/fps`` seconds across the episode for
|
||||||
|
Module 1's subtask-decomposition prompt. ``1.0`` = 1 fps. Capped by
|
||||||
|
``max_video_frames`` to avoid blowing up the request payload."""
|
||||||
|
max_video_frames: int = 128
|
||||||
|
"""Hard cap on the number of frames Module 1 sends. With ``fps=1`` and
|
||||||
|
a 30 s episode this yields 30 frames. Bumped from 32 since each frame
|
||||||
|
is small (~30-100 KB PNG when base64'd)."""
|
||||||
min_subtask_seconds: float = 1.5
|
min_subtask_seconds: float = 1.5
|
||||||
plan_max_steps: int = 8
|
plan_max_steps: int = 8
|
||||||
use_video_url: bool = False
|
use_video_url: bool = False
|
||||||
@@ -72,13 +79,12 @@ class Module3Config:
|
|||||||
class VlmConfig:
|
class VlmConfig:
|
||||||
"""Shared Qwen-VL client configuration."""
|
"""Shared Qwen-VL client configuration."""
|
||||||
|
|
||||||
backend: str = "vllm"
|
backend: str = "openai"
|
||||||
"""One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests only).
|
"""One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests only).
|
||||||
|
|
||||||
The ``openai`` backend talks to any OpenAI-compatible server — works
|
Default ``openai`` talks to a local OpenAI-compatible server (vllm /
|
||||||
with ``vllm serve``, ``transformers serve``, ``ktransformers serve``,
|
transformers) which the CLI auto-spawns when ``auto_serve=True``."""
|
||||||
or hosted endpoints. Set ``api_base`` and (optionally) ``api_key``."""
|
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||||
model_id: str = "Qwen/Qwen3.6-27B-FP8"
|
|
||||||
api_base: str = "http://localhost:8000/v1"
|
api_base: str = "http://localhost:8000/v1"
|
||||||
"""Base URL for the ``openai`` backend."""
|
"""Base URL for the ``openai`` backend."""
|
||||||
api_key: str = "EMPTY"
|
api_key: str = "EMPTY"
|
||||||
@@ -95,7 +101,21 @@ class VlmConfig:
|
|||||||
"""Port the auto-spawned server binds to. Sets ``api_base`` automatically."""
|
"""Port the auto-spawned server binds to. Sets ``api_base`` automatically."""
|
||||||
serve_command: str | None = None
|
serve_command: str | None = None
|
||||||
"""Override the auto-serve command (full shell command). When ``None``,
|
"""Override the auto-serve command (full shell command). When ``None``,
|
||||||
we run ``transformers serve <model_id> --port <serve_port> --continuous-batching``."""
|
we run ``transformers serve <model_id> --port <serve_port> --continuous-batching``.
|
||||||
|
|
||||||
|
When ``parallel_servers > 1``, the literal ``{port}`` placeholder in
|
||||||
|
this command (if present) is substituted per-replica."""
|
||||||
|
parallel_servers: int = 1
|
||||||
|
"""When >1, spawn this many independent inference servers (each pinned
|
||||||
|
to one GPU via ``CUDA_VISIBLE_DEVICES`` and listening on
|
||||||
|
``serve_port + i``) and round-robin client requests across them.
|
||||||
|
Useful when DP/TP NCCL setup is broken on the node — single-GPU
|
||||||
|
replicas don't need cross-GPU communication."""
|
||||||
|
client_concurrency: int = 16
|
||||||
|
"""Maximum number of in-flight chat requests the client issues in
|
||||||
|
parallel. vllm batches them internally for free, so bumping this
|
||||||
|
typically gives big throughput wins on a single TP=1 server. Set to
|
||||||
|
``1`` for strict serial calls."""
|
||||||
serve_ready_timeout_s: float = 600.0
|
serve_ready_timeout_s: float = 600.0
|
||||||
"""Max seconds to wait for the server to start serving requests."""
|
"""Max seconds to wait for the server to start serving requests."""
|
||||||
max_new_tokens: int = 512
|
max_new_tokens: int = 512
|
||||||
@@ -132,6 +152,13 @@ class ExecutorConfig:
|
|||||||
slurm_gpus: int = 1
|
slurm_gpus: int = 1
|
||||||
slurm_time: str = "06:00:00"
|
slurm_time: str = "06:00:00"
|
||||||
workers: int = 1
|
workers: int = 1
|
||||||
|
episode_parallelism: int = 16
|
||||||
|
"""Number of episodes processed concurrently within each module phase.
|
||||||
|
Each in-flight episode sends 3–5 dependent VLM calls; bumping this is
|
||||||
|
how you actually saturate ``parallel_servers`` and ``client_concurrency``
|
||||||
|
— without it, the executor loops one episode at a time and the
|
||||||
|
inference servers sit ~90% idle. Set to ``1`` for strict serial
|
||||||
|
execution."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -164,5 +191,14 @@ class AnnotationPipelineConfig:
|
|||||||
skip_validation: bool = False
|
skip_validation: bool = False
|
||||||
only_episodes: tuple[int, ...] | None = None
|
only_episodes: tuple[int, ...] | None = None
|
||||||
|
|
||||||
|
push_to_hub: str | None = None
|
||||||
|
"""If set, after the pipeline completes, upload the annotated dataset
|
||||||
|
root to the Hugging Face Hub as a dataset repo with this id (e.g.
|
||||||
|
``pepijn/super_poulain_steerable``). Creates the repo if missing."""
|
||||||
|
push_private: bool = False
|
||||||
|
"""When ``push_to_hub`` is set, create the repo as private."""
|
||||||
|
push_commit_message: str | None = None
|
||||||
|
"""Override the commit message used for the hub upload."""
|
||||||
|
|
||||||
def resolved_staging_dir(self, root: Path) -> Path:
|
def resolved_staging_dir(self, root: Path) -> Path:
|
||||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class Executor:
|
|||||||
raise ValueError(f"No episodes found under {root}/data/")
|
raise ValueError(f"No episodes found under {root}/data/")
|
||||||
|
|
||||||
executor_kind = select_executor_class(n, self.config.executor)
|
executor_kind = select_executor_class(n, self.config.executor)
|
||||||
logger.info("annotate: %d episodes; executor=%s", n, executor_kind)
|
print(f"[annotate] {n} episodes total; executor={executor_kind}", flush=True)
|
||||||
|
|
||||||
staging_dir = self.config.resolved_staging_dir(root)
|
staging_dir = self.config.resolved_staging_dir(root)
|
||||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -117,11 +117,15 @@ class Executor:
|
|||||||
# Phase 4: Module 3 (VQA)
|
# Phase 4: Module 3 (VQA)
|
||||||
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
|
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
|
||||||
|
|
||||||
|
print("[annotate] running validator...", flush=True)
|
||||||
report = self.validator.validate(records, staging_dir)
|
report = self.validator.validate(records, staging_dir)
|
||||||
if not report.ok and not self.config.skip_validation:
|
if not report.ok and not self.config.skip_validation:
|
||||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||||
|
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||||
|
|
||||||
|
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||||
written = self.writer.write_all(records, staging_dir, root)
|
written = self.writer.write_all(records, staging_dir, root)
|
||||||
|
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||||
|
|
||||||
def _run_module_phase(
|
def _run_module_phase(
|
||||||
@@ -131,16 +135,56 @@ class Executor:
|
|||||||
staging_dir: Path,
|
staging_dir: Path,
|
||||||
module: Any,
|
module: Any,
|
||||||
) -> PhaseResult:
|
) -> PhaseResult:
|
||||||
|
import time as _time # noqa: PLC0415
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
|
||||||
|
|
||||||
if not module.enabled:
|
if not module.enabled:
|
||||||
|
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||||
processed = 0
|
n = len(records)
|
||||||
for record in records:
|
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||||
|
print(
|
||||||
|
f"[annotate] phase={name} starting on {n} episode(s) "
|
||||||
|
f"(parallelism={parallelism})",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
t0 = _time.time()
|
||||||
|
|
||||||
|
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||||
|
i, record = idx_record
|
||||||
|
ep_start = _time.time()
|
||||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||||
module.run_episode(record, staging)
|
module.run_episode(record, staging)
|
||||||
|
return i, record.episode_index, _time.time() - ep_start
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
if parallelism == 1:
|
||||||
|
for i, record in enumerate(records, 1):
|
||||||
|
_, ep_idx, elapsed = _do((i, record))
|
||||||
processed += 1
|
processed += 1
|
||||||
|
print(
|
||||||
|
f"[annotate] {name} episode {i}/{n} "
|
||||||
|
f"(idx={ep_idx}) done in {elapsed:.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||||
|
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||||
|
for fut in as_completed(futures):
|
||||||
|
i, ep_idx, elapsed = fut.result()
|
||||||
|
processed += 1
|
||||||
|
print(
|
||||||
|
f"[annotate] {name} episode {processed}/{n} "
|
||||||
|
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
total = _time.time() - t0
|
||||||
|
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||||
|
|
||||||
def _run_plan_update_phase(self, records: list[EpisodeRecord], staging_dir: Path) -> PhaseResult:
|
def _run_plan_update_phase( # noqa: PLR0915
|
||||||
|
self, records: list[EpisodeRecord], staging_dir: Path
|
||||||
|
) -> PhaseResult:
|
||||||
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
|
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
|
||||||
|
|
||||||
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
|
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
|
||||||
|
|||||||
@@ -98,11 +98,24 @@ class GeneralVqaModule:
|
|||||||
anchor_idx = _emission_anchor_indices(
|
anchor_idx = _emission_anchor_indices(
|
||||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||||
)
|
)
|
||||||
rows: list[dict[str, Any]] = []
|
# Build all messages first, then issue them as a single batched
|
||||||
|
# generate_json call so the client can fan them out concurrently.
|
||||||
|
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
|
||||||
for idx in anchor_idx:
|
for idx in anchor_idx:
|
||||||
ts = float(record.frame_timestamps[idx])
|
ts = float(record.frame_timestamps[idx])
|
||||||
qtype = rng.choice(self.config.question_types)
|
qtype = rng.choice(self.config.question_types)
|
||||||
qa = self._generate_one(record, qtype, ts)
|
messages = self._build_messages(record, qtype, ts)
|
||||||
|
per_call.append((ts, qtype, messages))
|
||||||
|
|
||||||
|
if not per_call:
|
||||||
|
staging.write("module_3", [])
|
||||||
|
return
|
||||||
|
|
||||||
|
results = self.vlm.generate_json([m for _, _, m in per_call])
|
||||||
|
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for (ts, _qtype, _messages), result in zip(per_call, results):
|
||||||
|
qa = self._postprocess(result)
|
||||||
if qa is None:
|
if qa is None:
|
||||||
continue
|
continue
|
||||||
question, answer = qa
|
question, answer = qa
|
||||||
@@ -126,17 +139,18 @@ class GeneralVqaModule:
|
|||||||
)
|
)
|
||||||
staging.write("module_3", rows)
|
staging.write("module_3", rows)
|
||||||
|
|
||||||
def _generate_one(
|
def _build_messages(
|
||||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||||
) -> tuple[str, dict[str, Any]] | None:
|
) -> list[dict[str, Any]]:
|
||||||
prompt = load_prompt("module_3_vqa").format(
|
prompt = load_prompt("module_3_vqa").format(
|
||||||
episode_task=record.episode_task,
|
episode_task=record.episode_task,
|
||||||
question_type=question_type,
|
question_type=question_type,
|
||||||
)
|
)
|
||||||
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
||||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
messages = [{"role": "user", "content": content}]
|
return [{"role": "user", "content": content}]
|
||||||
result = self.vlm.generate_json([messages])[0]
|
|
||||||
|
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
return None
|
return None
|
||||||
question = result.get("question")
|
question = result.get("question")
|
||||||
@@ -150,3 +164,10 @@ class GeneralVqaModule:
|
|||||||
if classify_vqa_answer(answer) is None:
|
if classify_vqa_answer(answer) is None:
|
||||||
return None
|
return None
|
||||||
return question.strip(), answer
|
return question.strip(), answer
|
||||||
|
|
||||||
|
def _generate_one(
|
||||||
|
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||||
|
) -> tuple[str, dict[str, Any]] | None:
|
||||||
|
messages = self._build_messages(record, question_type, frame_timestamp)
|
||||||
|
result = self.vlm.generate_json([messages])[0]
|
||||||
|
return self._postprocess(result)
|
||||||
|
|||||||
@@ -175,9 +175,12 @@ class PlanSubtasksMemoryModule:
|
|||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
video_frames = self.frame_provider.video_for_episode(
|
target_count = max(
|
||||||
record, self.config.max_video_frames
|
1,
|
||||||
|
int(round(episode_duration * self.config.frames_per_second)),
|
||||||
)
|
)
|
||||||
|
target_count = min(target_count, self.config.max_video_frames)
|
||||||
|
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||||
video_block = to_video_block(video_frames)
|
video_block = to_video_block(video_frames)
|
||||||
content = [*video_block, {"type": "text", "text": prompt}]
|
content = [*video_block, {"type": "text", "text": prompt}]
|
||||||
messages = [{"role": "user", "content": content}]
|
messages = [{"role": "user", "content": content}]
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
@@ -75,13 +77,57 @@ class StubVlmClient:
|
|||||||
|
|
||||||
def _strip_to_json(text: str) -> Any:
|
def _strip_to_json(text: str) -> Any:
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||||
|
while "<think>" in text and "</think>" in text:
|
||||||
|
start = text.find("<think>")
|
||||||
|
end = text.find("</think>", start) + len("</think>")
|
||||||
|
text = (text[:start] + text[end:]).strip()
|
||||||
|
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||||
if text.startswith("```"):
|
if text.startswith("```"):
|
||||||
# tolerate ```json ... ``` fences from chat-tuned backbones
|
|
||||||
first = text.find("\n")
|
first = text.find("\n")
|
||||||
last = text.rfind("```")
|
last = text.rfind("```")
|
||||||
if first != -1 and last != -1 and last > first:
|
if first != -1 and last != -1 and last > first:
|
||||||
text = text[first + 1 : last].strip()
|
text = text[first + 1 : last].strip()
|
||||||
|
try:
|
||||||
return json.loads(text)
|
return json.loads(text)
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
# Fall back to extracting the first balanced {...} block.
|
||||||
|
obj_text = _extract_first_json_object(text)
|
||||||
|
if obj_text is None:
|
||||||
|
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||||
|
return json.loads(obj_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_first_json_object(text: str) -> str | None:
|
||||||
|
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||||
|
string literals. Returns ``None`` if no balanced block is found."""
|
||||||
|
start = text.find("{")
|
||||||
|
if start < 0:
|
||||||
|
return None
|
||||||
|
depth = 0
|
||||||
|
in_string = False
|
||||||
|
escape = False
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
ch = text[i]
|
||||||
|
if escape:
|
||||||
|
escape = False
|
||||||
|
continue
|
||||||
|
if ch == "\\":
|
||||||
|
escape = True
|
||||||
|
continue
|
||||||
|
if ch == '"' and not escape:
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
if in_string:
|
||||||
|
continue
|
||||||
|
if ch == "{":
|
||||||
|
depth += 1
|
||||||
|
elif ch == "}":
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return text[start : i + 1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -119,7 +165,17 @@ class _GenericTextClient:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||||
|
try:
|
||||||
out.append(_strip_to_json(retry_text))
|
out.append(_strip_to_json(retry_text))
|
||||||
|
except (ValueError, json.JSONDecodeError):
|
||||||
|
# After retry: log preview and return None instead of crashing
|
||||||
|
# the whole pipeline. Modules treat None as "skip".
|
||||||
|
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||||
|
print(
|
||||||
|
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
out.append(None)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -276,39 +332,206 @@ def _make_openai_client(config: VlmConfig) -> VlmClient:
|
|||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
api_base = config.api_base
|
api_base = config.api_base
|
||||||
|
api_key = config.api_key
|
||||||
|
auto_serve = config.auto_serve
|
||||||
|
api_bases: list[str] = [api_base]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||||
f"api_base={api_base} auto_serve={config.auto_serve}",
|
f"api_base={api_base} auto_serve={auto_serve}",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
if config.auto_serve:
|
if auto_serve:
|
||||||
if _server_is_up(api_base):
|
if config.parallel_servers > 1:
|
||||||
|
print(
|
||||||
|
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
api_bases = _spawn_parallel_inference_servers(config)
|
||||||
|
elif _server_is_up(api_base):
|
||||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||||
else:
|
else:
|
||||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||||
api_base = _spawn_inference_server(config)
|
api_base = _spawn_inference_server(config)
|
||||||
|
api_bases = [api_base]
|
||||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||||
|
|
||||||
client = OpenAI(base_url=api_base, api_key=config.api_key)
|
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||||
|
client = clients[0]
|
||||||
|
# round-robin counter for parallel mode
|
||||||
|
rr_counter = {"i": 0}
|
||||||
|
|
||||||
|
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||||
|
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||||
|
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||||
|
send_mm_kwargs = os.environ.get(
|
||||||
|
"LEROBOT_OPENAI_SEND_MM_KWARGS", ""
|
||||||
|
).lower() in {"1", "true", "yes"}
|
||||||
|
|
||||||
|
rr_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _one_call(
|
||||||
|
messages: Sequence[dict[str, Any]], max_tok: int, temp: float
|
||||||
|
) -> str:
|
||||||
|
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"model": config.model_id,
|
||||||
|
"messages": api_messages,
|
||||||
|
"max_tokens": max_tok,
|
||||||
|
"temperature": temp,
|
||||||
|
}
|
||||||
|
if send_mm_kwargs and mm_kwargs:
|
||||||
|
kwargs["extra_body"] = {
|
||||||
|
"mm_processor_kwargs": {**mm_kwargs, "do_sample_frames": True}
|
||||||
|
}
|
||||||
|
with rr_lock:
|
||||||
|
chosen = clients[rr_counter["i"] % len(clients)]
|
||||||
|
rr_counter["i"] += 1
|
||||||
|
response = chosen.chat.completions.create(**kwargs)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
|
||||||
def _gen(
|
def _gen(
|
||||||
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
|
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
outs: list[str] = []
|
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||||
for messages in batch:
|
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||||
api_messages = [_to_openai_message(m) for m in messages]
|
# Parallel fan-out — vllm batches these on the server side.
|
||||||
response = client.chat.completions.create(
|
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
|
||||||
model=config.model_id,
|
|
||||||
messages=api_messages,
|
max_workers = min(config.client_concurrency, len(batch))
|
||||||
max_tokens=max_tok,
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
temperature=temp,
|
futures = [
|
||||||
)
|
pool.submit(_one_call, messages, max_tok, temp) for messages in batch
|
||||||
outs.append(response.choices[0].message.content or "")
|
]
|
||||||
return outs
|
return [f.result() for f in futures]
|
||||||
|
|
||||||
return _GenericTextClient(_gen, config)
|
return _GenericTextClient(_gen, config)
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||||
|
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||||
|
|
||||||
|
Each replica:
|
||||||
|
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||||
|
- listens on ``serve_port + i``
|
||||||
|
- is shut down via the same atexit hook as the single-server path
|
||||||
|
|
||||||
|
Returns the list of ``api_base`` URLs the client should round-robin
|
||||||
|
across.
|
||||||
|
"""
|
||||||
|
import atexit # noqa: PLC0415
|
||||||
|
import os as _os # noqa: PLC0415
|
||||||
|
import shlex # noqa: PLC0415
|
||||||
|
import signal # noqa: PLC0415
|
||||||
|
import subprocess # noqa: PLC0415
|
||||||
|
import sys # noqa: PLC0415
|
||||||
|
import threading # noqa: PLC0415
|
||||||
|
import time # noqa: PLC0415
|
||||||
|
|
||||||
|
n = config.parallel_servers
|
||||||
|
api_bases: list[str] = []
|
||||||
|
procs: list[subprocess.Popen] = []
|
||||||
|
ready_events: list[threading.Event] = []
|
||||||
|
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||||
|
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||||
|
# "Starting vLLM API server" line and the route-listing line. The
|
||||||
|
# HTTP probe below is the ultimate fallback.
|
||||||
|
ready_markers = (
|
||||||
|
"Uvicorn running",
|
||||||
|
"Application startup complete",
|
||||||
|
"Starting vLLM API server",
|
||||||
|
"Available routes are",
|
||||||
|
)
|
||||||
|
# Single lock for all server-stream threads so multibyte chars from
|
||||||
|
# different servers don't interleave and tear UTF-8 sequences.
|
||||||
|
print_lock = threading.Lock()
|
||||||
|
|
||||||
|
base_cmd = config.serve_command or (
|
||||||
|
f"vllm serve {shlex.quote(config.model_id)} "
|
||||||
|
f"--tensor-parallel-size 1 "
|
||||||
|
f"--max-model-len {config.max_model_len or 32768} "
|
||||||
|
f"--uvicorn-log-level warning"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
port = config.serve_port + i
|
||||||
|
env = _os.environ.copy()
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = str(i)
|
||||||
|
cmd = base_cmd
|
||||||
|
if "{port}" in cmd:
|
||||||
|
cmd = cmd.replace("{port}", str(port))
|
||||||
|
else:
|
||||||
|
cmd = f"{cmd} --port {port}"
|
||||||
|
api_base = f"http://localhost:{port}/v1"
|
||||||
|
api_bases.append(api_base)
|
||||||
|
print(f"[server-{i}] launching on GPU {i} port {port}: {cmd}", flush=True)
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
shlex.split(cmd),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
procs.append(proc)
|
||||||
|
ready = threading.Event()
|
||||||
|
ready_events.append(ready)
|
||||||
|
|
||||||
|
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||||
|
# Read whole lines and emit each line atomically under the
|
||||||
|
# shared print_lock so output from N servers stays readable.
|
||||||
|
assert p.stdout is not None
|
||||||
|
for line in iter(p.stdout.readline, ""):
|
||||||
|
with print_lock:
|
||||||
|
sys.stdout.write(f"[server-{idx}] {line}")
|
||||||
|
if not line.endswith(("\n", "\r")):
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
if any(m in line for m in ready_markers):
|
||||||
|
ev.set()
|
||||||
|
|
||||||
|
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||||
|
|
||||||
|
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||||
|
while not ev.is_set() and p.poll() is None:
|
||||||
|
if _server_is_up(base):
|
||||||
|
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||||
|
ev.set()
|
||||||
|
return
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||||
|
|
||||||
|
def _shutdown() -> None:
|
||||||
|
for i, p in enumerate(procs):
|
||||||
|
if p.poll() is None:
|
||||||
|
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||||
|
p.send_signal(signal.SIGINT)
|
||||||
|
for p in procs:
|
||||||
|
try:
|
||||||
|
p.wait(timeout=15)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
p.kill()
|
||||||
|
p.wait(timeout=5)
|
||||||
|
|
||||||
|
atexit.register(_shutdown)
|
||||||
|
|
||||||
|
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||||
|
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||||
|
for i, p in enumerate(procs):
|
||||||
|
if p.poll() is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||||
|
)
|
||||||
|
time.sleep(2)
|
||||||
|
if any(not ev.is_set() for ev in ready_events):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s"
|
||||||
|
)
|
||||||
|
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||||
|
return api_bases
|
||||||
|
|
||||||
|
|
||||||
def _server_is_up(api_base: str) -> bool:
|
def _server_is_up(api_base: str) -> bool:
|
||||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||||
import urllib.request # noqa: PLC0415
|
import urllib.request # noqa: PLC0415
|
||||||
@@ -361,15 +584,49 @@ def _spawn_inference_server(config: VlmConfig) -> str:
|
|||||||
# rescans its cache on every model-list request, which can exceed
|
# rescans its cache on every model-list request, which can exceed
|
||||||
# the urllib timeout and trigger an infinite probe loop.
|
# the urllib timeout and trigger an infinite probe loop.
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
ready_markers = ("Uvicorn running", "Application startup complete")
|
# See _spawn_parallel_inference_servers for why we accept these.
|
||||||
|
ready_markers = (
|
||||||
|
"Uvicorn running",
|
||||||
|
"Application startup complete",
|
||||||
|
"Starting vLLM API server",
|
||||||
|
"Available routes are",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _probe() -> None:
|
||||||
|
while not ready_event.is_set() and proc.poll() is None:
|
||||||
|
if _server_is_up(api_base):
|
||||||
|
print("[server] ready (http probe)", flush=True)
|
||||||
|
ready_event.set()
|
||||||
|
return
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
threading.Thread(target=_probe, daemon=True).start()
|
||||||
|
|
||||||
def _stream_output() -> None:
|
def _stream_output() -> None:
|
||||||
|
# Read raw chunks instead of iterating lines so tqdm progress
|
||||||
|
# bars (which overwrite using \r) flush in real time.
|
||||||
assert proc.stdout is not None
|
assert proc.stdout is not None
|
||||||
for line in proc.stdout:
|
buf = ""
|
||||||
sys.stdout.write(f"[server] {line}")
|
prefix_started = False
|
||||||
|
while True:
|
||||||
|
ch = proc.stdout.read(1)
|
||||||
|
if ch == "":
|
||||||
|
# process exited; flush any tail
|
||||||
|
if buf:
|
||||||
|
sys.stdout.write(buf)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
if any(marker in line for marker in ready_markers):
|
return
|
||||||
|
if not prefix_started:
|
||||||
|
sys.stdout.write("[server] ")
|
||||||
|
prefix_started = True
|
||||||
|
sys.stdout.write(ch)
|
||||||
|
sys.stdout.flush()
|
||||||
|
buf += ch
|
||||||
|
if ch in ("\n", "\r"):
|
||||||
|
if any(marker in buf for marker in ready_markers):
|
||||||
ready_event.set()
|
ready_event.set()
|
||||||
|
buf = ""
|
||||||
|
prefix_started = False
|
||||||
|
|
||||||
threading.Thread(target=_stream_output, daemon=True).start()
|
threading.Thread(target=_stream_output, daemon=True).start()
|
||||||
|
|
||||||
@@ -400,15 +657,25 @@ def _spawn_inference_server(config: VlmConfig) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_openai_message(message: dict[str, Any]) -> dict[str, Any]:
|
def _to_openai_messages(
|
||||||
"""Convert an internal message dict to OpenAI chat format.
|
messages: Sequence[dict[str, Any]],
|
||||||
|
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||||
|
"""Convert internal messages to OpenAI chat format.
|
||||||
|
|
||||||
Internal image/video blocks (using PIL.Image objects) become
|
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||||
OpenAI ``image_url``/``video_url`` items via base64 data URLs.
|
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||||
|
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||||
|
inside the content blocks (which transformers serve rejects).
|
||||||
|
|
||||||
|
File-URL video blocks are inlined as base64 data URLs.
|
||||||
"""
|
"""
|
||||||
|
out_messages: list[dict[str, Any]] = []
|
||||||
|
mm_kwargs: dict[str, Any] = {}
|
||||||
|
for message in messages:
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
if not isinstance(content, list):
|
if not isinstance(content, list):
|
||||||
return {"role": message["role"], "content": content}
|
out_messages.append({"role": message["role"], "content": content})
|
||||||
|
continue
|
||||||
out_blocks: list[dict[str, Any]] = []
|
out_blocks: list[dict[str, Any]] = []
|
||||||
for block in content:
|
for block in content:
|
||||||
block_type = block.get("type") if isinstance(block, dict) else None
|
block_type = block.get("type") if isinstance(block, dict) else None
|
||||||
@@ -425,11 +692,27 @@ def _to_openai_message(message: dict[str, Any]) -> dict[str, Any]:
|
|||||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}}
|
{"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}}
|
||||||
)
|
)
|
||||||
elif block_type == "video_url":
|
elif block_type == "video_url":
|
||||||
# Pass through to the OpenAI-compatible server unchanged.
|
video_url = dict(block["video_url"])
|
||||||
out_blocks.append({"type": "video_url", "video_url": block["video_url"]})
|
url = video_url.get("url", "")
|
||||||
|
if url.startswith("file://"):
|
||||||
|
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||||
|
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||||
|
fps = block.get("fps")
|
||||||
|
if fps is not None:
|
||||||
|
mm_kwargs["fps"] = fps
|
||||||
else:
|
else:
|
||||||
out_blocks.append(block)
|
out_blocks.append(block)
|
||||||
return {"role": message["role"], "content": out_blocks}
|
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||||
|
return out_messages, mm_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _file_to_data_url(path: str) -> str:
|
||||||
|
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||||
|
import base64 # noqa: PLC0415
|
||||||
|
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||||
|
return f"data:video/mp4;base64,{b64}"
|
||||||
|
|
||||||
|
|
||||||
def _pil_to_data_url(image: Any) -> str:
|
def _pil_to_data_url(image: Any) -> str:
|
||||||
|
|||||||
@@ -95,6 +95,34 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
|||||||
for w in summary.validation_report.warnings:
|
for w in summary.validation_report.warnings:
|
||||||
logger.warning(w)
|
logger.warning(w)
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
_push_to_hub(root, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||||
|
"""Upload the annotated dataset directory to the Hugging Face Hub."""
|
||||||
|
from huggingface_hub import HfApi # noqa: PLC0415
|
||||||
|
|
||||||
|
repo_id = cfg.push_to_hub
|
||||||
|
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||||
|
api = HfApi()
|
||||||
|
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||||
|
api.create_repo(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
private=cfg.push_private,
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
|
||||||
|
api.upload_folder(
|
||||||
|
folder_path=str(root),
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
commit_message=commit_message,
|
||||||
|
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
|
||||||
|
)
|
||||||
|
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
annotate()
|
annotate()
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Pa
|
|||||||
provider = _StubFrameProvider()
|
provider = _StubFrameProvider()
|
||||||
module = PlanSubtasksMemoryModule(
|
module = PlanSubtasksMemoryModule(
|
||||||
vlm=StubVlmClient(responder=responder),
|
vlm=StubVlmClient(responder=responder),
|
||||||
config=Module1Config(max_video_frames=5),
|
config=Module1Config(max_video_frames=5, frames_per_second=10.0),
|
||||||
frame_provider=provider,
|
frame_provider=provider,
|
||||||
)
|
)
|
||||||
record = next(iter_episodes(fixture_dataset_root))
|
record = next(iter_episodes(fixture_dataset_root))
|
||||||
@@ -222,7 +222,10 @@ def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Pa
|
|||||||
# video block must wrap a list of frames covering the episode
|
# video block must wrap a list of frames covering the episode
|
||||||
assert isinstance(video_blocks[0]["video"], list)
|
assert isinstance(video_blocks[0]["video"], list)
|
||||||
assert len(video_blocks[0]["video"]) <= 5
|
assert len(video_blocks[0]["video"]) <= 5
|
||||||
assert provider.video_calls == [(record.episode_index, 5)]
|
# provider is called with target_count = min(duration * fps, max). With
|
||||||
|
# fps=10 on a ~1s episode that requests >max, so max=5 wins.
|
||||||
|
assert provider.video_calls and provider.video_calls[0][0] == record.episode_index
|
||||||
|
assert provider.video_calls[0][1] <= 5
|
||||||
|
|
||||||
|
|
||||||
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user