review: dedupe regex, centralize column names, harden collate, more tests

* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
  `recipe.py` and `language_render.py`. Promote to module-level
  `PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
  template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
  hardcoded `{"language_persistent", "language_events"}` literals at
  two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
  rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
  silently filtered language fields from samples that didn't have
  them, which would hand downstream consumers a preserved list
  shorter than the tensor batch. Now: if any sample carries a key,
  every sample in the batch must carry it; otherwise raise a
  `ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
  or multi-element list fell through and triggered confusing
  `float([])` errors downstream. Now raises `ValueError` with the
  actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
  of `key = {... if ... else {}}` plus an 11-line splat dict with a
  single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
  comment. Add a docstring explaining it's an intentional extension
  point: downstream modules append project-local styles before
  `column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
  page referenced `src/lerobot/tools/`, `registry.py`, and
  `get_tools(meta)` — none exist in this PR. Added a callout at the
  start of "How to add your own tool" plus a note on the
  implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
  validation.** `test_recipe.py` grew from 1 case to 12 covering:
  blend-or-messages exclusivity, target-turn requirement, blend
  emptiness, weight presence/positivity, nested-blend rejection,
  `from_dict` with nested blends, `from_yaml` / `load_recipe`
  agreement, top-level non-mapping rejection. Added a malformed-row
  test for `_normalize_rows` that asserts non-dict entries raise
  `TypeError`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-06 19:06:38 +02:00
parent d55b581ca1
commit beb22afd81
10 changed files with 238 additions and 63 deletions
+13 -3
View File
@@ -66,9 +66,11 @@ prompt_str = tokenizer.apply_chat_template(
) )
``` ```
**The implementations** — runnable Python — live under **The implementations** — runnable Python — will live under
`src/lerobot/tools/`, one file per tool. The canonical `say` `src/lerobot/tools/`, one file per tool. The runtime dispatcher and
implementation wraps Kyutai's pocket-tts model. the canonical `say` implementation (wrapping Kyutai's pocket-tts) land
in a follow-up PR; this PR ships only the catalog storage and
fallback constant.
## Per-row tool _invocations_ ## Per-row tool _invocations_
@@ -114,6 +116,14 @@ the matching implementation.
## How to add your own tool ## How to add your own tool
> **Note:** Steps 2 and 3 below describe the runtime layer
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
> `get_tools(meta)`) which lands in a follow-up PR. Today (this PR
> only), Step 1 is enough to make the tool visible to the chat
> template via `meta.tools` so the model can learn to _generate_ the
> call. Executing the call at inference is what the follow-up PR
> wires up.
Three steps. Concrete example: a `record_observation` tool the policy Three steps. Concrete example: a `record_observation` tool the policy
can call to capture an extra observation outside the regular control can call to capture an extra observation outside the regular control
loop. loop.
+6 -3
View File
@@ -34,7 +34,10 @@ DEFAULT_BINDINGS = {
"vqa_query": "emitted_at(t, style=vqa, role=user)", "vqa_query": "emitted_at(t, style=vqa, role=user)",
} }
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
"""``${name}`` placeholder pattern used by both recipe binding-reference
discovery (here) and rendered-message substitution (in ``language_render``)."""
_VALID_ROLES = frozenset(get_args(MessageRole)) _VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream)) _VALID_STREAMS = frozenset(get_args(MessageStream))
@@ -178,13 +181,13 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[
if content is None: if content is None:
return set() return set()
if isinstance(content, str): if isinstance(content, str):
return set(_PLACEHOLDER_RE.findall(content)) return set(PLACEHOLDER_RE.findall(content))
names: set[str] = set() names: set[str] = set()
for block in content: for block in content:
for value in block.values(): for value in block.values():
if isinstance(value, str): if isinstance(value, str):
names.update(_PLACEHOLDER_RE.findall(value)) names.update(PLACEHOLDER_RE.findall(value))
return names return names
+4 -6
View File
@@ -31,6 +31,7 @@ from torchvision import transforms
from lerobot.utils.io_utils import load_json, write_json from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
from .language import LANGUAGE_COLUMNS
from .utils import ( from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
@@ -256,7 +257,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors. dict: The batch with items converted to torch tensors.
""" """
for key in items_dict: for key in items_dict:
if key in {"language_persistent", "language_events"}: if key in LANGUAGE_COLUMNS:
continue continue
first_item = items_dict[key][0] first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image): if isinstance(first_item, PILImage.Image):
@@ -297,12 +298,9 @@ def item_to_torch(item: dict) -> dict:
Returns: Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor. dict: Dictionary with all tensor-like items converted to torch.Tensor.
""" """
skip_keys = {"task", *LANGUAGE_COLUMNS}
for key, val in item.items(): for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in [ if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
"task",
"language_persistent",
"language_events",
]:
# Convert numpy arrays and lists to torch tensors # Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val) item[key] = torch.tensor(val)
return item return item
+7 -1
View File
@@ -37,7 +37,13 @@ CORE_STYLES = {
"trace", "trace",
"task_aug", "task_aug",
} }
EXTENDED_STYLES = set() # Project-local styles can be registered at import time by appending to
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
# validation. Empty by default — populate from a downstream module that
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
# the new style's column.
EXTENDED_STYLES: set[str] = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"} PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
+2 -3
View File
@@ -22,7 +22,7 @@ import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
from .language import LANGUAGE_PERSISTENT, column_for_style from .language import LANGUAGE_PERSISTENT, column_for_style
@@ -30,7 +30,6 @@ LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]] RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$") _RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
def active_at( def active_at(
@@ -376,7 +375,7 @@ def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) ->
return "" if content is None else str(content) return "" if content is None else str(content)
return str(value) return str(value)
return _PLACEHOLDER_RE.sub(replace, template) return PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None: def _validate_rendered(rendered: RenderedMessages) -> None:
+20 -39
View File
@@ -153,49 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
return x return x
_COMPLEMENTARY_KEYS = (
"task",
"index",
"task_index",
"episode_index",
"timestamp",
"language_persistent",
"language_events",
"messages",
"message_streams",
"target_message_indices",
)
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
""" """Extract complementary data from a batch dictionary.
Extract complementary data from a batch dictionary.
This includes padding flags, task description, and indices. Includes padding flags (any key containing ``_is_pad``) plus the fixed
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS``
Args: each only when present in ``batch``.
batch: The batch dictionary.
Returns:
A dictionary with the extracted complementary data.
""" """
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {} extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
index_key = {"index": batch["index"]} if "index" in batch else {} return {**pad_keys, **extras}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
language_persistent_key = (
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
)
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
target_message_indices_key = (
{"target_message_indices": batch["target_message_indices"]}
if "target_message_indices" in batch
else {}
)
return {
**pad_keys,
**task_key,
**index_key,
**task_index_key,
**episode_index_key,
**timestamp_key,
**language_persistent_key,
**language_events_key,
**messages_key,
**message_streams_key,
**target_message_indices_key,
}
def create_transition( def create_transition(
@@ -87,6 +87,8 @@ def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar.""" """Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"): if hasattr(value, "item"):
return value.item() return value.item()
if isinstance(value, list) and len(value) == 1: if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return _scalar(value[0]) return _scalar(value[0])
return value return value
+16 -5
View File
@@ -36,11 +36,22 @@ def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | N
if not batch: if not batch:
return None return None
preserved = { # All-or-nothing per key: a partial-presence batch (e.g. half the samples
key: [sample[key] for sample in batch if key in sample] # carry `messages` and half don't) is a real bug in the upstream
for key in _PYTHON_LIST_KEYS # rendering step — silently filtering would hand downstream consumers a
if any(key in sample for sample in batch) # preserved list shorter than the tensor batch. Raise instead so the
} # mismatch surfaces at the boundary.
preserved: dict[str, list[Any]] = {}
for key in _PYTHON_LIST_KEYS:
presence = [key in sample for sample in batch]
if not any(presence):
continue
if not all(presence):
raise ValueError(
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
f"every sample in a batch must agree."
)
preserved[key] = [sample[key] for sample in batch]
tensorizable = [ tensorizable = [
{ {
key: value key: value
+144 -2
View File
@@ -1,8 +1,22 @@
#!/usr/bin/env python #!/usr/bin/env python
from pathlib import Path
from textwrap import dedent
import pytest import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
return MessageTurn(role="user", content=content, stream="high_level")
def _minimal_target_turn() -> MessageTurn:
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
# ── Message-recipe validation ────────────────────────────────────────
def test_message_recipe_validates_unknown_binding(): def test_message_recipe_validates_unknown_binding():
@@ -10,6 +24,134 @@ def test_message_recipe_validates_unknown_binding():
TrainingRecipe( TrainingRecipe(
messages=[ messages=[
MessageTurn(role="user", content="${missing}", stream="high_level"), MessageTurn(role="user", content="${missing}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True), _minimal_target_turn(),
] ]
) )
def test_message_recipe_requires_at_least_one_target():
with pytest.raises(ValueError, match="target"):
TrainingRecipe(
messages=[
_minimal_message_turn(),
MessageTurn(role="assistant", content="no target", stream="high_level"),
]
)
def test_recipe_rejects_both_messages_and_blend():
with pytest.raises(ValueError, match="only one"):
TrainingRecipe(
messages=[_minimal_message_turn(), _minimal_target_turn()],
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
)
def test_recipe_rejects_neither_messages_nor_blend():
with pytest.raises(ValueError, match="must set one"):
TrainingRecipe()
# ── Blend validation ─────────────────────────────────────────────────
def test_blend_must_be_non_empty():
with pytest.raises(ValueError, match="at least one component"):
TrainingRecipe(blend={})
def test_blend_component_must_define_weight():
with pytest.raises(ValueError, match="weight"):
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
def test_blend_component_weight_must_be_positive():
with pytest.raises(ValueError, match="positive weight"):
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
def test_blend_component_must_define_messages():
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
# going through __post_init__ to exercise the blend-level validator.
bad = TrainingRecipe.__new__(TrainingRecipe)
bad.messages = None
bad.bindings = None
bad.blend = None
bad.weight = 1.0
with pytest.raises(ValueError, match="must define messages"):
TrainingRecipe(blend={"a": bad})
def test_blend_components_cannot_themselves_define_a_blend():
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
# Force-bypass the inner component's normal validation so the test
# exercises the outer blend's "no nested blends" rule directly.
nested = TrainingRecipe.__new__(TrainingRecipe)
nested.messages = None
nested.bindings = None
nested.blend = inner.blend
nested.weight = 1.0
with pytest.raises(ValueError, match="cannot itself define a blend"):
TrainingRecipe(blend={"outer": nested})
# ── from_dict / from_yaml round-trips ────────────────────────────────
def test_from_dict_with_nested_blend():
recipe = TrainingRecipe.from_dict(
{
"blend": {
"a": {
"weight": 1.0,
"messages": [
{"role": "user", "content": "${task}", "stream": "high_level"},
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
],
},
"b": {
"weight": 2.0,
"messages": [
{"role": "user", "content": "${task}", "stream": "high_level"},
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
],
},
}
}
)
assert recipe.blend is not None
assert set(recipe.blend) == {"a", "b"}
assert recipe.blend["b"].weight == 2.0
# Inner messages were promoted to MessageTurn instances.
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
yaml_text = dedent(
"""
bindings:
custom: "active_at(t, style=subtask)"
messages:
- {role: user, content: "${task}: ${custom}", stream: high_level}
- {role: assistant, content: "ok", stream: high_level, target: true}
"""
).strip()
path = tmp_path / "recipe.yaml"
path.write_text(yaml_text)
via_classmethod = TrainingRecipe.from_yaml(path)
via_helper = load_recipe(path)
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
assert via_classmethod.messages[1].target is True
# ``load_recipe`` is just a wrapper, but assert the two paths agree
# on the structural result so a future divergence is caught here.
assert via_helper.bindings == via_classmethod.bindings
assert len(via_helper.messages) == len(via_classmethod.messages)
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
path = tmp_path / "bad.yaml"
path.write_text("- just\n- a\n- list\n")
with pytest.raises(ValueError, match="mapping at the top level"):
TrainingRecipe.from_yaml(path)
+23
View File
@@ -342,6 +342,29 @@ def test_resolve_task_explicit_override_beats_rephrasings():
assert rendered["messages"][0]["content"] == "explicit override wins" assert rendered["messages"][0]["content"] == "explicit override wins"
def test_render_sample_rejects_non_dict_language_rows():
"""``_normalize_rows`` must surface malformed inputs as TypeError.
A pipeline that hands the renderer a non-dict (e.g. a stray string)
is a real upstream bug silent skipping would let it propagate.
"""
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
with pytest.raises(TypeError, match="must be dictionaries"):
render_sample(
recipe=recipe,
persistent=["not a dict"],
events=[],
t=0.0,
sample_idx=0,
task="x",
)
def test_low_level_branch_renders_active_subtask(): def test_low_level_branch_renders_active_subtask():
low_level = TrainingRecipe( low_level = TrainingRecipe(
blend={ blend={