From e7e5fca5de0643e52a42f5ee06e7e3964300cda8 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 6 May 2026 19:55:08 +0200 Subject: [PATCH] review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * **Float tolerance in `emitted_at` for persistent styles.** The ``_timestamp(row) == t`` exact-equality check silently missed any caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``) even though the parquet timestamp would only differ by ULPs. Added ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance`` instead, with a docstring explaining why exact equality wasn't enough and why 0.1 s is safe at typical 30–100 Hz control rates. Test asserts the new behavior at half-window (matches) and double-window (no match) using the constant so it stays in sync. * **`MessageTurn.stream` is required at construction.** It was typed ``MessageStream | None = None`` so YAML could omit ``stream:`` and pass the dataclass invariant — but ``_validate_rendered`` rejected ``None`` streams later, surfacing the error at the first sample instead of at recipe load. Now ``__post_init__`` raises ``ValueError`` if ``stream`` is ``None``, with the list of valid streams in the message. The redundant late-stage check in ``_validate_rendered`` is replaced with a one-line comment that cites the upstream invariant. Test pins the new construction-time rejection. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/configs/recipe.py | 12 +++++++++- src/lerobot/datasets/language_render.py | 30 +++++++++++++++++-------- tests/configs/test_recipe.py | 11 +++++++++ tests/datasets/test_language_render.py | 16 +++++++++++++ 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/src/lerobot/configs/recipe.py b/src/lerobot/configs/recipe.py index 27d4c5b8b..28e5a0db3 100644 --- a/src/lerobot/configs/recipe.py +++ b/src/lerobot/configs/recipe.py @@ -64,7 +64,17 @@ class MessageTurn: """Validate role, stream, and content after dataclass construction.""" if self.role not in _VALID_ROLES: raise ValueError(f"Unsupported message role: {self.role!r}") - if self.stream is not None and self.stream not in _VALID_STREAMS: + # ``stream`` is typed Optional only so the dataclass can keep its + # field ordering, but recipes must always tag every turn with a + # stream — the renderer's ``_validate_rendered`` would reject + # ``None`` later on. Fail at construction so the bad recipe is + # caught at YAML load time rather than at the first sample. + if self.stream is None: + raise ValueError( + f"MessageTurn(role={self.role!r}) is missing a stream — " + f"every turn must declare one of {sorted(_VALID_STREAMS)}." + ) + if self.stream not in _VALID_STREAMS: raise ValueError(f"Unsupported message stream: {self.stream!r}") if self.content is None and self.tool_calls_from is None: raise ValueError("MessageTurn.content is required unless tool_calls_from is set.") diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index 8b1a364a0..bc99ef0fd 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -65,6 +65,16 @@ def active_at( ) +EMITTED_AT_TOLERANCE_S = 0.1 +"""Half-window for matching persistent rows to a frame timestamp in +``emitted_at``. Persistent timestamps come from parquet (float64) and ``t`` +is also a float64 from parquet, so in the ideal hot path an exact match +would suffice — but any caller that derives ``t`` arithmetically (e.g. +``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers +common arithmetic drift without admitting frames that are visibly far +apart at typical control rates (30–100 Hz).""" + + def emitted_at( t: float, *, @@ -78,17 +88,19 @@ def emitted_at( """Return the row of ``style`` emitted at exactly time ``t``. For persistent styles, this matches persistent rows whose own ``timestamp`` - equals ``t``. For event styles, the ``events`` list is assumed to come from - the dataset row at frame ``t`` (event rows carry no timestamp of their own), - so all matching event rows are considered emitted at ``t``. ``camera`` - filters by the row's ``camera`` field — required to disambiguate when - multiple view-dependent rows share ``(t, role)`` across cameras. + is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why + we use a tolerance instead of bit-equality). For event styles, the + ``events`` list is assumed to come from the dataset row at frame ``t`` + (event rows carry no timestamp of their own), so all matching event rows + are considered emitted at ``t``. ``camera`` filters by the row's + ``camera`` field — required to disambiguate when multiple view-dependent + rows share ``(t, role)`` across cameras. """ if column_for_style(style) == LANGUAGE_PERSISTENT: matches = [ row for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera) - if _timestamp(row) == t + if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S ] else: matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera) @@ -391,9 +403,9 @@ def _validate_rendered(rendered: RenderedMessages) -> None: for idx in target_indices: if idx < 0 or idx >= len(messages): raise ValueError(f"Target message index {idx} is out of bounds.") - for idx, stream in enumerate(streams): - if stream is None: - raise ValueError(f"Rendered message {idx} has no stream.") + # ``stream`` is enforced non-None at MessageTurn construction time + # (see ``MessageTurn.__post_init__``), so a missing stream here would + # mean the dataclass invariant was bypassed; no need to re-check. def _nth_relative( diff --git a/tests/configs/test_recipe.py b/tests/configs/test_recipe.py index 332e40333..b4954efbf 100644 --- a/tests/configs/test_recipe.py +++ b/tests/configs/test_recipe.py @@ -29,6 +29,17 @@ def test_message_recipe_validates_unknown_binding(): ) +def test_message_turn_requires_a_stream(): + """Every turn must declare a stream — None is rejected at construction. + + Previously this only failed at render time (``_validate_rendered``); + catching it here means a malformed recipe YAML errors at load instead + of at the first training sample. + """ + with pytest.raises(ValueError, match="missing a stream"): + MessageTurn(role="user", content="${task}") + + def test_message_recipe_requires_at_least_one_target(): with pytest.raises(ValueError, match="target"): TrainingRecipe( diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py index a401f8aad..fcef41fd8 100644 --- a/tests/datasets/test_language_render.py +++ b/tests/datasets/test_language_render.py @@ -6,6 +6,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402 from lerobot.datasets.language_render import ( # noqa: E402 + EMITTED_AT_TOLERANCE_S, active_at, emitted_at, nth_next, @@ -342,6 +343,21 @@ def test_resolve_task_explicit_override_beats_rephrasings(): assert rendered["messages"][0]["content"] == "explicit override wins" +def test_emitted_at_persistent_tolerates_small_timestamp_drift(): + """Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S + so callers that derive ``t`` arithmetically (``frame_idx / fps``) still + line up with the parquet-stored timestamp. + """ + rows = [persistent_row("assistant", "memo", "memory", 1.0)] + # Half a tolerance window — bit-different float, comfortably inside + inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory") + assert inside is not None and inside["content"] == "memo" + + # Just past the window — no match + outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory") + assert outside is None + + def test_render_sample_rejects_non_dict_language_rows(): """``_normalize_rows`` must surface malformed inputs as TypeError.