From 80b7708a615652b0cd43eb03a47d87735f6feceb Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 27 Apr 2026 16:58:45 +0200 Subject: [PATCH] feat(annotate): attach camera keyframes to module prompts; default to Qwen3.6-27B-FP8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the visual-grounding gap flagged after the initial PR review: modules now decode actual camera frames at the relevant timestamps and attach them as `{"type":"image", "image":}` content blocks to the VLM prompts. - New `frames.py`: - `FrameProvider` Protocol; `VideoFrameProvider` decodes from the dataset's first `observation.images.*` stream via `LeRobotDatasetMetadata.get_video_file_path` and `decode_video_frames`, with the same `from_timestamp` shift the main dataset uses. - Per-process LRU cache so co-timestamped Module 1 plan-update + Module 2 calls share decode work. - `make_frame_provider` falls back to a null provider when the dataset has no video tracks → text-only prompts (graceful absence). - Modules 1/2/3 take an optional `frame_provider` (default null) and prepend image blocks before the text block. - Module 1 attaches `keyframes_per_episode` keyframes to the subtask decomposition prompt. - Module 2 attaches the frame at the interjection timestamp. - Module 3 attaches the exact emission frame to each VQA pair. - VlmConfig: backend now defaults to `vllm`; default model is `Qwen/Qwen3.6-27B-FP8`. New knobs: `--vlm.tensor_parallel_size`, `--vlm.camera_key` (override the keyframe stream). - `_make_vllm_client` honours `tensor_parallel_size` so 27B-FP8 sharded on 2× GPUs works out of the box. - `test_module3_attaches_frame_image_block_to_prompt` asserts modules emit one image block per VQA prompt at the exact emission timestamp. - Docs: example switched to `imstevenpmwork/super_poulain_draft` + Qwen3.6-27B-FP8 + tensor_parallel_size=2; documents the keyframe attachment behaviour and the no-video fallback. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/source/annotation_pipeline.mdx | 13 +- .../annotations/steerable_pipeline/config.py | 8 +- .../annotations/steerable_pipeline/frames.py | 150 ++++++++++++++++++ .../steerable_pipeline/modules/general_vqa.py | 14 +- .../modules/interjections_and_speech.py | 8 +- .../modules/plan_subtasks_memory.py | 9 +- .../steerable_pipeline/vlm_client.py | 2 +- src/lerobot/scripts/lerobot_annotate.py | 10 +- tests/annotations/test_modules.py | 56 +++++++ 9 files changed, 253 insertions(+), 17 deletions(-) create mode 100644 src/lerobot/annotations/steerable_pipeline/frames.py diff --git a/docs/source/annotation_pipeline.mdx b/docs/source/annotation_pipeline.mdx index 6b88decc6..7104e7224 100644 --- a/docs/source/annotation_pipeline.mdx +++ b/docs/source/annotation_pipeline.mdx @@ -30,11 +30,18 @@ Install the extra and invoke the console script: ```bash uv sync --extra annotations uv run lerobot-annotate \ - --root=/path/to/dataset \ - --vlm.backend=transformers \ - --vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct + --repo_id=imstevenpmwork/super_poulain_draft \ + --vlm.backend=vllm \ + --vlm.model_id=Qwen/Qwen3.6-27B-FP8 \ + --vlm.tensor_parallel_size=2 ``` +The pipeline attaches camera keyframes to every Module 1/2/3 prompt by +default, decoded from the dataset's first `observation.images.*` stream. +Override with `--vlm.camera_key=observation.images.` to pin a +specific viewpoint. Datasets with no video tracks fall back to text-only +prompts automatically. + The executor picks `LocalPipelineExecutor` for small datasets and `SlurmPipelineExecutor` for large ones based on `--executor.auto_threshold` (default 32 episodes). Force local with diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py index eed745086..97b0e273d 100644 --- a/src/lerobot/annotations/steerable_pipeline/config.py +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -54,12 +54,16 @@ class Module3Config: class VlmConfig: """Shared Qwen-VL client configuration.""" - backend: Literal["vllm", "transformers", "stub"] = "transformers" - model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct" + backend: Literal["vllm", "transformers", "stub"] = "vllm" + model_id: str = "Qwen/Qwen3.6-27B-FP8" max_new_tokens: int = 512 temperature: float = 0.2 json_mode: bool = True batch_size: int = 4 + tensor_parallel_size: int = 1 + camera_key: str | None = None + """Override the camera stream used for keyframe attachment. ``None`` picks + the first ``observation.images.*`` key the dataset declares.""" @dataclass diff --git a/src/lerobot/annotations/steerable_pipeline/frames.py b/src/lerobot/annotations/steerable_pipeline/frames.py new file mode 100644 index 000000000..a28f30180 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/frames.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Keyframe extraction for the annotation pipeline. + +Modules attach decoded camera frames to their VLM prompts so the model can +ground subtask decomposition, interjection scenarios, and VQA in actual +visual content. The pipeline shares one provider across modules and one +episode at a time, with a small per-episode cache so multiple modules +querying the same timestamp pay decode cost once. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Protocol + +from .reader import EpisodeRecord + + +class FrameProvider(Protocol): + """Decodes camera frames at episode-relative timestamps.""" + + def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]: + """Return one PIL.Image per timestamp; empty list if no camera available.""" + + +@dataclass +class _NullProvider: + """No-op provider used when the dataset has no video keys or in tests.""" + + def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]: + return [] + + +def null_provider() -> FrameProvider: + return _NullProvider() + + +@dataclass +class VideoFrameProvider: + """Decodes frames from the dataset's first ``observation.images.*`` stream. + + The first camera key is used unconditionally — Module 1/2/3 prompts care + about *what is happening*, not which camera angle the model sees, so a + single canonical viewpoint is enough. Override ``camera_key`` if you + want a specific stream. + + Caches up to ``cache_size`` decoded frames per process to keep + co-timestamped Module 2 + Module 1 plan-update calls cheap. + """ + + root: Path + camera_key: str | None = None + tolerance_s: float = 1e-2 + cache_size: int = 256 + _meta: Any = field(default=None, init=False, repr=False) + _cache: dict = field(default_factory=dict, init=False, repr=False) + + def __post_init__(self) -> None: + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415 + + self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root) + if self.camera_key is None: + keys = self._meta.video_keys + self.camera_key = keys[0] if keys else None + + def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]: + if not timestamps or self.camera_key is None: + return [] + + out: list[Any] = [] + misses: list[float] = [] + miss_indices: list[int] = [] + for i, ts in enumerate(timestamps): + key = (record.episode_index, round(float(ts), 6)) + cached = self._cache.get(key) + if cached is not None: + out.append(cached) + else: + out.append(None) + misses.append(float(ts)) + miss_indices.append(i) + + if misses: + decoded = self._decode(record.episode_index, misses) + for i, img in zip(miss_indices, decoded, strict=True): + out[i] = img + key = (record.episode_index, round(float(timestamps[i]), 6)) + if len(self._cache) >= self.cache_size: + self._cache.pop(next(iter(self._cache))) + self._cache[key] = img + # filter out any None left over from decode failures + return [img for img in out if img is not None] + + def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]: + from PIL import Image # noqa: PLC0415 + + from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415 + + ep = self._meta.episodes[episode_index] + from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"] + shifted = [from_timestamp + ts for ts in timestamps] + video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key) + try: + frames = decode_video_frames( + video_path, + shifted, + self.tolerance_s, + return_uint8=True, + ) + except Exception: + return [] + # frames: [N, C, H, W] uint8, RGB + out: list[Any] = [] + arr = frames.cpu().numpy() if hasattr(frames, "cpu") else frames + for i in range(arr.shape[0]): + chw = arr[i] + hwc = chw.transpose(1, 2, 0) + out.append(Image.fromarray(hwc, mode="RGB")) + return out + + +def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider: + """Build a :class:`VideoFrameProvider` if videos are present, else null.""" + try: + provider = VideoFrameProvider(root=root, camera_key=camera_key) + except Exception: + return null_provider() + if provider.camera_key is None: + return null_provider() + return provider + + +def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]: + """Convert PIL images to Qwen-VL-compatible content blocks.""" + return [{"type": "image", "image": img} for img in images] diff --git a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py index 8ea19ab00..b9930d63a 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py @@ -30,10 +30,11 @@ from __future__ import annotations import json import random from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any from ..config import Module3Config +from ..frames import FrameProvider, null_provider, to_image_blocks from ..prompts import load as load_prompt from ..reader import EpisodeRecord from ..staging import EpisodeStaging @@ -83,6 +84,7 @@ class GeneralVqaModule: vlm: VlmClient config: Module3Config seed: int = 1729 + frame_provider: FrameProvider = field(default_factory=null_provider) @property def enabled(self) -> bool: @@ -100,7 +102,7 @@ class GeneralVqaModule: for idx in anchor_idx: ts = float(record.frame_timestamps[idx]) qtype = rng.choice(self.config.question_types) - qa = self._generate_one(record, qtype) + qa = self._generate_one(record, qtype, ts) if qa is None: continue question, answer = qa @@ -124,12 +126,16 @@ class GeneralVqaModule: ) staging.write("module_3", rows) - def _generate_one(self, record: EpisodeRecord, question_type: str) -> tuple[str, dict[str, Any]] | None: + def _generate_one( + self, record: EpisodeRecord, question_type: str, frame_timestamp: float + ) -> tuple[str, dict[str, Any]] | None: prompt = load_prompt("module_3_vqa").format( episode_task=record.episode_task, question_type=question_type, ) - messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + images = self.frame_provider.frames_at(record, [frame_timestamp]) + content = [*to_image_blocks(images), {"type": "text", "text": prompt}] + messages = [{"role": "user", "content": content}] result = self.vlm.generate_json([messages])[0] if not isinstance(result, dict): return None diff --git a/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py index 776cfb79c..d9b19959a 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py @@ -34,10 +34,11 @@ from __future__ import annotations import random from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any from ..config import Module2Config +from ..frames import FrameProvider, null_provider, to_image_blocks from ..prompts import load as load_prompt from ..reader import EpisodeRecord from ..staging import EpisodeStaging @@ -58,6 +59,7 @@ class InterjectionsAndSpeechModule: vlm: VlmClient config: Module2Config seed: int = 1729 + frame_provider: FrameProvider = field(default_factory=null_provider) @property def enabled(self) -> bool: @@ -106,7 +108,9 @@ class InterjectionsAndSpeechModule: current_subtask=current_subtask, timestamp=t_snap, ) - messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + images = self.frame_provider.frames_at(record, [t_snap]) + content = [*to_image_blocks(images), {"type": "text", "text": prompt}] + messages = [{"role": "user", "content": content}] result = self.vlm.generate_json([messages])[0] if not isinstance(result, dict): continue diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py index 47c905e1d..bbb6f6a86 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -18,10 +18,11 @@ from __future__ import annotations from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any from ..config import Module1Config +from ..frames import FrameProvider, null_provider, to_image_blocks from ..prompts import load as load_prompt from ..reader import EpisodeRecord, keyframe_indices from ..staging import EpisodeStaging @@ -53,6 +54,7 @@ class PlanSubtasksMemoryModule: vlm: VlmClient config: Module1Config + frame_provider: FrameProvider = field(default_factory=null_provider) @property def enabled(self) -> bool: @@ -150,6 +152,8 @@ class PlanSubtasksMemoryModule: return [] episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] keyframe_local = keyframe_indices(record, self.config.keyframes_per_episode) + keyframe_ts = [float(record.frame_timestamps[i]) for i in keyframe_local] + images = self.frame_provider.frames_at(record, keyframe_ts) prompt = load_prompt("module_1_subtasks").format( episode_task=record.episode_task, num_keyframes=len(keyframe_local), @@ -157,7 +161,8 @@ class PlanSubtasksMemoryModule: max_steps=self.config.plan_max_steps, episode_duration=f"{episode_duration:.3f}", ) - messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + content = [*to_image_blocks(images), {"type": "text", "text": prompt}] + messages = [{"role": "user", "content": content}] result = self.vlm.generate_json([messages])[0] spans = result.get("subtasks") if isinstance(result, dict) else None if not spans: diff --git a/src/lerobot/annotations/steerable_pipeline/vlm_client.py b/src/lerobot/annotations/steerable_pipeline/vlm_client.py index 923445f35..60e86627b 100644 --- a/src/lerobot/annotations/steerable_pipeline/vlm_client.py +++ b/src/lerobot/annotations/steerable_pipeline/vlm_client.py @@ -148,7 +148,7 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient: raise ImportError( "vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`." ) from exc - llm = LLM(model=config.model_id) + llm = LLM(model=config.model_id, tensor_parallel_size=config.tensor_parallel_size) def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]: params = SamplingParams( diff --git a/src/lerobot/scripts/lerobot_annotate.py b/src/lerobot/scripts/lerobot_annotate.py index b71d3b3ba..08c52fd82 100644 --- a/src/lerobot/scripts/lerobot_annotate.py +++ b/src/lerobot/scripts/lerobot_annotate.py @@ -33,6 +33,7 @@ from pathlib import Path from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig from lerobot.annotations.steerable_pipeline.executor import Executor +from lerobot.annotations.steerable_pipeline.frames import make_frame_provider from lerobot.annotations.steerable_pipeline.modules import ( GeneralVqaModule, InterjectionsAndSpeechModule, @@ -64,9 +65,12 @@ def annotate(cfg: AnnotationPipelineConfig) -> None: logger.info("annotate: root=%s", root) vlm = make_vlm_client(cfg.vlm) - module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1) - module_2 = InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed) - module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed) + frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key) + module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1, frame_provider=frame_provider) + module_2 = InterjectionsAndSpeechModule( + vlm=vlm, config=cfg.module_2, seed=cfg.seed, frame_provider=frame_provider + ) + module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider) writer = LanguageColumnsWriter() validator = StagingValidator() diff --git a/tests/annotations/test_modules.py b/tests/annotations/test_modules.py index 27ab92e3d..96ab6b9b3 100644 --- a/tests/annotations/test_modules.py +++ b/tests/annotations/test_modules.py @@ -18,7 +18,9 @@ from __future__ import annotations import json +from dataclasses import dataclass, field from pathlib import Path +from typing import Any from lerobot.annotations.steerable_pipeline.config import ( Module1Config, @@ -32,10 +34,31 @@ from lerobot.annotations.steerable_pipeline.modules import ( ) from lerobot.annotations.steerable_pipeline.reader import iter_episodes from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging +from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient from ._helpers import make_canned_responder +@dataclass +class _StubFrameProvider: + """Returns one sentinel object per requested timestamp.""" + + sentinel: Any = field(default_factory=lambda: object()) + calls: list[tuple[int, tuple[float, ...]]] = field(default_factory=list) + + def frames_at(self, record, timestamps): + self.calls.append((record.episode_index, tuple(timestamps))) + return [self.sentinel] * len(timestamps) + + +def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any): + def responder(messages): + captured.append(list(messages)) + return reply + + return StubVlmClient(responder=responder) + + def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None: vlm = make_canned_responder( { @@ -145,6 +168,39 @@ def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) assert ts in frame_set +def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None: + """Each VQA prompt must carry a single image block at the emission frame.""" + captured: list[list[dict[str, Any]]] = [] + payload = { + "question": "How many cups?", + "answer": {"label": "cup", "count": 1}, + } + provider = _StubFrameProvider() + module = GeneralVqaModule( + vlm=_spy_responder(captured, payload), + config=Module3Config(vqa_emission_hz=1.0, K=1), + seed=0, + frame_provider=provider, + ) + record = next(iter_episodes(single_episode_root)) + staging = EpisodeStaging(tmp_path / "stage", record.episode_index) + module.run_episode(record, staging) + + assert captured, "no VLM calls made" + for messages in captured: + content = messages[0]["content"] + image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"] + text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"] + assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}" + assert image_blocks[0]["image"] is provider.sentinel + assert len(text_blocks) == 1 + # provider was called once per emission with the exact emission timestamp + for ep_idx, ts_tuple in provider.calls: + assert ep_idx == record.episode_index + assert len(ts_tuple) == 1 + assert ts_tuple[0] in record.frame_timestamps + + def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None: payload = { "question": "Where is the cup?",