From 5a6aa64570e9c16841edbf7c9ff16f87eba1e869 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 30 Apr 2026 10:48:17 +0200 Subject: [PATCH] feat(language): per-camera tagging on view-dependent styles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a nullable `camera` field to the language row struct (both persistent and event variants) so view-dependent styles like `vqa` can carry which `observation.images.*` view they were grounded against. Without this, multi-camera datasets ended up with multiple `(vqa, role)` rows at the same timestamp that the resolver could not disambiguate. - `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS, to both Arrow struct types and the HF datasets feature mappings; introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus `is_view_dependent_style` and `validate_camera_field` helpers (camera required iff style is view-dependent). - `language_render.py`: thread an optional `camera=` kwarg through every resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`. Without a `camera` filter, multi-row matches keep raising the existing ambiguity error — which is the desired behaviour on multi-camera data. - `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying the matching image block), keeping the original 0.20 budget and documenting the customization point for datasets with different cameras. - Tests: schema test asserts the new field order; new tests cover `is_view_dependent_style`, `validate_camera_field` (both required and forbidden directions), per-camera `emitted_at` filtering, and the ambiguity error when two cameras emit `(vqa, assistant)` at the same timestamp without a `camera=` filter. RenderMessagesStep + dataset passthrough fixtures updated to include the new field. - `docs/source/language_and_recipes.mdx`: document the `camera` field, the per-camera resolver pattern, and the canonical recipe convention. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/source/language_and_recipes.mdx | 40 ++++- src/lerobot/configs/recipes/pi05_hirobot.yaml | 33 ++++- src/lerobot/datasets/language.py | 37 ++++- src/lerobot/datasets/language_render.py | 73 ++++++--- tests/configs/test_recipe.py | 3 +- tests/datasets/test_language.py | 51 ++++++- tests/datasets/test_language_render.py | 139 +++++++++++++++++- .../test_render_messages_processor.py | 1 + 8 files changed, 344 insertions(+), 33 deletions(-) diff --git a/docs/source/language_and_recipes.mdx b/docs/source/language_and_recipes.mdx index 135aa6301..952b6ef09 100644 --- a/docs/source/language_and_recipes.mdx +++ b/docs/source/language_and_recipes.mdx @@ -6,16 +6,24 @@ The two optional columns are: - `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`. - `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls. -Both columns share the same row shape: +Both columns share the same row shape (event rows omit `timestamp` because the +frame the row sits on already provides it): ```text role: string content: string | null style: string | null -timestamp: float64 +timestamp: float64 # persistent rows only +camera: string | null # observation.images.* feature key, view-dependent rows only tool_calls: list[Json] | null ``` +The `camera` field tags rows whose `content` is grounded in a specific camera +view. Rows of view-dependent styles (`vqa`, and the reserved `motion` / +`trace`) MUST set `camera` to the matching `observation.images.*` feature key. +Rows of every other style MUST leave `camera` as `null`. Pipeline writers and +the validator enforce this via `validate_camera_field(style, camera)`. + `meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations. ## Architecture @@ -39,11 +47,37 @@ Persistent styles are active after emission until replaced: Event styles only exist on their exact timestamp: - `emitted_at(t, style=interjection)` -- `emitted_at(t, style=vqa, role=user)` +- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)` - `emitted_at(t, role=assistant, tool_name=say)` Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data. +## View-dependent resolution + +For view-dependent styles (`vqa`, `motion`, `trace`), the resolver gains a +`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple +cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per +camera at the same timestamp; without `camera=`, those resolvers see two +matches and raise an ambiguity error. Recipes consume each camera through its +own binding plus a matching image block, e.g. + +```yaml +ask_vqa_top: + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)" + messages: + - role: user + stream: high_level + if_present: vqa_query + content: + - { type: image, feature: observation.images.top } + - { type: text, text: "${vqa_query}" } + - { role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa } +``` + +Add one such sub-recipe per camera the dataset records. + ## Recipe anatomy Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. diff --git a/src/lerobot/configs/recipes/pi05_hirobot.yaml b/src/lerobot/configs/recipes/pi05_hirobot.yaml index 3dbfb44be..7cd6b009f 100644 --- a/src/lerobot/configs/recipes/pi05_hirobot.yaml +++ b/src/lerobot/configs/recipes/pi05_hirobot.yaml @@ -40,8 +40,35 @@ blend: - {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level} - {role: assistant, content: "${subtask}", stream: low_level, target: true} - ask_vqa: - weight: 0.20 + # VQA is view-dependent: bbox / keypoint / count answers only make sense for + # the camera they were grounded against. Each camera gets its own sub-recipe + # so the resolver can disambiguate via `camera=...` and the user-turn carries + # the matching image block. Adjust the camera keys (and add more sub-recipes) + # to match the cameras present on your dataset. + ask_vqa_top: + weight: 0.10 + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)" messages: - - {role: user, content: "${vqa_query}", stream: high_level, if_present: vqa_query} + - role: user + stream: high_level + if_present: vqa_query + content: + - {type: image, feature: observation.images.top} + - {type: text, text: "${vqa_query}"} + - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} + + ask_vqa_wrist: + weight: 0.10 + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)" + messages: + - role: user + stream: high_level + if_present: vqa_query + content: + - {type: image, feature: observation.images.wrist} + - {type: text, text: "${vqa_query}"} - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} diff --git a/src/lerobot/datasets/language.py b/src/lerobot/datasets/language.py index b8cc649bf..69866e6bd 100644 --- a/src/lerobot/datasets/language.py +++ b/src/lerobot/datasets/language.py @@ -24,8 +24,8 @@ import pyarrow as pa LANGUAGE_PERSISTENT = "language_persistent" LANGUAGE_EVENTS = "language_events" LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS) -PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls") -EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls") +PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls") +EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls") CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"} EXTENDED_STYLES = set() @@ -34,6 +34,11 @@ STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"} EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"} +# Styles whose ``content`` is grounded in a specific camera view. Rows of these +# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*`` +# feature key. Rows of every other style MUST have ``camera=None``. +VIEW_DEPENDENT_STYLES = {"vqa", "motion", "trace"} + LanguageColumn = Literal["language_persistent", "language_events"] @@ -59,6 +64,7 @@ def language_persistent_row_arrow_type() -> pa.StructType: pa.field("content", pa.string(), nullable=True), pa.field("style", pa.string(), nullable=True), pa.field("timestamp", pa.float64(), nullable=False), + pa.field("camera", pa.string(), nullable=True), pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True), ] ) @@ -75,6 +81,7 @@ def language_event_row_arrow_type() -> pa.StructType: pa.field("role", pa.string(), nullable=False), pa.field("content", pa.string(), nullable=True), pa.field("style", pa.string(), nullable=True), + pa.field("camera", pa.string(), nullable=True), pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True), ] ) @@ -97,6 +104,7 @@ def language_persistent_row_feature() -> dict[str, object]: "content": datasets.Value("string"), "style": datasets.Value("string"), "timestamp": datasets.Value("float64"), + "camera": datasets.Value("string"), "tool_calls": datasets.List(_json_feature()), } @@ -107,6 +115,7 @@ def language_event_row_feature() -> dict[str, object]: "role": datasets.Value("string"), "content": datasets.Value("string"), "style": datasets.Value("string"), + "camera": datasets.Value("string"), "tool_calls": datasets.List(_json_feature()), } @@ -134,6 +143,30 @@ def is_language_column(key: str) -> bool: return key in LANGUAGE_COLUMNS +def is_view_dependent_style(style: str | None) -> bool: + """Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key.""" + return style in VIEW_DEPENDENT_STYLES + + +def validate_camera_field(style: str | None, camera: str | None) -> None: + """Enforce the ``camera`` invariant: required iff ``style`` is view-dependent. + + Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if + a non-view-dependent style carries one. Pipeline writers and the validator + should call this on every emitted row. + """ + if is_view_dependent_style(style): + if not camera: + raise ValueError( + f"Rows of view-dependent style {style!r} require a non-empty 'camera' " + f"field referencing an 'observation.images.*' feature key." + ) + elif camera is not None: + raise ValueError( + f"Rows of style {style!r} must have camera=None; got camera={camera!r}." + ) + + def column_for_style(style: str | None) -> LanguageColumn: """Map a language style to the column where rows of that style are stored. diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index 954ac8141..42cd03a9a 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -46,18 +46,23 @@ def active_at( style: str | None = None, role: str | None = None, tool_name: str | None = None, + camera: str | None = None, ) -> LanguageRow | None: """Return the persistent row of ``style`` that is active at time ``t``. A persistent row is "active" at ``t`` when its own ``timestamp`` is the - most recent one ``<= t`` for the given ``style``/``role``/``tool_name`` - selector. ``events`` is accepted for resolver-signature uniformity but is - not consulted: only persistent styles are valid here. + most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/ + ``camera`` selector. ``events`` is accepted for resolver-signature + uniformity but is not consulted: only persistent styles are valid here. """ _validate_persistent_resolver("active_at", style) - matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name) + matches = _matching_rows( + persistent, style=style, role=role, tool_name=tool_name, camera=camera + ) matches = [row for row in matches if _timestamp(row) <= t] - return _select_latest(matches, style=style, role=role, tool_name=tool_name) + return _select_latest( + matches, style=style, role=role, tool_name=tool_name, camera=camera + ) def emitted_at( @@ -68,26 +73,45 @@ def emitted_at( style: str | None = None, role: str | None = None, tool_name: str | None = None, + camera: str | None = None, ) -> LanguageRow | None: """Return the row of ``style`` emitted at exactly time ``t``. For persistent styles, this matches persistent rows whose own ``timestamp`` equals ``t``. For event styles, the ``events`` list is assumed to come from the dataset row at frame ``t`` (event rows carry no timestamp of their own), - so all matching event rows are considered emitted at ``t``. + so all matching event rows are considered emitted at ``t``. ``camera`` + filters by the row's ``camera`` field — required to disambiguate when + multiple view-dependent rows share ``(t, role)`` across cameras. """ column = column_for_style(style) if column == LANGUAGE_PERSISTENT: matches = [ row - for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name) + for row in _matching_rows( + persistent, style=style, role=role, tool_name=tool_name, camera=camera + ) if _timestamp(row) == t ] return _select_one( - matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key + matches, + style=style, + role=role, + tool_name=tool_name, + camera=camera, + sort_key=_persistent_sort_key, ) - matches = _matching_rows(events, style=style, role=role, tool_name=tool_name) - return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key) + matches = _matching_rows( + events, style=style, role=role, tool_name=tool_name, camera=camera + ) + return _select_one( + matches, + style=style, + role=role, + tool_name=tool_name, + camera=camera, + sort_key=_event_sort_key, + ) def nth_prev( @@ -99,12 +123,14 @@ def nth_prev( offset: int = 1, role: str | None = None, tool_name: str | None = None, + camera: str | None = None, ) -> LanguageRow | None: """Return the persistent row that was active ``offset`` steps before ``t``. Walks back through chronologically sorted persistent rows of ``style`` - (filtered by optional ``role``/``tool_name``) and returns the one ``offset`` - positions before the row active at ``t``. Only valid for persistent styles. + (filtered by optional ``role``/``tool_name``/``camera``) and returns the + one ``offset`` positions before the row active at ``t``. Only valid for + persistent styles. """ return _nth_relative( t, @@ -113,6 +139,7 @@ def nth_prev( offset=-offset, role=role, tool_name=tool_name, + camera=camera, resolver_name="nth_prev", ) @@ -126,12 +153,14 @@ def nth_next( offset: int = 1, role: str | None = None, tool_name: str | None = None, + camera: str | None = None, ) -> LanguageRow | None: """Return the persistent row that becomes active ``offset`` steps after ``t``. Walks forward through chronologically sorted persistent rows of ``style`` - (filtered by optional ``role``/``tool_name``) and returns the one ``offset`` - positions after the row active at ``t``. Only valid for persistent styles. + (filtered by optional ``role``/``tool_name``/``camera``) and returns the + one ``offset`` positions after the row active at ``t``. Only valid for + persistent styles. """ return _nth_relative( t, @@ -140,6 +169,7 @@ def nth_next( offset=offset, role=role, tool_name=tool_name, + camera=camera, resolver_name="nth_next", ) @@ -376,6 +406,7 @@ def _nth_relative( offset: int, role: str | None, tool_name: str | None, + camera: str | None, resolver_name: str, ) -> LanguageRow | None: """Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``.""" @@ -384,7 +415,7 @@ def _nth_relative( raise ValueError(f"{resolver_name} offset must be non-zero.") rows = sorted( - _matching_rows(persistent, style=style, role=role, tool_name=tool_name), + _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera), key=_persistent_sort_key, ) if not rows: @@ -420,14 +451,16 @@ def _matching_rows( style: str | None, role: str | None, tool_name: str | None, + camera: str | None, ) -> list[LanguageRow]: - """Return ``rows`` filtered by optional ``style``/``role``/``tool_name`` selectors.""" + """Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors.""" return [ row for row in rows if (style is None or row.get("style") == style) and (role is None or row.get("role") == role) and (tool_name is None or _row_has_tool_name(row, tool_name)) + and (camera is None or row.get("camera") == camera) ] @@ -437,6 +470,7 @@ def _select_latest( style: str | None, role: str | None, tool_name: str | None, + camera: str | None, ) -> LanguageRow | None: """Return the row tied for the latest ``timestamp`` (disambiguated by selectors).""" if not rows: @@ -448,6 +482,7 @@ def _select_latest( style=style, role=role, tool_name=tool_name, + camera=camera, sort_key=_persistent_sort_key, ) @@ -458,14 +493,16 @@ def _select_one( style: str | None, role: str | None, tool_name: str | None, + camera: str | None, sort_key: Any, ) -> LanguageRow | None: """Return the single matching row, or raise if the selectors are ambiguous.""" if not rows: return None - if len(rows) > 1 and role is None and tool_name is None: + if len(rows) > 1 and role is None and tool_name is None and camera is None: raise ValueError( - f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate." + f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., " + f"or camera=... to disambiguate." ) return sorted(rows, key=sort_key)[0] diff --git a/tests/configs/test_recipe.py b/tests/configs/test_recipe.py index e03f27a75..bd71d540f 100644 --- a/tests/configs/test_recipe.py +++ b/tests/configs/test_recipe.py @@ -26,6 +26,7 @@ def test_canonical_recipe_loads(): "user_interjection_response", "high_level_subtask", "low_level_execution", - "ask_vqa", + "ask_vqa_top", + "ask_vqa_wrist", } assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96) diff --git a/tests/datasets/test_language.py b/tests/datasets/test_language.py index 4f4066249..f01075d2c 100644 --- a/tests/datasets/test_language.py +++ b/tests/datasets/test_language.py @@ -13,10 +13,13 @@ from lerobot.datasets.language import ( LANGUAGE_PERSISTENT, PERSISTENT_STYLES, STYLE_REGISTRY, + VIEW_DEPENDENT_STYLES, column_for_style, + is_view_dependent_style, language_events_arrow_type, language_feature_info, language_persistent_arrow_type, + validate_camera_field, ) from lerobot.datasets.utils import DEFAULT_DATA_PATH @@ -26,10 +29,17 @@ def test_language_arrow_schema_has_expected_fields(): event_row_type = language_events_arrow_type().value_type assert isinstance(persistent_row_type, pa.StructType) - assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"] + assert persistent_row_type.names == [ + "role", + "content", + "style", + "timestamp", + "camera", + "tool_calls", + ] assert isinstance(event_row_type, pa.StructType) - assert event_row_type.names == ["role", "content", "style", "tool_calls"] + assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"] def test_style_registry_routes_columns(): @@ -47,6 +57,41 @@ def test_style_registry_routes_columns(): assert column_for_style(None) == LANGUAGE_EVENTS +def test_view_dependent_styles(): + assert {"vqa", "motion", "trace"} == VIEW_DEPENDENT_STYLES + assert is_view_dependent_style("vqa") + assert is_view_dependent_style("motion") + assert is_view_dependent_style("trace") + assert not is_view_dependent_style("subtask") + assert not is_view_dependent_style("plan") + assert not is_view_dependent_style("interjection") + assert not is_view_dependent_style(None) + + +def test_validate_camera_field_requires_camera_for_view_dependent_styles(): + validate_camera_field("vqa", "observation.images.top") + validate_camera_field("motion", "observation.images.wrist") + validate_camera_field("trace", "observation.images.front") + with pytest.raises(ValueError, match="view-dependent"): + validate_camera_field("vqa", None) + with pytest.raises(ValueError, match="view-dependent"): + validate_camera_field("motion", "") + + +def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles(): + validate_camera_field("subtask", None) + validate_camera_field("plan", None) + validate_camera_field("memory", None) + validate_camera_field("interjection", None) + validate_camera_field(None, None) + with pytest.raises(ValueError, match="must have camera=None"): + validate_camera_field("subtask", "observation.images.top") + with pytest.raises(ValueError, match="must have camera=None"): + validate_camera_field("interjection", "observation.images.top") + with pytest.raises(ValueError, match="must have camera=None"): + validate_camera_field(None, "observation.images.top") + + def test_unknown_style_rejected(): with pytest.raises(ValueError, match="Unknown language style"): column_for_style("surprise") @@ -70,6 +115,7 @@ def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot "content": "reach for the cup", "style": "subtask", "timestamp": 0.0, + "camera": None, "tool_calls": None, } ] @@ -77,6 +123,7 @@ def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot "role": "user", "content": "what is visible?", "style": "vqa", + "camera": "observation.images.top", "tool_calls": None, } data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py index 6ae22f211..ddf8ca263 100644 --- a/tests/datasets/test_language_render.py +++ b/tests/datasets/test_language_render.py @@ -8,21 +8,23 @@ from lerobot.configs.recipe import MessageTurn, TrainingRecipe from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample -def persistent_row(role, content, style, timestamp, tool_calls=None): +def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None): return { "role": role, "content": content, "style": style, "timestamp": timestamp, + "camera": camera, "tool_calls": tool_calls, } -def event_row(role, content, style, tool_calls=None): +def event_row(role, content, style, tool_calls=None, camera=None): return { "role": role, "content": content, "style": style, + "camera": camera, "tool_calls": tool_calls, } @@ -35,8 +37,8 @@ PERSISTENT = [ persistent_row("assistant", "subtask 1", "subtask", 1.0), ] EVENTS_AT_1 = [ - event_row("user", "what is visible?", "vqa"), - event_row("assistant", '{"count": 2}', "vqa"), + event_row("user", "what is visible?", "vqa", camera="observation.images.top"), + event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"), ] EVENTS_AT_2 = [ event_row("user", "skip wiping", "interjection"), @@ -47,6 +49,15 @@ EVENTS_AT_2 = [ [{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}], ), ] +# Same emission tick, two cameras: triggers per-camera disambiguation in +# resolvers, mirroring how Module 3 of the annotation pipeline writes one +# (vqa, user) + (vqa, assistant) pair per camera. +EVENTS_AT_3_TWO_CAMERAS = [ + event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"), + event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"), + event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"), + event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"), +] def test_resolver_temporal_semantics(): @@ -158,6 +169,126 @@ def test_deterministic_blend_sampling(): assert first == second +def test_emitted_at_filters_vqa_by_camera(): + top = emitted_at( + 3.0, + persistent=PERSISTENT, + events=EVENTS_AT_3_TWO_CAMERAS, + style="vqa", + role="assistant", + camera="observation.images.top", + ) + wrist = emitted_at( + 3.0, + persistent=PERSISTENT, + events=EVENTS_AT_3_TWO_CAMERAS, + style="vqa", + role="assistant", + camera="observation.images.wrist", + ) + assert top["content"] == '{"count": 3}' + assert wrist["content"] == '{"count": 1}' + + +def test_emitted_at_raises_on_ambiguous_per_camera_vqa(): + with pytest.raises(ValueError, match="Ambiguous resolver"): + emitted_at( + 3.0, + persistent=PERSISTENT, + events=EVENTS_AT_3_TWO_CAMERAS, + style="vqa", + role="assistant", + ) + + +def test_per_camera_blend_renders_both_views(): + recipe = TrainingRecipe( + blend={ + "top": TrainingRecipe( + weight=1.0, + bindings={ + "vqa_query": ( + "emitted_at(t, style=vqa, role=user, camera=observation.images.top)" + ), + "vqa": ( + "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)" + ), + }, + messages=[ + MessageTurn( + role="user", + content=[ + {"type": "image", "feature": "observation.images.top"}, + {"type": "text", "text": "${vqa_query}"}, + ], + stream="high_level", + if_present="vqa_query", + ), + MessageTurn( + role="assistant", + content="${vqa}", + stream="high_level", + target=True, + if_present="vqa", + ), + ], + ), + "wrist": TrainingRecipe( + weight=1.0, + bindings={ + "vqa_query": ( + "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)" + ), + "vqa": ( + "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)" + ), + }, + messages=[ + MessageTurn( + role="user", + content=[ + {"type": "image", "feature": "observation.images.wrist"}, + {"type": "text", "text": "${vqa_query}"}, + ], + stream="high_level", + if_present="vqa_query", + ), + MessageTurn( + role="assistant", + content="${vqa}", + stream="high_level", + target=True, + if_present="vqa", + ), + ], + ), + } + ) + + rendered_top = render_sample( + recipe=recipe.blend["top"], + persistent=PERSISTENT, + events=EVENTS_AT_3_TWO_CAMERAS, + t=3.0, + sample_idx=0, + ) + rendered_wrist = render_sample( + recipe=recipe.blend["wrist"], + persistent=PERSISTENT, + events=EVENTS_AT_3_TWO_CAMERAS, + t=3.0, + sample_idx=0, + ) + + assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top" + assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?" + assert rendered_top["messages"][1]["content"] == '{"count": 3}' + + assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist" + assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?" + assert rendered_wrist["messages"][1]["content"] == '{"count": 1}' + + def test_canonical_recipe_can_render_low_level_branch(): recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml")) low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]}) diff --git a/tests/processor/test_render_messages_processor.py b/tests/processor/test_render_messages_processor.py index c218b9152..ff808f38f 100644 --- a/tests/processor/test_render_messages_processor.py +++ b/tests/processor/test_render_messages_processor.py @@ -38,6 +38,7 @@ def test_render_messages_step_renders_and_drops_raw_language(): "content": "reach carefully", "style": "subtask", "timestamp": 0.0, + "camera": None, "tool_calls": None, } ],