review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction

* **Float tolerance in `emitted_at` for persistent styles.** The
  ``_timestamp(row) == t`` exact-equality check silently missed any
  caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``)
  even though the parquet timestamp would only differ by ULPs. Added
  ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance``
  instead, with a docstring explaining why exact equality wasn't
  enough and why 0.1 s is safe at typical 30–100 Hz control rates.
  Test asserts the new behavior at half-window (matches) and
  double-window (no match) using the constant so it stays in sync.

* **`MessageTurn.stream` is required at construction.** It was typed
  ``MessageStream | None = None`` so YAML could omit ``stream:`` and
  pass the dataclass invariant — but ``_validate_rendered`` rejected
  ``None`` streams later, surfacing the error at the first sample
  instead of at recipe load. Now ``__post_init__`` raises
  ``ValueError`` if ``stream`` is ``None``, with the list of valid
  streams in the message. The redundant late-stage check in
  ``_validate_rendered`` is replaced with a one-line comment that
  cites the upstream invariant. Test pins the new construction-time
  rejection.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-06 19:55:08 +02:00
parent beb22afd81
commit e7e5fca5de
4 changed files with 59 additions and 10 deletions
+11 -1
View File
@@ -64,7 +64,17 @@ class MessageTurn:
"""Validate role, stream, and content after dataclass construction.""" """Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES: if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}") raise ValueError(f"Unsupported message role: {self.role!r}")
if self.stream is not None and self.stream not in _VALID_STREAMS: # ``stream`` is typed Optional only so the dataclass can keep its
# field ordering, but recipes must always tag every turn with a
# stream — the renderer's ``_validate_rendered`` would reject
# ``None`` later on. Fail at construction so the bad recipe is
# caught at YAML load time rather than at the first sample.
if self.stream is None:
raise ValueError(
f"MessageTurn(role={self.role!r}) is missing a stream — "
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
)
if self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}") raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None: if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.") raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
+21 -9
View File
@@ -65,6 +65,16 @@ def active_at(
) )
EMITTED_AT_TOLERANCE_S = 0.1
"""Half-window for matching persistent rows to a frame timestamp in
``emitted_at``. Persistent timestamps come from parquet (float64) and ``t``
is also a float64 from parquet, so in the ideal hot path an exact match
would suffice — but any caller that derives ``t`` arithmetically (e.g.
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
common arithmetic drift without admitting frames that are visibly far
apart at typical control rates (30100 Hz)."""
def emitted_at( def emitted_at(
t: float, t: float,
*, *,
@@ -78,17 +88,19 @@ def emitted_at(
"""Return the row of ``style`` emitted at exactly time ``t``. """Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp`` For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
the dataset row at frame ``t`` (event rows carry no timestamp of their own), we use a tolerance instead of bit-equality). For event styles, the
so all matching event rows are considered emitted at ``t``. ``camera`` ``events`` list is assumed to come from the dataset row at frame ``t``
filters by the row's ``camera`` field — required to disambiguate when (event rows carry no timestamp of their own), so all matching event rows
multiple view-dependent rows share ``(t, role)`` across cameras. are considered emitted at ``t``. ``camera`` filters by the row's
``camera`` field — required to disambiguate when multiple view-dependent
rows share ``(t, role)`` across cameras.
""" """
if column_for_style(style) == LANGUAGE_PERSISTENT: if column_for_style(style) == LANGUAGE_PERSISTENT:
matches = [ matches = [
row row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera) for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
if _timestamp(row) == t if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
] ]
else: else:
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera) matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
@@ -391,9 +403,9 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
for idx in target_indices: for idx in target_indices:
if idx < 0 or idx >= len(messages): if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.") raise ValueError(f"Target message index {idx} is out of bounds.")
for idx, stream in enumerate(streams): # ``stream`` is enforced non-None at MessageTurn construction time
if stream is None: # (see ``MessageTurn.__post_init__``), so a missing stream here would
raise ValueError(f"Rendered message {idx} has no stream.") # mean the dataclass invariant was bypassed; no need to re-check.
def _nth_relative( def _nth_relative(
+11
View File
@@ -29,6 +29,17 @@ def test_message_recipe_validates_unknown_binding():
) )
def test_message_turn_requires_a_stream():
"""Every turn must declare a stream — None is rejected at construction.
Previously this only failed at render time (``_validate_rendered``);
catching it here means a malformed recipe YAML errors at load instead
of at the first training sample.
"""
with pytest.raises(ValueError, match="missing a stream"):
MessageTurn(role="user", content="${task}")
def test_message_recipe_requires_at_least_one_target(): def test_message_recipe_requires_at_least_one_target():
with pytest.raises(ValueError, match="target"): with pytest.raises(ValueError, match="target"):
TrainingRecipe( TrainingRecipe(
+16
View File
@@ -6,6 +6,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402 from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
from lerobot.datasets.language_render import ( # noqa: E402 from lerobot.datasets.language_render import ( # noqa: E402
EMITTED_AT_TOLERANCE_S,
active_at, active_at,
emitted_at, emitted_at,
nth_next, nth_next,
@@ -342,6 +343,21 @@ def test_resolve_task_explicit_override_beats_rephrasings():
assert rendered["messages"][0]["content"] == "explicit override wins" assert rendered["messages"][0]["content"] == "explicit override wins"
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
line up with the parquet-stored timestamp.
"""
rows = [persistent_row("assistant", "memo", "memory", 1.0)]
# Half a tolerance window — bit-different float, comfortably inside
inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory")
assert inside is not None and inside["content"] == "memo"
# Just past the window — no match
outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory")
assert outside is None
def test_render_sample_rejects_non_dict_language_rows(): def test_render_sample_rejects_non_dict_language_rows():
"""``_normalize_rows`` must surface malformed inputs as TypeError. """``_normalize_rows`` must surface malformed inputs as TypeError.