mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
beb22afd81
* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
`recipe.py` and `language_render.py`. Promote to module-level
`PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
hardcoded `{"language_persistent", "language_events"}` literals at
two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
silently filtered language fields from samples that didn't have
them, which would hand downstream consumers a preserved list
shorter than the tensor batch. Now: if any sample carries a key,
every sample in the batch must carry it; otherwise raise a
`ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
or multi-element list fell through and triggered confusing
`float([])` errors downstream. Now raises `ValueError` with the
actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
of `key = {... if ... else {}}` plus an 11-line splat dict with a
single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
comment. Add a docstring explaining it's an intentional extension
point: downstream modules append project-local styles before
`column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
page referenced `src/lerobot/tools/`, `registry.py`, and
`get_tools(meta)` — none exist in this PR. Added a callout at the
start of "How to add your own tool" plus a note on the
implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
validation.** `test_recipe.py` grew from 1 case to 12 covering:
blend-or-messages exclusivity, target-turn requirement, blend
emptiness, weight presence/positivity, nested-blend rejection,
`from_dict` with nested blends, `from_yaml` / `load_recipe`
agreement, top-level non-mapping rejection. Added a malformed-row
test for `_normalize_rows` that asserts non-dict entries raise
`TypeError`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
#!/usr/bin/env python
|
|
|
|
from pathlib import Path
|
|
from textwrap import dedent
|
|
|
|
import pytest
|
|
|
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
|
|
|
|
|
|
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
|
|
return MessageTurn(role="user", content=content, stream="high_level")
|
|
|
|
|
|
def _minimal_target_turn() -> MessageTurn:
|
|
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
|
|
|
|
|
|
# ── Message-recipe validation ────────────────────────────────────────
|
|
|
|
|
|
def test_message_recipe_validates_unknown_binding():
|
|
with pytest.raises(ValueError, match="unknown binding"):
|
|
TrainingRecipe(
|
|
messages=[
|
|
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
|
_minimal_target_turn(),
|
|
]
|
|
)
|
|
|
|
|
|
def test_message_recipe_requires_at_least_one_target():
|
|
with pytest.raises(ValueError, match="target"):
|
|
TrainingRecipe(
|
|
messages=[
|
|
_minimal_message_turn(),
|
|
MessageTurn(role="assistant", content="no target", stream="high_level"),
|
|
]
|
|
)
|
|
|
|
|
|
def test_recipe_rejects_both_messages_and_blend():
|
|
with pytest.raises(ValueError, match="only one"):
|
|
TrainingRecipe(
|
|
messages=[_minimal_message_turn(), _minimal_target_turn()],
|
|
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
|
|
)
|
|
|
|
|
|
def test_recipe_rejects_neither_messages_nor_blend():
|
|
with pytest.raises(ValueError, match="must set one"):
|
|
TrainingRecipe()
|
|
|
|
|
|
# ── Blend validation ─────────────────────────────────────────────────
|
|
|
|
|
|
def test_blend_must_be_non_empty():
|
|
with pytest.raises(ValueError, match="at least one component"):
|
|
TrainingRecipe(blend={})
|
|
|
|
|
|
def test_blend_component_must_define_weight():
|
|
with pytest.raises(ValueError, match="weight"):
|
|
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
|
|
|
|
|
|
def test_blend_component_weight_must_be_positive():
|
|
with pytest.raises(ValueError, match="positive weight"):
|
|
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
|
|
|
|
|
|
def test_blend_component_must_define_messages():
|
|
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
|
|
# going through __post_init__ to exercise the blend-level validator.
|
|
bad = TrainingRecipe.__new__(TrainingRecipe)
|
|
bad.messages = None
|
|
bad.bindings = None
|
|
bad.blend = None
|
|
bad.weight = 1.0
|
|
with pytest.raises(ValueError, match="must define messages"):
|
|
TrainingRecipe(blend={"a": bad})
|
|
|
|
|
|
def test_blend_components_cannot_themselves_define_a_blend():
|
|
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
|
|
# Force-bypass the inner component's normal validation so the test
|
|
# exercises the outer blend's "no nested blends" rule directly.
|
|
nested = TrainingRecipe.__new__(TrainingRecipe)
|
|
nested.messages = None
|
|
nested.bindings = None
|
|
nested.blend = inner.blend
|
|
nested.weight = 1.0
|
|
with pytest.raises(ValueError, match="cannot itself define a blend"):
|
|
TrainingRecipe(blend={"outer": nested})
|
|
|
|
|
|
# ── from_dict / from_yaml round-trips ────────────────────────────────
|
|
|
|
|
|
def test_from_dict_with_nested_blend():
|
|
recipe = TrainingRecipe.from_dict(
|
|
{
|
|
"blend": {
|
|
"a": {
|
|
"weight": 1.0,
|
|
"messages": [
|
|
{"role": "user", "content": "${task}", "stream": "high_level"},
|
|
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
|
|
],
|
|
},
|
|
"b": {
|
|
"weight": 2.0,
|
|
"messages": [
|
|
{"role": "user", "content": "${task}", "stream": "high_level"},
|
|
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
|
|
],
|
|
},
|
|
}
|
|
}
|
|
)
|
|
assert recipe.blend is not None
|
|
assert set(recipe.blend) == {"a", "b"}
|
|
assert recipe.blend["b"].weight == 2.0
|
|
# Inner messages were promoted to MessageTurn instances.
|
|
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
|
|
|
|
|
|
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
|
|
yaml_text = dedent(
|
|
"""
|
|
bindings:
|
|
custom: "active_at(t, style=subtask)"
|
|
messages:
|
|
- {role: user, content: "${task}: ${custom}", stream: high_level}
|
|
- {role: assistant, content: "ok", stream: high_level, target: true}
|
|
"""
|
|
).strip()
|
|
path = tmp_path / "recipe.yaml"
|
|
path.write_text(yaml_text)
|
|
|
|
via_classmethod = TrainingRecipe.from_yaml(path)
|
|
via_helper = load_recipe(path)
|
|
|
|
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
|
|
assert via_classmethod.messages[1].target is True
|
|
# ``load_recipe`` is just a wrapper, but assert the two paths agree
|
|
# on the structural result so a future divergence is caught here.
|
|
assert via_helper.bindings == via_classmethod.bindings
|
|
assert len(via_helper.messages) == len(via_classmethod.messages)
|
|
|
|
|
|
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
|
|
path = tmp_path / "bad.yaml"
|
|
path.write_text("- just\n- a\n- list\n")
|
|
with pytest.raises(ValueError, match="mapping at the top level"):
|
|
TrainingRecipe.from_yaml(path)
|