mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
review: dedupe regex, centralize column names, harden collate, more tests
* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
`recipe.py` and `language_render.py`. Promote to module-level
`PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
hardcoded `{"language_persistent", "language_events"}` literals at
two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
silently filtered language fields from samples that didn't have
them, which would hand downstream consumers a preserved list
shorter than the tensor batch. Now: if any sample carries a key,
every sample in the batch must carry it; otherwise raise a
`ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
or multi-element list fell through and triggered confusing
`float([])` errors downstream. Now raises `ValueError` with the
actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
of `key = {... if ... else {}}` plus an 11-line splat dict with a
single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
comment. Add a docstring explaining it's an intentional extension
point: downstream modules append project-local styles before
`column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
page referenced `src/lerobot/tools/`, `registry.py`, and
`get_tools(meta)` — none exist in this PR. Added a callout at the
start of "How to add your own tool" plus a note on the
implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
validation.** `test_recipe.py` grew from 1 case to 12 covering:
blend-or-messages exclusivity, target-turn requirement, blend
emptiness, weight presence/positivity, nested-blend rejection,
`from_dict` with nested blends, `from_yaml` / `load_recipe`
agreement, top-level non-mapping rejection. Added a malformed-row
test for `_normalize_rows` that asserts non-dict entries raise
`TypeError`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+13
-3
@@ -66,9 +66,11 @@ prompt_str = tokenizer.apply_chat_template(
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — live under
|
||||
`src/lerobot/tools/`, one file per tool. The canonical `say`
|
||||
implementation wraps Kyutai's pocket-tts model.
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) land
|
||||
in a follow-up PR; this PR ships only the catalog storage and
|
||||
fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
@@ -114,6 +116,14 @@ the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which lands in a follow-up PR. Today (this PR
|
||||
> only), Step 1 is enough to make the tool visible to the chat
|
||||
> template via `meta.tools` so the model can learn to _generate_ the
|
||||
> call. Executing the call at inference is what the follow-up PR
|
||||
> wires up.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
@@ -34,7 +34,10 @@ DEFAULT_BINDINGS = {
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
@@ -178,13 +181,13 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(_PLACEHOLDER_RE.findall(content))
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(_PLACEHOLDER_RE.findall(value))
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
@@ -256,7 +257,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in {"language_persistent", "language_events"}:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
@@ -297,12 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in [
|
||||
"task",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
@@ -37,7 +37,13 @@ CORE_STYLES = {
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
EXTENDED_STYLES = set()
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
|
||||
@@ -22,7 +22,7 @@ import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
@@ -30,7 +30,6 @@ LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
|
||||
|
||||
def active_at(
|
||||
@@ -376,7 +375,7 @@ def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) ->
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return _PLACEHOLDER_RE.sub(replace, template)
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
|
||||
@@ -153,49 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
||||
return x
|
||||
|
||||
|
||||
_COMPLEMENTARY_KEYS = (
|
||||
"task",
|
||||
"index",
|
||||
"task_index",
|
||||
"episode_index",
|
||||
"timestamp",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
"messages",
|
||||
"message_streams",
|
||||
"target_message_indices",
|
||||
)
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
"""Extract complementary data from a batch dictionary.
|
||||
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
Includes padding flags (any key containing ``_is_pad``) plus the fixed
|
||||
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
|
||||
each only when present in ``batch``.
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
|
||||
language_persistent_key = (
|
||||
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
|
||||
)
|
||||
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
|
||||
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
|
||||
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
|
||||
target_message_indices_key = (
|
||||
{"target_message_indices": batch["target_message_indices"]}
|
||||
if "target_message_indices" in batch
|
||||
else {}
|
||||
)
|
||||
|
||||
return {
|
||||
**pad_keys,
|
||||
**task_key,
|
||||
**index_key,
|
||||
**task_index_key,
|
||||
**episode_index_key,
|
||||
**timestamp_key,
|
||||
**language_persistent_key,
|
||||
**language_events_key,
|
||||
**messages_key,
|
||||
**message_streams_key,
|
||||
**target_message_indices_key,
|
||||
}
|
||||
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
|
||||
return {**pad_keys, **extras}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -87,6 +87,8 @@ def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list) and len(value) == 1:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
@@ -36,11 +36,22 @@ def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | N
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
preserved = {
|
||||
key: [sample[key] for sample in batch if key in sample]
|
||||
for key in _PYTHON_LIST_KEYS
|
||||
if any(key in sample for sample in batch)
|
||||
}
|
||||
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
|
||||
# carry `messages` and half don't) is a real bug in the upstream
|
||||
# rendering step — silently filtering would hand downstream consumers a
|
||||
# preserved list shorter than the tensor batch. Raise instead so the
|
||||
# mismatch surfaces at the boundary.
|
||||
preserved: dict[str, list[Any]] = {}
|
||||
for key in _PYTHON_LIST_KEYS:
|
||||
presence = [key in sample for sample in batch]
|
||||
if not any(presence):
|
||||
continue
|
||||
if not all(presence):
|
||||
raise ValueError(
|
||||
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
|
||||
f"every sample in a batch must agree."
|
||||
)
|
||||
preserved[key] = [sample[key] for sample in batch]
|
||||
tensorizable = [
|
||||
{
|
||||
key: value
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
|
||||
|
||||
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
|
||||
return MessageTurn(role="user", content=content, stream="high_level")
|
||||
|
||||
|
||||
def _minimal_target_turn() -> MessageTurn:
|
||||
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
|
||||
|
||||
|
||||
# ── Message-recipe validation ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_message_recipe_validates_unknown_binding():
|
||||
@@ -10,6 +24,134 @@ def test_message_recipe_validates_unknown_binding():
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
_minimal_target_turn(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_message_recipe_requires_at_least_one_target():
|
||||
with pytest.raises(ValueError, match="target"):
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
_minimal_message_turn(),
|
||||
MessageTurn(role="assistant", content="no target", stream="high_level"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_recipe_rejects_both_messages_and_blend():
|
||||
with pytest.raises(ValueError, match="only one"):
|
||||
TrainingRecipe(
|
||||
messages=[_minimal_message_turn(), _minimal_target_turn()],
|
||||
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
|
||||
)
|
||||
|
||||
|
||||
def test_recipe_rejects_neither_messages_nor_blend():
|
||||
with pytest.raises(ValueError, match="must set one"):
|
||||
TrainingRecipe()
|
||||
|
||||
|
||||
# ── Blend validation ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_blend_must_be_non_empty():
|
||||
with pytest.raises(ValueError, match="at least one component"):
|
||||
TrainingRecipe(blend={})
|
||||
|
||||
|
||||
def test_blend_component_must_define_weight():
|
||||
with pytest.raises(ValueError, match="weight"):
|
||||
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
|
||||
|
||||
|
||||
def test_blend_component_weight_must_be_positive():
|
||||
with pytest.raises(ValueError, match="positive weight"):
|
||||
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
|
||||
|
||||
|
||||
def test_blend_component_must_define_messages():
|
||||
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
|
||||
# going through __post_init__ to exercise the blend-level validator.
|
||||
bad = TrainingRecipe.__new__(TrainingRecipe)
|
||||
bad.messages = None
|
||||
bad.bindings = None
|
||||
bad.blend = None
|
||||
bad.weight = 1.0
|
||||
with pytest.raises(ValueError, match="must define messages"):
|
||||
TrainingRecipe(blend={"a": bad})
|
||||
|
||||
|
||||
def test_blend_components_cannot_themselves_define_a_blend():
|
||||
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
|
||||
# Force-bypass the inner component's normal validation so the test
|
||||
# exercises the outer blend's "no nested blends" rule directly.
|
||||
nested = TrainingRecipe.__new__(TrainingRecipe)
|
||||
nested.messages = None
|
||||
nested.bindings = None
|
||||
nested.blend = inner.blend
|
||||
nested.weight = 1.0
|
||||
with pytest.raises(ValueError, match="cannot itself define a blend"):
|
||||
TrainingRecipe(blend={"outer": nested})
|
||||
|
||||
|
||||
# ── from_dict / from_yaml round-trips ────────────────────────────────
|
||||
|
||||
|
||||
def test_from_dict_with_nested_blend():
|
||||
recipe = TrainingRecipe.from_dict(
|
||||
{
|
||||
"blend": {
|
||||
"a": {
|
||||
"weight": 1.0,
|
||||
"messages": [
|
||||
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
|
||||
],
|
||||
},
|
||||
"b": {
|
||||
"weight": 2.0,
|
||||
"messages": [
|
||||
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
|
||||
],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
assert recipe.blend is not None
|
||||
assert set(recipe.blend) == {"a", "b"}
|
||||
assert recipe.blend["b"].weight == 2.0
|
||||
# Inner messages were promoted to MessageTurn instances.
|
||||
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
|
||||
|
||||
|
||||
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
|
||||
yaml_text = dedent(
|
||||
"""
|
||||
bindings:
|
||||
custom: "active_at(t, style=subtask)"
|
||||
messages:
|
||||
- {role: user, content: "${task}: ${custom}", stream: high_level}
|
||||
- {role: assistant, content: "ok", stream: high_level, target: true}
|
||||
"""
|
||||
).strip()
|
||||
path = tmp_path / "recipe.yaml"
|
||||
path.write_text(yaml_text)
|
||||
|
||||
via_classmethod = TrainingRecipe.from_yaml(path)
|
||||
via_helper = load_recipe(path)
|
||||
|
||||
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
|
||||
assert via_classmethod.messages[1].target is True
|
||||
# ``load_recipe`` is just a wrapper, but assert the two paths agree
|
||||
# on the structural result so a future divergence is caught here.
|
||||
assert via_helper.bindings == via_classmethod.bindings
|
||||
assert len(via_helper.messages) == len(via_classmethod.messages)
|
||||
|
||||
|
||||
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
|
||||
path = tmp_path / "bad.yaml"
|
||||
path.write_text("- just\n- a\n- list\n")
|
||||
with pytest.raises(ValueError, match="mapping at the top level"):
|
||||
TrainingRecipe.from_yaml(path)
|
||||
|
||||
@@ -342,6 +342,29 @@ def test_resolve_task_explicit_override_beats_rephrasings():
|
||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||
|
||||
|
||||
def test_render_sample_rejects_non_dict_language_rows():
|
||||
"""``_normalize_rows`` must surface malformed inputs as TypeError.
|
||||
|
||||
A pipeline that hands the renderer a non-dict (e.g. a stray string)
|
||||
is a real upstream bug — silent skipping would let it propagate.
|
||||
"""
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
with pytest.raises(TypeError, match="must be dictionaries"):
|
||||
render_sample(
|
||||
recipe=recipe,
|
||||
persistent=["not a dict"],
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
task="x",
|
||||
)
|
||||
|
||||
|
||||
def test_low_level_branch_renders_active_subtask():
|
||||
low_level = TrainingRecipe(
|
||||
blend={
|
||||
|
||||
Reference in New Issue
Block a user