feat(language): per-camera tagging on view-dependent styles

Adds a nullable `camera` field to the language row struct (both persistent
and event variants) so view-dependent styles like `vqa` can carry which
`observation.images.*` view they were grounded against. Without this,
multi-camera datasets ended up with multiple `(vqa, role)` rows at the
same timestamp that the resolver could not disambiguate.

- `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS,
  to both Arrow struct types and the HF datasets feature mappings;
  introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus
  `is_view_dependent_style` and `validate_camera_field` helpers (camera
  required iff style is view-dependent).
- `language_render.py`: thread an optional `camera=` kwarg through every
  resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through
  `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera
  VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`.
  Without a `camera` filter, multi-row matches keep raising the existing
  ambiguity error — which is the desired behaviour on multi-camera data.
- `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with
  `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying
  the matching image block), keeping the original 0.20 budget and
  documenting the customization point for datasets with different cameras.
- Tests: schema test asserts the new field order; new tests cover
  `is_view_dependent_style`, `validate_camera_field` (both required and
  forbidden directions), per-camera `emitted_at` filtering, and the
  ambiguity error when two cameras emit `(vqa, assistant)` at the same
  timestamp without a `camera=` filter. RenderMessagesStep + dataset
  passthrough fixtures updated to include the new field.
- `docs/source/language_and_recipes.mdx`: document the `camera` field,
  the per-camera resolver pattern, and the canonical recipe convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-30 10:48:17 +02:00
parent 0b06790da0
commit 5a6aa64570
8 changed files with 344 additions and 33 deletions
+37 -3
View File
@@ -6,16 +6,24 @@ The two optional columns are:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`. - `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls. - `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape: Both columns share the same row shape (event rows omit `timestamp` because the
frame the row sits on already provides it):
```text ```text
role: string role: string
content: string | null content: string | null
style: string | null style: string | null
timestamp: float64 timestamp: float64 # persistent rows only
camera: string | null # observation.images.* feature key, view-dependent rows only
tool_calls: list[Json] | null tool_calls: list[Json] | null
``` ```
The `camera` field tags rows whose `content` is grounded in a specific camera
view. Rows of view-dependent styles (`vqa`, and the reserved `motion` /
`trace`) MUST set `camera` to the matching `observation.images.*` feature key.
Rows of every other style MUST leave `camera` as `null`. Pipeline writers and
the validator enforce this via `validate_camera_field(style, camera)`.
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations. `meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
## Architecture ## Architecture
@@ -39,11 +47,37 @@ Persistent styles are active after emission until replaced:
Event styles only exist on their exact timestamp: Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)` - `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user)` - `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
- `emitted_at(t, role=assistant, tool_name=say)` - `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data. Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
## View-dependent resolution
For view-dependent styles (`vqa`, `motion`, `trace`), the resolver gains a
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
camera at the same timestamp; without `camera=`, those resolvers see two
matches and raise an ambiguity error. Recipes consume each camera through its
own binding plus a matching image block, e.g.
```yaml
ask_vqa_top:
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- { type: image, feature: observation.images.top }
- { type: text, text: "${vqa_query}" }
- { role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa }
```
Add one such sub-recipe per camera the dataset records.
## Recipe anatomy ## Recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`.
+30 -3
View File
@@ -40,8 +40,35 @@ blend:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level} - {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: low_level, target: true} - {role: assistant, content: "${subtask}", stream: low_level, target: true}
ask_vqa: # VQA is view-dependent: bbox / keypoint / count answers only make sense for
weight: 0.20 # the camera they were grounded against. Each camera gets its own sub-recipe
# so the resolver can disambiguate via `camera=...` and the user-turn carries
# the matching image block. Adjust the camera keys (and add more sub-recipes)
# to match the cameras present on your dataset.
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages: messages:
- {role: user, content: "${vqa_query}", stream: high_level, if_present: vqa_query} - role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.top}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+35 -2
View File
@@ -24,8 +24,8 @@ import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent" LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events" LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS) LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls") PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls") EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"} CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"}
EXTENDED_STYLES = set() EXTENDED_STYLES = set()
@@ -34,6 +34,11 @@ STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"} PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"} EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
# feature key. Rows of every other style MUST have ``camera=None``.
VIEW_DEPENDENT_STYLES = {"vqa", "motion", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"] LanguageColumn = Literal["language_persistent", "language_events"]
@@ -59,6 +64,7 @@ def language_persistent_row_arrow_type() -> pa.StructType:
pa.field("content", pa.string(), nullable=True), pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True), pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False), pa.field("timestamp", pa.float64(), nullable=False),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True), pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
] ]
) )
@@ -75,6 +81,7 @@ def language_event_row_arrow_type() -> pa.StructType:
pa.field("role", pa.string(), nullable=False), pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True), pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True), pa.field("style", pa.string(), nullable=True),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True), pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
] ]
) )
@@ -97,6 +104,7 @@ def language_persistent_row_feature() -> dict[str, object]:
"content": datasets.Value("string"), "content": datasets.Value("string"),
"style": datasets.Value("string"), "style": datasets.Value("string"),
"timestamp": datasets.Value("float64"), "timestamp": datasets.Value("float64"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()), "tool_calls": datasets.List(_json_feature()),
} }
@@ -107,6 +115,7 @@ def language_event_row_feature() -> dict[str, object]:
"role": datasets.Value("string"), "role": datasets.Value("string"),
"content": datasets.Value("string"), "content": datasets.Value("string"),
"style": datasets.Value("string"), "style": datasets.Value("string"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()), "tool_calls": datasets.List(_json_feature()),
} }
@@ -134,6 +143,30 @@ def is_language_column(key: str) -> bool:
return key in LANGUAGE_COLUMNS return key in LANGUAGE_COLUMNS
def is_view_dependent_style(style: str | None) -> bool:
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
return style in VIEW_DEPENDENT_STYLES
def validate_camera_field(style: str | None, camera: str | None) -> None:
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
a non-view-dependent style carries one. Pipeline writers and the validator
should call this on every emitted row.
"""
if is_view_dependent_style(style):
if not camera:
raise ValueError(
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
f"field referencing an 'observation.images.*' feature key."
)
elif camera is not None:
raise ValueError(
f"Rows of style {style!r} must have camera=None; got camera={camera!r}."
)
def column_for_style(style: str | None) -> LanguageColumn: def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored. """Map a language style to the column where rows of that style are stored.
+55 -18
View File
@@ -46,18 +46,23 @@ def active_at(
style: str | None = None, style: str | None = None,
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``. """Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name`` most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
selector. ``events`` is accepted for resolver-signature uniformity but is ``camera`` selector. ``events`` is accepted for resolver-signature
not consulted: only persistent styles are valid here. uniformity but is not consulted: only persistent styles are valid here.
""" """
_validate_persistent_resolver("active_at", style) _validate_persistent_resolver("active_at", style)
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name) matches = _matching_rows(
persistent, style=style, role=role, tool_name=tool_name, camera=camera
)
matches = [row for row in matches if _timestamp(row) <= t] matches = [row for row in matches if _timestamp(row) <= t]
return _select_latest(matches, style=style, role=role, tool_name=tool_name) return _select_latest(
matches, style=style, role=role, tool_name=tool_name, camera=camera
)
def emitted_at( def emitted_at(
@@ -68,26 +73,45 @@ def emitted_at(
style: str | None = None, style: str | None = None,
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``. """Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp`` For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from equals ``t``. For event styles, the ``events`` list is assumed to come from
the dataset row at frame ``t`` (event rows carry no timestamp of their own), the dataset row at frame ``t`` (event rows carry no timestamp of their own),
so all matching event rows are considered emitted at ``t``. so all matching event rows are considered emitted at ``t``. ``camera``
filters by the row's ``camera`` field — required to disambiguate when
multiple view-dependent rows share ``(t, role)`` across cameras.
""" """
column = column_for_style(style) column = column_for_style(style)
if column == LANGUAGE_PERSISTENT: if column == LANGUAGE_PERSISTENT:
matches = [ matches = [
row row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name) for row in _matching_rows(
persistent, style=style, role=role, tool_name=tool_name, camera=camera
)
if _timestamp(row) == t if _timestamp(row) == t
] ]
return _select_one( return _select_one(
matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key matches,
style=style,
role=role,
tool_name=tool_name,
camera=camera,
sort_key=_persistent_sort_key,
) )
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name) matches = _matching_rows(
return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key) events, style=style, role=role, tool_name=tool_name, camera=camera
)
return _select_one(
matches,
style=style,
role=role,
tool_name=tool_name,
camera=camera,
sort_key=_event_sort_key,
)
def nth_prev( def nth_prev(
@@ -99,12 +123,14 @@ def nth_prev(
offset: int = 1, offset: int = 1,
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``. """Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style`` Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset`` (filtered by optional ``role``/``tool_name``/``camera``) and returns the
positions before the row active at ``t``. Only valid for persistent styles. one ``offset`` positions before the row active at ``t``. Only valid for
persistent styles.
""" """
return _nth_relative( return _nth_relative(
t, t,
@@ -113,6 +139,7 @@ def nth_prev(
offset=-offset, offset=-offset,
role=role, role=role,
tool_name=tool_name, tool_name=tool_name,
camera=camera,
resolver_name="nth_prev", resolver_name="nth_prev",
) )
@@ -126,12 +153,14 @@ def nth_next(
offset: int = 1, offset: int = 1,
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``. """Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style`` Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset`` (filtered by optional ``role``/``tool_name``/``camera``) and returns the
positions after the row active at ``t``. Only valid for persistent styles. one ``offset`` positions after the row active at ``t``. Only valid for
persistent styles.
""" """
return _nth_relative( return _nth_relative(
t, t,
@@ -140,6 +169,7 @@ def nth_next(
offset=offset, offset=offset,
role=role, role=role,
tool_name=tool_name, tool_name=tool_name,
camera=camera,
resolver_name="nth_next", resolver_name="nth_next",
) )
@@ -376,6 +406,7 @@ def _nth_relative(
offset: int, offset: int,
role: str | None, role: str | None,
tool_name: str | None, tool_name: str | None,
camera: str | None,
resolver_name: str, resolver_name: str,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``.""" """Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
@@ -384,7 +415,7 @@ def _nth_relative(
raise ValueError(f"{resolver_name} offset must be non-zero.") raise ValueError(f"{resolver_name} offset must be non-zero.")
rows = sorted( rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name), _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
key=_persistent_sort_key, key=_persistent_sort_key,
) )
if not rows: if not rows:
@@ -420,14 +451,16 @@ def _matching_rows(
style: str | None, style: str | None,
role: str | None, role: str | None,
tool_name: str | None, tool_name: str | None,
camera: str | None,
) -> list[LanguageRow]: ) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name`` selectors.""" """Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
return [ return [
row row
for row in rows for row in rows
if (style is None or row.get("style") == style) if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role) and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name)) and (tool_name is None or _row_has_tool_name(row, tool_name))
and (camera is None or row.get("camera") == camera)
] ]
@@ -437,6 +470,7 @@ def _select_latest(
style: str | None, style: str | None,
role: str | None, role: str | None,
tool_name: str | None, tool_name: str | None,
camera: str | None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors).""" """Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
if not rows: if not rows:
@@ -448,6 +482,7 @@ def _select_latest(
style=style, style=style,
role=role, role=role,
tool_name=tool_name, tool_name=tool_name,
camera=camera,
sort_key=_persistent_sort_key, sort_key=_persistent_sort_key,
) )
@@ -458,14 +493,16 @@ def _select_one(
style: str | None, style: str | None,
role: str | None, role: str | None,
tool_name: str | None, tool_name: str | None,
camera: str | None,
sort_key: Any, sort_key: Any,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the single matching row, or raise if the selectors are ambiguous.""" """Return the single matching row, or raise if the selectors are ambiguous."""
if not rows: if not rows:
return None return None
if len(rows) > 1 and role is None and tool_name is None: if len(rows) > 1 and role is None and tool_name is None and camera is None:
raise ValueError( raise ValueError(
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate." f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., "
f"or camera=... to disambiguate."
) )
return sorted(rows, key=sort_key)[0] return sorted(rows, key=sort_key)[0]
+2 -1
View File
@@ -26,6 +26,7 @@ def test_canonical_recipe_loads():
"user_interjection_response", "user_interjection_response",
"high_level_subtask", "high_level_subtask",
"low_level_execution", "low_level_execution",
"ask_vqa", "ask_vqa_top",
"ask_vqa_wrist",
} }
assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96) assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96)
+49 -2
View File
@@ -13,10 +13,13 @@ from lerobot.datasets.language import (
LANGUAGE_PERSISTENT, LANGUAGE_PERSISTENT,
PERSISTENT_STYLES, PERSISTENT_STYLES,
STYLE_REGISTRY, STYLE_REGISTRY,
VIEW_DEPENDENT_STYLES,
column_for_style, column_for_style,
is_view_dependent_style,
language_events_arrow_type, language_events_arrow_type,
language_feature_info, language_feature_info,
language_persistent_arrow_type, language_persistent_arrow_type,
validate_camera_field,
) )
from lerobot.datasets.utils import DEFAULT_DATA_PATH from lerobot.datasets.utils import DEFAULT_DATA_PATH
@@ -26,10 +29,17 @@ def test_language_arrow_schema_has_expected_fields():
event_row_type = language_events_arrow_type().value_type event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType) assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"] assert persistent_row_type.names == [
"role",
"content",
"style",
"timestamp",
"camera",
"tool_calls",
]
assert isinstance(event_row_type, pa.StructType) assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "tool_calls"] assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
def test_style_registry_routes_columns(): def test_style_registry_routes_columns():
@@ -47,6 +57,41 @@ def test_style_registry_routes_columns():
assert column_for_style(None) == LANGUAGE_EVENTS assert column_for_style(None) == LANGUAGE_EVENTS
def test_view_dependent_styles():
assert {"vqa", "motion", "trace"} == VIEW_DEPENDENT_STYLES
assert is_view_dependent_style("vqa")
assert is_view_dependent_style("motion")
assert is_view_dependent_style("trace")
assert not is_view_dependent_style("subtask")
assert not is_view_dependent_style("plan")
assert not is_view_dependent_style("interjection")
assert not is_view_dependent_style(None)
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
validate_camera_field("vqa", "observation.images.top")
validate_camera_field("motion", "observation.images.wrist")
validate_camera_field("trace", "observation.images.front")
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("vqa", None)
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("motion", "")
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
validate_camera_field("subtask", None)
validate_camera_field("plan", None)
validate_camera_field("memory", None)
validate_camera_field("interjection", None)
validate_camera_field(None, None)
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("subtask", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("interjection", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field(None, "observation.images.top")
def test_unknown_style_rejected(): def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"): with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise") column_for_style("surprise")
@@ -70,6 +115,7 @@ def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot
"content": "reach for the cup", "content": "reach for the cup",
"style": "subtask", "style": "subtask",
"timestamp": 0.0, "timestamp": 0.0,
"camera": None,
"tool_calls": None, "tool_calls": None,
} }
] ]
@@ -77,6 +123,7 @@ def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot
"role": "user", "role": "user",
"content": "what is visible?", "content": "what is visible?",
"style": "vqa", "style": "vqa",
"camera": "observation.images.top",
"tool_calls": None, "tool_calls": None,
} }
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
+135 -4
View File
@@ -8,21 +8,23 @@ from lerobot.configs.recipe import MessageTurn, TrainingRecipe
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
def persistent_row(role, content, style, timestamp, tool_calls=None): def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
return { return {
"role": role, "role": role,
"content": content, "content": content,
"style": style, "style": style,
"timestamp": timestamp, "timestamp": timestamp,
"camera": camera,
"tool_calls": tool_calls, "tool_calls": tool_calls,
} }
def event_row(role, content, style, tool_calls=None): def event_row(role, content, style, tool_calls=None, camera=None):
return { return {
"role": role, "role": role,
"content": content, "content": content,
"style": style, "style": style,
"camera": camera,
"tool_calls": tool_calls, "tool_calls": tool_calls,
} }
@@ -35,8 +37,8 @@ PERSISTENT = [
persistent_row("assistant", "subtask 1", "subtask", 1.0), persistent_row("assistant", "subtask 1", "subtask", 1.0),
] ]
EVENTS_AT_1 = [ EVENTS_AT_1 = [
event_row("user", "what is visible?", "vqa"), event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 2}', "vqa"), event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
] ]
EVENTS_AT_2 = [ EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"), event_row("user", "skip wiping", "interjection"),
@@ -47,6 +49,15 @@ EVENTS_AT_2 = [
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}], [{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
), ),
] ]
# Same emission tick, two cameras: triggers per-camera disambiguation in
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
# (vqa, user) + (vqa, assistant) pair per camera.
EVENTS_AT_3_TWO_CAMERAS = [
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
]
def test_resolver_temporal_semantics(): def test_resolver_temporal_semantics():
@@ -158,6 +169,126 @@ def test_deterministic_blend_sampling():
assert first == second assert first == second
def test_emitted_at_filters_vqa_by_camera():
top = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.top",
)
wrist = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.wrist",
)
assert top["content"] == '{"count": 3}'
assert wrist["content"] == '{"count": 1}'
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
with pytest.raises(ValueError, match="Ambiguous resolver"):
emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
)
def test_per_camera_blend_renders_both_views():
recipe = TrainingRecipe(
blend={
"top": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
),
},
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "${vqa_query}"},
],
stream="high_level",
if_present="vqa_query",
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
"wrist": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
),
},
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.wrist"},
{"type": "text", "text": "${vqa_query}"},
],
stream="high_level",
if_present="vqa_query",
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
}
)
rendered_top = render_sample(
recipe=recipe.blend["top"],
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
t=3.0,
sample_idx=0,
)
rendered_wrist = render_sample(
recipe=recipe.blend["wrist"],
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
t=3.0,
sample_idx=0,
)
assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top"
assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?"
assert rendered_top["messages"][1]["content"] == '{"count": 3}'
assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist"
assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?"
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
def test_canonical_recipe_can_render_low_level_branch(): def test_canonical_recipe_can_render_low_level_branch():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml")) recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]}) low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
@@ -38,6 +38,7 @@ def test_render_messages_step_renders_and_drops_raw_language():
"content": "reach carefully", "content": "reach carefully",
"style": "subtask", "style": "subtask",
"timestamp": 0.0, "timestamp": 0.0,
"camera": None,
"tool_calls": None, "tool_calls": None,
} }
], ],