mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
feat(annotate): attach camera keyframes to module prompts; default to Qwen3.6-27B-FP8
Closes the visual-grounding gap flagged after the initial PR review:
modules now decode actual camera frames at the relevant timestamps and
attach them as `{"type":"image", "image":<PIL>}` content blocks to the
VLM prompts.
- New `frames.py`:
- `FrameProvider` Protocol; `VideoFrameProvider` decodes from the
dataset's first `observation.images.*` stream via
`LeRobotDatasetMetadata.get_video_file_path` and
`decode_video_frames`, with the same `from_timestamp` shift the main
dataset uses.
- Per-process LRU cache so co-timestamped Module 1 plan-update + Module
2 calls share decode work.
- `make_frame_provider` falls back to a null provider when the dataset
has no video tracks → text-only prompts (graceful absence).
- Modules 1/2/3 take an optional `frame_provider` (default null) and
prepend image blocks before the text block.
- Module 1 attaches `keyframes_per_episode` keyframes to the subtask
decomposition prompt.
- Module 2 attaches the frame at the interjection timestamp.
- Module 3 attaches the exact emission frame to each VQA pair.
- VlmConfig: backend now defaults to `vllm`; default model is
`Qwen/Qwen3.6-27B-FP8`. New knobs: `--vlm.tensor_parallel_size`,
`--vlm.camera_key` (override the keyframe stream).
- `_make_vllm_client` honours `tensor_parallel_size` so 27B-FP8 sharded
on 2× GPUs works out of the box.
- `test_module3_attaches_frame_image_block_to_prompt` asserts modules
emit one image block per VQA prompt at the exact emission timestamp.
- Docs: example switched to `imstevenpmwork/super_poulain_draft` +
Qwen3.6-27B-FP8 + tensor_parallel_size=2; documents the keyframe
attachment behaviour and the no-video fallback.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -30,11 +30,18 @@ Install the extra and invoke the console script:
|
||||
```bash
|
||||
uv sync --extra annotations
|
||||
uv run lerobot-annotate \
|
||||
--root=/path/to/dataset \
|
||||
--vlm.backend=transformers \
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
--repo_id=imstevenpmwork/super_poulain_draft \
|
||||
--vlm.backend=vllm \
|
||||
--vlm.model_id=Qwen/Qwen3.6-27B-FP8 \
|
||||
--vlm.tensor_parallel_size=2
|
||||
```
|
||||
|
||||
The pipeline attaches camera keyframes to every Module 1/2/3 prompt by
|
||||
default, decoded from the dataset's first `observation.images.*` stream.
|
||||
Override with `--vlm.camera_key=observation.images.<name>` to pin a
|
||||
specific viewpoint. Datasets with no video tracks fall back to text-only
|
||||
prompts automatically.
|
||||
|
||||
The executor picks `LocalPipelineExecutor` for small datasets and
|
||||
`SlurmPipelineExecutor` for large ones based on
|
||||
`--executor.auto_threshold` (default 32 episodes). Force local with
|
||||
|
||||
@@ -54,12 +54,16 @@ class Module3Config:
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
backend: Literal["vllm", "transformers", "stub"] = "transformers"
|
||||
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
backend: Literal["vllm", "transformers", "stub"] = "vllm"
|
||||
model_id: str = "Qwen/Qwen3.6-27B-FP8"
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
json_mode: bool = True
|
||||
batch_size: int = 4
|
||||
tensor_parallel_size: int = 1
|
||||
camera_key: str | None = None
|
||||
"""Override the camera stream used for keyframe attachment. ``None`` picks
|
||||
the first ``observation.images.*`` key the dataset declares."""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp; empty list if no camera available."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
|
||||
|
||||
The first camera key is used unconditionally — Module 1/2/3 prompts care
|
||||
about *what is happening*, not which camera angle the model sees, so a
|
||||
single canonical viewpoint is enough. Override ``camera_key`` if you
|
||||
want a specific stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
if self.camera_key is None:
|
||||
keys = self._meta.video_keys
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
if not timestamps or self.camera_key is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses)
|
||||
for i, img in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = img
|
||||
key = (record.episode_index, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = img
|
||||
# filter out any None left over from decode failures
|
||||
return [img for img in out if img is not None]
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
|
||||
from PIL import Image # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
||||
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
|
||||
try:
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted,
|
||||
self.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
# frames: [N, C, H, W] uint8, RGB
|
||||
out: list[Any] = []
|
||||
arr = frames.cpu().numpy() if hasattr(frames, "cpu") else frames
|
||||
for i in range(arr.shape[0]):
|
||||
chw = arr[i]
|
||||
hwc = chw.transpose(1, 2, 0)
|
||||
out.append(Image.fromarray(hwc, mode="RGB"))
|
||||
return out
|
||||
|
||||
|
||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert PIL images to Qwen-VL-compatible content blocks."""
|
||||
return [{"type": "image", "image": img} for img in images]
|
||||
@@ -30,10 +30,11 @@ from __future__ import annotations
|
||||
import json
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module3Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
@@ -83,6 +84,7 @@ class GeneralVqaModule:
|
||||
vlm: VlmClient
|
||||
config: Module3Config
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
@@ -100,7 +102,7 @@ class GeneralVqaModule:
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
qa = self._generate_one(record, qtype)
|
||||
qa = self._generate_one(record, qtype, ts)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
@@ -124,12 +126,16 @@ class GeneralVqaModule:
|
||||
)
|
||||
staging.write("module_3", rows)
|
||||
|
||||
def _generate_one(self, record: EpisodeRecord, question_type: str) -> tuple[str, dict[str, Any]] | None:
|
||||
def _generate_one(
|
||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
|
||||
@@ -34,10 +34,11 @@ from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module2Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
@@ -58,6 +59,7 @@ class InterjectionsAndSpeechModule:
|
||||
vlm: VlmClient
|
||||
config: Module2Config
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
@@ -106,7 +108,9 @@ class InterjectionsAndSpeechModule:
|
||||
current_subtask=current_subtask,
|
||||
timestamp=t_snap,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
images = self.frame_provider.frames_at(record, [t_snap])
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
|
||||
@@ -18,10 +18,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module1Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, keyframe_indices
|
||||
from ..staging import EpisodeStaging
|
||||
@@ -53,6 +54,7 @@ class PlanSubtasksMemoryModule:
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module1Config
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
@@ -150,6 +152,8 @@ class PlanSubtasksMemoryModule:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
keyframe_local = keyframe_indices(record, self.config.keyframes_per_episode)
|
||||
keyframe_ts = [float(record.frame_timestamps[i]) for i in keyframe_local]
|
||||
images = self.frame_provider.frames_at(record, keyframe_ts)
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=record.episode_task,
|
||||
num_keyframes=len(keyframe_local),
|
||||
@@ -157,7 +161,8 @@ class PlanSubtasksMemoryModule:
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
spans = result.get("subtasks") if isinstance(result, dict) else None
|
||||
if not spans:
|
||||
|
||||
@@ -148,7 +148,7 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
raise ImportError(
|
||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||
) from exc
|
||||
llm = LLM(model=config.model_id)
|
||||
llm = LLM(model=config.model_id, tensor_parallel_size=config.tensor_parallel_size)
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
params = SamplingParams(
|
||||
|
||||
@@ -33,6 +33,7 @@ from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
@@ -64,9 +65,12 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
logger.info("annotate: root=%s", root)
|
||||
|
||||
vlm = make_vlm_client(cfg.vlm)
|
||||
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1)
|
||||
module_2 = InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed)
|
||||
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed)
|
||||
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key)
|
||||
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1, frame_provider=frame_provider)
|
||||
module_2 = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.module_2, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator()
|
||||
|
||||
|
||||
@@ -18,7 +18,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
Module1Config,
|
||||
@@ -32,10 +34,31 @@ from lerobot.annotations.steerable_pipeline.modules import (
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StubFrameProvider:
|
||||
"""Returns one sentinel object per requested timestamp."""
|
||||
|
||||
sentinel: Any = field(default_factory=lambda: object())
|
||||
calls: list[tuple[int, tuple[float, ...]]] = field(default_factory=list)
|
||||
|
||||
def frames_at(self, record, timestamps):
|
||||
self.calls.append((record.episode_index, tuple(timestamps)))
|
||||
return [self.sentinel] * len(timestamps)
|
||||
|
||||
|
||||
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
return reply
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
@@ -145,6 +168,39 @@ def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path)
|
||||
assert ts in frame_set
|
||||
|
||||
|
||||
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
"""Each VQA prompt must carry a single image block at the emission frame."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
}
|
||||
provider = _StubFrameProvider()
|
||||
module = GeneralVqaModule(
|
||||
vlm=_spy_responder(captured, payload),
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=1),
|
||||
seed=0,
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
assert captured, "no VLM calls made"
|
||||
for messages in captured:
|
||||
content = messages[0]["content"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||
assert image_blocks[0]["image"] is provider.sentinel
|
||||
assert len(text_blocks) == 1
|
||||
# provider was called once per emission with the exact emission timestamp
|
||||
for ep_idx, ts_tuple in provider.calls:
|
||||
assert ep_idx == record.episode_index
|
||||
assert len(ts_tuple) == 1
|
||||
assert ts_tuple[0] in record.frame_timestamps
|
||||
|
||||
|
||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "Where is the cup?",
|
||||
|
||||
Reference in New Issue
Block a user