diff --git a/src/lerobot/configs/recipe.py b/src/lerobot/configs/recipe.py index 27d4c5b8b..28e5a0db3 100644 --- a/src/lerobot/configs/recipe.py +++ b/src/lerobot/configs/recipe.py @@ -64,7 +64,17 @@ class MessageTurn: """Validate role, stream, and content after dataclass construction.""" if self.role not in _VALID_ROLES: raise ValueError(f"Unsupported message role: {self.role!r}") - if self.stream is not None and self.stream not in _VALID_STREAMS: + # ``stream`` is typed Optional only so the dataclass can keep its + # field ordering, but recipes must always tag every turn with a + # stream — the renderer's ``_validate_rendered`` would reject + # ``None`` later on. Fail at construction so the bad recipe is + # caught at YAML load time rather than at the first sample. + if self.stream is None: + raise ValueError( + f"MessageTurn(role={self.role!r}) is missing a stream — " + f"every turn must declare one of {sorted(_VALID_STREAMS)}." + ) + if self.stream not in _VALID_STREAMS: raise ValueError(f"Unsupported message stream: {self.stream!r}") if self.content is None and self.tool_calls_from is None: raise ValueError("MessageTurn.content is required unless tool_calls_from is set.") diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index 8b1a364a0..bc99ef0fd 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -65,6 +65,16 @@ def active_at( ) +EMITTED_AT_TOLERANCE_S = 0.1 +"""Half-window for matching persistent rows to a frame timestamp in +``emitted_at``. Persistent timestamps come from parquet (float64) and ``t`` +is also a float64 from parquet, so in the ideal hot path an exact match +would suffice — but any caller that derives ``t`` arithmetically (e.g. +``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers +common arithmetic drift without admitting frames that are visibly far +apart at typical control rates (30–100 Hz).""" + + def emitted_at( t: float, *, @@ -78,17 +88,19 @@ def emitted_at( """Return the row of ``style`` emitted at exactly time ``t``. For persistent styles, this matches persistent rows whose own ``timestamp`` - equals ``t``. For event styles, the ``events`` list is assumed to come from - the dataset row at frame ``t`` (event rows carry no timestamp of their own), - so all matching event rows are considered emitted at ``t``. ``camera`` - filters by the row's ``camera`` field — required to disambiguate when - multiple view-dependent rows share ``(t, role)`` across cameras. + is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why + we use a tolerance instead of bit-equality). For event styles, the + ``events`` list is assumed to come from the dataset row at frame ``t`` + (event rows carry no timestamp of their own), so all matching event rows + are considered emitted at ``t``. ``camera`` filters by the row's + ``camera`` field — required to disambiguate when multiple view-dependent + rows share ``(t, role)`` across cameras. """ if column_for_style(style) == LANGUAGE_PERSISTENT: matches = [ row for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera) - if _timestamp(row) == t + if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S ] else: matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera) @@ -391,9 +403,9 @@ def _validate_rendered(rendered: RenderedMessages) -> None: for idx in target_indices: if idx < 0 or idx >= len(messages): raise ValueError(f"Target message index {idx} is out of bounds.") - for idx, stream in enumerate(streams): - if stream is None: - raise ValueError(f"Rendered message {idx} has no stream.") + # ``stream`` is enforced non-None at MessageTurn construction time + # (see ``MessageTurn.__post_init__``), so a missing stream here would + # mean the dataclass invariant was bypassed; no need to re-check. def _nth_relative( diff --git a/tests/configs/test_recipe.py b/tests/configs/test_recipe.py index 332e40333..b4954efbf 100644 --- a/tests/configs/test_recipe.py +++ b/tests/configs/test_recipe.py @@ -29,6 +29,17 @@ def test_message_recipe_validates_unknown_binding(): ) +def test_message_turn_requires_a_stream(): + """Every turn must declare a stream — None is rejected at construction. + + Previously this only failed at render time (``_validate_rendered``); + catching it here means a malformed recipe YAML errors at load instead + of at the first training sample. + """ + with pytest.raises(ValueError, match="missing a stream"): + MessageTurn(role="user", content="${task}") + + def test_message_recipe_requires_at_least_one_target(): with pytest.raises(ValueError, match="target"): TrainingRecipe( diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py index a401f8aad..fcef41fd8 100644 --- a/tests/datasets/test_language_render.py +++ b/tests/datasets/test_language_render.py @@ -6,6 +6,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402 from lerobot.datasets.language_render import ( # noqa: E402 + EMITTED_AT_TOLERANCE_S, active_at, emitted_at, nth_next, @@ -342,6 +343,21 @@ def test_resolve_task_explicit_override_beats_rephrasings(): assert rendered["messages"][0]["content"] == "explicit override wins" +def test_emitted_at_persistent_tolerates_small_timestamp_drift(): + """Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S + so callers that derive ``t`` arithmetically (``frame_idx / fps``) still + line up with the parquet-stored timestamp. + """ + rows = [persistent_row("assistant", "memo", "memory", 1.0)] + # Half a tolerance window — bit-different float, comfortably inside + inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory") + assert inside is not None and inside["content"] == "memo" + + # Just past the window — no match + outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory") + assert outside is None + + def test_render_sample_rejects_non_dict_language_rows(): """``_normalize_rows`` must surface malformed inputs as TypeError.