Compare commits

...

4 Commits

Author SHA1 Message Date
Pepijn 0b06790da0 feat(language): add motion (persistent) and trace (event-only) styles
Promote the previously-reserved motion/trace styles to first-class core
styles. motion routes to language_persistent (it tracks robot state over
time); trace routes to language_events (single-moment annotations).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:21:49 +02:00
Pepijn b43dc39ba4 Add docstrings to all new helpers; revert uv.lock
Covers private helpers in recipe.py, language.py, language_render.py,
and render_messages_processor.py. Also reverts uv.lock to main (it was
re-generated by `uv run` during local checks).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:15:03 +02:00
Pepijn 2b71221194 Address review: split persistent/event schemas, drop event timestamps
- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals
- dataset_metadata.py: keep CODEBASE_VERSION at v3.0
- language.py: remove RESERVED_STYLES; split arrow/feature schemas into
  persistent (with timestamp) and event (without timestamp); add docstrings
- language_render.py: events use frame-row timestamp implicitly; no
  per-event timestamp filtering or sorting
- converters.py: drop unused subtask_key passthrough
- add docstrings to new public APIs (recipe, render_messages_processor, collate)
- update tests for split schemas; revert uv.lock

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 13:38:23 +02:00
Pepijn 8833d735a1 Add extensive language support 2026-04-27 10:56:32 +02:00
28 changed files with 1611 additions and 497 deletions
+2 -2
View File
@@ -31,8 +31,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: language_and_recipes
title: Language Columns and Recipes
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
-277
View File
@@ -1,277 +0,0 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor import TokenizerProcessorStep
# Create a tokenizer processor step
tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
+75
View File
@@ -0,0 +1,75 @@
# Language columns and recipes
LeRobot stores reusable language annotations directly next to frame data in `data/chunk-*/file-*.parquet`.
The two optional columns are:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape:
```text
role: string
content: string | null
style: string | null
timestamp: float64
tool_calls: list[Json] | null
```
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
## Architecture
The language stack has three layers:
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
2. `lerobot.datasets.language_render` resolves rows and renders messages.
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
## Temporal semantics
Persistent styles are active after emission until replaced:
- `active_at(t, style=subtask)`
- `nth_prev(style=memory, offset=1)`
- `nth_next(style=subtask, offset=1)`
Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user)`
- `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
## Recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`.
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
Rendered samples use HF-style chat messages plus LeRobot sidecars:
```python
sample["messages"]
sample["message_streams"]
sample["target_message_indices"]
```
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone.
## Blends
Blend recipes select one weighted sub-recipe deterministically from the sample index.
The canonical `recipes/pi05_hirobot.yaml` combines memory updates, interjection responses, high-level subtask prediction, low-level execution, and VQA.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
+1 -1
View File
@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.0.0,<5.0.0",
"datasets>=4.7.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
+4
View File
@@ -23,6 +23,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .recipe import MessageTurn, TrainingRecipe, load_recipe
from .types import (
FeatureType,
NormalizationMode,
@@ -41,7 +42,10 @@ __all__ = [
# Config classes
"DatasetConfig",
"EvalConfig",
"MessageTurn",
"PeftConfig",
"PreTrainedConfig",
"TrainingRecipe",
"WandBConfig",
"load_recipe",
]
+193
View File
@@ -0,0 +1,193 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
DEFAULT_BINDINGS = {
"subtask": "active_at(t, style=subtask)",
"memory": "active_at(t, style=memory)",
"plan": "active_at(t, style=plan)",
"speech": "emitted_at(t, role=assistant, tool_name=say)",
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
}
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
target: bool = False
if_present: str | None = None
tool_calls_from: str | None = None
def __post_init__(self) -> None:
"""Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}")
if self.stream is not None and self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
if self.content is not None and not isinstance(self.content, (str, list)):
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
if isinstance(self.content, list):
for block in self.content:
if not isinstance(block, dict) or "type" not in block:
raise ValueError(
"Multimodal content blocks must be HF-style dictionaries with a type key."
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
"""Construct a :class:`MessageTurn` from a plain dictionary."""
return cls(**data)
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
weight: float | None = None
def __post_init__(self) -> None:
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
if self.messages is not None and self.blend is not None:
raise ValueError("TrainingRecipe must set only one of messages or blend.")
if self.messages is None and self.blend is None:
raise ValueError("TrainingRecipe must set one of messages or blend.")
if self.messages is not None:
self._validate_message_recipe()
if self.blend is not None:
self._validate_blend_recipe()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
data = dict(data)
if data.get("messages") is not None:
data["messages"] = [
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
for turn in data["messages"]
]
if data.get("blend") is not None:
data["blend"] = {
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
for name, recipe in data["blend"].items()
}
return cls(**data)
@classmethod
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
import yaml # type: ignore[import-untyped]
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
for turn in self.messages:
missing = self._referenced_bindings(turn) - known_bindings
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
assert self.blend is not None
if not self.blend:
raise ValueError("Blend recipes must contain at least one component.")
for name, recipe in self.blend.items():
if recipe.blend is not None:
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
if recipe.messages is None:
raise ValueError(f"Blend component {name!r} must define messages.")
if recipe.weight is None:
raise ValueError(f"Blend component {name!r} must define weight.")
if recipe.weight <= 0:
raise ValueError(f"Blend component {name!r} must have a positive weight.")
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
"""Return the binding names that ``turn`` references via placeholders or attributes."""
names: set[str] = set()
if turn.if_present is not None:
names.add(turn.if_present)
if turn.tool_calls_from is not None:
names.add(turn.tool_calls_from)
names.update(_placeholders_in_content(turn.content))
return names
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
if content is None:
return set()
if isinstance(content, str):
return set(_PLACEHOLDER_RE.findall(content))
names: set[str] = set()
for block in content:
for value in block.values():
if isinstance(value, str):
names.update(_PLACEHOLDER_RE.findall(value))
return names
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)
@@ -0,0 +1,47 @@
blend:
memory_update:
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "emitted_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.16
bindings:
prior_plan: "nth_prev(style=plan, offset=1)"
current_plan: "emitted_at(t, style=plan)"
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
high_level_subtask:
weight: 0.15
bindings:
next_subtask: "nth_next(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
low_level_execution:
weight: 0.35
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
ask_vqa:
weight: 0.20
messages:
- {role: user, content: "${vqa_query}", stream: high_level, if_present: vqa_query}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+14
View File
@@ -37,6 +37,14 @@ from .dataset_tools import (
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
)
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
@@ -53,10 +61,15 @@ __all__ = [
"CODEBASE_VERSION",
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"PERSISTENT_STYLES",
"STYLE_REGISTRY",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"add_features",
@@ -66,6 +79,7 @@ __all__ = [
"convert_image_to_video_dataset",
"create_initial_features",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"get_feature_stats",
"load_episodes",
+1 -1
View File
@@ -512,7 +512,7 @@ def compute_episode_stats(
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
if features[key]["dtype"] in {"string", "language"}:
continue
if features[key]["dtype"] in ["image", "video"]:
-3
View File
@@ -34,7 +34,6 @@ from .io_utils import (
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_json,
@@ -177,7 +176,6 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -635,7 +633,6 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
-5
View File
@@ -295,9 +295,4 @@ class DatasetReader:
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item
+15 -1
View File
@@ -22,6 +22,12 @@ from PIL import Image as PILImage
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -45,7 +51,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
if is_language_column(key):
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -242,6 +254,8 @@ def validate_feature_dtype_and_shape(
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return ""
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
+8 -11
View File
@@ -34,7 +34,6 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di
from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
@@ -189,14 +188,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -268,11 +259,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
if key in {"language_persistent", "language_events"}:
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
elif first_item is None or isinstance(first_item, dict):
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
@@ -308,7 +301,11 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in [
"task",
"language_persistent",
"language_events",
]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
+150
View File
@@ -0,0 +1,150 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Literal
import datasets
import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls")
CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"}
EXTENDED_STYLES = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"]
def _json_arrow_type() -> pa.DataType:
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float64"),
"tool_calls": datasets.List(_json_feature()),
}
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
}
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
raise ValueError(f"Unknown language style: {style!r}")
+511
View File
@@ -0,0 +1,511 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import hashlib
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
)
LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
def active_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``
selector. ``events`` is accepted for resolver-signature uniformity but is
not consulted: only persistent styles are valid here.
"""
_validate_persistent_resolver("active_at", style)
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
matches = [row for row in matches if _timestamp(row) <= t]
return _select_latest(matches, style=style, role=role, tool_name=tool_name)
def emitted_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
so all matching event rows are considered emitted at ``t``.
"""
column = column_for_style(style)
if column == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
if _timestamp(row) == t
]
return _select_one(
matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key
)
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name)
return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key)
def nth_prev(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions before the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
style=style,
offset=-offset,
role=role,
tool_name=tool_name,
resolver_name="nth_prev",
)
def nth_next(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions after the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
style=style,
offset=offset,
role=role,
tool_name=tool_name,
resolver_name="nth_next",
)
def render_sample(
*,
recipe: TrainingRecipe,
persistent: Sequence[LanguageRow] | None,
events: Sequence[LanguageRow] | None,
t: float,
sample_idx: int,
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
task=task,
dataset_ctx=dataset_ctx,
)
return _render_message_recipe(selected_recipe, bindings)
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
return recipe
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
if total_weight <= 0:
raise ValueError("Blend weights must sum to a positive value.")
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
cumulative = 0.0
last_component: TrainingRecipe | None = None
for component in recipe.blend.values():
last_component = component
cumulative += component.weight or 0.0
if draw < cumulative:
return component
assert last_component is not None
return last_component
def _resolve_bindings(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
task: str | None,
dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings
def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None:
"""Return ``task`` if set, otherwise look it up on ``dataset_ctx``."""
if task is not None:
return task
if dataset_ctx is None:
return None
if isinstance(dataset_ctx, dict):
return dataset_ctx.get("task")
return getattr(dataset_ctx, "task", None)
def _resolve_spec(
spec: str,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
) -> LanguageRow | None:
"""Parse a single binding's resolver expression and dispatch to its function."""
match = _RESOLVER_RE.match(spec.strip())
if match is None:
raise ValueError(f"Invalid resolver expression: {spec!r}")
name = match.group("name")
kwargs = _parse_resolver_args(match.group("args"))
kwargs.pop("t_arg", None)
resolvers = {
"active_at": active_at,
"emitted_at": emitted_at,
"nth_prev": nth_prev,
"nth_next": nth_next,
}
if name not in resolvers:
raise ValueError(f"Unknown language resolver: {name!r}")
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
def _parse_resolver_args(args: str) -> dict[str, Any]:
"""Parse a comma-separated resolver argument list into a kwargs dict."""
kwargs: dict[str, Any] = {}
if not args.strip():
return kwargs
parts = [part.strip() for part in args.split(",") if part.strip()]
for part in parts:
if part == "t":
kwargs["t_arg"] = True
continue
if "=" not in part:
raise ValueError(f"Invalid resolver argument: {part!r}")
key, value = (item.strip() for item in part.split("=", 1))
if key == "offset":
kwargs[key] = int(value)
else:
kwargs[key] = value.strip("\"'")
return kwargs
def _render_message_recipe(
recipe: TrainingRecipe,
bindings: dict[str, LanguageRow | str | None],
) -> RenderedMessages | None:
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
assert recipe.messages is not None
messages: list[dict[str, Any]] = []
streams: list[str | None] = []
target_indices: list[int] = []
for turn in recipe.messages:
if turn.if_present is not None and bindings.get(turn.if_present) is None:
continue
message = {"role": turn.role}
if turn.content is not None:
message["content"] = _render_content(turn.content, bindings)
if turn.tool_calls_from is not None:
row = bindings.get(turn.tool_calls_from)
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
if tool_calls:
message["tool_calls"] = copy.deepcopy(tool_calls)
message_idx = len(messages)
messages.append(message)
streams.append(turn.stream)
if turn.target:
target_indices.append(message_idx)
if not target_indices:
return None
rendered = {
"messages": messages,
"message_streams": streams,
"target_message_indices": target_indices,
}
_validate_rendered(rendered)
return rendered
def _render_content(
content: str | list[dict[str, Any]],
bindings: dict[str, LanguageRow | str | None],
) -> str | list[dict[str, Any]]:
"""Substitute bindings into a string or each string field of multimodal blocks."""
if isinstance(content, str):
return _substitute(content, bindings)
rendered_blocks = []
for block in content:
rendered_block = copy.deepcopy(block)
for key, value in rendered_block.items():
if isinstance(value, str):
rendered_block[key] = _substitute(value, bindings)
rendered_blocks.append(rendered_block)
return rendered_blocks
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
def replace(match: re.Match[str]) -> str:
"""Resolve a single ``${name}`` match to its bound string value."""
name = match.group(1)
if name not in bindings:
raise ValueError(f"Unknown template binding: {name!r}")
value = bindings[name]
if value is None:
return ""
if isinstance(value, dict):
content = value.get("content")
return "" if content is None else str(content)
return str(value)
return _PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None:
"""Sanity-check the rendered output for stream/target alignment."""
messages = rendered["messages"]
streams = rendered["message_streams"]
target_indices = rendered["target_message_indices"]
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
for idx, stream in enumerate(streams):
if stream is None:
raise ValueError(f"Rendered message {idx} has no stream.")
def _nth_relative(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None,
offset: int,
role: str | None,
tool_name: str | None,
resolver_name: str,
) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
_validate_persistent_resolver(resolver_name, style)
if abs(offset) < 1:
raise ValueError(f"{resolver_name} offset must be non-zero.")
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name),
key=_persistent_sort_key,
)
if not rows:
return None
anchor_idx = None
for idx, row in enumerate(rows):
if _timestamp(row) <= t:
anchor_idx = idx
else:
break
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
return None
return rows[target_idx]
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
if style is None:
raise ValueError(f"{resolver_name} requires a persistent style.")
if style in EVENT_ONLY_STYLES:
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
if style not in PERSISTENT_STYLES:
column_for_style(style)
def _matching_rows(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name`` selectors."""
return [
row
for row in rows
if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name))
]
def _select_latest(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
) -> LanguageRow | None:
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
if not rows:
return None
rows = sorted(rows, key=_persistent_sort_key)
latest_ts = _timestamp(rows[-1])
return _select_one(
[row for row in rows if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
sort_key=_persistent_sort_key,
)
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
sort_key: Any,
) -> LanguageRow | None:
"""Return the single matching row, or raise if the selectors are ambiguous."""
if not rows:
return None
if len(rows) > 1 and role is None and tool_name is None:
raise ValueError(
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate."
)
return sorted(rows, key=sort_key)[0]
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
"""Sort key for persistent rows: ``(timestamp, style, role)``."""
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
"""Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame)."""
return (row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
value = row["timestamp"]
return float(value.item() if hasattr(value, "item") else value)
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
for tool_call in row.get("tool_calls") or []:
if isinstance(tool_call, str):
continue
function = tool_call.get("function") if isinstance(tool_call, dict) else None
if isinstance(function, dict) and function.get("name") == tool_name:
return True
return False
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
normalized = []
for row in rows:
if row is None:
continue
if hasattr(row, "as_py"):
row = row.as_py()
if not isinstance(row, dict):
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
normalized.append(dict(row))
return normalized
-1
View File
@@ -83,7 +83,6 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
+2
View File
@@ -93,6 +93,7 @@ from .relative_action_processor import (
to_relative_actions,
)
from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .render_messages_processor import RenderMessagesStep
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
__all__ = [
@@ -128,6 +129,7 @@ __all__ = [
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"RelativeActionsProcessorStep",
"RenderMessagesStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NewLineTaskProcessorStep",
+18
View File
@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
complementary_data["messages"] = [messages]
if "message_streams" in complementary_data:
streams = complementary_data["message_streams"]
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
complementary_data["message_streams"] = [streams]
if "target_message_indices" in complementary_data:
indices = complementary_data["target_message_indices"]
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
complementary_data["target_message_indices"] = [indices]
return complementary_data
def transform_features(
+25 -2
View File
@@ -167,12 +167,35 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
language_persistent_key = (
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
)
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
target_message_indices_key = (
{"target_message_indices": batch["target_message_indices"]}
if "target_message_indices" in batch
else {}
)
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
return {
**pad_keys,
**task_key,
**index_key,
**task_index_key,
**episode_index_key,
**timestamp_key,
**language_persistent_key,
**language_events_key,
**messages_key,
**message_streams_key,
**target_message_indices_key,
}
def create_transition(
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=_scalar(timestamp),
sample_idx=int(_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list) and len(value) == 1:
return _scalar(value[0])
return value
+2
View File
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.utils.collate import lerobot_collate_fn
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -386,6 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=lerobot_collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
+54
View File
@@ -0,0 +1,54 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
"""Collate function that preserves Python-list and language fields as lists.
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
rendered-message and language fields as plain Python lists, and delegates
every other key to PyTorch's ``default_collate``.
"""
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
preserved = {
key: [sample[key] for sample in batch if key in sample]
for key in _PYTHON_LIST_KEYS
if any(key in sample for sample in batch)
}
tensorizable = [
{
key: value
for key, value in sample.items()
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
}
for sample in batch
]
collated = default_collate(tensorizable)
collated.update(preserved)
return collated
+31
View File
@@ -0,0 +1,31 @@
#!/usr/bin/env python
from pathlib import Path
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
def test_message_recipe_validates_unknown_binding():
with pytest.raises(ValueError, match="unknown binding"):
TrainingRecipe(
messages=[
MessageTurn(role="user", content="${missing}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
def test_canonical_recipe_loads():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
assert recipe.blend is not None
assert set(recipe.blend) == {
"memory_update",
"user_interjection_response",
"high_level_subtask",
"low_level_execution",
"ask_vqa",
}
assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96)
+99
View File
@@ -0,0 +1,99 @@
#!/usr/bin/env python
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.io_utils import write_info
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
language_events_arrow_type,
language_feature_info,
language_persistent_arrow_type,
)
from lerobot.datasets.utils import DEFAULT_DATA_PATH
def test_language_arrow_schema_has_expected_fields():
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "tool_calls"]
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS
assert column_for_style(None) == LANGUAGE_EVENTS
def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise")
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
root = tmp_path / "language_dataset"
dataset = empty_lerobot_dataset_factory(
root=root,
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
use_videos=False,
)
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
dataset.save_episode()
dataset.finalize()
persistent = [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
]
event = {
"role": "user",
"content": "what is visible?",
"style": "vqa",
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
df = pd.read_parquet(data_path)
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
df[LANGUAGE_EVENTS] = [[event], []]
df.to_parquet(data_path)
info = dataset.meta.info
info["features"].update(language_feature_info())
write_info(info, root)
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
first = reloaded[0]
second = reloaded[1]
assert first[LANGUAGE_PERSISTENT] == persistent
assert first[LANGUAGE_EVENTS] == [event]
assert second[LANGUAGE_PERSISTENT] == persistent
assert second[LANGUAGE_EVENTS] == []
+176
View File
@@ -0,0 +1,176 @@
#!/usr/bin/env python
from pathlib import Path
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
def persistent_row(role, content, style, timestamp, tool_calls=None):
return {
"role": role,
"content": content,
"style": style,
"timestamp": timestamp,
"tool_calls": tool_calls,
}
def event_row(role, content, style, tool_calls=None):
return {
"role": role,
"content": content,
"style": style,
"tool_calls": tool_calls,
}
PERSISTENT = [
persistent_row("assistant", "plan 0", "plan", 0.0),
persistent_row("assistant", "memory 0", "memory", 0.0),
persistent_row("assistant", "subtask 0", "subtask", 0.0),
persistent_row("assistant", "memory 1", "memory", 1.0),
persistent_row("assistant", "subtask 1", "subtask", 1.0),
]
EVENTS_AT_1 = [
event_row("user", "what is visible?", "vqa"),
event_row("assistant", '{"count": 2}', "vqa"),
]
EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"),
event_row(
"assistant",
None,
None,
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
),
]
def test_resolver_temporal_semantics():
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
assert (
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
== '{"count": 2}'
)
def test_persistent_relative_resolvers_reject_event_styles():
with pytest.raises(ValueError, match="event-only"):
active_at(1.0, persistent=PERSISTENT, style="vqa")
with pytest.raises(ValueError, match="event-only"):
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
def test_nth_prev_and_next():
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
def test_substitution_if_present_multimodal_and_tool_calls():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "${task}: ${interjection}"},
],
stream="high_level",
if_present="interjection",
),
MessageTurn(
role="assistant",
content="${plan}",
stream="high_level",
target=True,
tool_calls_from="speech",
),
],
bindings={"plan": "active_at(t, style=plan)"},
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=EVENTS_AT_2,
t=2.0,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
assert rendered["messages"][1]["content"] == "plan 0"
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
assert rendered["message_streams"] == ["high_level", "high_level"]
assert rendered["target_message_indices"] == [1]
def test_exact_event_miss_returns_none_when_target_skips():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
]
)
assert (
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
)
def test_deterministic_blend_sampling():
recipe = TrainingRecipe(
blend={
"a": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
],
),
"b": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
],
),
}
)
first = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
second = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
assert first == second
def test_canonical_recipe_can_render_low_level_branch():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
rendered = render_sample(
recipe=low_level,
persistent=PERSISTENT,
events=[],
t=0.5,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
assert rendered["message_streams"][-1] == "low_level"
assert rendered["target_message_indices"] == [1]
-193
View File
@@ -1,193 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask
@@ -0,0 +1,55 @@
#!/usr/bin/env python
import torch
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
from lerobot.processor.converters import create_transition
from lerobot.processor.render_messages_processor import RenderMessagesStep
from lerobot.types import TransitionKey
def test_render_messages_step_noops_without_language_columns():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(complementary_data={"task": "do it"})
assert RenderMessagesStep(recipe)(transition) == transition
def test_render_messages_step_renders_and_drops_raw_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(
complementary_data={
"task": "do it",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [
{
"role": "assistant",
"content": "reach carefully",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
"language_events": [],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "language_persistent" not in data
assert "language_events" not in data
assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1]
+36
View File
@@ -0,0 +1,36 @@
#!/usr/bin/env python
import torch
from lerobot.utils.collate import lerobot_collate_fn
def test_lerobot_collate_preserves_messages_and_drops_raw_language():
batch = [
{
"index": torch.tensor(0),
"messages": [{"role": "assistant", "content": "a"}],
"message_streams": ["low_level"],
"target_message_indices": [0],
"language_persistent": [{"content": "raw"}],
"language_events": [],
},
{
"index": torch.tensor(1),
"messages": [{"role": "assistant", "content": "b"}],
"message_streams": ["low_level"],
"target_message_indices": [0],
"language_persistent": [{"content": "raw"}],
"language_events": [],
},
]
out = lerobot_collate_fn(batch)
assert out["index"].tolist() == [0, 1]
assert out["messages"][0][0]["content"] == "a"
assert out["messages"][1][0]["content"] == "b"
assert out["message_streams"] == [["low_level"], ["low_level"]]
assert out["target_message_indices"] == [[0], [0]]
assert "language_persistent" not in out
assert "language_events" not in out