From e8327b8e6289c7fd105f59090b10d9e34b216689 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 6 May 2026 13:15:45 +0200 Subject: [PATCH] refactor(language): unify resolver dispatch and prune redundant test scaffolding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`; only `emitted_at` actually consults events. The dispatcher in `_resolve_spec` now passes events conditionally. * Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a single `_row_sort_key` and drop the `sort_key` parameter from `_select_one`. Event rows lack `timestamp` (it is implicit in the frame) and now default to `0.0` for sort purposes — the `(style, role)` tiebreaker is unchanged. * Inline `_select_latest` into `active_at` (its only caller). * Collapse `emitted_at`'s dual-branch into one `_select_one` call. * Tighten `_validate_persistent_resolver` to a single `column_for_style(style) != LANGUAGE_PERSISTENT` check. * Parameterize `test_per_camera_blend_renders_both_views` over the two cameras and factor the sub-recipe builder into `_vqa_subrecipe` so the test no longer hand-rolls two near-identical recipe blocks. Net -98 LOC; behavior, public resolver names, and test expectations unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/datasets/language_render.py | 186 ++++++++---------------- tests/datasets/test_language_render.py | 102 +++++-------- 2 files changed, 95 insertions(+), 193 deletions(-) diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index 1f4ed2749..da1e6d81e 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -24,12 +24,7 @@ from typing import Any from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe -from .language import ( - EVENT_ONLY_STYLES, - LANGUAGE_PERSISTENT, - PERSISTENT_STYLES, - column_for_style, -) +from .language import LANGUAGE_PERSISTENT, column_for_style LanguageRow = dict[str, Any] RenderedMessages = dict[str, list[Any]] @@ -42,7 +37,6 @@ def active_at( t: float, *, persistent: Sequence[LanguageRow], - events: Sequence[LanguageRow] | None = None, style: str | None = None, role: str | None = None, tool_name: str | None = None, @@ -52,16 +46,23 @@ def active_at( A persistent row is "active" at ``t`` when its own ``timestamp`` is the most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/ - ``camera`` selector. ``events`` is accepted for resolver-signature - uniformity but is not consulted: only persistent styles are valid here. + ``camera`` selector. Only valid for persistent styles. """ _validate_persistent_resolver("active_at", style) - matches = _matching_rows( - persistent, style=style, role=role, tool_name=tool_name, camera=camera - ) - matches = [row for row in matches if _timestamp(row) <= t] - return _select_latest( - matches, style=style, role=role, tool_name=tool_name, camera=camera + matches = [ + row + for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera) + if _timestamp(row) <= t + ] + if not matches: + return None + latest_ts = max(_timestamp(row) for row in matches) + return _select_one( + [row for row in matches if _timestamp(row) == latest_ts], + style=style, + role=role, + tool_name=tool_name, + camera=camera, ) @@ -84,41 +85,21 @@ def emitted_at( filters by the row's ``camera`` field — required to disambiguate when multiple view-dependent rows share ``(t, role)`` across cameras. """ - column = column_for_style(style) - if column == LANGUAGE_PERSISTENT: + if column_for_style(style) == LANGUAGE_PERSISTENT: matches = [ row - for row in _matching_rows( - persistent, style=style, role=role, tool_name=tool_name, camera=camera - ) + for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera) if _timestamp(row) == t ] - return _select_one( - matches, - style=style, - role=role, - tool_name=tool_name, - camera=camera, - sort_key=_persistent_sort_key, - ) - matches = _matching_rows( - events, style=style, role=role, tool_name=tool_name, camera=camera - ) - return _select_one( - matches, - style=style, - role=role, - tool_name=tool_name, - camera=camera, - sort_key=_event_sort_key, - ) + else: + matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera) + return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera) def nth_prev( t: float, *, persistent: Sequence[LanguageRow], - events: Sequence[LanguageRow] | None = None, style: str | None = None, offset: int = 1, role: str | None = None, @@ -132,23 +113,13 @@ def nth_prev( one ``offset`` positions before the row active at ``t``. Only valid for persistent styles. """ - return _nth_relative( - t, - persistent=persistent, - style=style, - offset=-offset, - role=role, - tool_name=tool_name, - camera=camera, - resolver_name="nth_prev", - ) + return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera) def nth_next( t: float, *, persistent: Sequence[LanguageRow], - events: Sequence[LanguageRow] | None = None, style: str | None = None, offset: int = 1, role: str | None = None, @@ -162,16 +133,7 @@ def nth_next( one ``offset`` positions after the row active at ``t``. Only valid for persistent styles. """ - return _nth_relative( - t, - persistent=persistent, - style=style, - offset=offset, - role=role, - tool_name=tool_name, - camera=camera, - resolver_name="nth_next", - ) + return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera) def render_sample( @@ -239,9 +201,7 @@ def _resolve_bindings( ) -> dict[str, LanguageRow | str | None]: """Resolve every binding in ``recipe`` (plus ``task``) at time ``t``.""" bindings: dict[str, LanguageRow | str | None] = { - "task": _resolve_task( - task, dataset_ctx, persistent=persistent, sample_idx=sample_idx - ), + "task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx), } specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})} for name, spec in specs.items(): @@ -275,18 +235,12 @@ def _resolve_task( if task is not None: return task - aug_rows = [ - r - for r in persistent - if r.get("style") == "task_aug" and r.get("role") == "user" - ] + aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"] if aug_rows: # Deterministic, blake2b-based pick keyed on sample_idx so the # rotation is reproducible across runs (Python's built-in ``hash`` # is process-randomized). - digest = hashlib.blake2b( - f"task_aug:{sample_idx}".encode(), digest_size=8 - ).digest() + digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest() idx = int.from_bytes(digest, "big") % len(aug_rows) chosen = aug_rows[idx].get("content") if chosen: @@ -314,15 +268,15 @@ def _resolve_spec( kwargs = _parse_resolver_args(match.group("args")) kwargs.pop("t_arg", None) - resolvers = { - "active_at": active_at, - "emitted_at": emitted_at, - "nth_prev": nth_prev, - "nth_next": nth_next, - } - if name not in resolvers: - raise ValueError(f"Unknown language resolver: {name!r}") - return resolvers[name](t, persistent=persistent, events=events, **kwargs) + if name == "emitted_at": + return emitted_at(t, persistent=persistent, events=events, **kwargs) + if name == "active_at": + return active_at(t, persistent=persistent, **kwargs) + if name == "nth_prev": + return nth_prev(t, persistent=persistent, **kwargs) + if name == "nth_next": + return nth_next(t, persistent=persistent, **kwargs) + raise ValueError(f"Unknown language resolver: {name!r}") def _parse_resolver_args(args: str) -> dict[str, Any]: @@ -444,24 +398,23 @@ def _validate_rendered(rendered: RenderedMessages) -> None: def _nth_relative( + name: str, t: float, - *, persistent: Sequence[LanguageRow], style: str | None, offset: int, role: str | None, tool_name: str | None, camera: str | None, - resolver_name: str, ) -> LanguageRow | None: """Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``.""" - _validate_persistent_resolver(resolver_name, style) + _validate_persistent_resolver(name, style) if abs(offset) < 1: - raise ValueError(f"{resolver_name} offset must be non-zero.") + raise ValueError(f"{name} offset must be non-zero.") rows = sorted( _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera), - key=_persistent_sort_key, + key=_row_sort_key, ) if not rows: return None @@ -480,14 +433,12 @@ def _nth_relative( return rows[target_idx] -def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None: +def _validate_persistent_resolver(name: str, style: str | None) -> None: """Reject calls with missing or event-only ``style`` for persistent resolvers.""" if style is None: - raise ValueError(f"{resolver_name} requires a persistent style.") - if style in EVENT_ONLY_STYLES: - raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.") - if style not in PERSISTENT_STYLES: - column_for_style(style) + raise ValueError(f"{name} requires a persistent style.") + if column_for_style(style) != LANGUAGE_PERSISTENT: + raise ValueError(f"{name} cannot be used with event-only style {style!r}.") def _matching_rows( @@ -509,29 +460,6 @@ def _matching_rows( ] -def _select_latest( - rows: Sequence[LanguageRow], - *, - style: str | None, - role: str | None, - tool_name: str | None, - camera: str | None, -) -> LanguageRow | None: - """Return the row tied for the latest ``timestamp`` (disambiguated by selectors).""" - if not rows: - return None - rows = sorted(rows, key=_persistent_sort_key) - latest_ts = _timestamp(rows[-1]) - return _select_one( - [row for row in rows if _timestamp(row) == latest_ts], - style=style, - role=role, - tool_name=tool_name, - camera=camera, - sort_key=_persistent_sort_key, - ) - - def _select_one( rows: Sequence[LanguageRow], *, @@ -539,9 +467,13 @@ def _select_one( role: str | None, tool_name: str | None, camera: str | None, - sort_key: Any, ) -> LanguageRow | None: - """Return the single matching row, or raise if the selectors are ambiguous.""" + """Return the single matching row, or raise if the selectors are ambiguous. + + Ties are broken deterministically by ``_row_sort_key`` so that + multiple rows with identical ``(style, role, tool_name, camera)`` still + resolve to a stable choice. + """ if not rows: return None if len(rows) > 1 and role is None and tool_name is None and camera is None: @@ -549,17 +481,21 @@ def _select_one( f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., " f"or camera=... to disambiguate." ) - return sorted(rows, key=sort_key)[0] + return sorted(rows, key=_row_sort_key)[0] -def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]: - """Sort key for persistent rows: ``(timestamp, style, role)``.""" - return (_timestamp(row), row.get("style") or "", row.get("role") or "") +def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]: + """Stable sort key for both persistent and event rows. - -def _event_sort_key(row: LanguageRow) -> tuple[str, str]: - """Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame).""" - return (row.get("style") or "", row.get("role") or "") + Event rows lack ``timestamp`` (it is implicit in the frame), so default + to ``0.0`` — within a single frame all event rows share the same sort + bucket and are tiebroken by ``(style, role)``. + """ + timestamp = row.get("timestamp") + ts = ( + float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0 + ) + return (ts, row.get("style") or "", row.get("role") or "") def _timestamp(row: LanguageRow) -> float: diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py index d8befecac..5aefea733 100644 --- a/tests/datasets/test_language_render.py +++ b/tests/datasets/test_language_render.py @@ -199,84 +199,50 @@ def test_emitted_at_raises_on_ambiguous_per_camera_vqa(): ) -def test_per_camera_blend_renders_both_views(): - recipe = TrainingRecipe( - blend={ - "top": TrainingRecipe( - weight=1.0, - bindings={ - "vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"), - "vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"), - }, - messages=[ - MessageTurn( - role="user", - content=[ - {"type": "image", "feature": "observation.images.top"}, - {"type": "text", "text": "${vqa_query}"}, - ], - stream="high_level", - if_present="vqa_query", - ), - MessageTurn( - role="assistant", - content="${vqa}", - stream="high_level", - target=True, - if_present="vqa", - ), - ], +def _vqa_subrecipe(camera: str) -> TrainingRecipe: + return TrainingRecipe( + weight=1.0, + bindings={ + "vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})", + "vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})", + }, + messages=[ + MessageTurn( + role="user", + content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}], + stream="high_level", + if_present="vqa_query", ), - "wrist": TrainingRecipe( - weight=1.0, - bindings={ - "vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"), - "vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"), - }, - messages=[ - MessageTurn( - role="user", - content=[ - {"type": "image", "feature": "observation.images.wrist"}, - {"type": "text", "text": "${vqa_query}"}, - ], - stream="high_level", - if_present="vqa_query", - ), - MessageTurn( - role="assistant", - content="${vqa}", - stream="high_level", - target=True, - if_present="vqa", - ), - ], + MessageTurn( + role="assistant", + content="${vqa}", + stream="high_level", + target=True, + if_present="vqa", ), - } + ], ) - rendered_top = render_sample( - recipe=recipe.blend["top"], - persistent=PERSISTENT, - events=EVENTS_AT_3_TWO_CAMERAS, - t=3.0, - sample_idx=0, - ) - rendered_wrist = render_sample( - recipe=recipe.blend["wrist"], + +@pytest.mark.parametrize( + ("camera", "expected_query", "expected_answer"), + [ + ("observation.images.top", "how many cups (top)?", '{"count": 3}'), + ("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'), + ], +) +def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer): + rendered = render_sample( + recipe=_vqa_subrecipe(camera), persistent=PERSISTENT, events=EVENTS_AT_3_TWO_CAMERAS, t=3.0, sample_idx=0, ) - assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top" - assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?" - assert rendered_top["messages"][1]["content"] == '{"count": 3}' - - assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist" - assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?" - assert rendered_wrist["messages"][1]["content"] == '{"count": 1}' + assert rendered["messages"][0]["content"][0]["feature"] == camera + assert rendered["messages"][0]["content"][1]["text"] == expected_query + assert rendered["messages"][1]["content"] == expected_answer def test_resolve_task_picks_rephrasing_deterministically_per_sample():