feat(annotate): attach camera keyframes to module prompts; default to Qwen3.6-27B-FP8

Closes the visual-grounding gap flagged after the initial PR review:
modules now decode actual camera frames at the relevant timestamps and
attach them as `{"type":"image", "image":<PIL>}` content blocks to the
VLM prompts.

- New `frames.py`:
  - `FrameProvider` Protocol; `VideoFrameProvider` decodes from the
    dataset's first `observation.images.*` stream via
    `LeRobotDatasetMetadata.get_video_file_path` and
    `decode_video_frames`, with the same `from_timestamp` shift the main
    dataset uses.
  - Per-process LRU cache so co-timestamped Module 1 plan-update + Module
    2 calls share decode work.
  - `make_frame_provider` falls back to a null provider when the dataset
    has no video tracks → text-only prompts (graceful absence).
- Modules 1/2/3 take an optional `frame_provider` (default null) and
  prepend image blocks before the text block.
  - Module 1 attaches `keyframes_per_episode` keyframes to the subtask
    decomposition prompt.
  - Module 2 attaches the frame at the interjection timestamp.
  - Module 3 attaches the exact emission frame to each VQA pair.
- VlmConfig: backend now defaults to `vllm`; default model is
  `Qwen/Qwen3.6-27B-FP8`. New knobs: `--vlm.tensor_parallel_size`,
  `--vlm.camera_key` (override the keyframe stream).
- `_make_vllm_client` honours `tensor_parallel_size` so 27B-FP8 sharded
  on 2× GPUs works out of the box.
- `test_module3_attaches_frame_image_block_to_prompt` asserts modules
  emit one image block per VQA prompt at the exact emission timestamp.
- Docs: example switched to `imstevenpmwork/super_poulain_draft` +
  Qwen3.6-27B-FP8 + tensor_parallel_size=2; documents the keyframe
  attachment behaviour and the no-video fallback.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-27 16:58:45 +02:00
