diff --git a/docs/source/tools.mdx b/docs/source/tools.mdx index 3309be8cd..04d5da6b9 100644 --- a/docs/source/tools.mdx +++ b/docs/source/tools.mdx @@ -66,9 +66,11 @@ prompt_str = tokenizer.apply_chat_template( ) ``` -**The implementations** — runnable Python — live under -`src/lerobot/tools/`, one file per tool. The canonical `say` -implementation wraps Kyutai's pocket-tts model. +**The implementations** — runnable Python — will live under +`src/lerobot/tools/`, one file per tool. The runtime dispatcher and +the canonical `say` implementation (wrapping Kyutai's pocket-tts) land +in a follow-up PR; this PR ships only the catalog storage and +fallback constant. ## Per-row tool _invocations_ @@ -114,6 +116,14 @@ the matching implementation. ## How to add your own tool +> **Note:** Steps 2 and 3 below describe the runtime layer +> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`, +> `get_tools(meta)`) which lands in a follow-up PR. Today (this PR +> only), Step 1 is enough to make the tool visible to the chat +> template via `meta.tools` so the model can learn to _generate_ the +> call. Executing the call at inference is what the follow-up PR +> wires up. + Three steps. Concrete example: a `record_observation` tool the policy can call to capture an extra observation outside the regular control loop. diff --git a/src/lerobot/configs/recipe.py b/src/lerobot/configs/recipe.py index ef496c3c4..27d4c5b8b 100644 --- a/src/lerobot/configs/recipe.py +++ b/src/lerobot/configs/recipe.py @@ -34,7 +34,10 @@ DEFAULT_BINDINGS = { "vqa_query": "emitted_at(t, style=vqa, role=user)", } -_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") +PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") +"""``${name}`` placeholder pattern used by both recipe binding-reference +discovery (here) and rendered-message substitution (in ``language_render``).""" + _VALID_ROLES = frozenset(get_args(MessageRole)) _VALID_STREAMS = frozenset(get_args(MessageStream)) @@ -178,13 +181,13 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[ if content is None: return set() if isinstance(content, str): - return set(_PLACEHOLDER_RE.findall(content)) + return set(PLACEHOLDER_RE.findall(content)) names: set[str] = set() for block in content: for value in block.values(): if isinstance(value, str): - names.update(_PLACEHOLDER_RE.findall(value)) + names.update(PLACEHOLDER_RE.findall(value)) return names diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index f4552c8b9..a41f34704 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -31,6 +31,7 @@ from torchvision import transforms from lerobot.utils.io_utils import load_json, write_json from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict +from .language import LANGUAGE_COLUMNS from .utils import ( DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_EPISODES_PATH, @@ -256,7 +257,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to dict: The batch with items converted to torch tensors. """ for key in items_dict: - if key in {"language_persistent", "language_events"}: + if key in LANGUAGE_COLUMNS: continue first_item = items_dict[key][0] if isinstance(first_item, PILImage.Image): @@ -297,12 +298,9 @@ def item_to_torch(item: dict) -> dict: Returns: dict: Dictionary with all tensor-like items converted to torch.Tensor. """ + skip_keys = {"task", *LANGUAGE_COLUMNS} for key, val in item.items(): - if isinstance(val, (np.ndarray | list)) and key not in [ - "task", - "language_persistent", - "language_events", - ]: + if isinstance(val, (np.ndarray | list)) and key not in skip_keys: # Convert numpy arrays and lists to torch tensors item[key] = torch.tensor(val) return item diff --git a/src/lerobot/datasets/language.py b/src/lerobot/datasets/language.py index 8d518df9d..27c26ac06 100644 --- a/src/lerobot/datasets/language.py +++ b/src/lerobot/datasets/language.py @@ -37,7 +37,13 @@ CORE_STYLES = { "trace", "task_aug", } -EXTENDED_STYLES = set() +# Project-local styles can be registered at import time by appending to +# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added +# here is treated as a known style alongside ``CORE_STYLES`` for resolver +# validation. Empty by default — populate from a downstream module that +# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare +# the new style's column. +EXTENDED_STYLES: set[str] = set() STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"} diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index 1069980b8..8b1a364a0 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -22,7 +22,7 @@ import re from collections.abc import Sequence from typing import Any -from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe +from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe from .language import LANGUAGE_PERSISTENT, column_for_style @@ -30,7 +30,6 @@ LanguageRow = dict[str, Any] RenderedMessages = dict[str, list[Any]] _RESOLVER_RE = re.compile(r"^(?P[A-Za-z_][A-Za-z0-9_]*)\((?P.*)\)$") -_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") def active_at( @@ -376,7 +375,7 @@ def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> return "" if content is None else str(content) return str(value) - return _PLACEHOLDER_RE.sub(replace, template) + return PLACEHOLDER_RE.sub(replace, template) def _validate_rendered(rendered: RenderedMessages) -> None: diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 98e9253b0..faa4d5cd9 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -153,49 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An return x +_COMPLEMENTARY_KEYS = ( + "task", + "index", + "task_index", + "episode_index", + "timestamp", + "language_persistent", + "language_events", + "messages", + "message_streams", + "target_message_indices", +) + + def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: - """ - Extract complementary data from a batch dictionary. + """Extract complementary data from a batch dictionary. - This includes padding flags, task description, and indices. - - Args: - batch: The batch dictionary. - - Returns: - A dictionary with the extracted complementary data. + Includes padding flags (any key containing ``_is_pad``) plus the fixed + set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` — + each only when present in ``batch``. """ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} - task_key = {"task": batch["task"]} if "task" in batch else {} - index_key = {"index": batch["index"]} if "index" in batch else {} - task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} - episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {} - language_persistent_key = ( - {"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {} - ) - language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {} - messages_key = {"messages": batch["messages"]} if "messages" in batch else {} - message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {} - target_message_indices_key = ( - {"target_message_indices": batch["target_message_indices"]} - if "target_message_indices" in batch - else {} - ) - - return { - **pad_keys, - **task_key, - **index_key, - **task_index_key, - **episode_index_key, - **timestamp_key, - **language_persistent_key, - **language_events_key, - **messages_key, - **message_streams_key, - **target_message_indices_key, - } + extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch} + return {**pad_keys, **extras} def create_transition( diff --git a/src/lerobot/processor/render_messages_processor.py b/src/lerobot/processor/render_messages_processor.py index 7d88fab73..b5bb81f9d 100644 --- a/src/lerobot/processor/render_messages_processor.py +++ b/src/lerobot/processor/render_messages_processor.py @@ -87,6 +87,8 @@ def _scalar(value: Any) -> float | int: """Unwrap a tensor/array/single-element list into a Python scalar.""" if hasattr(value, "item"): return value.item() - if isinstance(value, list) and len(value) == 1: + if isinstance(value, list): + if len(value) != 1: + raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}") return _scalar(value[0]) return value diff --git a/src/lerobot/utils/collate.py b/src/lerobot/utils/collate.py index ca32430cd..fce7e6b42 100644 --- a/src/lerobot/utils/collate.py +++ b/src/lerobot/utils/collate.py @@ -36,11 +36,22 @@ def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | N if not batch: return None - preserved = { - key: [sample[key] for sample in batch if key in sample] - for key in _PYTHON_LIST_KEYS - if any(key in sample for sample in batch) - } + # All-or-nothing per key: a partial-presence batch (e.g. half the samples + # carry `messages` and half don't) is a real bug in the upstream + # rendering step — silently filtering would hand downstream consumers a + # preserved list shorter than the tensor batch. Raise instead so the + # mismatch surfaces at the boundary. + preserved: dict[str, list[Any]] = {} + for key in _PYTHON_LIST_KEYS: + presence = [key in sample for sample in batch] + if not any(presence): + continue + if not all(presence): + raise ValueError( + f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; " + f"every sample in a batch must agree." + ) + preserved[key] = [sample[key] for sample in batch] tensorizable = [ { key: value diff --git a/tests/configs/test_recipe.py b/tests/configs/test_recipe.py index ba0ad9117..332e40333 100644 --- a/tests/configs/test_recipe.py +++ b/tests/configs/test_recipe.py @@ -1,8 +1,22 @@ #!/usr/bin/env python +from pathlib import Path +from textwrap import dedent + import pytest -from lerobot.configs.recipe import MessageTurn, TrainingRecipe +from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe + + +def _minimal_message_turn(content: str = "${task}") -> MessageTurn: + return MessageTurn(role="user", content=content, stream="high_level") + + +def _minimal_target_turn() -> MessageTurn: + return MessageTurn(role="assistant", content="ok", stream="high_level", target=True) + + +# ── Message-recipe validation ──────────────────────────────────────── def test_message_recipe_validates_unknown_binding(): @@ -10,6 +24,134 @@ def test_message_recipe_validates_unknown_binding(): TrainingRecipe( messages=[ MessageTurn(role="user", content="${missing}", stream="high_level"), - MessageTurn(role="assistant", content="ok", stream="high_level", target=True), + _minimal_target_turn(), ] ) + + +def test_message_recipe_requires_at_least_one_target(): + with pytest.raises(ValueError, match="target"): + TrainingRecipe( + messages=[ + _minimal_message_turn(), + MessageTurn(role="assistant", content="no target", stream="high_level"), + ] + ) + + +def test_recipe_rejects_both_messages_and_blend(): + with pytest.raises(ValueError, match="only one"): + TrainingRecipe( + messages=[_minimal_message_turn(), _minimal_target_turn()], + blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])}, + ) + + +def test_recipe_rejects_neither_messages_nor_blend(): + with pytest.raises(ValueError, match="must set one"): + TrainingRecipe() + + +# ── Blend validation ───────────────────────────────────────────────── + + +def test_blend_must_be_non_empty(): + with pytest.raises(ValueError, match="at least one component"): + TrainingRecipe(blend={}) + + +def test_blend_component_must_define_weight(): + with pytest.raises(ValueError, match="weight"): + TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])}) + + +def test_blend_component_weight_must_be_positive(): + with pytest.raises(ValueError, match="positive weight"): + TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])}) + + +def test_blend_component_must_define_messages(): + # A bare TrainingRecipe(weight=1.0) would itself raise; build it without + # going through __post_init__ to exercise the blend-level validator. + bad = TrainingRecipe.__new__(TrainingRecipe) + bad.messages = None + bad.bindings = None + bad.blend = None + bad.weight = 1.0 + with pytest.raises(ValueError, match="must define messages"): + TrainingRecipe(blend={"a": bad}) + + +def test_blend_components_cannot_themselves_define_a_blend(): + inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])}) + # Force-bypass the inner component's normal validation so the test + # exercises the outer blend's "no nested blends" rule directly. + nested = TrainingRecipe.__new__(TrainingRecipe) + nested.messages = None + nested.bindings = None + nested.blend = inner.blend + nested.weight = 1.0 + with pytest.raises(ValueError, match="cannot itself define a blend"): + TrainingRecipe(blend={"outer": nested}) + + +# ── from_dict / from_yaml round-trips ──────────────────────────────── + + +def test_from_dict_with_nested_blend(): + recipe = TrainingRecipe.from_dict( + { + "blend": { + "a": { + "weight": 1.0, + "messages": [ + {"role": "user", "content": "${task}", "stream": "high_level"}, + {"role": "assistant", "content": "a", "stream": "high_level", "target": True}, + ], + }, + "b": { + "weight": 2.0, + "messages": [ + {"role": "user", "content": "${task}", "stream": "high_level"}, + {"role": "assistant", "content": "b", "stream": "high_level", "target": True}, + ], + }, + } + } + ) + assert recipe.blend is not None + assert set(recipe.blend) == {"a", "b"} + assert recipe.blend["b"].weight == 2.0 + # Inner messages were promoted to MessageTurn instances. + assert isinstance(recipe.blend["a"].messages[0], MessageTurn) + + +def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path): + yaml_text = dedent( + """ + bindings: + custom: "active_at(t, style=subtask)" + messages: + - {role: user, content: "${task}: ${custom}", stream: high_level} + - {role: assistant, content: "ok", stream: high_level, target: true} + """ + ).strip() + path = tmp_path / "recipe.yaml" + path.write_text(yaml_text) + + via_classmethod = TrainingRecipe.from_yaml(path) + via_helper = load_recipe(path) + + assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"} + assert via_classmethod.messages[1].target is True + # ``load_recipe`` is just a wrapper, but assert the two paths agree + # on the structural result so a future divergence is caught here. + assert via_helper.bindings == via_classmethod.bindings + assert len(via_helper.messages) == len(via_classmethod.messages) + + +def test_from_yaml_rejects_non_mapping(tmp_path: Path): + path = tmp_path / "bad.yaml" + path.write_text("- just\n- a\n- list\n") + with pytest.raises(ValueError, match="mapping at the top level"): + TrainingRecipe.from_yaml(path) diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py index 4fccaadfb..a401f8aad 100644 --- a/tests/datasets/test_language_render.py +++ b/tests/datasets/test_language_render.py @@ -342,6 +342,29 @@ def test_resolve_task_explicit_override_beats_rephrasings(): assert rendered["messages"][0]["content"] == "explicit override wins" +def test_render_sample_rejects_non_dict_language_rows(): + """``_normalize_rows`` must surface malformed inputs as TypeError. + + A pipeline that hands the renderer a non-dict (e.g. a stray string) + is a real upstream bug — silent skipping would let it propagate. + """ + recipe = TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="ok", stream="high_level", target=True), + ] + ) + with pytest.raises(TypeError, match="must be dictionaries"): + render_sample( + recipe=recipe, + persistent=["not a dict"], + events=[], + t=0.0, + sample_idx=0, + task="x", + ) + + def test_low_level_branch_renders_active_subtask(): low_level = TrainingRecipe( blend={