diff --git a/src/lerobot/annotations/steerable_pipeline/frames.py b/src/lerobot/annotations/steerable_pipeline/frames.py index c015a146f..8602d2a28 100644 --- a/src/lerobot/annotations/steerable_pipeline/frames.py +++ b/src/lerobot/annotations/steerable_pipeline/frames.py @@ -34,10 +34,29 @@ from .reader import EpisodeRecord class FrameProvider(Protocol): """Decodes camera frames at episode-relative timestamps.""" - def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]: - """Return one PIL.Image per timestamp; empty list if no camera available.""" + @property + def camera_keys(self) -> list[str]: + """All ``observation.images.*`` feature keys this provider can decode.""" - def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]: + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + """Return one PIL.Image per timestamp from ``camera_key`` (or default). + + Empty list if the camera is unavailable. ``camera_key=None`` falls back + to the provider's default camera so existing single-camera callers + (Module 1, Module 2) keep working unchanged. + """ + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: """Return up to ``max_frames`` PIL images covering the whole episode. Sampling is uniform across the episode duration. The returned list is @@ -51,10 +70,24 @@ class FrameProvider(Protocol): class _NullProvider: """No-op provider used when the dataset has no video keys or in tests.""" - def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]: + @property + def camera_keys(self) -> list[str]: return [] - def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]: + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + return [] + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: return [] @@ -64,12 +97,18 @@ def null_provider() -> FrameProvider: @dataclass class VideoFrameProvider: - """Decodes frames from the dataset's first ``observation.images.*`` stream. + """Decodes frames from the dataset's ``observation.images.*`` streams. - The first camera key is used unconditionally — Module 1/2/3 prompts care - about *what is happening*, not which camera angle the model sees, so a - single canonical viewpoint is enough. Override ``camera_key`` if you - want a specific stream. + By default the *first* camera key is used for Module 1 (subtask + decomposition) and Module 2 (interjection scenarios) — those prompts care + about *what is happening*, not which angle. Module 3 (VQA) instead + iterates over every camera in :attr:`camera_keys` so each frame's + grounded answer (bbox/keypoint/...) is tagged with the camera it was + grounded against. + + ``camera_key`` overrides the default-camera choice but does not restrict + :attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` / + ``video_for_episode`` to read a non-default stream. Caches up to ``cache_size`` decoded frames per process to keep co-timestamped Module 2 + Module 1 plan-update calls cheap. @@ -81,24 +120,37 @@ class VideoFrameProvider: cache_size: int = 256 _meta: Any = field(default=None, init=False, repr=False) _cache: dict = field(default_factory=dict, init=False, repr=False) + _camera_keys: list[str] = field(default_factory=list, init=False, repr=False) def __post_init__(self) -> None: from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415 self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root) + keys = list(self._meta.video_keys or []) + self._camera_keys = keys if self.camera_key is None: - keys = self._meta.video_keys self.camera_key = keys[0] if keys else None - def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]: - if not timestamps or self.camera_key is None: + @property + def camera_keys(self) -> list[str]: + """All ``observation.images.*`` keys available on this dataset.""" + return list(self._camera_keys) + + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + target = camera_key if camera_key is not None else self.camera_key + if not timestamps or target is None: return [] out: list[Any] = [] misses: list[float] = [] miss_indices: list[int] = [] for i, ts in enumerate(timestamps): - key = (record.episode_index, round(float(ts), 6)) + key = (record.episode_index, target, round(float(ts), 6)) cached = self._cache.get(key) if cached is not None: out.append(cached) @@ -108,20 +160,22 @@ class VideoFrameProvider: miss_indices.append(i) if misses: - decoded = self._decode(record.episode_index, misses) + decoded = self._decode(record.episode_index, misses, target) # decoder may return fewer frames than requested when some # timestamps fall outside the video; pair what we have and # leave the rest as None to be filtered below. for i, img in zip(miss_indices, decoded): out[i] = img - key = (record.episode_index, round(float(timestamps[i]), 6)) + key = (record.episode_index, target, round(float(timestamps[i]), 6)) if len(self._cache) >= self.cache_size: self._cache.pop(next(iter(self._cache))) self._cache[key] = img # filter out any None left over from decode failures return [img for img in out if img is not None] - def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]: + def _decode( + self, episode_index: int, timestamps: list[float], camera_key: str + ) -> list[Any]: import os as _os # noqa: PLC0415 from PIL import Image # noqa: PLC0415 @@ -129,9 +183,9 @@ class VideoFrameProvider: from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415 ep = self._meta.episodes[episode_index] - from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"] + from_timestamp = ep[f"videos/{camera_key}/from_timestamp"] shifted = [from_timestamp + ts for ts in timestamps] - video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key) + video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key) # ``torchcodec`` import currently bad-allocs on cu128/torch-2.8 in # some environments; default to ``pyav`` (always available via # the ``av`` package) and let users override with @@ -156,13 +210,19 @@ class VideoFrameProvider: out.append(Image.fromarray(hwc, mode="RGB")) return out - def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]: + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: """Return up to ``max_frames`` images uniformly sampled across the episode. The whole episode duration is covered; the model picks subtask boundaries from the temporal pooling it does internally. """ - if max_frames <= 0 or self.camera_key is None or not record.frame_timestamps: + target = camera_key if camera_key is not None else self.camera_key + if max_frames <= 0 or target is None or not record.frame_timestamps: return [] n_frames = min(max_frames, len(record.frame_timestamps)) if n_frames == len(record.frame_timestamps): @@ -175,7 +235,7 @@ class VideoFrameProvider: else: step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0 timestamps = [float(t0 + i * step) for i in range(n_frames)] - return self.frames_at(record, timestamps) + return self.frames_at(record, timestamps, camera_key=target) def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider: diff --git a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py index df34a2772..2fe71d5dc 100644 --- a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py +++ b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py @@ -16,8 +16,15 @@ """Module 3: general VQA at a timed cadence. Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per -emission so each frame gets at most one ``(vqa, user)`` and one -``(vqa, assistant)`` pair — keeps the resolver contract scalar. +emission. For datasets with multiple cameras, every emission tick produces +one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is +generated against that camera's frame and stamped with the matching +``camera`` field on the emitted rows. The resolver disambiguates via +``camera=...``; recipes that consume VQA do so through one sub-recipe +per camera (see ``recipes/pi05_hirobot.yaml``). + +Within a single (frame, camera) we still emit at most one ``(vqa, user)`` +and one ``(vqa, assistant)`` row, so the resolver contract stays scalar. Question types covered (per the plan's Module 3 table): bbox, keypoint, count, attribute, spatial. The assistant's ``content`` is a JSON string @@ -98,23 +105,37 @@ class GeneralVqaModule: anchor_idx = _emission_anchor_indices( record.frame_timestamps, self.config.vqa_emission_hz, self.config.K ) - # Build all messages first, then issue them as a single batched - # generate_json call so the client can fan them out concurrently. - per_call: list[tuple[float, str, list[dict[str, Any]]]] = [] + cameras = self._target_cameras() + if not cameras: + # No camera available — keep behaviour parity with previous + # text-only stub: emit nothing rather than producing untagged + # rows that would fail validation. + staging.write("module_3", []) + return + + # Build all messages first (one per (frame, camera)), then issue them + # as a single batched generate_json call so the client can fan them + # out concurrently. + per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = [] for idx in anchor_idx: ts = float(record.frame_timestamps[idx]) qtype = rng.choice(self.config.question_types) - messages = self._build_messages(record, qtype, ts) - per_call.append((ts, qtype, messages)) + for camera in cameras: + messages = self._build_messages(record, qtype, ts, camera) + # Skip cameras that decoded to zero frames at this ts: no point + # asking the VLM to ground a bbox without an image. + if not _has_image_block(messages): + continue + per_call.append((ts, camera, qtype, messages)) if not per_call: staging.write("module_3", []) return - results = self.vlm.generate_json([m for _, _, m in per_call]) + results = self.vlm.generate_json([m for _, _, _, m in per_call]) rows: list[dict[str, Any]] = [] - for (ts, _qtype, _messages), result in zip(per_call, results): + for (ts, camera, _qtype, _messages), result in zip(per_call, results): qa = self._postprocess(result) if qa is None: continue @@ -125,6 +146,7 @@ class GeneralVqaModule: "content": question, "style": "vqa", "timestamp": ts, + "camera": camera, "tool_calls": None, } ) @@ -134,19 +156,35 @@ class GeneralVqaModule: "content": json.dumps(answer, sort_keys=True), "style": "vqa", "timestamp": ts, + "camera": camera, "tool_calls": None, } ) staging.write("module_3", rows) + def _target_cameras(self) -> list[str]: + """Return the cameras Module 3 should iterate per emission tick. + + Defaults to every camera the provider exposes. Datasets with no + cameras (or test/null providers) yield an empty list, which makes + ``run_episode`` a no-op. + """ + return list(getattr(self.frame_provider, "camera_keys", []) or []) + def _build_messages( - self, record: EpisodeRecord, question_type: str, frame_timestamp: float + self, + record: EpisodeRecord, + question_type: str, + frame_timestamp: float, + camera_key: str, ) -> list[dict[str, Any]]: prompt = load_prompt("module_3_vqa").format( episode_task=record.episode_task, question_type=question_type, ) - images = self.frame_provider.frames_at(record, [frame_timestamp]) + images = self.frame_provider.frames_at( + record, [frame_timestamp], camera_key=camera_key + ) content = [*to_image_blocks(images), {"type": "text", "text": prompt}] return [{"role": "user", "content": content}] @@ -166,8 +204,24 @@ class GeneralVqaModule: return question.strip(), answer def _generate_one( - self, record: EpisodeRecord, question_type: str, frame_timestamp: float + self, + record: EpisodeRecord, + question_type: str, + frame_timestamp: float, + camera_key: str, ) -> tuple[str, dict[str, Any]] | None: - messages = self._build_messages(record, question_type, frame_timestamp) + messages = self._build_messages(record, question_type, frame_timestamp, camera_key) result = self.vlm.generate_json([messages])[0] return self._postprocess(result) + + +def _has_image_block(messages: list[dict[str, Any]]) -> bool: + """Return True if any user content block is a populated image block.""" + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for block in content: + if isinstance(block, dict) and block.get("type") == "image": + return True + return False diff --git a/src/lerobot/annotations/steerable_pipeline/validator.py b/src/lerobot/annotations/steerable_pipeline/validator.py index ccc79bc38..a847ba29f 100644 --- a/src/lerobot/annotations/steerable_pipeline/validator.py +++ b/src/lerobot/annotations/steerable_pipeline/validator.py @@ -43,6 +43,8 @@ from lerobot.datasets.language import ( LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, column_for_style, + is_view_dependent_style, + validate_camera_field, ) from .reader import EpisodeRecord @@ -98,6 +100,11 @@ class StagingValidator: """Walks the staging tree and produces a :class:`ValidationReport`.""" timestamp_atol: float = 0.0 # exact-match by default + dataset_camera_keys: tuple[str, ...] | None = None + """Known ``observation.images.*`` keys on the dataset. When set, the + validator additionally enforces that every view-dependent row's + ``camera`` field references one of these keys. Pass ``None`` (default) + to skip that cross-check (e.g. in unit tests with no real dataset).""" def validate( self, @@ -130,6 +137,9 @@ class StagingValidator: persistent: list[dict[str, Any]] = [] for row in all_rows: self._check_column_routing(row, report, record.episode_index) + self._check_camera_field( + row, report, record.episode_index, self.dataset_camera_keys + ) if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT: persistent.append(row) else: @@ -141,6 +151,59 @@ class StagingValidator: self._check_speech_interjection_pairs(events, report, record.episode_index) self._check_plan_memory_consistency(persistent, events, report, record.episode_index) self._check_vqa_json(events, report, record.episode_index) + self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index) + + def _check_camera_field( + self, + row: dict[str, Any], + report: ValidationReport, + episode_index: int, + dataset_camera_keys: Sequence[str] | None, + ) -> None: + """Enforce the camera invariant + that the key matches the dataset's cameras.""" + style = row.get("style") + camera = row.get("camera") + try: + validate_camera_field(style, camera) + except ValueError as exc: + report.add_error( + f"ep={episode_index} module={row.get('_module')}: {exc}" + ) + return + if ( + is_view_dependent_style(style) + and dataset_camera_keys + and camera not in dataset_camera_keys + ): + report.add_error( + f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style " + f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}" + ) + + def _check_vqa_uniqueness_per_frame_camera( + self, + events: Iterable[dict[str, Any]], + report: ValidationReport, + episode_index: int, + ) -> None: + """Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera).""" + counts: dict[tuple[float, str, str], int] = {} + for row in events: + if row.get("style") != "vqa": + continue + ts = row.get("timestamp") + camera = row.get("camera") + role = row.get("role") + if ts is None or camera is None or role is None: + continue # other validators flag these + key = (float(ts), str(camera), str(role)) + counts[key] = counts.get(key, 0) + 1 + for (ts, camera, role), n in counts.items(): + if n > 1: + report.add_error( + f"ep={episode_index}: {n} duplicate vqa rows at t={ts} " + f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)" + ) def _check_column_routing( self, diff --git a/src/lerobot/annotations/steerable_pipeline/writer.py b/src/lerobot/annotations/steerable_pipeline/writer.py index c83a2b168..b440201a5 100644 --- a/src/lerobot/annotations/steerable_pipeline/writer.py +++ b/src/lerobot/annotations/steerable_pipeline/writer.py @@ -55,6 +55,7 @@ from lerobot.datasets.language import ( LANGUAGE_PERSISTENT, PERSISTENT_STYLES, column_for_style, + validate_camera_field, ) from .reader import EpisodeRecord @@ -88,7 +89,11 @@ def _row_persistent_sort_key(row: dict[str, Any]) -> tuple: def _row_event_sort_key(row: dict[str, Any]) -> tuple: # events are bucketed per-frame, but within a frame we still want determinism - return (row.get("style") or "", row.get("role") or "") + return ( + row.get("style") or "", + row.get("role") or "", + row.get("camera") or "", + ) def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]: @@ -101,11 +106,14 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]: ) if "timestamp" not in row: raise ValueError(f"persistent row missing timestamp: {row!r}") + camera = row.get("camera") + validate_camera_field(style, camera) return { "role": str(row["role"]), "content": None if row.get("content") is None else str(row["content"]), "style": style, "timestamp": float(row["timestamp"]), + "camera": None if camera is None else str(camera), "tool_calls": _normalize_tool_calls(row.get("tool_calls")), } @@ -119,10 +127,13 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]: ) if column_for_style(style) != LANGUAGE_EVENTS: raise ValueError(f"event row with style {style!r} would not route to language_events") + camera = row.get("camera") + validate_camera_field(style, camera) return { "role": str(row["role"]), "content": None if row.get("content") is None else str(row["content"]), "style": style, + "camera": None if camera is None else str(camera), "tool_calls": _normalize_tool_calls(row.get("tool_calls")), } @@ -311,6 +322,7 @@ def speech_atom(timestamp: float, text: str) -> dict[str, Any]: "content": None, "style": None, "timestamp": float(timestamp), + "camera": None, "tool_calls": [ { "type": "function", diff --git a/src/lerobot/scripts/lerobot_annotate.py b/src/lerobot/scripts/lerobot_annotate.py index 6c2ccd72b..61790b1bb 100644 --- a/src/lerobot/scripts/lerobot_annotate.py +++ b/src/lerobot/scripts/lerobot_annotate.py @@ -72,7 +72,9 @@ def annotate(cfg: AnnotationPipelineConfig) -> None: ) module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider) writer = LanguageColumnsWriter() - validator = StagingValidator() + validator = StagingValidator( + dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None, + ) executor = Executor( config=cfg, diff --git a/tests/annotations/test_modules.py b/tests/annotations/test_modules.py index ec7116556..7e7ecace3 100644 --- a/tests/annotations/test_modules.py +++ b/tests/annotations/test_modules.py @@ -44,15 +44,20 @@ class _StubFrameProvider: """Returns one sentinel object per requested timestamp.""" sentinel: Any = field(default_factory=lambda: object()) - calls: list[tuple[int, tuple[float, ...]]] = field(default_factory=list) - video_calls: list[tuple[int, int]] = field(default_factory=list) + cameras: tuple[str, ...] = ("observation.images.top",) + calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list) + video_calls: list[tuple[int, int, str | None]] = field(default_factory=list) - def frames_at(self, record, timestamps): - self.calls.append((record.episode_index, tuple(timestamps))) + @property + def camera_keys(self) -> list[str]: + return list(self.cameras) + + def frames_at(self, record, timestamps, camera_key=None): + self.calls.append((record.episode_index, tuple(timestamps), camera_key)) return [self.sentinel] * len(timestamps) - def video_for_episode(self, record, max_frames): - self.video_calls.append((record.episode_index, max_frames)) + def video_for_episode(self, record, max_frames, camera_key=None): + self.video_calls.append((record.episode_index, max_frames, camera_key)) n = min(max_frames, len(record.frame_timestamps)) return [self.sentinel] * n @@ -148,7 +153,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech( assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches) -def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) -> None: +def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None: payload = { "question": "How many cups?", "answer": {"label": "cup", "count": 2, "note": "white & blue"}, @@ -158,19 +163,34 @@ def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) vlm=vlm, config=Module3Config(vqa_emission_hz=1.0, K=3), seed=1, + frame_provider=_StubFrameProvider( + cameras=("observation.images.top", "observation.images.wrist") + ), ) record = next(iter_episodes(single_episode_root)) staging = EpisodeStaging(tmp_path / "stage", record.episode_index) module.run_episode(record, staging) rows = staging.read("module_3") - user_ts = [r["timestamp"] for r in rows if r["role"] == "user" and r["style"] == "vqa"] - assistant_ts = [r["timestamp"] for r in rows if r["role"] == "assistant" and r["style"] == "vqa"] - # at most one user (vqa) per frame; same for assistant - assert len(user_ts) == len(set(user_ts)) - assert len(assistant_ts) == len(set(assistant_ts)) + # every vqa row must carry a camera tag and one of the configured cameras + for r in rows: + assert r["style"] == "vqa" + assert r.get("camera") in {"observation.images.top", "observation.images.wrist"} + # at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera) + user_keys = [ + (r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa" + ] + assistant_keys = [ + (r["timestamp"], r["camera"]) + for r in rows + if r["role"] == "assistant" and r["style"] == "vqa" + ] + assert len(user_keys) == len(set(user_keys)) + assert len(assistant_keys) == len(set(assistant_keys)) + # both cameras must be represented + assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"} # every emitted timestamp must be an exact source frame timestamp frame_set = set(record.frame_timestamps) - for ts in user_ts + assistant_ts: + for ts, _ in user_keys + assistant_keys: assert ts in frame_set @@ -254,11 +274,12 @@ def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}" assert image_blocks[0]["image"] is provider.sentinel assert len(text_blocks) == 1 - # provider was called once per emission with the exact emission timestamp - for ep_idx, ts_tuple in provider.calls: + # provider was called once per emission per camera with the exact emission timestamp + for ep_idx, ts_tuple, camera in provider.calls: assert ep_idx == record.episode_index assert len(ts_tuple) == 1 assert ts_tuple[0] in record.frame_timestamps + assert camera in provider.cameras def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None: @@ -271,6 +292,7 @@ def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_ vlm=vlm, config=Module3Config(vqa_emission_hz=1.0, K=2), seed=2, + frame_provider=_StubFrameProvider(), ) record = next(iter_episodes(single_episode_root)) staging = EpisodeStaging(tmp_path / "stage", record.episode_index)