mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
review: dedupe regex, centralize column names, harden collate, more tests
* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
`recipe.py` and `language_render.py`. Promote to module-level
`PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
hardcoded `{"language_persistent", "language_events"}` literals at
two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
silently filtered language fields from samples that didn't have
them, which would hand downstream consumers a preserved list
shorter than the tensor batch. Now: if any sample carries a key,
every sample in the batch must carry it; otherwise raise a
`ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
or multi-element list fell through and triggered confusing
`float([])` errors downstream. Now raises `ValueError` with the
actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
of `key = {... if ... else {}}` plus an 11-line splat dict with a
single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
comment. Add a docstring explaining it's an intentional extension
point: downstream modules append project-local styles before
`column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
page referenced `src/lerobot/tools/`, `registry.py`, and
`get_tools(meta)` — none exist in this PR. Added a callout at the
start of "How to add your own tool" plus a note on the
implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
validation.** `test_recipe.py` grew from 1 case to 12 covering:
blend-or-messages exclusivity, target-turn requirement, blend
emptiness, weight presence/positivity, nested-blend rejection,
`from_dict` with nested blends, `from_yaml` / `load_recipe`
agreement, top-level non-mapping rejection. Added a malformed-row
test for `_normalize_rows` that asserts non-dict entries raise
`TypeError`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,8 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
|
||||
|
||||
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
|
||||
return MessageTurn(role="user", content=content, stream="high_level")
|
||||
|
||||
|
||||
def _minimal_target_turn() -> MessageTurn:
|
||||
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
|
||||
|
||||
|
||||
# ── Message-recipe validation ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_message_recipe_validates_unknown_binding():
|
||||
@@ -10,6 +24,134 @@ def test_message_recipe_validates_unknown_binding():
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
_minimal_target_turn(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_message_recipe_requires_at_least_one_target():
|
||||
with pytest.raises(ValueError, match="target"):
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
_minimal_message_turn(),
|
||||
MessageTurn(role="assistant", content="no target", stream="high_level"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_recipe_rejects_both_messages_and_blend():
|
||||
with pytest.raises(ValueError, match="only one"):
|
||||
TrainingRecipe(
|
||||
messages=[_minimal_message_turn(), _minimal_target_turn()],
|
||||
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
|
||||
)
|
||||
|
||||
|
||||
def test_recipe_rejects_neither_messages_nor_blend():
|
||||
with pytest.raises(ValueError, match="must set one"):
|
||||
TrainingRecipe()
|
||||
|
||||
|
||||
# ── Blend validation ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_blend_must_be_non_empty():
|
||||
with pytest.raises(ValueError, match="at least one component"):
|
||||
TrainingRecipe(blend={})
|
||||
|
||||
|
||||
def test_blend_component_must_define_weight():
|
||||
with pytest.raises(ValueError, match="weight"):
|
||||
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
|
||||
|
||||
|
||||
def test_blend_component_weight_must_be_positive():
|
||||
with pytest.raises(ValueError, match="positive weight"):
|
||||
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
|
||||
|
||||
|
||||
def test_blend_component_must_define_messages():
|
||||
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
|
||||
# going through __post_init__ to exercise the blend-level validator.
|
||||
bad = TrainingRecipe.__new__(TrainingRecipe)
|
||||
bad.messages = None
|
||||
bad.bindings = None
|
||||
bad.blend = None
|
||||
bad.weight = 1.0
|
||||
with pytest.raises(ValueError, match="must define messages"):
|
||||
TrainingRecipe(blend={"a": bad})
|
||||
|
||||
|
||||
def test_blend_components_cannot_themselves_define_a_blend():
|
||||
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
|
||||
# Force-bypass the inner component's normal validation so the test
|
||||
# exercises the outer blend's "no nested blends" rule directly.
|
||||
nested = TrainingRecipe.__new__(TrainingRecipe)
|
||||
nested.messages = None
|
||||
nested.bindings = None
|
||||
nested.blend = inner.blend
|
||||
nested.weight = 1.0
|
||||
with pytest.raises(ValueError, match="cannot itself define a blend"):
|
||||
TrainingRecipe(blend={"outer": nested})
|
||||
|
||||
|
||||
# ── from_dict / from_yaml round-trips ────────────────────────────────
|
||||
|
||||
|
||||
def test_from_dict_with_nested_blend():
|
||||
recipe = TrainingRecipe.from_dict(
|
||||
{
|
||||
"blend": {
|
||||
"a": {
|
||||
"weight": 1.0,
|
||||
"messages": [
|
||||
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
|
||||
],
|
||||
},
|
||||
"b": {
|
||||
"weight": 2.0,
|
||||
"messages": [
|
||||
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
assert recipe.blend is not None
|
||||
assert set(recipe.blend) == {"a", "b"}
|
||||
assert recipe.blend["b"].weight == 2.0
|
||||
# Inner messages were promoted to MessageTurn instances.
|
||||
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
|
||||
|
||||
|
||||
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
|
||||
yaml_text = dedent(
|
||||
"""
|
||||
bindings:
|
||||
custom: "active_at(t, style=subtask)"
|
||||
messages:
|
||||
- {role: user, content: "${task}: ${custom}", stream: high_level}
|
||||
- {role: assistant, content: "ok", stream: high_level, target: true}
|
||||
"""
|
||||
).strip()
|
||||
path = tmp_path / "recipe.yaml"
|
||||
path.write_text(yaml_text)
|
||||
|
||||
via_classmethod = TrainingRecipe.from_yaml(path)
|
||||
via_helper = load_recipe(path)
|
||||
|
||||
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
|
||||
assert via_classmethod.messages[1].target is True
|
||||
# ``load_recipe`` is just a wrapper, but assert the two paths agree
|
||||
# on the structural result so a future divergence is caught here.
|
||||
assert via_helper.bindings == via_classmethod.bindings
|
||||
assert len(via_helper.messages) == len(via_classmethod.messages)
|
||||
|
||||
|
||||
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
|
||||
path = tmp_path / "bad.yaml"
|
||||
path.write_text("- just\n- a\n- list\n")
|
||||
with pytest.raises(ValueError, match="mapping at the top level"):
|
||||
TrainingRecipe.from_yaml(path)
|
||||
|
||||
@@ -342,6 +342,29 @@ def test_resolve_task_explicit_override_beats_rephrasings():
|
||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||
|
||||
|
||||
def test_render_sample_rejects_non_dict_language_rows():
|
||||
"""``_normalize_rows`` must surface malformed inputs as TypeError.
|
||||
|
||||
A pipeline that hands the renderer a non-dict (e.g. a stray string)
|
||||
is a real upstream bug — silent skipping would let it propagate.
|
||||
"""
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
with pytest.raises(TypeError, match="must be dictionaries"):
|
||||
render_sample(
|
||||
recipe=recipe,
|
||||
persistent=["not a dict"],
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
task="x",
|
||||
)
|
||||
|
||||
|
||||
def test_low_level_branch_renders_active_subtask():
|
||||
low_level = TrainingRecipe(
|
||||
blend={
|
||||
|
||||
Reference in New Issue
Block a user