parent a635a32290
commit 80b7708a61
9 changed files with 253 additions and 17 deletions
+10 -3
View File
@@ -30,11 +30,18 @@ Install the extra and invoke the console script:
```bash
uv sync --extra annotations
uv run lerobot-annotate \
--root=/path/to/dataset \
--vlm.backend=transformers \
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
--repo_id=imstevenpmwork/super_poulain_draft \
--vlm.backend=vllm \
--vlm.model_id=Qwen/Qwen3.6-27B-FP8 \
--vlm.tensor_parallel_size=2
```
The pipeline attaches camera keyframes to every Module 1/2/3 prompt by
default, decoded from the dataset's first `observation.images.*` stream.
Override with `--vlm.camera_key=observation.images.<name>` to pin a
specific viewpoint. Datasets with no video tracks fall back to text-only
prompts automatically.
The executor picks `LocalPipelineExecutor` for small datasets and
`SlurmPipelineExecutor` for large ones based on
`--executor.auto_threshold` (default 32 episodes). Force local with
@@ -54,12 +54,16 @@ class Module3Config:
class VlmConfig:
"""Shared Qwen-VL client configuration."""
backend: Literal["vllm", "transformers", "stub"] = "transformers"
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
backend: Literal["vllm", "transformers", "stub"] = "vllm"
model_id: str = "Qwen/Qwen3.6-27B-FP8"
max_new_tokens: int = 512
temperature: float = 0.2
json_mode: bool = True
batch_size: int = 4
tensor_parallel_size: int = 1
camera_key: str | None = None
"""Override the camera stream used for keyframe attachment. ``None`` picks
the first ``observation.images.*`` key the dataset declares."""
@dataclass
@@ -0,0 +1,150 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keyframe extraction for the annotation pipeline.
Modules attach decoded camera frames to their VLM prompts so the model can
ground subtask decomposition, interjection scenarios, and VQA in actual
visual content. The pipeline shares one provider across modules and one
episode at a time, with a small per-episode cache so multiple modules
querying the same timestamp pay decode cost once.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Protocol
from .reader import EpisodeRecord
class FrameProvider(Protocol):
"""Decodes camera frames at episode-relative timestamps."""
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
"""Return one PIL.Image per timestamp; empty list if no camera available."""
@dataclass
class _NullProvider:
"""No-op provider used when the dataset has no video keys or in tests."""
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
return []
def null_provider() -> FrameProvider:
return _NullProvider()
@dataclass
class VideoFrameProvider:
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
The first camera key is used unconditionally — Module 1/2/3 prompts care
about *what is happening*, not which camera angle the model sees, so a
single canonical viewpoint is enough. Override ``camera_key`` if you
want a specific stream.
Caches up to ``cache_size`` decoded frames per process to keep
co-timestamped Module 2 + Module 1 plan-update calls cheap.
"""
root: Path
camera_key: str | None = None
tolerance_s: float = 1e-2
cache_size: int = 256
_meta: Any = field(default=None, init=False, repr=False)
_cache: dict = field(default_factory=dict, init=False, repr=False)
def __post_init__(self) -> None:
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
if self.camera_key is None:
keys = self._meta.video_keys
self.camera_key = keys[0] if keys else None
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
if not timestamps or self.camera_key is None:
return []
out: list[Any] = []
misses: list[float] = []
miss_indices: list[int] = []
for i, ts in enumerate(timestamps):
key = (record.episode_index, round(float(ts), 6))
cached = self._cache.get(key)
if cached is not None:
out.append(cached)
else:
out.append(None)
misses.append(float(ts))
miss_indices.append(i)
if misses:
decoded = self._decode(record.episode_index, misses)
for i, img in zip(miss_indices, decoded, strict=True):
out[i] = img
key = (record.episode_index, round(float(timestamps[i]), 6))
if len(self._cache) >= self.cache_size:
self._cache.pop(next(iter(self._cache)))
self._cache[key] = img
# filter out any None left over from decode failures
return [img for img in out if img is not None]
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
from PIL import Image # noqa: PLC0415
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
ep = self._meta.episodes[episode_index]
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
shifted = [from_timestamp + ts for ts in timestamps]
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
try:
frames = decode_video_frames(
video_path,
shifted,
self.tolerance_s,
return_uint8=True,
)
except Exception:
return []
# frames: [N, C, H, W] uint8, RGB
out: list[Any] = []
arr = frames.cpu().numpy() if hasattr(frames, "cpu") else frames
for i in range(arr.shape[0]):
chw = arr[i]
hwc = chw.transpose(1, 2, 0)
out.append(Image.fromarray(hwc, mode="RGB"))
return out
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
try:
provider = VideoFrameProvider(root=root, camera_key=camera_key)
except Exception:
return null_provider()
if provider.camera_key is None:
return null_provider()
return provider
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
"""Convert PIL images to Qwen-VL-compatible content blocks."""
return [{"type": "image", "image": img} for img in images]
@@ -30,10 +30,11 @@ from __future__ import annotations
import json
import random
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
from ..config import Module3Config
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
@@ -83,6 +84,7 @@ class GeneralVqaModule:
vlm: VlmClient
config: Module3Config
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
@@ -100,7 +102,7 @@ class GeneralVqaModule:
for idx in anchor_idx:
ts = float(record.frame_timestamps[idx])
qtype = rng.choice(self.config.question_types)
qa = self._generate_one(record, qtype)
qa = self._generate_one(record, qtype, ts)
if qa is None:
continue
question, answer = qa
@@ -124,12 +126,16 @@ class GeneralVqaModule:
)
staging.write("module_3", rows)
def _generate_one(self, record: EpisodeRecord, question_type: str) -> tuple[str, dict[str, Any]] | None:
def _generate_one(
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
) -> tuple[str, dict[str, Any]] | None:
prompt = load_prompt("module_3_vqa").format(
episode_task=record.episode_task,
question_type=question_type,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
images = self.frame_provider.frames_at(record, [frame_timestamp])
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
return None
@@ -34,10 +34,11 @@ from __future__ import annotations
import random
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
from ..config import Module2Config
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord
from ..staging import EpisodeStaging
@@ -58,6 +59,7 @@ class InterjectionsAndSpeechModule:
vlm: VlmClient
config: Module2Config
seed: int = 1729
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
@@ -106,7 +108,9 @@ class InterjectionsAndSpeechModule:
current_subtask=current_subtask,
timestamp=t_snap,
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
images = self.frame_provider.frames_at(record, [t_snap])
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
if not isinstance(result, dict):
continue
@@ -18,10 +18,11 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
from ..config import Module1Config
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, keyframe_indices
from ..staging import EpisodeStaging
@@ -53,6 +54,7 @@ class PlanSubtasksMemoryModule:
vlm: VlmClient
config: Module1Config
frame_provider: FrameProvider = field(default_factory=null_provider)
@property
def enabled(self) -> bool:
@@ -150,6 +152,8 @@ class PlanSubtasksMemoryModule:
return []
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
keyframe_local = keyframe_indices(record, self.config.keyframes_per_episode)
keyframe_ts = [float(record.frame_timestamps[i]) for i in keyframe_local]
images = self.frame_provider.frames_at(record, keyframe_ts)
prompt = load_prompt("module_1_subtasks").format(
episode_task=record.episode_task,
num_keyframes=len(keyframe_local),
@@ -157,7 +161,8 @@ class PlanSubtasksMemoryModule:
max_steps=self.config.plan_max_steps,
episode_duration=f"{episode_duration:.3f}",
)
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
messages = [{"role": "user", "content": content}]
result = self.vlm.generate_json([messages])[0]
spans = result.get("subtasks") if isinstance(result, dict) else None
if not spans:
@@ -148,7 +148,7 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
raise ImportError(
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
) from exc
llm = LLM(model=config.model_id)
llm = LLM(model=config.model_id, tensor_parallel_size=config.tensor_parallel_size)
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
params = SamplingParams(
+7 -3
View File
@@ -33,6 +33,7 @@ from pathlib import Path
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
from lerobot.annotations.steerable_pipeline.executor import Executor
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
from lerobot.annotations.steerable_pipeline.modules import (
GeneralVqaModule,
InterjectionsAndSpeechModule,
@@ -64,9 +65,12 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
logger.info("annotate: root=%s", root)
vlm = make_vlm_client(cfg.vlm)
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1)
module_2 = InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed)
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed)
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key)
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1, frame_provider=frame_provider)
module_2 = InterjectionsAndSpeechModule(
vlm=vlm, config=cfg.module_2, seed=cfg.seed, frame_provider=frame_provider
)
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
writer = LanguageColumnsWriter()
validator = StagingValidator()
+56
View File
@@ -18,7 +18,9 @@
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.annotations.steerable_pipeline.config import (
Module1Config,
@@ -32,10 +34,31 @@ from lerobot.annotations.steerable_pipeline.modules import (
)
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
from ._helpers import make_canned_responder
@dataclass
class _StubFrameProvider:
"""Returns one sentinel object per requested timestamp."""
sentinel: Any = field(default_factory=lambda: object())
calls: list[tuple[int, tuple[float, ...]]] = field(default_factory=list)
def frames_at(self, record, timestamps):
self.calls.append((record.episode_index, tuple(timestamps)))
return [self.sentinel] * len(timestamps)
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
def responder(messages):
captured.append(list(messages))
return reply
return StubVlmClient(responder=responder)
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
vlm = make_canned_responder(
{
@@ -145,6 +168,39 @@ def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path)
assert ts in frame_set
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
"""Each VQA prompt must carry a single image block at the emission frame."""
captured: list[list[dict[str, Any]]] = []
payload = {
"question": "How many cups?",
"answer": {"label": "cup", "count": 1},
}
provider = _StubFrameProvider()
module = GeneralVqaModule(
vlm=_spy_responder(captured, payload),
config=Module3Config(vqa_emission_hz=1.0, K=1),
seed=0,
frame_provider=provider,
)
record = next(iter_episodes(single_episode_root))
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
module.run_episode(record, staging)
assert captured, "no VLM calls made"
for messages in captured:
content = messages[0]["content"]
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
assert image_blocks[0]["image"] is provider.sentinel
assert len(text_blocks) == 1
# provider was called once per emission with the exact emission timestamp
for ep_idx, ts_tuple in provider.calls:
assert ep_idx == record.episode_index
assert len(ts_tuple) == 1
assert ts_tuple[0] in record.frame_timestamps
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
payload = {
"question": "Where is the cup?",