From 8833d735a1e6a4d65693962b59d9dced42e6fbcf Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 27 Apr 2026 10:56:32 +0200 Subject: [PATCH] Add extensive language support --- docs/source/_toctree.yml | 4 +- docs/source/dataset_subtask.mdx | 277 ----------- docs/source/language_and_recipes.mdx | 75 +++ pyproject.toml | 2 +- src/lerobot/configs/__init__.py | 4 + src/lerobot/configs/recipe.py | 167 +++++++ src/lerobot/configs/recipes/pi05_hirobot.yaml | 47 ++ src/lerobot/datasets/__init__.py | 14 + src/lerobot/datasets/compute_stats.py | 2 +- src/lerobot/datasets/dataset_metadata.py | 5 +- src/lerobot/datasets/dataset_reader.py | 5 - src/lerobot/datasets/feature_utils.py | 7 +- src/lerobot/datasets/io_utils.py | 19 +- src/lerobot/datasets/language.py | 96 ++++ src/lerobot/datasets/language_render.py | 445 ++++++++++++++++++ src/lerobot/datasets/utils.py | 1 - src/lerobot/processor/__init__.py | 2 + src/lerobot/processor/batch_processor.py | 18 + src/lerobot/processor/converters.py | 27 +- .../processor/render_messages_processor.py | 81 ++++ src/lerobot/scripts/lerobot_train.py | 2 + src/lerobot/utils/collate.py | 48 ++ tests/configs/test_recipe.py | 31 ++ tests/datasets/test_language.py | 95 ++++ tests/datasets/test_language_render.py | 164 +++++++ tests/datasets/test_subtask_dataset.py | 193 -------- .../test_render_messages_processor.py | 55 +++ tests/utils/test_collate.py | 36 ++ uv.lock | 40 +- 29 files changed, 1445 insertions(+), 517 deletions(-) delete mode 100644 docs/source/dataset_subtask.mdx create mode 100644 docs/source/language_and_recipes.mdx create mode 100644 src/lerobot/configs/recipe.py create mode 100644 src/lerobot/configs/recipes/pi05_hirobot.yaml create mode 100644 src/lerobot/datasets/language.py create mode 100644 src/lerobot/datasets/language_render.py create mode 100644 src/lerobot/processor/render_messages_processor.py create mode 100644 src/lerobot/utils/collate.py create mode 100644 tests/configs/test_recipe.py create mode 100644 tests/datasets/test_language.py create mode 100644 tests/datasets/test_language_render.py delete mode 100644 tests/datasets/test_subtask_dataset.py create mode 100644 tests/processor/test_render_messages_processor.py create mode 100644 tests/utils/test_collate.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f5e1129f3..5ca449145 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -31,8 +31,8 @@ title: Porting Large Datasets - local: using_dataset_tools title: Using the Dataset Tools - - local: dataset_subtask - title: Using Subtasks in the Dataset + - local: language_and_recipes + title: Language Columns and Recipes - local: streaming_video_encoding title: Streaming Video Encoding title: "Datasets" diff --git a/docs/source/dataset_subtask.mdx b/docs/source/dataset_subtask.mdx deleted file mode 100644 index 6264aca22..000000000 --- a/docs/source/dataset_subtask.mdx +++ /dev/null @@ -1,277 +0,0 @@ -# Using Subtasks in LeRobot Datasets - -Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for: - -- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time -- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models) -- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps - -LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks. - -## What are Subtasks? - -While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps: - -1. "Approach the apple" -2. "Grasp the apple" -3. "Lift the apple" -4. "Move to basket" -5. "Release the apple" - -Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages. - -An overview of subtask annotation showing how frames are labeled with intermediate subtask stages - -

- Figure: Overview of subtask annotation. -

