mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction
* **Float tolerance in `emitted_at` for persistent styles.** The ``_timestamp(row) == t`` exact-equality check silently missed any caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``) even though the parquet timestamp would only differ by ULPs. Added ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance`` instead, with a docstring explaining why exact equality wasn't enough and why 0.1 s is safe at typical 30–100 Hz control rates. Test asserts the new behavior at half-window (matches) and double-window (no match) using the constant so it stays in sync. * **`MessageTurn.stream` is required at construction.** It was typed ``MessageStream | None = None`` so YAML could omit ``stream:`` and pass the dataclass invariant — but ``_validate_rendered`` rejected ``None`` streams later, surfacing the error at the first sample instead of at recipe load. Now ``__post_init__`` raises ``ValueError`` if ``stream`` is ``None``, with the list of valid streams in the message. The redundant late-stage check in ``_validate_rendered`` is replaced with a one-line comment that cites the upstream invariant. Test pins the new construction-time rejection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -64,7 +64,17 @@ class MessageTurn:
|
|||||||
"""Validate role, stream, and content after dataclass construction."""
|
"""Validate role, stream, and content after dataclass construction."""
|
||||||
if self.role not in _VALID_ROLES:
|
if self.role not in _VALID_ROLES:
|
||||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||||
if self.stream is not None and self.stream not in _VALID_STREAMS:
|
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||||
|
# field ordering, but recipes must always tag every turn with a
|
||||||
|
# stream — the renderer's ``_validate_rendered`` would reject
|
||||||
|
# ``None`` later on. Fail at construction so the bad recipe is
|
||||||
|
# caught at YAML load time rather than at the first sample.
|
||||||
|
if self.stream is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||||
|
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||||
|
)
|
||||||
|
if self.stream not in _VALID_STREAMS:
|
||||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||||
if self.content is None and self.tool_calls_from is None:
|
if self.content is None and self.tool_calls_from is None:
|
||||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||||
|
|||||||
@@ -65,6 +65,16 @@ def active_at(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EMITTED_AT_TOLERANCE_S = 0.1
|
||||||
|
"""Half-window for matching persistent rows to a frame timestamp in
|
||||||
|
``emitted_at``. Persistent timestamps come from parquet (float64) and ``t``
|
||||||
|
is also a float64 from parquet, so in the ideal hot path an exact match
|
||||||
|
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||||
|
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||||
|
common arithmetic drift without admitting frames that are visibly far
|
||||||
|
apart at typical control rates (30–100 Hz)."""
|
||||||
|
|
||||||
|
|
||||||
def emitted_at(
|
def emitted_at(
|
||||||
t: float,
|
t: float,
|
||||||
*,
|
*,
|
||||||
@@ -78,17 +88,19 @@ def emitted_at(
|
|||||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||||
|
|
||||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||||
equals ``t``. For event styles, the ``events`` list is assumed to come from
|
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||||
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
|
we use a tolerance instead of bit-equality). For event styles, the
|
||||||
so all matching event rows are considered emitted at ``t``. ``camera``
|
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||||
filters by the row's ``camera`` field — required to disambiguate when
|
(event rows carry no timestamp of their own), so all matching event rows
|
||||||
multiple view-dependent rows share ``(t, role)`` across cameras.
|
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||||
|
``camera`` field — required to disambiguate when multiple view-dependent
|
||||||
|
rows share ``(t, role)`` across cameras.
|
||||||
"""
|
"""
|
||||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||||
matches = [
|
matches = [
|
||||||
row
|
row
|
||||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
if _timestamp(row) == t
|
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
@@ -391,9 +403,9 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
|||||||
for idx in target_indices:
|
for idx in target_indices:
|
||||||
if idx < 0 or idx >= len(messages):
|
if idx < 0 or idx >= len(messages):
|
||||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||||
for idx, stream in enumerate(streams):
|
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||||
if stream is None:
|
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||||
raise ValueError(f"Rendered message {idx} has no stream.")
|
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||||
|
|
||||||
|
|
||||||
def _nth_relative(
|
def _nth_relative(
|
||||||
|
|||||||
@@ -29,6 +29,17 @@ def test_message_recipe_validates_unknown_binding():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_turn_requires_a_stream():
|
||||||
|
"""Every turn must declare a stream — None is rejected at construction.
|
||||||
|
|
||||||
|
Previously this only failed at render time (``_validate_rendered``);
|
||||||
|
catching it here means a malformed recipe YAML errors at load instead
|
||||||
|
of at the first training sample.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError, match="missing a stream"):
|
||||||
|
MessageTurn(role="user", content="${task}")
|
||||||
|
|
||||||
|
|
||||||
def test_message_recipe_requires_at_least_one_target():
|
def test_message_recipe_requires_at_least_one_target():
|
||||||
with pytest.raises(ValueError, match="target"):
|
with pytest.raises(ValueError, match="target"):
|
||||||
TrainingRecipe(
|
TrainingRecipe(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
|||||||
|
|
||||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
||||||
from lerobot.datasets.language_render import ( # noqa: E402
|
from lerobot.datasets.language_render import ( # noqa: E402
|
||||||
|
EMITTED_AT_TOLERANCE_S,
|
||||||
active_at,
|
active_at,
|
||||||
emitted_at,
|
emitted_at,
|
||||||
nth_next,
|
nth_next,
|
||||||
@@ -342,6 +343,21 @@ def test_resolve_task_explicit_override_beats_rephrasings():
|
|||||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||||
|
|
||||||
|
|
||||||
|
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
|
||||||
|
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
|
||||||
|
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
|
||||||
|
line up with the parquet-stored timestamp.
|
||||||
|
"""
|
||||||
|
rows = [persistent_row("assistant", "memo", "memory", 1.0)]
|
||||||
|
# Half a tolerance window — bit-different float, comfortably inside
|
||||||
|
inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory")
|
||||||
|
assert inside is not None and inside["content"] == "memo"
|
||||||
|
|
||||||
|
# Just past the window — no match
|
||||||
|
outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory")
|
||||||
|
assert outside is None
|
||||||
|
|
||||||
|
|
||||||
def test_render_sample_rejects_non_dict_language_rows():
|
def test_render_sample_rejects_non_dict_language_rows():
|
||||||
"""``_normalize_rows`` must surface malformed inputs as TypeError.
|
"""``_normalize_rows`` must surface malformed inputs as TypeError.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user