mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
feat(annotate): attach camera keyframes to module prompts; default to Qwen3.6-27B-FP8
Closes the visual-grounding gap flagged after the initial PR review:
modules now decode actual camera frames at the relevant timestamps and
attach them as `{"type":"image", "image":<PIL>}` content blocks to the
VLM prompts.
- New `frames.py`:
- `FrameProvider` Protocol; `VideoFrameProvider` decodes from the
dataset's first `observation.images.*` stream via
`LeRobotDatasetMetadata.get_video_file_path` and
`decode_video_frames`, with the same `from_timestamp` shift the main
dataset uses.
- Per-process LRU cache so co-timestamped Module 1 plan-update + Module
2 calls share decode work.
- `make_frame_provider` falls back to a null provider when the dataset
has no video tracks → text-only prompts (graceful absence).
- Modules 1/2/3 take an optional `frame_provider` (default null) and
prepend image blocks before the text block.
- Module 1 attaches `keyframes_per_episode` keyframes to the subtask
decomposition prompt.
- Module 2 attaches the frame at the interjection timestamp.
- Module 3 attaches the exact emission frame to each VQA pair.
- VlmConfig: backend now defaults to `vllm`; default model is
`Qwen/Qwen3.6-27B-FP8`. New knobs: `--vlm.tensor_parallel_size`,
`--vlm.camera_key` (override the keyframe stream).
- `_make_vllm_client` honours `tensor_parallel_size` so 27B-FP8 sharded
on 2× GPUs works out of the box.
- `test_module3_attaches_frame_image_block_to_prompt` asserts modules
emit one image block per VQA prompt at the exact emission timestamp.
- Docs: example switched to `imstevenpmwork/super_poulain_draft` +
Qwen3.6-27B-FP8 + tensor_parallel_size=2; documents the keyframe
attachment behaviour and the no-video fallback.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -30,11 +30,18 @@ Install the extra and invoke the console script:
|
|||||||
```bash
|
```bash
|
||||||
uv sync --extra annotations
|
uv sync --extra annotations
|
||||||
uv run lerobot-annotate \
|
uv run lerobot-annotate \
|
||||||
--root=/path/to/dataset \
|
--repo_id=imstevenpmwork/super_poulain_draft \
|
||||||
--vlm.backend=transformers \
|
--vlm.backend=vllm \
|
||||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
--vlm.model_id=Qwen/Qwen3.6-27B-FP8 \
|
||||||
|
--vlm.tensor_parallel_size=2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The pipeline attaches camera keyframes to every Module 1/2/3 prompt by
|
||||||
|
default, decoded from the dataset's first `observation.images.*` stream.
|
||||||
|
Override with `--vlm.camera_key=observation.images.<name>` to pin a
|
||||||
|
specific viewpoint. Datasets with no video tracks fall back to text-only
|
||||||
|
prompts automatically.
|
||||||
|
|
||||||
The executor picks `LocalPipelineExecutor` for small datasets and
|
The executor picks `LocalPipelineExecutor` for small datasets and
|
||||||
`SlurmPipelineExecutor` for large ones based on
|
`SlurmPipelineExecutor` for large ones based on
|
||||||
`--executor.auto_threshold` (default 32 episodes). Force local with
|
`--executor.auto_threshold` (default 32 episodes). Force local with
|
||||||
|
|||||||
@@ -54,12 +54,16 @@ class Module3Config:
|
|||||||
class VlmConfig:
|
class VlmConfig:
|
||||||
"""Shared Qwen-VL client configuration."""
|
"""Shared Qwen-VL client configuration."""
|
||||||
|
|
||||||
backend: Literal["vllm", "transformers", "stub"] = "transformers"
|
backend: Literal["vllm", "transformers", "stub"] = "vllm"
|
||||||
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
model_id: str = "Qwen/Qwen3.6-27B-FP8"
|
||||||
max_new_tokens: int = 512
|
max_new_tokens: int = 512
|
||||||
temperature: float = 0.2
|
temperature: float = 0.2
|
||||||
json_mode: bool = True
|
json_mode: bool = True
|
||||||
batch_size: int = 4
|
batch_size: int = 4
|
||||||
|
tensor_parallel_size: int = 1
|
||||||
|
camera_key: str | None = None
|
||||||
|
"""Override the camera stream used for keyframe attachment. ``None`` picks
|
||||||
|
the first ``observation.images.*`` key the dataset declares."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Keyframe extraction for the annotation pipeline.
|
||||||
|
|
||||||
|
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||||
|
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||||
|
visual content. The pipeline shares one provider across modules and one
|
||||||
|
episode at a time, with a small per-episode cache so multiple modules
|
||||||
|
querying the same timestamp pay decode cost once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from .reader import EpisodeRecord
|
||||||
|
|
||||||
|
|
||||||
|
class FrameProvider(Protocol):
|
||||||
|
"""Decodes camera frames at episode-relative timestamps."""
|
||||||
|
|
||||||
|
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||||
|
"""Return one PIL.Image per timestamp; empty list if no camera available."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _NullProvider:
|
||||||
|
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||||
|
|
||||||
|
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def null_provider() -> FrameProvider:
|
||||||
|
return _NullProvider()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoFrameProvider:
|
||||||
|
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
|
||||||
|
|
||||||
|
The first camera key is used unconditionally — Module 1/2/3 prompts care
|
||||||
|
about *what is happening*, not which camera angle the model sees, so a
|
||||||
|
single canonical viewpoint is enough. Override ``camera_key`` if you
|
||||||
|
want a specific stream.
|
||||||
|
|
||||||
|
Caches up to ``cache_size`` decoded frames per process to keep
|
||||||
|
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
root: Path
|
||||||
|
camera_key: str | None = None
|
||||||
|
tolerance_s: float = 1e-2
|
||||||
|
cache_size: int = 256
|
||||||
|
_meta: Any = field(default=None, init=False, repr=False)
|
||||||
|
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||||
|
|
||||||
|
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||||
|
if self.camera_key is None:
|
||||||
|
keys = self._meta.video_keys
|
||||||
|
self.camera_key = keys[0] if keys else None
|
||||||
|
|
||||||
|
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||||
|
if not timestamps or self.camera_key is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
out: list[Any] = []
|
||||||
|
misses: list[float] = []
|
||||||
|
miss_indices: list[int] = []
|
||||||
|
for i, ts in enumerate(timestamps):
|
||||||
|
key = (record.episode_index, round(float(ts), 6))
|
||||||
|
cached = self._cache.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
out.append(cached)
|
||||||
|
else:
|
||||||
|
out.append(None)
|
||||||
|
misses.append(float(ts))
|
||||||
|
miss_indices.append(i)
|
||||||
|
|
||||||
|
if misses:
|
||||||
|
decoded = self._decode(record.episode_index, misses)
|
||||||
|
for i, img in zip(miss_indices, decoded, strict=True):
|
||||||
|
out[i] = img
|
||||||
|
key = (record.episode_index, round(float(timestamps[i]), 6))
|
||||||
|
if len(self._cache) >= self.cache_size:
|
||||||
|
self._cache.pop(next(iter(self._cache)))
|
||||||
|
self._cache[key] = img
|
||||||
|
# filter out any None left over from decode failures
|
||||||
|
return [img for img in out if img is not None]
|
||||||
|
|
||||||
|
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
|
||||||
|
from PIL import Image # noqa: PLC0415
|
||||||
|
|
||||||
|
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
||||||
|
|
||||||
|
ep = self._meta.episodes[episode_index]
|
||||||
|
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
|
||||||
|
shifted = [from_timestamp + ts for ts in timestamps]
|
||||||
|
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
|
||||||
|
try:
|
||||||
|
frames = decode_video_frames(
|
||||||
|
video_path,
|
||||||
|
shifted,
|
||||||
|
self.tolerance_s,
|
||||||
|
return_uint8=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
# frames: [N, C, H, W] uint8, RGB
|
||||||
|
out: list[Any] = []
|
||||||
|
arr = frames.cpu().numpy() if hasattr(frames, "cpu") else frames
|
||||||
|
for i in range(arr.shape[0]):
|
||||||
|
chw = arr[i]
|
||||||
|
hwc = chw.transpose(1, 2, 0)
|
||||||
|
out.append(Image.fromarray(hwc, mode="RGB"))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||||
|
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||||
|
try:
|
||||||
|
provider = VideoFrameProvider(root=root, camera_key=camera_key)
|
||||||
|
except Exception:
|
||||||
|
return null_provider()
|
||||||
|
if provider.camera_key is None:
|
||||||
|
return null_provider()
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
|
||||||
|
"""Convert PIL images to Qwen-VL-compatible content blocks."""
|
||||||
|
return [{"type": "image", "image": img} for img in images]
|
||||||
@@ -30,10 +30,11 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ..config import Module3Config
|
from ..config import Module3Config
|
||||||
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||||
from ..prompts import load as load_prompt
|
from ..prompts import load as load_prompt
|
||||||
from ..reader import EpisodeRecord
|
from ..reader import EpisodeRecord
|
||||||
from ..staging import EpisodeStaging
|
from ..staging import EpisodeStaging
|
||||||
@@ -83,6 +84,7 @@ class GeneralVqaModule:
|
|||||||
vlm: VlmClient
|
vlm: VlmClient
|
||||||
config: Module3Config
|
config: Module3Config
|
||||||
seed: int = 1729
|
seed: int = 1729
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled(self) -> bool:
|
def enabled(self) -> bool:
|
||||||
@@ -100,7 +102,7 @@ class GeneralVqaModule:
|
|||||||
for idx in anchor_idx:
|
for idx in anchor_idx:
|
||||||
ts = float(record.frame_timestamps[idx])
|
ts = float(record.frame_timestamps[idx])
|
||||||
qtype = rng.choice(self.config.question_types)
|
qtype = rng.choice(self.config.question_types)
|
||||||
qa = self._generate_one(record, qtype)
|
qa = self._generate_one(record, qtype, ts)
|
||||||
if qa is None:
|
if qa is None:
|
||||||
continue
|
continue
|
||||||
question, answer = qa
|
question, answer = qa
|
||||||
@@ -124,12 +126,16 @@ class GeneralVqaModule:
|
|||||||
)
|
)
|
||||||
staging.write("module_3", rows)
|
staging.write("module_3", rows)
|
||||||
|
|
||||||
def _generate_one(self, record: EpisodeRecord, question_type: str) -> tuple[str, dict[str, Any]] | None:
|
def _generate_one(
|
||||||
|
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||||
|
) -> tuple[str, dict[str, Any]] | None:
|
||||||
prompt = load_prompt("module_3_vqa").format(
|
prompt = load_prompt("module_3_vqa").format(
|
||||||
episode_task=record.episode_task,
|
episode_task=record.episode_task,
|
||||||
question_type=question_type,
|
question_type=question_type,
|
||||||
)
|
)
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
||||||
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
result = self.vlm.generate_json([messages])[0]
|
result = self.vlm.generate_json([messages])[0]
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -34,10 +34,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ..config import Module2Config
|
from ..config import Module2Config
|
||||||
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||||
from ..prompts import load as load_prompt
|
from ..prompts import load as load_prompt
|
||||||
from ..reader import EpisodeRecord
|
from ..reader import EpisodeRecord
|
||||||
from ..staging import EpisodeStaging
|
from ..staging import EpisodeStaging
|
||||||
@@ -58,6 +59,7 @@ class InterjectionsAndSpeechModule:
|
|||||||
vlm: VlmClient
|
vlm: VlmClient
|
||||||
config: Module2Config
|
config: Module2Config
|
||||||
seed: int = 1729
|
seed: int = 1729
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled(self) -> bool:
|
def enabled(self) -> bool:
|
||||||
@@ -106,7 +108,9 @@ class InterjectionsAndSpeechModule:
|
|||||||
current_subtask=current_subtask,
|
current_subtask=current_subtask,
|
||||||
timestamp=t_snap,
|
timestamp=t_snap,
|
||||||
)
|
)
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
images = self.frame_provider.frames_at(record, [t_snap])
|
||||||
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
result = self.vlm.generate_json([messages])[0]
|
result = self.vlm.generate_json([messages])[0]
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -18,10 +18,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from ..config import Module1Config
|
from ..config import Module1Config
|
||||||
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||||
from ..prompts import load as load_prompt
|
from ..prompts import load as load_prompt
|
||||||
from ..reader import EpisodeRecord, keyframe_indices
|
from ..reader import EpisodeRecord, keyframe_indices
|
||||||
from ..staging import EpisodeStaging
|
from ..staging import EpisodeStaging
|
||||||
@@ -53,6 +54,7 @@ class PlanSubtasksMemoryModule:
|
|||||||
|
|
||||||
vlm: VlmClient
|
vlm: VlmClient
|
||||||
config: Module1Config
|
config: Module1Config
|
||||||
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled(self) -> bool:
|
def enabled(self) -> bool:
|
||||||
@@ -150,6 +152,8 @@ class PlanSubtasksMemoryModule:
|
|||||||
return []
|
return []
|
||||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||||
keyframe_local = keyframe_indices(record, self.config.keyframes_per_episode)
|
keyframe_local = keyframe_indices(record, self.config.keyframes_per_episode)
|
||||||
|
keyframe_ts = [float(record.frame_timestamps[i]) for i in keyframe_local]
|
||||||
|
images = self.frame_provider.frames_at(record, keyframe_ts)
|
||||||
prompt = load_prompt("module_1_subtasks").format(
|
prompt = load_prompt("module_1_subtasks").format(
|
||||||
episode_task=record.episode_task,
|
episode_task=record.episode_task,
|
||||||
num_keyframes=len(keyframe_local),
|
num_keyframes=len(keyframe_local),
|
||||||
@@ -157,7 +161,8 @@ class PlanSubtasksMemoryModule:
|
|||||||
max_steps=self.config.plan_max_steps,
|
max_steps=self.config.plan_max_steps,
|
||||||
episode_duration=f"{episode_duration:.3f}",
|
episode_duration=f"{episode_duration:.3f}",
|
||||||
)
|
)
|
||||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
result = self.vlm.generate_json([messages])[0]
|
result = self.vlm.generate_json([messages])[0]
|
||||||
spans = result.get("subtasks") if isinstance(result, dict) else None
|
spans = result.get("subtasks") if isinstance(result, dict) else None
|
||||||
if not spans:
|
if not spans:
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||||
) from exc
|
) from exc
|
||||||
llm = LLM(model=config.model_id)
|
llm = LLM(model=config.model_id, tensor_parallel_size=config.tensor_parallel_size)
|
||||||
|
|
||||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||||
params = SamplingParams(
|
params = SamplingParams(
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||||
|
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||||
from lerobot.annotations.steerable_pipeline.modules import (
|
from lerobot.annotations.steerable_pipeline.modules import (
|
||||||
GeneralVqaModule,
|
GeneralVqaModule,
|
||||||
InterjectionsAndSpeechModule,
|
InterjectionsAndSpeechModule,
|
||||||
@@ -64,9 +65,12 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
|||||||
logger.info("annotate: root=%s", root)
|
logger.info("annotate: root=%s", root)
|
||||||
|
|
||||||
vlm = make_vlm_client(cfg.vlm)
|
vlm = make_vlm_client(cfg.vlm)
|
||||||
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1)
|
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key)
|
||||||
module_2 = InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed)
|
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1, frame_provider=frame_provider)
|
||||||
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed)
|
module_2 = InterjectionsAndSpeechModule(
|
||||||
|
vlm=vlm, config=cfg.module_2, seed=cfg.seed, frame_provider=frame_provider
|
||||||
|
)
|
||||||
|
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
|
||||||
writer = LanguageColumnsWriter()
|
writer = LanguageColumnsWriter()
|
||||||
validator = StagingValidator()
|
validator = StagingValidator()
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.annotations.steerable_pipeline.config import (
|
from lerobot.annotations.steerable_pipeline.config import (
|
||||||
Module1Config,
|
Module1Config,
|
||||||
@@ -32,10 +34,31 @@ from lerobot.annotations.steerable_pipeline.modules import (
|
|||||||
)
|
)
|
||||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||||
|
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||||
|
|
||||||
from ._helpers import make_canned_responder
|
from ._helpers import make_canned_responder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _StubFrameProvider:
|
||||||
|
"""Returns one sentinel object per requested timestamp."""
|
||||||
|
|
||||||
|
sentinel: Any = field(default_factory=lambda: object())
|
||||||
|
calls: list[tuple[int, tuple[float, ...]]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def frames_at(self, record, timestamps):
|
||||||
|
self.calls.append((record.episode_index, tuple(timestamps)))
|
||||||
|
return [self.sentinel] * len(timestamps)
|
||||||
|
|
||||||
|
|
||||||
|
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
|
||||||
|
def responder(messages):
|
||||||
|
captured.append(list(messages))
|
||||||
|
return reply
|
||||||
|
|
||||||
|
return StubVlmClient(responder=responder)
|
||||||
|
|
||||||
|
|
||||||
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||||
vlm = make_canned_responder(
|
vlm = make_canned_responder(
|
||||||
{
|
{
|
||||||
@@ -145,6 +168,39 @@ def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path)
|
|||||||
assert ts in frame_set
|
assert ts in frame_set
|
||||||
|
|
||||||
|
|
||||||
|
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
|
"""Each VQA prompt must carry a single image block at the emission frame."""
|
||||||
|
captured: list[list[dict[str, Any]]] = []
|
||||||
|
payload = {
|
||||||
|
"question": "How many cups?",
|
||||||
|
"answer": {"label": "cup", "count": 1},
|
||||||
|
}
|
||||||
|
provider = _StubFrameProvider()
|
||||||
|
module = GeneralVqaModule(
|
||||||
|
vlm=_spy_responder(captured, payload),
|
||||||
|
config=Module3Config(vqa_emission_hz=1.0, K=1),
|
||||||
|
seed=0,
|
||||||
|
frame_provider=provider,
|
||||||
|
)
|
||||||
|
record = next(iter_episodes(single_episode_root))
|
||||||
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
module.run_episode(record, staging)
|
||||||
|
|
||||||
|
assert captured, "no VLM calls made"
|
||||||
|
for messages in captured:
|
||||||
|
content = messages[0]["content"]
|
||||||
|
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||||
|
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||||
|
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||||
|
assert image_blocks[0]["image"] is provider.sentinel
|
||||||
|
assert len(text_blocks) == 1
|
||||||
|
# provider was called once per emission with the exact emission timestamp
|
||||||
|
for ep_idx, ts_tuple in provider.calls:
|
||||||
|
assert ep_idx == record.episode_index
|
||||||
|
assert len(ts_tuple) == 1
|
||||||
|
assert ts_tuple[0] in record.frame_timestamps
|
||||||
|
|
||||||
|
|
||||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
payload = {
|
payload = {
|
||||||
"question": "Where is the cup?",
|
"question": "Where is the cup?",
|
||||||
|
|||||||
Reference in New Issue
Block a user