- -**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022. - -## Dataset Structure - -Subtask information is stored in the dataset metadata: - -``` -my-dataset/ -├── data/ -│ └── ... -├── meta/ -│ ├── info.json -│ ├── stats.json -│ ├── tasks.parquet -│ ├── subtasks.parquet # Subtask index → subtask string mapping -│ └── episodes/ -│ └── ... -└── videos/ - └── ... -``` - -### Subtasks Parquet File - -The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions: - -| subtask_index | subtask (index column) | -| ------------- | ---------------------- | -| 0 | "Approach the apple" | -| 1 | "Grasp the apple" | -| 2 | "Lift the apple" | -| ... | ... | - -### Frame-Level Annotations - -Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file: - -```python -# Example frame data in the parquet file -{ - "index": 42, - "timestamp": 1.4, - "episode_index": 0, - "task_index": 0, - "subtask_index": 2, # References "Lift the apple" - "observation.state": [...], - "action": [...], -} -``` - -## Annotating Datasets with Subtasks - -We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks: - -**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)** - -After completing your annotation: - -1. Click "Push to Hub" to upload your annotated dataset -2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate) - -## Loading Datasets with Subtasks - -When you load a dataset with subtask annotations, the subtask information is automatically available: - -```python -from lerobot.datasets import LeRobotDataset - -# Load a dataset with subtask annotations -dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") - -# Access a sample -sample = dataset[100] - -# The sample includes both task and subtask information -print(sample["task"]) # "Collect the fruit" -print(sample["subtask"]) # "Grasp the apple" -print(sample["task_index"]) # tensor(0) -print(sample["subtask_index"]) # tensor(2) -``` - -### Checking for Subtask Support - -You can check if a dataset has subtask annotations: - -```python -# Check if subtasks are available -has_subtasks = ( - "subtask_index" in dataset.features - and dataset.meta.subtasks is not None -) - -if has_subtasks: - print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks") - print("Subtasks:", list(dataset.meta.subtasks.index)) -``` - -## Using Subtasks for Training - -### With the Tokenizer Processor - -The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models: - -```python -from lerobot.processor import TokenizerProcessorStep - -# Create a tokenizer processor step -tokenizer_processor = TokenizerProcessorStep( - tokenizer_name_or_path="google/paligemma-3b-pt-224", - padding="max_length", - max_length=64, -) - -# The processor will automatically tokenize subtasks if present in the batch -# and add them to the observation under: -# - "observation.subtask.tokens" -# - "observation.subtask.attention_mask" -``` - -When subtasks are available in the batch, the tokenizer processor adds: - -- `observation.subtask.tokens`: Tokenized subtask text -- `observation.subtask.attention_mask`: Attention mask for the subtask tokens - -### DataLoader with Subtasks - -```python -import torch -from lerobot.datasets import LeRobotDataset - -dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") - -dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=16, - shuffle=True, -) - -for batch in dataloader: - # Access subtask information in the batch - subtasks = batch["subtask"] # List of subtask strings - subtask_indices = batch["subtask_index"] # Tensor of subtask indices - - # Use for training hierarchical policies or reward models - print(f"Batch subtasks: {set(subtasks)}") -``` - -## Example Datasets with Subtask Annotations - -Try loading a dataset with subtask annotations: - -```python -from lerobot.datasets import LeRobotDataset - -# Example dataset with subtask annotations -dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") - -# Explore the subtasks -print("Available subtasks:") -for subtask_name in dataset.meta.subtasks.index: - print(f" - {subtask_name}") - -# Get subtask distribution -subtask_counts = {} -for i in range(len(dataset)): - sample = dataset[i] - subtask = sample["subtask"] - subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1 - -print("\nSubtask distribution:") -for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]): - print(f" {subtask}: {count} frames") -``` - -## Use Cases - -### 1. Hierarchical Policy Training - -Train policies that predict both actions and current subtask: - -```python -class HierarchicalPolicy(nn.Module): - def __init__(self, num_subtasks): - super().__init__() - self.action_head = nn.Linear(hidden_dim, action_dim) - self.subtask_head = nn.Linear(hidden_dim, num_subtasks) - - def forward(self, observations): - features = self.encoder(observations) - actions = self.action_head(features) - subtask_logits = self.subtask_head(features) - return actions, subtask_logits -``` - -### 2. Stage-Aware Reward Modeling (SARM) - -Build reward models that understand task progression: - -```python -# SARM predicts: -# - Stage: Which subtask is being executed (discrete) -# - Progress: How far along the subtask (continuous 0-1) - -class SARMRewardModel(nn.Module): - def forward(self, observations): - features = self.encoder(observations) - stage_logits = self.stage_classifier(features) - progress = self.progress_regressor(features) - return stage_logits, progress -``` - -### 3. Progress Visualization - -Monitor robot execution by tracking subtask progression: - -```python -def visualize_execution(model, observations): - for t, obs in enumerate(observations): - action, subtask_logits = model(obs) - predicted_subtask = subtask_names[subtask_logits.argmax()] - print(f"t={t}: Executing '{predicted_subtask}'") -``` - -## API Reference - -### LeRobotDataset Properties - -| Property | Type | Description | -| --------------------------- | ---------------------- | ------------------------------------------ | -| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices | -| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present | - -### Sample Keys - -When subtasks are available, each sample includes: - -| Key | Type | Description | -| --------------- | -------------- | ------------------------------------ | -| `subtask_index` | `torch.Tensor` | Integer index of the current subtask | -| `subtask` | `str` | Natural language subtask description | - -## Related Resources - -- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation -- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool -- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation diff --git a/docs/source/language_and_recipes.mdx b/docs/source/language_and_recipes.mdx new file mode 100644 index 000000000..135aa6301 --- /dev/null +++ b/docs/source/language_and_recipes.mdx @@ -0,0 +1,75 @@ +# Language columns and recipes + +LeRobot stores reusable language annotations directly next to frame data in `data/chunk-*/file-*.parquet`. +The two optional columns are: + +- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`. +- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls. + +Both columns share the same row shape: + +```text +role: string +content: string | null +style: string | null +timestamp: float64 +tool_calls: list[Json] | null +``` + +`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations. + +## Architecture + +The language stack has three layers: + +1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`. +2. `lerobot.datasets.language_render` resolves rows and renders messages. +3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`. + +`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior. + +## Temporal semantics + +Persistent styles are active after emission until replaced: + +- `active_at(t, style=subtask)` +- `nth_prev(style=memory, offset=1)` +- `nth_next(style=subtask, offset=1)` + +Event styles only exist on their exact timestamp: + +- `emitted_at(t, style=interjection)` +- `emitted_at(t, style=vqa, role=user)` +- `emitted_at(t, role=assistant, tool_name=say)` + +Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data. + +## Recipe anatomy + +Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. + +```yaml +messages: + - { role: user, content: "${task}", stream: high_level } + - { role: assistant, content: "${subtask}", stream: low_level, target: true } +``` + +Rendered samples use HF-style chat messages plus LeRobot sidecars: + +```python +sample["messages"] +sample["message_streams"] +sample["target_message_indices"] +``` + +The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone. + +## Blends + +Blend recipes select one weighted sub-recipe deterministically from the sample index. +The canonical `recipes/pi05_hirobot.yaml` combines memory updates, interjection responses, high-level subtask prediction, low-level execution, and VQA. + +## Graceful absence + +If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op. +If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample. diff --git a/pyproject.toml b/pyproject.toml index 790c7f2d9..0790db6fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ dependencies = [ # ── Feature-scoped extras ────────────────────────────────── dataset = [ - "datasets>=4.0.0,<5.0.0", + "datasets>=4.7.0,<5.0.0", "pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets "pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets "lerobot[av-dep]", diff --git a/src/lerobot/configs/__init__.py b/src/lerobot/configs/__init__.py index 3ddaec1af..a2cb1d72d 100644 --- a/src/lerobot/configs/__init__.py +++ b/src/lerobot/configs/__init__.py @@ -23,6 +23,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig`` from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from .policies import PreTrainedConfig +from .recipe import MessageTurn, TrainingRecipe, load_recipe from .types import ( FeatureType, NormalizationMode, @@ -41,7 +42,10 @@ __all__ = [ # Config classes "DatasetConfig", "EvalConfig", + "MessageTurn", "PeftConfig", "PreTrainedConfig", + "TrainingRecipe", "WandBConfig", + "load_recipe", ] diff --git a/src/lerobot/configs/recipe.py b/src/lerobot/configs/recipe.py new file mode 100644 index 000000000..e01a96a79 --- /dev/null +++ b/src/lerobot/configs/recipe.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +MessageRole = Literal["user", "assistant", "system", "tool"] +MessageStream = Literal["high_level", "low_level"] + +DEFAULT_BINDINGS = { + "subtask": "active_at(t, style=subtask)", + "memory": "active_at(t, style=memory)", + "plan": "active_at(t, style=plan)", + "speech": "emitted_at(t, role=assistant, tool_name=say)", + "interjection": "emitted_at(t, style=interjection)", + "vqa": "emitted_at(t, style=vqa, role=assistant)", + "vqa_query": "emitted_at(t, style=vqa, role=user)", +} + +_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") +_VALID_ROLES = {"user", "assistant", "system", "tool"} +_VALID_STREAMS = {"high_level", "low_level"} + + +@dataclass +class MessageTurn: + role: MessageRole + content: str | list[dict[str, Any]] | None = None + stream: MessageStream | None = None + target: bool = False + if_present: str | None = None + tool_calls_from: str | None = None + + def __post_init__(self) -> None: + if self.role not in _VALID_ROLES: + raise ValueError(f"Unsupported message role: {self.role!r}") + if self.stream is not None and self.stream not in _VALID_STREAMS: + raise ValueError(f"Unsupported message stream: {self.stream!r}") + if self.content is None and self.tool_calls_from is None: + raise ValueError("MessageTurn.content is required unless tool_calls_from is set.") + if self.content is not None and not isinstance(self.content, (str, list)): + raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.") + if isinstance(self.content, list): + for block in self.content: + if not isinstance(block, dict) or "type" not in block: + raise ValueError( + "Multimodal content blocks must be HF-style dictionaries with a type key." + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> MessageTurn: + return cls(**data) + + +@dataclass +class TrainingRecipe: + messages: list[MessageTurn] | None = None + bindings: dict[str, str] | None = None + blend: dict[str, TrainingRecipe] | None = None + weight: float | None = None + + def __post_init__(self) -> None: + if self.messages is not None and self.blend is not None: + raise ValueError("TrainingRecipe must set only one of messages or blend.") + if self.messages is None and self.blend is None: + raise ValueError("TrainingRecipe must set one of messages or blend.") + + if self.messages is not None: + self._validate_message_recipe() + if self.blend is not None: + self._validate_blend_recipe() + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe: + data = dict(data) + if data.get("messages") is not None: + data["messages"] = [ + turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn) + for turn in data["messages"] + ] + if data.get("blend") is not None: + data["blend"] = { + name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe) + for name, recipe in data["blend"].items() + } + return cls(**data) + + @classmethod + def from_yaml(cls, path: str | Path) -> TrainingRecipe: + import yaml # type: ignore[import-untyped] + + with open(path) as f: + data = yaml.safe_load(f) + if not isinstance(data, dict): + raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}") + return cls.from_dict(data) + + def _validate_message_recipe(self) -> None: + assert self.messages is not None + known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"} + + for turn in self.messages: + missing = self._referenced_bindings(turn) - known_bindings + if missing: + raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}") + + if not any(turn.target for turn in self.messages): + raise ValueError("Message recipes must contain at least one target turn.") + + def _validate_blend_recipe(self) -> None: + assert self.blend is not None + if not self.blend: + raise ValueError("Blend recipes must contain at least one component.") + + for name, recipe in self.blend.items(): + if recipe.blend is not None: + raise ValueError(f"Blend component {name!r} cannot itself define a blend.") + if recipe.messages is None: + raise ValueError(f"Blend component {name!r} must define messages.") + if recipe.weight is None: + raise ValueError(f"Blend component {name!r} must define weight.") + if recipe.weight <= 0: + raise ValueError(f"Blend component {name!r} must have a positive weight.") + + def _referenced_bindings(self, turn: MessageTurn) -> set[str]: + names: set[str] = set() + if turn.if_present is not None: + names.add(turn.if_present) + if turn.tool_calls_from is not None: + names.add(turn.tool_calls_from) + names.update(_placeholders_in_content(turn.content)) + return names + + +def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]: + if content is None: + return set() + if isinstance(content, str): + return set(_PLACEHOLDER_RE.findall(content)) + + names: set[str] = set() + for block in content: + for value in block.values(): + if isinstance(value, str): + names.update(_PLACEHOLDER_RE.findall(value)) + return names + + +def load_recipe(path: str | Path) -> TrainingRecipe: + return TrainingRecipe.from_yaml(path) diff --git a/src/lerobot/configs/recipes/pi05_hirobot.yaml b/src/lerobot/configs/recipes/pi05_hirobot.yaml new file mode 100644 index 000000000..3dbfb44be --- /dev/null +++ b/src/lerobot/configs/recipes/pi05_hirobot.yaml @@ -0,0 +1,47 @@ +blend: + + memory_update: + weight: 0.10 + bindings: + prior_memory: "nth_prev(style=memory, offset=1)" + current_memory: "emitted_at(t, style=memory)" + completed_subtask: "nth_prev(style=subtask, offset=1)" + messages: + - {role: user, content: "${task}", stream: high_level} + - {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory} + - {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask} + - {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory} + + user_interjection_response: + weight: 0.16 + bindings: + prior_plan: "nth_prev(style=plan, offset=1)" + current_plan: "emitted_at(t, style=plan)" + interjection: "emitted_at(t, style=interjection)" + speech: "emitted_at(t, role=assistant, tool_name=say)" + messages: + - {role: user, content: "${task}", stream: high_level} + - {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan} + - {role: user, content: "${interjection}", stream: high_level, if_present: interjection} + - {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech} + + high_level_subtask: + weight: 0.15 + bindings: + next_subtask: "nth_next(style=subtask, offset=1)" + messages: + - {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level} + - {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask} + - {role: assistant, content: "${next_subtask}", stream: high_level, target: true} + + low_level_execution: + weight: 0.35 + messages: + - {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level} + - {role: assistant, content: "${subtask}", stream: low_level, target: true} + + ask_vqa: + weight: 0.20 + messages: + - {role: user, content: "${vqa_query}", stream: high_level, if_present: vqa_query} + - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index 6c42959a5..8be3609f3 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -37,6 +37,14 @@ from .dataset_tools import ( from .factory import make_dataset, resolve_delta_timestamps from .image_writer import safe_stop_image_writer from .io_utils import load_episodes, write_stats +from .language import ( + EVENT_ONLY_STYLES, + LANGUAGE_EVENTS, + LANGUAGE_PERSISTENT, + PERSISTENT_STYLES, + STYLE_REGISTRY, + column_for_style, +) from .lerobot_dataset import LeRobotDataset from .multi_dataset import MultiLeRobotDataset from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features @@ -53,10 +61,15 @@ __all__ = [ "CODEBASE_VERSION", "DEFAULT_EPISODES_PATH", "DEFAULT_QUANTILES", + "EVENT_ONLY_STYLES", "EpisodeAwareSampler", + "LANGUAGE_EVENTS", + "LANGUAGE_PERSISTENT", "LeRobotDataset", "LeRobotDatasetMetadata", "MultiLeRobotDataset", + "PERSISTENT_STYLES", + "STYLE_REGISTRY", "StreamingLeRobotDataset", "VideoEncodingManager", "add_features", @@ -66,6 +79,7 @@ __all__ = [ "convert_image_to_video_dataset", "create_initial_features", "create_lerobot_dataset_card", + "column_for_style", "delete_episodes", "get_feature_stats", "load_episodes", diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index f489c84a7..438ac7fba 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -512,7 +512,7 @@ def compute_episode_stats( ep_stats = {} for key, data in episode_data.items(): - if features[key]["dtype"] == "string": + if features[key]["dtype"] in {"string", "language"}: continue if features[key]["dtype"] in ["image", "video"]: diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 8bf67fa39..4dd34c758 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -34,7 +34,6 @@ from .io_utils import ( load_episodes, load_info, load_stats, - load_subtasks, load_tasks, write_info, write_json, @@ -52,7 +51,7 @@ from .utils import ( ) from .video_utils import get_video_info -CODEBASE_VERSION = "v3.0" +CODEBASE_VERSION = "v3.1" class LeRobotDatasetMetadata: @@ -177,7 +176,6 @@ class LeRobotDatasetMetadata: self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) - self.subtasks = load_subtasks(self.root) self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) @@ -635,7 +633,6 @@ class LeRobotDatasetMetadata: _validate_feature_names(features) obj.tasks = None - obj.subtasks = None obj.episodes = None obj.stats = None obj.info = create_empty_dataset_info( diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index bd1298590..59aaa40e5 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -295,9 +295,4 @@ class DatasetReader: task_idx = item["task_index"].item() item["task"] = self._meta.tasks.iloc[task_idx].name - # add subtask information if available - if "subtask_index" in self._meta.features and self._meta.subtasks is not None: - subtask_idx = item["subtask_index"].item() - item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name - return item diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index b05dbf2cc..43775b3a1 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -22,6 +22,7 @@ from PIL import Image as PILImage from lerobot.utils.constants import DEFAULT_FEATURES from lerobot.utils.utils import is_valid_numpy_dtype_string +from .language import is_language_column, language_column_feature from .utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -45,7 +46,9 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: """ hf_features = {} for key, ft in features.items(): - if ft["dtype"] == "video": + if is_language_column(key): + hf_features[key] = language_column_feature() + elif ft["dtype"] == "video": continue elif ft["dtype"] == "image": hf_features[key] = datasets.Image() @@ -242,6 +245,8 @@ def validate_feature_dtype_and_shape( return validate_feature_image_or_video(name, expected_shape, value) elif expected_dtype == "string": return validate_feature_string(name, value) + elif expected_dtype == "language": + return "" else: raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index 2ee859e97..f416cf46b 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -34,7 +34,6 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di from .utils import ( DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_EPISODES_PATH, - DEFAULT_SUBTASKS_PATH, DEFAULT_TASKS_PATH, EPISODES_DIR, INFO_PATH, @@ -189,14 +188,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame: return tasks -def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: - """Load subtasks from subtasks.parquet if it exists.""" - subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH - if subtasks_path.exists(): - return pd.read_parquet(subtasks_path) - return None - - def write_episodes(episodes: Dataset, local_dir: Path) -> None: """Write episode metadata to a parquet file in the LeRobot v3.0 format. This function writes episode-level metadata to a single parquet file. @@ -268,11 +259,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to dict: The batch with items converted to torch tensors. """ for key in items_dict: + if key in {"language_persistent", "language_events"}: + continue first_item = items_dict[key][0] if isinstance(first_item, PILImage.Image): to_tensor = transforms.ToTensor() items_dict[key] = [to_tensor(img) for img in items_dict[key]] - elif first_item is None: + elif first_item is None or isinstance(first_item, dict): pass else: items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] @@ -308,7 +301,11 @@ def item_to_torch(item: dict) -> dict: dict: Dictionary with all tensor-like items converted to torch.Tensor. """ for key, val in item.items(): - if isinstance(val, (np.ndarray | list)) and key not in ["task"]: + if isinstance(val, (np.ndarray | list)) and key not in [ + "task", + "language_persistent", + "language_events", + ]: # Convert numpy arrays and lists to torch tensors item[key] = torch.tensor(val) return item diff --git a/src/lerobot/datasets/language.py b/src/lerobot/datasets/language.py new file mode 100644 index 000000000..64fb0bcf3 --- /dev/null +++ b/src/lerobot/datasets/language.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Literal + +import datasets +import pyarrow as pa + +LANGUAGE_PERSISTENT = "language_persistent" +LANGUAGE_EVENTS = "language_events" +LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS) +LANGUAGE_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls") + +CORE_STYLES = {"subtask", "plan", "memory", "interjection", "vqa"} +EXTENDED_STYLES = set() +RESERVED_STYLES = {"motion", "trace"} +STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES | RESERVED_STYLES + +PERSISTENT_STYLES = {"subtask", "plan", "memory"} +EVENT_ONLY_STYLES = {"interjection", "vqa"} + +LanguageColumn = Literal["language_persistent", "language_events"] + + +def language_row_arrow_type() -> pa.StructType: + json_type = pa.json_() if hasattr(pa, "json_") else pa.string() + return pa.struct( + [ + pa.field("role", pa.string(), nullable=False), + pa.field("content", pa.string(), nullable=True), + pa.field("style", pa.string(), nullable=True), + pa.field("timestamp", pa.float64(), nullable=False), + pa.field("tool_calls", pa.list_(json_type), nullable=True), + ] + ) + + +def language_persistent_arrow_type() -> pa.ListType: + return pa.list_(language_row_arrow_type()) + + +def language_events_arrow_type() -> pa.ListType: + return pa.list_(language_row_arrow_type()) + + +def language_row_feature() -> dict[str, object]: + json_feature = datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string") + return { + "role": datasets.Value("string"), + "content": datasets.Value("string"), + "style": datasets.Value("string"), + "timestamp": datasets.Value("float64"), + "tool_calls": datasets.List(json_feature), + } + + +def language_column_feature() -> datasets.List: + return datasets.List(language_row_feature()) + + +def language_feature_info() -> dict[str, dict]: + return { + LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None}, + LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None}, + } + + +def is_language_column(key: str) -> bool: + return key in LANGUAGE_COLUMNS + + +def column_for_style(style: str | None) -> LanguageColumn: + if style is None: + return LANGUAGE_EVENTS + if style in PERSISTENT_STYLES: + return LANGUAGE_PERSISTENT + if style in EVENT_ONLY_STYLES: + return LANGUAGE_EVENTS + if style in RESERVED_STYLES: + raise ValueError(f"Style {style!r} is registered but has no storage column yet.") + raise ValueError(f"Unknown language style: {style!r}") diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py new file mode 100644 index 000000000..98db669cc --- /dev/null +++ b/src/lerobot/datasets/language_render.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import hashlib +import re +from collections.abc import Sequence +from typing import Any + +from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe + +from .language import ( + EVENT_ONLY_STYLES, + LANGUAGE_PERSISTENT, + PERSISTENT_STYLES, + column_for_style, +) + +LanguageRow = dict[str, Any] +RenderedMessages = dict[str, list[Any]] + +_RESOLVER_RE = re.compile(r"^(?P[A-Za-z_][A-Za-z0-9_]*)\((?P.*)\)$") +_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") + + +def active_at( + t: float, + *, + persistent: Sequence[LanguageRow], + events: Sequence[LanguageRow] | None = None, + style: str | None = None, + role: str | None = None, + tool_name: str | None = None, +) -> LanguageRow | None: + _validate_persistent_resolver("active_at", style) + matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name) + matches = [row for row in matches if _timestamp(row) <= t] + return _select_latest(matches, style=style, role=role, tool_name=tool_name) + + +def emitted_at( + t: float, + *, + persistent: Sequence[LanguageRow], + events: Sequence[LanguageRow], + style: str | None = None, + role: str | None = None, + tool_name: str | None = None, +) -> LanguageRow | None: + column = column_for_style(style) + rows = persistent if column == LANGUAGE_PERSISTENT else events + matches = [ + row + for row in _matching_rows(rows, style=style, role=role, tool_name=tool_name) + if _timestamp(row) == t + ] + return _select_exact(matches, style=style, role=role, tool_name=tool_name) + + +def nth_prev( + t: float, + *, + persistent: Sequence[LanguageRow], + events: Sequence[LanguageRow] | None = None, + style: str | None = None, + offset: int = 1, + role: str | None = None, + tool_name: str | None = None, +) -> LanguageRow | None: + return _nth_relative( + t, + persistent=persistent, + style=style, + offset=-offset, + role=role, + tool_name=tool_name, + resolver_name="nth_prev", + ) + + +def nth_next( + t: float, + *, + persistent: Sequence[LanguageRow], + events: Sequence[LanguageRow] | None = None, + style: str | None = None, + offset: int = 1, + role: str | None = None, + tool_name: str | None = None, +) -> LanguageRow | None: + return _nth_relative( + t, + persistent=persistent, + style=style, + offset=offset, + role=role, + tool_name=tool_name, + resolver_name="nth_next", + ) + + +def render_sample( + *, + recipe: TrainingRecipe, + persistent: Sequence[LanguageRow] | None, + events: Sequence[LanguageRow] | None, + t: float, + sample_idx: int, + task: str | None = None, + dataset_ctx: Any | None = None, +) -> RenderedMessages | None: + persistent_rows = _normalize_rows(persistent or []) + event_rows = _normalize_rows(events or []) + selected_recipe = _select_recipe(recipe, sample_idx) + bindings = _resolve_bindings( + selected_recipe, + persistent=persistent_rows, + events=event_rows, + t=t, + task=task, + dataset_ctx=dataset_ctx, + ) + return _render_message_recipe(selected_recipe, bindings) + + +def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe: + if recipe.blend is None: + return recipe + + total_weight = sum(component.weight or 0.0 for component in recipe.blend.values()) + if total_weight <= 0: + raise ValueError("Blend weights must sum to a positive value.") + + digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest() + draw = int.from_bytes(digest, "big") / 2**64 * total_weight + cumulative = 0.0 + last_component: TrainingRecipe | None = None + for component in recipe.blend.values(): + last_component = component + cumulative += component.weight or 0.0 + if draw < cumulative: + return component + assert last_component is not None + return last_component + + +def _resolve_bindings( + recipe: TrainingRecipe, + *, + persistent: Sequence[LanguageRow], + events: Sequence[LanguageRow], + t: float, + task: str | None, + dataset_ctx: Any | None, +) -> dict[str, LanguageRow | str | None]: + bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)} + specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})} + for name, spec in specs.items(): + bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t) + return bindings + + +def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None: + if task is not None: + return task + if dataset_ctx is None: + return None + if isinstance(dataset_ctx, dict): + return dataset_ctx.get("task") + return getattr(dataset_ctx, "task", None) + + +def _resolve_spec( + spec: str, + *, + persistent: Sequence[LanguageRow], + events: Sequence[LanguageRow], + t: float, +) -> LanguageRow | None: + match = _RESOLVER_RE.match(spec.strip()) + if match is None: + raise ValueError(f"Invalid resolver expression: {spec!r}") + name = match.group("name") + kwargs = _parse_resolver_args(match.group("args")) + kwargs.pop("t_arg", None) + + resolvers = { + "active_at": active_at, + "emitted_at": emitted_at, + "nth_prev": nth_prev, + "nth_next": nth_next, + } + if name not in resolvers: + raise ValueError(f"Unknown language resolver: {name!r}") + return resolvers[name](t, persistent=persistent, events=events, **kwargs) + + +def _parse_resolver_args(args: str) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + if not args.strip(): + return kwargs + + parts = [part.strip() for part in args.split(",") if part.strip()] + for part in parts: + if part == "t": + kwargs["t_arg"] = True + continue + if "=" not in part: + raise ValueError(f"Invalid resolver argument: {part!r}") + key, value = (item.strip() for item in part.split("=", 1)) + if key == "offset": + kwargs[key] = int(value) + else: + kwargs[key] = value.strip("\"'") + return kwargs + + +def _render_message_recipe( + recipe: TrainingRecipe, + bindings: dict[str, LanguageRow | str | None], +) -> RenderedMessages | None: + assert recipe.messages is not None + messages: list[dict[str, Any]] = [] + streams: list[str | None] = [] + target_indices: list[int] = [] + + for turn in recipe.messages: + if turn.if_present is not None and bindings.get(turn.if_present) is None: + continue + + message = {"role": turn.role} + if turn.content is not None: + message["content"] = _render_content(turn.content, bindings) + + if turn.tool_calls_from is not None: + row = bindings.get(turn.tool_calls_from) + tool_calls = row.get("tool_calls") if isinstance(row, dict) else None + if tool_calls: + message["tool_calls"] = copy.deepcopy(tool_calls) + + message_idx = len(messages) + messages.append(message) + streams.append(turn.stream) + if turn.target: + target_indices.append(message_idx) + + if not target_indices: + return None + + rendered = { + "messages": messages, + "message_streams": streams, + "target_message_indices": target_indices, + } + _validate_rendered(rendered) + return rendered + + +def _render_content( + content: str | list[dict[str, Any]], + bindings: dict[str, LanguageRow | str | None], +) -> str | list[dict[str, Any]]: + if isinstance(content, str): + return _substitute(content, bindings) + + rendered_blocks = [] + for block in content: + rendered_block = copy.deepcopy(block) + for key, value in rendered_block.items(): + if isinstance(value, str): + rendered_block[key] = _substitute(value, bindings) + rendered_blocks.append(rendered_block) + return rendered_blocks + + +def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str: + def replace(match: re.Match[str]) -> str: + name = match.group(1) + if name not in bindings: + raise ValueError(f"Unknown template binding: {name!r}") + value = bindings[name] + if value is None: + return "" + if isinstance(value, dict): + content = value.get("content") + return "" if content is None else str(content) + return str(value) + + return _PLACEHOLDER_RE.sub(replace, template) + + +def _validate_rendered(rendered: RenderedMessages) -> None: + messages = rendered["messages"] + streams = rendered["message_streams"] + target_indices = rendered["target_message_indices"] + + if len(streams) != len(messages): + raise ValueError("message_streams must be aligned with messages.") + if not target_indices: + raise ValueError("Rendered samples must contain at least one target message.") + for idx in target_indices: + if idx < 0 or idx >= len(messages): + raise ValueError(f"Target message index {idx} is out of bounds.") + for idx, stream in enumerate(streams): + if stream is None: + raise ValueError(f"Rendered message {idx} has no stream.") + + +def _nth_relative( + t: float, + *, + persistent: Sequence[LanguageRow], + style: str | None, + offset: int, + role: str | None, + tool_name: str | None, + resolver_name: str, +) -> LanguageRow | None: + _validate_persistent_resolver(resolver_name, style) + if abs(offset) < 1: + raise ValueError(f"{resolver_name} offset must be non-zero.") + + rows = _sort_rows(_matching_rows(persistent, style=style, role=role, tool_name=tool_name)) + if not rows: + return None + + anchor_idx = None + for idx, row in enumerate(rows): + if _timestamp(row) <= t: + anchor_idx = idx + else: + break + + target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset + + if target_idx is None or target_idx < 0 or target_idx >= len(rows): + return None + return rows[target_idx] + + +def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None: + if style is None: + raise ValueError(f"{resolver_name} requires a persistent style.") + if style in EVENT_ONLY_STYLES: + raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.") + if style not in PERSISTENT_STYLES: + column_for_style(style) + + +def _matching_rows( + rows: Sequence[LanguageRow], + *, + style: str | None, + role: str | None, + tool_name: str | None, +) -> list[LanguageRow]: + return [ + row + for row in rows + if (style is None or row.get("style") == style) + and (role is None or row.get("role") == role) + and (tool_name is None or _row_has_tool_name(row, tool_name)) + ] + + +def _select_latest( + rows: Sequence[LanguageRow], + *, + style: str | None, + role: str | None, + tool_name: str | None, +) -> LanguageRow | None: + if not rows: + return None + rows = _sort_rows(rows) + latest_ts = _timestamp(rows[-1]) + return _select_exact( + [row for row in rows if _timestamp(row) == latest_ts], + style=style, + role=role, + tool_name=tool_name, + ) + + +def _select_exact( + rows: Sequence[LanguageRow], + *, + style: str | None, + role: str | None, + tool_name: str | None, +) -> LanguageRow | None: + if not rows: + return None + if len(rows) > 1 and role is None and tool_name is None: + raise ValueError( + f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate." + ) + return _sort_rows(rows)[0] + + +def _sort_rows(rows: Sequence[LanguageRow]) -> list[LanguageRow]: + return sorted(rows, key=lambda row: (_timestamp(row), row.get("style") or "", row.get("role") or "")) + + +def _timestamp(row: LanguageRow) -> float: + value = row["timestamp"] + return float(value.item() if hasattr(value, "item") else value) + + +def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool: + for tool_call in row.get("tool_calls") or []: + if isinstance(tool_call, str): + continue + function = tool_call.get("function") if isinstance(tool_call, dict) else None + if isinstance(function, dict) and function.get("name") == tool_name: + return True + return False + + +def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]: + normalized = [] + for row in rows: + if row is None: + continue + if hasattr(row, "as_py"): + row = row.as_py() + if not isinstance(row, dict): + raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.") + normalized.append(dict(row)) + return normalized diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index c6815e0f5..922a2d3ce 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -83,7 +83,6 @@ VIDEO_DIR = "videos" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" DEFAULT_TASKS_PATH = "meta/tasks.parquet" -DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 3688a4b8c..5cac04faa 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -93,6 +93,7 @@ from .relative_action_processor import ( to_relative_actions, ) from .rename_processor import RenameObservationsProcessorStep, rename_stats +from .render_messages_processor import RenderMessagesStep from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep __all__ = [ @@ -128,6 +129,7 @@ __all__ = [ "make_default_robot_observation_processor", "AbsoluteActionsProcessorStep", "RelativeActionsProcessorStep", + "RenderMessagesStep", "MapDeltaActionToRobotActionStep", "MapTensorToDeltaActionDictStep", "NewLineTaskProcessorStep", diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index eb7db255a..669c68a0a 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep): task_index_value = complementary_data["task_index"] if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0: complementary_data["task_index"] = task_index_value.unsqueeze(0) + + complementary_data.pop("language_persistent", None) + complementary_data.pop("language_events", None) + + if "messages" in complementary_data: + messages = complementary_data["messages"] + if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)): + complementary_data["messages"] = [messages] + + if "message_streams" in complementary_data: + streams = complementary_data["message_streams"] + if isinstance(streams, list) and (not streams or isinstance(streams[0], str)): + complementary_data["message_streams"] = [streams] + + if "target_message_indices" in complementary_data: + indices = complementary_data["target_message_indices"] + if isinstance(indices, list) and (not indices or isinstance(indices[0], int)): + complementary_data["target_message_indices"] = [indices] return complementary_data def transform_features( diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index ffdf0098c..ed4bb78d2 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -171,8 +171,33 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} + timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {} + language_persistent_key = ( + {"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {} + ) + language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {} + messages_key = {"messages": batch["messages"]} if "messages" in batch else {} + message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {} + target_message_indices_key = ( + {"target_message_indices": batch["target_message_indices"]} + if "target_message_indices" in batch + else {} + ) - return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key} + return { + **pad_keys, + **task_key, + **subtask_key, + **index_key, + **task_index_key, + **episode_index_key, + **timestamp_key, + **language_persistent_key, + **language_events_key, + **messages_key, + **message_streams_key, + **target_message_indices_key, + } def create_transition( diff --git a/src/lerobot/processor/render_messages_processor.py b/src/lerobot/processor/render_messages_processor.py new file mode 100644 index 000000000..b6b6b2340 --- /dev/null +++ b/src/lerobot/processor/render_messages_processor.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from lerobot.configs import PipelineFeatureType, PolicyFeature +from lerobot.configs.recipe import TrainingRecipe +from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT +from lerobot.datasets.language_render import render_sample +from lerobot.types import EnvTransition, TransitionKey + +from .pipeline import ProcessorStep, ProcessorStepRegistry + + +@dataclass +@ProcessorStepRegistry.register(name="render_messages_processor") +class RenderMessagesStep(ProcessorStep): + recipe: TrainingRecipe + dataset_ctx: Any | None = None + + def __call__(self, transition: EnvTransition) -> EnvTransition | None: + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} + persistent = complementary_data.get(LANGUAGE_PERSISTENT) or [] + events = complementary_data.get(LANGUAGE_EVENTS) or [] + + if not persistent and not events: + return transition + + timestamp = complementary_data.get("timestamp") + if timestamp is None: + raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.") + + sample_idx = complementary_data.get("index", 0) + rendered = render_sample( + recipe=self.recipe, + persistent=persistent, + events=events, + t=_scalar(timestamp), + sample_idx=int(_scalar(sample_idx)), + task=complementary_data.get("task"), + dataset_ctx=self.dataset_ctx, + ) + if rendered is None: + return None + + new_transition = transition.copy() + new_complementary_data = dict(complementary_data) + new_complementary_data.pop(LANGUAGE_PERSISTENT, None) + new_complementary_data.pop(LANGUAGE_EVENTS, None) + new_complementary_data.update(rendered) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data + return new_transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +def _scalar(value: Any) -> float | int: + if hasattr(value, "item"): + return value.item() + if isinstance(value, list) and len(value) == 1: + return _scalar(value[0]) + return value diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 856006507..ec923bcc1 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors +from lerobot.utils.collate import lerobot_collate_fn from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed @@ -386,6 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): sampler=sampler, pin_memory=device.type == "cuda", drop_last=False, + collate_fn=lerobot_collate_fn, prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None, persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, ) diff --git a/src/lerobot/utils/collate.py b/src/lerobot/utils/collate.py new file mode 100644 index 000000000..6915f4ed1 --- /dev/null +++ b/src/lerobot/utils/collate.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +from torch.utils.data._utils.collate import default_collate + +from lerobot.datasets.language import LANGUAGE_COLUMNS + +_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"} + + +def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None: + batch = [sample for sample in batch if sample is not None] + if not batch: + return None + + preserved = { + key: [sample[key] for sample in batch if key in sample] + for key in _PYTHON_LIST_KEYS + if any(key in sample for sample in batch) + } + tensorizable = [ + { + key: value + for key, value in sample.items() + if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS + } + for sample in batch + ] + collated = default_collate(tensorizable) + collated.update(preserved) + return collated diff --git a/tests/configs/test_recipe.py b/tests/configs/test_recipe.py new file mode 100644 index 000000000..e03f27a75 --- /dev/null +++ b/tests/configs/test_recipe.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python + +from pathlib import Path + +import pytest + +from lerobot.configs.recipe import MessageTurn, TrainingRecipe + + +def test_message_recipe_validates_unknown_binding(): + with pytest.raises(ValueError, match="unknown binding"): + TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${missing}", stream="high_level"), + MessageTurn(role="assistant", content="ok", stream="high_level", target=True), + ] + ) + + +def test_canonical_recipe_loads(): + recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml")) + + assert recipe.blend is not None + assert set(recipe.blend) == { + "memory_update", + "user_interjection_response", + "high_level_subtask", + "low_level_execution", + "ask_vqa", + } + assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96) diff --git a/tests/datasets/test_language.py b/tests/datasets/test_language.py new file mode 100644 index 000000000..08173bc0c --- /dev/null +++ b/tests/datasets/test_language.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +from lerobot.datasets import LeRobotDataset +from lerobot.datasets.io_utils import write_info +from lerobot.datasets.language import ( + EVENT_ONLY_STYLES, + LANGUAGE_EVENTS, + LANGUAGE_PERSISTENT, + PERSISTENT_STYLES, + STYLE_REGISTRY, + column_for_style, + language_events_arrow_type, + language_feature_info, + language_persistent_arrow_type, +) +from lerobot.datasets.utils import DEFAULT_DATA_PATH + + +def test_language_arrow_schema_has_expected_fields(): + row_type = language_persistent_arrow_type().value_type + + assert isinstance(row_type, pa.StructType) + assert row_type.names == ["role", "content", "style", "timestamp", "tool_calls"] + assert language_events_arrow_type().value_type == row_type + + +def test_style_registry_routes_columns(): + assert {"subtask", "plan", "memory"} == PERSISTENT_STYLES + assert {"interjection", "vqa"} == EVENT_ONLY_STYLES + assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY + + assert column_for_style("subtask") == LANGUAGE_PERSISTENT + assert column_for_style("plan") == LANGUAGE_PERSISTENT + assert column_for_style("memory") == LANGUAGE_PERSISTENT + assert column_for_style("interjection") == LANGUAGE_EVENTS + assert column_for_style("vqa") == LANGUAGE_EVENTS + assert column_for_style(None) == LANGUAGE_EVENTS + + +def test_unknown_style_rejected(): + with pytest.raises(ValueError, match="Unknown language style"): + column_for_style("surprise") + + +def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory): + root = tmp_path / "language_dataset" + dataset = empty_lerobot_dataset_factory( + root=root, + features={"state": {"dtype": "float32", "shape": (2,), "names": None}}, + use_videos=False, + ) + dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"}) + dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"}) + dataset.save_episode() + dataset.finalize() + + persistent = [ + { + "role": "assistant", + "content": "reach for the cup", + "style": "subtask", + "timestamp": 0.0, + "tool_calls": None, + } + ] + event = { + "role": "user", + "content": "what is visible?", + "style": "vqa", + "timestamp": 0.0, + "tool_calls": None, + } + data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) + df = pd.read_parquet(data_path) + df[LANGUAGE_PERSISTENT] = [persistent, persistent] + df[LANGUAGE_EVENTS] = [[event], []] + df.to_parquet(data_path) + + info = dataset.meta.info + info["features"].update(language_feature_info()) + write_info(info, root) + + reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root) + + first = reloaded[0] + second = reloaded[1] + assert first[LANGUAGE_PERSISTENT] == persistent + assert first[LANGUAGE_EVENTS] == [event] + assert second[LANGUAGE_PERSISTENT] == persistent + assert second[LANGUAGE_EVENTS] == [] diff --git a/tests/datasets/test_language_render.py b/tests/datasets/test_language_render.py new file mode 100644 index 000000000..5b4904c9a --- /dev/null +++ b/tests/datasets/test_language_render.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python + +from pathlib import Path + +import pytest + +from lerobot.configs.recipe import MessageTurn, TrainingRecipe +from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample + + +def row(role, content, style, timestamp, tool_calls=None): + return { + "role": role, + "content": content, + "style": style, + "timestamp": timestamp, + "tool_calls": tool_calls, + } + + +PERSISTENT = [ + row("assistant", "plan 0", "plan", 0.0), + row("assistant", "memory 0", "memory", 0.0), + row("assistant", "subtask 0", "subtask", 0.0), + row("assistant", "memory 1", "memory", 1.0), + row("assistant", "subtask 1", "subtask", 1.0), +] +EVENTS = [ + row("user", "what is visible?", "vqa", 1.0), + row("assistant", '{"count": 2}', "vqa", 1.0), + row("user", "skip wiping", "interjection", 2.0), + row( + "assistant", + None, + None, + 2.0, + [{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}], + ), +] + + +def test_resolver_temporal_semantics(): + assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0" + assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1" + assert emitted_at(0.5, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant") is None + assert ( + emitted_at(1.0, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant")["content"] + == '{"count": 2}' + ) + + +def test_persistent_relative_resolvers_reject_event_styles(): + with pytest.raises(ValueError, match="event-only"): + active_at(1.0, persistent=PERSISTENT, style="vqa") + with pytest.raises(ValueError, match="event-only"): + nth_prev(1.0, persistent=PERSISTENT, style="interjection") + + +def test_nth_prev_and_next(): + assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0" + assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1" + + +def test_substitution_if_present_multimodal_and_tool_calls(): + recipe = TrainingRecipe( + messages=[ + MessageTurn( + role="user", + content=[ + {"type": "image", "feature": "observation.images.top"}, + {"type": "text", "text": "${task}: ${interjection}"}, + ], + stream="high_level", + if_present="interjection", + ), + MessageTurn( + role="assistant", + content="${plan}", + stream="high_level", + target=True, + tool_calls_from="speech", + ), + ], + bindings={"plan": "active_at(t, style=plan)"}, + ) + + rendered = render_sample( + recipe=recipe, + persistent=PERSISTENT, + events=EVENTS, + t=2.0, + sample_idx=0, + task="clean kitchen", + ) + + assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping" + assert rendered["messages"][1]["content"] == "plan 0" + assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say" + assert rendered["message_streams"] == ["high_level", "high_level"] + assert rendered["target_message_indices"] == [1] + + +def test_exact_event_miss_returns_none_when_target_skips(): + recipe = TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"), + MessageTurn( + role="assistant", + content="${vqa}", + stream="high_level", + target=True, + if_present="vqa", + ), + ] + ) + + assert render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=0) is None + + +def test_deterministic_blend_sampling(): + recipe = TrainingRecipe( + blend={ + "a": TrainingRecipe( + weight=1.0, + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="a", stream="high_level", target=True), + ], + ), + "b": TrainingRecipe( + weight=1.0, + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="b", stream="high_level", target=True), + ], + ), + } + ) + + first = render_sample( + recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x" + ) + second = render_sample( + recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x" + ) + assert first == second + + +def test_canonical_recipe_can_render_low_level_branch(): + recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml")) + low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]}) + + rendered = render_sample( + recipe=low_level, + persistent=PERSISTENT, + events=[], + t=0.5, + sample_idx=0, + task="clean kitchen", + ) + + assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"} + assert rendered["message_streams"][-1] == "low_level" + assert rendered["target_message_indices"] == [1] diff --git a/tests/datasets/test_subtask_dataset.py b/tests/datasets/test_subtask_dataset.py deleted file mode 100644 index bb77b77d1..000000000 --- a/tests/datasets/test_subtask_dataset.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Tests for subtask functionality in LeRobotDataset. - -These tests verify that: -- Subtask information is correctly loaded from datasets that have subtask data -- The __getitem__ method correctly adds subtask strings to returned items -- Subtask handling gracefully handles missing data -""" - -import pytest - -pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])") - -import pandas as pd # noqa: E402 -import torch - -from lerobot.datasets.lerobot_dataset import LeRobotDataset - - -class TestSubtaskDataset: - """Tests for subtask handling in LeRobotDataset.""" - - @pytest.fixture - def subtask_dataset(self): - """Load the test subtask dataset from the hub.""" - # Use lerobot/pusht-subtask dataset with episode 1 - return LeRobotDataset( - repo_id="lerobot/pusht-subtask", - episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ) - - def test_subtask_dataset_loads(self, subtask_dataset): - """Test that the subtask dataset loads successfully.""" - assert subtask_dataset is not None - assert len(subtask_dataset) > 0 - - def test_subtask_metadata_loaded(self, subtask_dataset): - """Test that subtask metadata is loaded when present in dataset.""" - # The dataset should have subtasks metadata loaded - assert subtask_dataset.meta.subtasks is not None - assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame) - - def test_subtask_index_in_features(self, subtask_dataset): - """Test that subtask_index is a feature when dataset has subtasks.""" - assert "subtask_index" in subtask_dataset.features - - def test_getitem_returns_subtask_string(self, subtask_dataset): - """Test that __getitem__ correctly adds subtask string to returned item.""" - item = subtask_dataset[0] - - # Subtask should be present in the returned item - assert "subtask" in item - assert isinstance(item["subtask"], str) - assert len(item["subtask"]) > 0 # Should not be empty - - def test_getitem_has_subtask_index(self, subtask_dataset): - """Test that __getitem__ includes subtask_index.""" - item = subtask_dataset[0] - - assert "subtask_index" in item - assert isinstance(item["subtask_index"], torch.Tensor) - - def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset): - """Test that subtask_index correctly maps to a subtask in metadata.""" - item = subtask_dataset[0] - - subtask_idx = item["subtask_index"].item() - subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name - - assert item["subtask"] == subtask_from_metadata - - def test_all_items_have_subtask(self, subtask_dataset): - """Test that all items in the dataset have subtask information.""" - for i in range(min(len(subtask_dataset), 5)): # Check first 5 items - item = subtask_dataset[i] - assert "subtask" in item - assert isinstance(item["subtask"], str) - - def test_task_and_subtask_coexist(self, subtask_dataset): - """Test that both task and subtask are present in returned items.""" - item = subtask_dataset[0] - - # Both task and subtask should be present - assert "task" in item - assert "subtask" in item - assert isinstance(item["task"], str) - assert isinstance(item["subtask"], str) - - -class TestSubtaskDatasetMissing: - """Tests for graceful handling when subtask data is missing.""" - - @pytest.fixture - def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory): - """Create a dataset without subtask information.""" - features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features) - - # Add some frames and save - for _ in range(5): - dataset.add_frame({"state": torch.randn(2), "task": "Test task"}) - dataset.save_episode() - dataset.finalize() - - # Reload the dataset - return LeRobotDataset(dataset.repo_id, root=dataset.root) - - def test_no_subtask_in_features(self, dataset_without_subtasks): - """Test that subtask_index is not in features when not provided.""" - assert "subtask_index" not in dataset_without_subtasks.features - - def test_getitem_without_subtask(self, dataset_without_subtasks): - """Test that __getitem__ works when subtask is not present.""" - item = dataset_without_subtasks[0] - - # Item should still be retrievable - assert item is not None - assert "state" in item - assert "task" in item - - # Subtask should NOT be present - assert "subtask" not in item - - def test_subtasks_metadata_is_none(self, dataset_without_subtasks): - """Test that subtasks metadata is None when not present.""" - assert dataset_without_subtasks.meta.subtasks is None - - -class TestSubtaskEdgeCases: - """Edge case tests for subtask handling.""" - - def test_subtask_with_multiple_episodes(self): - """Test subtask handling with multiple episodes if available.""" - try: - dataset = LeRobotDataset( - repo_id="lerobot/pusht-subtask", - episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ) - except Exception: - pytest.skip("Could not load test-subtask dataset") - - # Check first and last items have valid subtasks - first_item = dataset[0] - last_item = dataset[len(dataset) - 1] - - assert "subtask" in first_item - assert "subtask" in last_item - assert isinstance(first_item["subtask"], str) - assert isinstance(last_item["subtask"], str) - - def test_subtask_index_consistency(self): - """Test that same subtask_index returns same subtask string.""" - try: - dataset = LeRobotDataset( - repo_id="lerobot/pusht-subtask", - episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - ) - except Exception: - pytest.skip("Could not load test-subtask dataset") - - if len(dataset) < 2: - pytest.skip("Dataset too small for this test") - - # Collect subtask_index to subtask mappings - subtask_map = {} - for i in range(min(len(dataset), 10)): - item = dataset[i] - idx = item["subtask_index"].item() - subtask = item["subtask"] - - if idx in subtask_map: - # Same index should always return same subtask - assert subtask_map[idx] == subtask, ( - f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'" - ) - else: - subtask_map[idx] = subtask diff --git a/tests/processor/test_render_messages_processor.py b/tests/processor/test_render_messages_processor.py new file mode 100644 index 000000000..c218b9152 --- /dev/null +++ b/tests/processor/test_render_messages_processor.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python + +import torch + +from lerobot.configs.recipe import MessageTurn, TrainingRecipe +from lerobot.processor.converters import create_transition +from lerobot.processor.render_messages_processor import RenderMessagesStep +from lerobot.types import TransitionKey + + +def test_render_messages_step_noops_without_language_columns(): + recipe = TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True), + ] + ) + transition = create_transition(complementary_data={"task": "do it"}) + + assert RenderMessagesStep(recipe)(transition) == transition + + +def test_render_messages_step_renders_and_drops_raw_language(): + recipe = TrainingRecipe( + messages=[ + MessageTurn(role="user", content="${task}", stream="high_level"), + MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True), + ] + ) + transition = create_transition( + complementary_data={ + "task": "do it", + "timestamp": torch.tensor(0.0), + "index": torch.tensor(7), + "language_persistent": [ + { + "role": "assistant", + "content": "reach carefully", + "style": "subtask", + "timestamp": 0.0, + "tool_calls": None, + } + ], + "language_events": [], + } + ) + + out = RenderMessagesStep(recipe)(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert "language_persistent" not in data + assert "language_events" not in data + assert data["messages"][-1]["content"] == "reach carefully" + assert data["message_streams"] == ["high_level", "low_level"] + assert data["target_message_indices"] == [1] diff --git a/tests/utils/test_collate.py b/tests/utils/test_collate.py new file mode 100644 index 000000000..d53858079 --- /dev/null +++ b/tests/utils/test_collate.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +import torch + +from lerobot.utils.collate import lerobot_collate_fn + + +def test_lerobot_collate_preserves_messages_and_drops_raw_language(): + batch = [ + { + "index": torch.tensor(0), + "messages": [{"role": "assistant", "content": "a"}], + "message_streams": ["low_level"], + "target_message_indices": [0], + "language_persistent": [{"content": "raw"}], + "language_events": [], + }, + { + "index": torch.tensor(1), + "messages": [{"role": "assistant", "content": "b"}], + "message_streams": ["low_level"], + "target_message_indices": [0], + "language_persistent": [{"content": "raw"}], + "language_events": [], + }, + ] + + out = lerobot_collate_fn(batch) + + assert out["index"].tolist() == [0, 1] + assert out["messages"][0][0]["content"] == "a" + assert out["messages"][1][0]["content"] == "b" + assert out["message_streams"] == [["low_level"], ["low_level"]] + assert out["target_message_indices"] == [[0], [0]] + assert "language_persistent" not in out + assert "language_events" not in out diff --git a/uv.lock b/uv.lock index 6e141f11d..e7dc8882d 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.14' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'", @@ -951,7 +951,7 @@ name = "cuda-bindings" version = "12.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, + { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, @@ -1038,7 +1038,7 @@ name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l') or sys_platform != 'linux'" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x') or (platform_machine != 's390x' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -2993,7 +2993,7 @@ requires-dist = [ { name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" }, { name = "cmake", specifier = ">=3.29.0.1,<4.2.0" }, { name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" }, - { name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" }, + { name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" }, { name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" }, { name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" }, { name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" }, @@ -4039,7 +4039,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -4050,7 +4050,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -4077,9 +4077,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -4090,7 +4090,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -4933,10 +4933,10 @@ name = "pyobjc-framework-applicationservices" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, - { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, - { name = "pyobjc-framework-coretext", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, - { name = "pyobjc-framework-quartz", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-cocoa", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-coretext", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-quartz", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/be/6a/d4e613c8e926a5744fc47a9e9fea08384a510dc4f27d844f7ad7a2d793bd/pyobjc_framework_applicationservices-12.1.tar.gz", hash = "sha256:c06abb74f119bc27aeb41bf1aef8102c0ae1288aec1ac8665ea186a067a8945b", size = 103247, upload-time = "2025-11-14T10:08:52.18Z" } wheels = [ @@ -4952,7 +4952,7 @@ name = "pyobjc-framework-cocoa" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" } wheels = [ @@ -4968,9 +4968,9 @@ name = "pyobjc-framework-coretext" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, - { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, - { name = "pyobjc-framework-quartz", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-cocoa", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-quartz", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/29/da/682c9c92a39f713bd3c56e7375fa8f1b10ad558ecb075258ab6f1cdd4a6d/pyobjc_framework_coretext-12.1.tar.gz", hash = "sha256:e0adb717738fae395dc645c9e8a10bb5f6a4277e73cba8fa2a57f3b518e71da5", size = 90124, upload-time = "2025-11-14T10:14:38.596Z" } wheels = [ @@ -4986,8 +4986,8 @@ name = "pyobjc-framework-quartz" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, - { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "pyobjc-framework-cocoa", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" } wheels = [