Compare commits

...

8 Commits

Author SHA1 Message Date
Pepijn e3e9374e2c feat(language): tool catalog in meta/info.json + LeRobotDatasetMetadata.tools
Stores OpenAI-style function schemas at ``meta/info.json["tools"]`` so
datasets can declare which tools are available (today: just ``say``;
tomorrow: per-dataset extensions). The ``DEFAULT_TOOLS`` constant
fills in for unannotated datasets so chat-template consumers don't
have to special-case anything.

Three pieces:

- ``language.py``: ``SAY_TOOL_SCHEMA`` and ``DEFAULT_TOOLS``
  constants. Single source of truth — PR 2's writer and PR 3's
  runtime tool registry will both import from here instead of
  duplicating the dict.
- ``dataset_metadata.py``: ``LeRobotDatasetMetadata.tools`` property
  reads ``info.json["tools"]`` and falls back to ``DEFAULT_TOOLS``.
  Returns deep-copied dicts so callers can mutate the result safely.
- ``docs/source/tools.mdx``: spec page covering the catalog, per-row
  invocations, and the three-step "how to add a new tool" workflow
  (declare schema, implement, register). Linked from the docs
  toctree under the Datasets section.

This lays the groundwork for PR 2's pipeline writing the catalog out
during annotation, and PR 3's ``src/lerobot/tools/`` package shipping
runnable implementations (one file per tool — first up:
``say.py`` wrapping Kyutai's pocket-tts).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 18:44:58 +02:00
Pepijn c1a0c601e2 feat(language): task_aug style + automatic ${task} rephrasing rotation
Adds task-prompt diversity (Xiao 2022 / CAST) without touching
``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved
``task_aug`` as a future style; this lands it now.

- ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and
  ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns
  ``language_persistent`` so PR 2 writers route it correctly.

- ``language_render.py``: ``_resolve_task`` now consults the persistent
  slice for rows of ``style="task_aug", role="user"``. When any exist
  it picks one deterministically by ``sample_idx`` (blake2b-keyed, not
  Python's randomized hash) so an epoch sees every rephrasing of every
  episode while the same sample still resolves identically across
  reruns. Falls back to the canonical ``meta/tasks.parquet`` task when
  no rephrasings are present, so existing datasets and unannotated runs
  keep their behaviour. Explicit ``task=`` overrides still win.

- Tests: rephrasing coverage across samples, determinism on repeat
  ``sample_idx``, fallback when persistent has no ``task_aug`` rows,
  and explicit override priority.

Recipes get this for free: any ``${task}`` placeholder rotates through
the available rephrasings. Recipes that want the literal canonical task
can override the binding.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 16:45:39 +02:00
Pepijn 1ca38d9748 fix(language): drop motion from VIEW_DEPENDENT_STYLES
Motion primitives are described in robot-frame (joint / Cartesian) terms,
not pixel space, so they are camera-agnostic. Only `vqa` (event) and
`trace` (event, pixel-trajectory) are view-dependent.

The `camera` field stays on PERSISTENT_ROW_FIELDS for schema symmetry —
the validator, resolver, and HF feature mapping behave identically across
the two columns regardless of which styles populate `camera` today —
but persistent rows now always have `camera=None` in practice.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 10:54:12 +02:00
Pepijn 5a6aa64570 feat(language): per-camera tagging on view-dependent styles
Adds a nullable `camera` field to the language row struct (both persistent
and event variants) so view-dependent styles like `vqa` can carry which
`observation.images.*` view they were grounded against. Without this,
multi-camera datasets ended up with multiple `(vqa, role)` rows at the
same timestamp that the resolver could not disambiguate.

- `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS,
  to both Arrow struct types and the HF datasets feature mappings;
  introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus
  `is_view_dependent_style` and `validate_camera_field` helpers (camera
  required iff style is view-dependent).
- `language_render.py`: thread an optional `camera=` kwarg through every
  resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through
  `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera
  VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`.
  Without a `camera` filter, multi-row matches keep raising the existing
  ambiguity error — which is the desired behaviour on multi-camera data.
- `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with
  `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying
  the matching image block), keeping the original 0.20 budget and
  documenting the customization point for datasets with different cameras.
- Tests: schema test asserts the new field order; new tests cover
  `is_view_dependent_style`, `validate_camera_field` (both required and
  forbidden directions), per-camera `emitted_at` filtering, and the
  ambiguity error when two cameras emit `(vqa, assistant)` at the same
  timestamp without a `camera=` filter. RenderMessagesStep + dataset
  passthrough fixtures updated to include the new field.
- `docs/source/language_and_recipes.mdx`: document the `camera` field,
  the per-camera resolver pattern, and the canonical recipe convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-30 10:48:17 +02:00
Pepijn 0b06790da0 feat(language): add motion (persistent) and trace (event-only) styles
Promote the previously-reserved motion/trace styles to first-class core
styles. motion routes to language_persistent (it tracks robot state over
time); trace routes to language_events (single-moment annotations).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:21:49 +02:00
Pepijn b43dc39ba4 Add docstrings to all new helpers; revert uv.lock
Covers private helpers in recipe.py, language.py, language_render.py,
and render_messages_processor.py. Also reverts uv.lock to main (it was
re-generated by `uv run` during local checks).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 14:15:03 +02:00
Pepijn 2b71221194 Address review: split persistent/event schemas, drop event timestamps
- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals
- dataset_metadata.py: keep CODEBASE_VERSION at v3.0
- language.py: remove RESERVED_STYLES; split arrow/feature schemas into
  persistent (with timestamp) and event (without timestamp); add docstrings
- language_render.py: events use frame-row timestamp implicitly; no
  per-event timestamp filtering or sorting
- converters.py: drop unused subtask_key passthrough
- add docstrings to new public APIs (recipe, render_messages_processor, collate)
- update tests for split schemas; revert uv.lock

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-27 13:38:23 +02:00
Pepijn 8833d735a1 Add extensive language support 2026-04-27 10:56:32 +02:00
29 changed files with 2329 additions and 497 deletions
+4 -2
View File
@@ -31,8 +31,10 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
- local: language_and_recipes
title: Language Columns and Recipes
- local: tools
title: Tools
- local: streaming_video_encoding
title: Streaming Video Encoding
title: "Datasets"
-277
View File
@@ -1,277 +0,0 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor import TokenizerProcessorStep
# Create a tokenizer processor step
tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
+109
View File
@@ -0,0 +1,109 @@
# Language columns and recipes
LeRobot stores reusable language annotations directly next to frame data in `data/chunk-*/file-*.parquet`.
The two optional columns are:
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
Both columns share the same row shape (event rows omit `timestamp` because the
frame the row sits on already provides it):
```text
role: string
content: string | null
style: string | null
timestamp: float64 # persistent rows only
camera: string | null # observation.images.* feature key, view-dependent rows only
tool_calls: list[Json] | null
```
The `camera` field tags rows whose `content` is grounded in a specific camera
view. Rows of view-dependent styles (`vqa`, and the reserved `motion` /
`trace`) MUST set `camera` to the matching `observation.images.*` feature key.
Rows of every other style MUST leave `camera` as `null`. Pipeline writers and
the validator enforce this via `validate_camera_field(style, camera)`.
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
## Architecture
The language stack has three layers:
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
2. `lerobot.datasets.language_render` resolves rows and renders messages.
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
## Temporal semantics
Persistent styles are active after emission until replaced:
- `active_at(t, style=subtask)`
- `nth_prev(style=memory, offset=1)`
- `nth_next(style=subtask, offset=1)`
Event styles only exist on their exact timestamp:
- `emitted_at(t, style=interjection)`
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
- `emitted_at(t, role=assistant, tool_name=say)`
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
## View-dependent resolution
For view-dependent styles (`vqa`, `motion`, `trace`), the resolver gains a
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
camera at the same timestamp; without `camera=`, those resolvers see two
matches and raise an ambiguity error. Recipes consume each camera through its
own binding plus a matching image block, e.g.
```yaml
ask_vqa_top:
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- { type: image, feature: observation.images.top }
- { type: text, text: "${vqa_query}" }
- { role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa }
```
Add one such sub-recipe per camera the dataset records.
## Recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`.
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
Rendered samples use HF-style chat messages plus LeRobot sidecars:
```python
sample["messages"]
sample["message_streams"]
sample["target_message_indices"]
```
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone.
## Blends
Blend recipes select one weighted sub-recipe deterministically from the sample index.
The canonical `recipes/pi05_hirobot.yaml` combines memory updates, interjection responses, high-level subtask prediction, low-level execution, and VQA.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
+198
View File
@@ -0,0 +1,198 @@
# Tools
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
emit structured invocations like `say(text="OK, starting now")` that the
runtime dispatches to a real implementation (TTS, controller, logger, …).
This page covers:
1. Where the tool catalog lives (PR 1).
2. How the annotation pipeline produces tool-call atoms (PR 2).
3. How to add your own tool (PR 3).
## Where tools are declared
Two layers.
**The catalog** — a list of OpenAI-style function schemas — lives at
`meta/info.json["tools"]` on each dataset. Example:
```json
{
"features": { "...": "..." },
"tools": [
{
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": { "type": "string", "description": "The verbatim text to speak." }
},
"required": ["text"]
}
}
}
]
}
```
Read it via the dataset metadata accessor:
```python
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
tools = meta.tools # list[dict] — OpenAI tool schemas
```
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
single-entry list with the canonical `say` schema. So unannotated
datasets and chat-template consumers keep working without any
configuration:
```python
prompt_str = tokenizer.apply_chat_template(
sample["messages"],
tools=meta.tools, # works either way
add_generation_prompt=False,
tokenize=False,
)
```
**The implementations** — runnable Python — live under
`src/lerobot/tools/`, one file per tool. The `say` implementation
arrives in PR 3 and wraps Kyutai's pocket-tts model.
## Per-row tool *invocations*
The catalog above describes *what can be called*. The actual *call* — the
function name plus the argument values — is stored per-row, on the
assistant atoms in `language_events`:
```python
{
"role": "assistant",
"content": null,
"style": null,
"timestamp": 12.4,
"camera": null,
"tool_calls": [
{ "type": "function",
"function": { "name": "say", "arguments": { "text": "On it." } } }
]
}
```
Recipes splice these into rendered messages via `tool_calls_from`:
```yaml
user_interjection_response:
bindings:
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${current_plan}", stream: high_level,
target: true, tool_calls_from: speech }
```
The model's training target is one assistant turn that carries both the
plan text *and* the `say` tool call. At inference, the runtime parses
the generated text back into structured `tool_calls` and dispatches to
the matching implementation.
## How to add your own tool
Three steps. Concrete example: a `record_observation` tool the policy
can call to capture an extra observation outside the regular control
loop.
### Step 1 — declare the schema
Add an entry under `meta/info.json["tools"]`. Either edit the file
directly on disk *before* running the annotation pipeline (it'll be
preserved) or hand it to `lerobot-annotate` via a config flag (PR 2 —
exact CLI lands with the pipeline change).
```json
{
"tools": [
{ "type": "function", "function": { "name": "say", "...": "..." } },
{
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a high-resolution still image for the user.",
"parameters": {
"type": "object",
"properties": {
"label": { "type": "string", "description": "Short label for the saved image." }
},
"required": ["label"]
}
}
}
]
}
```
The schema follows OpenAI's function-calling convention exactly, so the
chat template can render it natively.
### Step 2 — implement the call
Create `src/lerobot/tools/record_observation.py`:
```python
from .base import Tool
from typing import Any
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
class RecordObservationTool:
name = "record_observation"
schema = RECORD_OBSERVATION_SCHEMA
def __init__(self, schema: dict | None = None, output_dir: str = "."):
self.output_dir = output_dir
def call(self, arguments: dict) -> str:
label = arguments["label"]
# ... save the latest camera frame to <output_dir>/<label>.png ...
return f"saved {label}.png"
```
One file per tool keeps dependencies isolated — `record_observation`
might pull `pillow`, while `say` (PR 3) pulls `pocket-tts`. Users
installing only the tools they need avoid heavy transitive deps.
### Step 3 — register it
Add to `src/lerobot/tools/registry.py` (PR 3):
```python
from .record_observation import RecordObservationTool
TOOL_REGISTRY["record_observation"] = RecordObservationTool
```
That's it. At runtime `get_tools(meta)` looks up each schema in
`meta.tools`, instantiates the matching registered class, and returns
a name → instance dict the dispatcher can route into.
## Where this fits in the three-PR stack
| Layer | PR | What lands |
|---|---|---|
| Catalog storage in `meta/info.json` + `meta.tools` accessor | PR 1 | This page; `SAY_TOOL_SCHEMA`, `DEFAULT_TOOLS` constants in `lerobot.datasets.language`; `LeRobotDatasetMetadata.tools` property |
| Annotation pipeline writes `tools` to meta after a run; honors anything users pre-populated | PR 2 | `lerobot-annotate` ensures `meta/info.json["tools"]` includes the canonical `say` and merges any user-declared tools |
| Runnable implementations under `src/lerobot/tools/`; runtime dispatcher; `say.py` wired to Kyutai's pocket-tts | PR 3 | One file per tool; `Tool` protocol; `TOOL_REGISTRY`; optional `[tools]` extra in `pyproject.toml` |
If you want to use a tool *without* writing an implementation (e.g. for
training-time chat-template formatting only), step 1 alone is enough —
the model still learns to *generate* the call. Steps 2 and 3 are only
needed to actually *execute* it at inference.
+1 -1
View File
@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.0.0,<5.0.0",
"datasets>=4.7.0,<5.0.0",
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
+4
View File
@@ -23,6 +23,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .recipe import MessageTurn, TrainingRecipe, load_recipe
from .types import (
FeatureType,
NormalizationMode,
@@ -41,7 +42,10 @@ __all__ = [
# Config classes
"DatasetConfig",
"EvalConfig",
"MessageTurn",
"PeftConfig",
"PreTrainedConfig",
"TrainingRecipe",
"WandBConfig",
"load_recipe",
]
+193
View File
@@ -0,0 +1,193 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
DEFAULT_BINDINGS = {
"subtask": "active_at(t, style=subtask)",
"memory": "active_at(t, style=memory)",
"plan": "active_at(t, style=plan)",
"speech": "emitted_at(t, role=assistant, tool_name=say)",
"interjection": "emitted_at(t, style=interjection)",
"vqa": "emitted_at(t, style=vqa, role=assistant)",
"vqa_query": "emitted_at(t, style=vqa, role=user)",
}
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
target: bool = False
if_present: str | None = None
tool_calls_from: str | None = None
def __post_init__(self) -> None:
"""Validate role, stream, and content after dataclass construction."""
if self.role not in _VALID_ROLES:
raise ValueError(f"Unsupported message role: {self.role!r}")
if self.stream is not None and self.stream not in _VALID_STREAMS:
raise ValueError(f"Unsupported message stream: {self.stream!r}")
if self.content is None and self.tool_calls_from is None:
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
if self.content is not None and not isinstance(self.content, (str, list)):
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
if isinstance(self.content, list):
for block in self.content:
if not isinstance(block, dict) or "type" not in block:
raise ValueError(
"Multimodal content blocks must be HF-style dictionaries with a type key."
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
"""Construct a :class:`MessageTurn` from a plain dictionary."""
return cls(**data)
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
weight: float | None = None
def __post_init__(self) -> None:
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
if self.messages is not None and self.blend is not None:
raise ValueError("TrainingRecipe must set only one of messages or blend.")
if self.messages is None and self.blend is None:
raise ValueError("TrainingRecipe must set one of messages or blend.")
if self.messages is not None:
self._validate_message_recipe()
if self.blend is not None:
self._validate_blend_recipe()
@classmethod
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
data = dict(data)
if data.get("messages") is not None:
data["messages"] = [
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
for turn in data["messages"]
]
if data.get("blend") is not None:
data["blend"] = {
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
for name, recipe in data["blend"].items()
}
return cls(**data)
@classmethod
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
import yaml # type: ignore[import-untyped]
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
for turn in self.messages:
missing = self._referenced_bindings(turn) - known_bindings
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""
assert self.blend is not None
if not self.blend:
raise ValueError("Blend recipes must contain at least one component.")
for name, recipe in self.blend.items():
if recipe.blend is not None:
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
if recipe.messages is None:
raise ValueError(f"Blend component {name!r} must define messages.")
if recipe.weight is None:
raise ValueError(f"Blend component {name!r} must define weight.")
if recipe.weight <= 0:
raise ValueError(f"Blend component {name!r} must have a positive weight.")
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
"""Return the binding names that ``turn`` references via placeholders or attributes."""
names: set[str] = set()
if turn.if_present is not None:
names.add(turn.if_present)
if turn.tool_calls_from is not None:
names.add(turn.tool_calls_from)
names.update(_placeholders_in_content(turn.content))
return names
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
if content is None:
return set()
if isinstance(content, str):
return set(_PLACEHOLDER_RE.findall(content))
names: set[str] = set()
for block in content:
for value in block.values():
if isinstance(value, str):
names.update(_PLACEHOLDER_RE.findall(value))
return names
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)
@@ -0,0 +1,74 @@
blend:
memory_update:
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "emitted_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.16
bindings:
prior_plan: "nth_prev(style=plan, offset=1)"
current_plan: "emitted_at(t, style=plan)"
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
high_level_subtask:
weight: 0.15
bindings:
next_subtask: "nth_next(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
low_level_execution:
weight: 0.35
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
# VQA is view-dependent: bbox / keypoint / count answers only make sense for
# the camera they were grounded against. Each camera gets its own sub-recipe
# so the resolver can disambiguate via `camera=...` and the user-turn carries
# the matching image block. Adjust the camera keys (and add more sub-recipes)
# to match the cameras present on your dataset.
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.top}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+14
View File
@@ -37,6 +37,14 @@ from .dataset_tools import (
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
column_for_style,
)
from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
@@ -53,10 +61,15 @@ __all__ = [
"CODEBASE_VERSION",
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"PERSISTENT_STYLES",
"STYLE_REGISTRY",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"add_features",
@@ -66,6 +79,7 @@ __all__ = [
"convert_image_to_video_dataset",
"create_initial_features",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
"get_feature_stats",
"load_episodes",
+1 -1
View File
@@ -512,7 +512,7 @@ def compute_episode_stats(
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
if features[key]["dtype"] in {"string", "language"}:
continue
if features[key]["dtype"] in ["image", "video"]:
+22 -3
View File
@@ -34,7 +34,6 @@ from .io_utils import (
load_episodes,
load_info,
load_stats,
load_subtasks,
load_tasks,
write_info,
write_json,
@@ -177,7 +176,6 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -320,6 +318,28 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def tools(self) -> list[dict]:
"""OpenAI-style tool schemas declared by this dataset.
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
can mutate the result safely. Falls back to
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
``say`` schema) when the dataset doesn't declare any — that way
unannotated datasets and chat-template consumers
(``apply_chat_template(messages, tools=meta.tools)``) keep
working out of the box.
Implementations live under :mod:`lerobot.tools` (one file per
tool); see ``docs/source/tools.mdx`` for the authoring guide.
"""
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
declared = self.info.get("tools")
if isinstance(declared, list) and declared:
return [dict(t) for t in declared]
return [dict(t) for t in DEFAULT_TOOLS]
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
@@ -635,7 +655,6 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
-5
View File
@@ -295,9 +295,4 @@ class DatasetReader:
task_idx = item["task_index"].item()
item["task"] = self._meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
return item
+15 -1
View File
@@ -22,6 +22,12 @@ from PIL import Image as PILImage
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -45,7 +51,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
if is_language_column(key):
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
hf_features[key] = datasets.Image()
@@ -242,6 +254,8 @@ def validate_feature_dtype_and_shape(
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return ""
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
+8 -11
View File
@@ -34,7 +34,6 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di
from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
DEFAULT_TASKS_PATH,
EPISODES_DIR,
INFO_PATH,
@@ -189,14 +188,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
@@ -268,11 +259,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
if key in {"language_persistent", "language_events"}:
continue
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
elif first_item is None or isinstance(first_item, dict):
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
@@ -308,7 +301,11 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in [
"task",
"language_persistent",
"language_events",
]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
+236
View File
@@ -0,0 +1,236 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Literal
import datasets
import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
CORE_STYLES = {
"subtask",
"plan",
"memory",
"motion",
"interjection",
"vqa",
"trace",
"task_aug",
}
EXTENDED_STYLES = set()
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
# is intentionally NOT in this set: motion primitives are described in
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
# view-dependent. The ``camera`` field nevertheless lives on
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
# behave symmetrically across the two columns; persistent rows simply
# always have ``camera=None`` in practice today.
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
LanguageColumn = Literal["language_persistent", "language_events"]
def _json_arrow_type() -> pa.DataType:
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float64"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
}
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def is_view_dependent_style(style: str | None) -> bool:
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
return style in VIEW_DEPENDENT_STYLES
def validate_camera_field(style: str | None, camera: str | None) -> None:
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
a non-view-dependent style carries one. Pipeline writers and the validator
should call this on every emitted row.
"""
if is_view_dependent_style(style):
if not camera:
raise ValueError(
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
f"field referencing an 'observation.images.*' feature key."
)
elif camera is not None:
raise ValueError(
f"Rows of style {style!r} must have camera=None; got camera={camera!r}."
)
# --- Tool registry --------------------------------------------------------
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
# of OpenAI-style function schemas. The runtime / training stack reads them
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
# fallback when the dataset doesn't declare any). Implementations live
# under :mod:`lerobot.tools` (one file per tool); see
# ``docs/source/tools.mdx`` for the authoring guide.
SAY_TOOL_SCHEMA: dict = {
"type": "function",
"function": {
"name": "say",
"description": "Speak a short utterance to the user via the TTS executor.",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The verbatim text to speak.",
}
},
"required": ["text"],
},
},
}
"""Canonical schema for the ``say`` tool emitted by the steerable
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
writer, PR 3's runtime tool registry, and the dataset visualizer all
import this constant rather than duplicating the dict."""
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
chat-template consumers (``apply_chat_template(messages, tools=...)``)
keep working out of the box."""
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
raise ValueError(f"Unknown language style: {style!r}")
+593
View File
@@ -0,0 +1,593 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import hashlib
import re
from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
from .language import (
EVENT_ONLY_STYLES,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
column_for_style,
)
LanguageRow = dict[str, Any]
RenderedMessages = dict[str, list[Any]]
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
def active_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
``camera`` selector. ``events`` is accepted for resolver-signature
uniformity but is not consulted: only persistent styles are valid here.
"""
_validate_persistent_resolver("active_at", style)
matches = _matching_rows(
persistent, style=style, role=role, tool_name=tool_name, camera=camera
)
matches = [row for row in matches if _timestamp(row) <= t]
return _select_latest(
matches, style=style, role=role, tool_name=tool_name, camera=camera
)
def emitted_at(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
style: str | None = None,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
so all matching event rows are considered emitted at ``t``. ``camera``
filters by the row's ``camera`` field — required to disambiguate when
multiple view-dependent rows share ``(t, role)`` across cameras.
"""
column = column_for_style(style)
if column == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(
persistent, style=style, role=role, tool_name=tool_name, camera=camera
)
if _timestamp(row) == t
]
return _select_one(
matches,
style=style,
role=role,
tool_name=tool_name,
camera=camera,
sort_key=_persistent_sort_key,
)
matches = _matching_rows(
events, style=style, role=role, tool_name=tool_name, camera=camera
)
return _select_one(
matches,
style=style,
role=role,
tool_name=tool_name,
camera=camera,
sort_key=_event_sort_key,
)
def nth_prev(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions before the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
style=style,
offset=-offset,
role=role,
tool_name=tool_name,
camera=camera,
resolver_name="nth_prev",
)
def nth_next(
t: float,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow] | None = None,
style: str | None = None,
offset: int = 1,
role: str | None = None,
tool_name: str | None = None,
camera: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
one ``offset`` positions after the row active at ``t``. Only valid for
persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
style=style,
offset=offset,
role=role,
tool_name=tool_name,
camera=camera,
resolver_name="nth_next",
)
def render_sample(
*,
recipe: TrainingRecipe,
persistent: Sequence[LanguageRow] | None,
events: Sequence[LanguageRow] | None,
t: float,
sample_idx: int,
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
return _render_message_recipe(selected_recipe, bindings)
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
return recipe
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
if total_weight <= 0:
raise ValueError("Blend weights must sum to a positive value.")
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
cumulative = 0.0
last_component: TrainingRecipe | None = None
for component in recipe.blend.values():
last_component = component
cumulative += component.weight or 0.0
if draw < cumulative:
return component
assert last_component is not None
return last_component
def _resolve_bindings(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> dict[str, LanguageRow | str | None]:
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
bindings: dict[str, LanguageRow | str | None] = {
"task": _resolve_task(
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
),
}
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
for name, spec in specs.items():
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
return bindings
def _resolve_task(
task: str | None,
dataset_ctx: Any | None,
*,
persistent: Sequence[LanguageRow] = (),
sample_idx: int = 0,
) -> str | None:
"""Return the task string for ``sample_idx``.
Resolution order:
1. Explicit ``task`` override (caller-supplied) wins.
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
deterministically pick one by ``sample_idx`` so each frame of an
episode rotates through the available rephrasings across an epoch.
This realizes Xiao 2022 / CAST-style task-prompt diversity without
changing ``meta/tasks.parquet`` and without forcing recipes to opt
in: ``${task}`` automatically picks a rephrasing when one exists,
and falls back to the canonical task otherwise. Recipes that want
the literal canonical task can override the binding.
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
backed by ``meta/tasks.parquet``).
"""
if task is not None:
return task
aug_rows = [
r
for r in persistent
if r.get("style") == "task_aug" and r.get("role") == "user"
]
if aug_rows:
# Deterministic, blake2b-based pick keyed on sample_idx so the
# rotation is reproducible across runs (Python's built-in ``hash``
# is process-randomized).
digest = hashlib.blake2b(
f"task_aug:{sample_idx}".encode(), digest_size=8
).digest()
idx = int.from_bytes(digest, "big") % len(aug_rows)
chosen = aug_rows[idx].get("content")
if chosen:
return str(chosen)
if dataset_ctx is None:
return None
if isinstance(dataset_ctx, dict):
return dataset_ctx.get("task")
return getattr(dataset_ctx, "task", None)
def _resolve_spec(
spec: str,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
) -> LanguageRow | None:
"""Parse a single binding's resolver expression and dispatch to its function."""
match = _RESOLVER_RE.match(spec.strip())
if match is None:
raise ValueError(f"Invalid resolver expression: {spec!r}")
name = match.group("name")
kwargs = _parse_resolver_args(match.group("args"))
kwargs.pop("t_arg", None)
resolvers = {
"active_at": active_at,
"emitted_at": emitted_at,
"nth_prev": nth_prev,
"nth_next": nth_next,
}
if name not in resolvers:
raise ValueError(f"Unknown language resolver: {name!r}")
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
def _parse_resolver_args(args: str) -> dict[str, Any]:
"""Parse a comma-separated resolver argument list into a kwargs dict."""
kwargs: dict[str, Any] = {}
if not args.strip():
return kwargs
parts = [part.strip() for part in args.split(",") if part.strip()]
for part in parts:
if part == "t":
kwargs["t_arg"] = True
continue
if "=" not in part:
raise ValueError(f"Invalid resolver argument: {part!r}")
key, value = (item.strip() for item in part.split("=", 1))
if key == "offset":
kwargs[key] = int(value)
else:
kwargs[key] = value.strip("\"'")
return kwargs
def _render_message_recipe(
recipe: TrainingRecipe,
bindings: dict[str, LanguageRow | str | None],
) -> RenderedMessages | None:
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
assert recipe.messages is not None
messages: list[dict[str, Any]] = []
streams: list[str | None] = []
target_indices: list[int] = []
for turn in recipe.messages:
if turn.if_present is not None and bindings.get(turn.if_present) is None:
continue
message = {"role": turn.role}
if turn.content is not None:
message["content"] = _render_content(turn.content, bindings)
if turn.tool_calls_from is not None:
row = bindings.get(turn.tool_calls_from)
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
if tool_calls:
message["tool_calls"] = copy.deepcopy(tool_calls)
message_idx = len(messages)
messages.append(message)
streams.append(turn.stream)
if turn.target:
target_indices.append(message_idx)
if not target_indices:
return None
rendered = {
"messages": messages,
"message_streams": streams,
"target_message_indices": target_indices,
}
_validate_rendered(rendered)
return rendered
def _render_content(
content: str | list[dict[str, Any]],
bindings: dict[str, LanguageRow | str | None],
) -> str | list[dict[str, Any]]:
"""Substitute bindings into a string or each string field of multimodal blocks."""
if isinstance(content, str):
return _substitute(content, bindings)
rendered_blocks = []
for block in content:
rendered_block = copy.deepcopy(block)
for key, value in rendered_block.items():
if isinstance(value, str):
rendered_block[key] = _substitute(value, bindings)
rendered_blocks.append(rendered_block)
return rendered_blocks
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
def replace(match: re.Match[str]) -> str:
"""Resolve a single ``${name}`` match to its bound string value."""
name = match.group(1)
if name not in bindings:
raise ValueError(f"Unknown template binding: {name!r}")
value = bindings[name]
if value is None:
return ""
if isinstance(value, dict):
content = value.get("content")
return "" if content is None else str(content)
return str(value)
return _PLACEHOLDER_RE.sub(replace, template)
def _validate_rendered(rendered: RenderedMessages) -> None:
"""Sanity-check the rendered output for stream/target alignment."""
messages = rendered["messages"]
streams = rendered["message_streams"]
target_indices = rendered["target_message_indices"]
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")
for idx, stream in enumerate(streams):
if stream is None:
raise ValueError(f"Rendered message {idx} has no stream.")
def _nth_relative(
t: float,
*,
persistent: Sequence[LanguageRow],
style: str | None,
offset: int,
role: str | None,
tool_name: str | None,
camera: str | None,
resolver_name: str,
) -> LanguageRow | None:
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
_validate_persistent_resolver(resolver_name, style)
if abs(offset) < 1:
raise ValueError(f"{resolver_name} offset must be non-zero.")
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
key=_persistent_sort_key,
)
if not rows:
return None
anchor_idx = None
for idx, row in enumerate(rows):
if _timestamp(row) <= t:
anchor_idx = idx
else:
break
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
return None
return rows[target_idx]
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
if style is None:
raise ValueError(f"{resolver_name} requires a persistent style.")
if style in EVENT_ONLY_STYLES:
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
if style not in PERSISTENT_STYLES:
column_for_style(style)
def _matching_rows(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> list[LanguageRow]:
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
return [
row
for row in rows
if (style is None or row.get("style") == style)
and (role is None or row.get("role") == role)
and (tool_name is None or _row_has_tool_name(row, tool_name))
and (camera is None or row.get("camera") == camera)
]
def _select_latest(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
) -> LanguageRow | None:
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
if not rows:
return None
rows = sorted(rows, key=_persistent_sort_key)
latest_ts = _timestamp(rows[-1])
return _select_one(
[row for row in rows if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
camera=camera,
sort_key=_persistent_sort_key,
)
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
camera: str | None,
sort_key: Any,
) -> LanguageRow | None:
"""Return the single matching row, or raise if the selectors are ambiguous."""
if not rows:
return None
if len(rows) > 1 and role is None and tool_name is None and camera is None:
raise ValueError(
f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., "
f"or camera=... to disambiguate."
)
return sorted(rows, key=sort_key)[0]
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
"""Sort key for persistent rows: ``(timestamp, style, role)``."""
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
"""Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame)."""
return (row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
value = row["timestamp"]
return float(value.item() if hasattr(value, "item") else value)
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
for tool_call in row.get("tool_calls") or []:
if isinstance(tool_call, str):
continue
function = tool_call.get("function") if isinstance(tool_call, dict) else None
if isinstance(function, dict) and function.get("name") == tool_name:
return True
return False
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
normalized = []
for row in rows:
if row is None:
continue
if hasattr(row, "as_py"):
row = row.as_py()
if not isinstance(row, dict):
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
normalized.append(dict(row))
return normalized
-1
View File
@@ -83,7 +83,6 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
+2
View File
@@ -93,6 +93,7 @@ from .relative_action_processor import (
to_relative_actions,
)
from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .render_messages_processor import RenderMessagesStep
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
__all__ = [
@@ -128,6 +129,7 @@ __all__ = [
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"RelativeActionsProcessorStep",
"RenderMessagesStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NewLineTaskProcessorStep",
+18
View File
@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
complementary_data["messages"] = [messages]
if "message_streams" in complementary_data:
streams = complementary_data["message_streams"]
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
complementary_data["message_streams"] = [streams]
if "target_message_indices" in complementary_data:
indices = complementary_data["target_message_indices"]
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
complementary_data["target_message_indices"] = [indices]
return complementary_data
def transform_features(
+25 -2
View File
@@ -167,12 +167,35 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
language_persistent_key = (
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
)
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
target_message_indices_key = (
{"target_message_indices": batch["target_message_indices"]}
if "target_message_indices" in batch
else {}
)
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
return {
**pad_keys,
**task_key,
**index_key,
**task_index_key,
**episode_index_key,
**timestamp_key,
**language_persistent_key,
**language_events_key,
**messages_key,
**message_streams_key,
**target_message_indices_key,
}
def create_transition(
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=_scalar(timestamp),
sample_idx=int(_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list) and len(value) == 1:
return _scalar(value[0])
return value
+2
View File
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.utils.collate import lerobot_collate_fn
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -386,6 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=lerobot_collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
+54
View File
@@ -0,0 +1,54 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
"""Collate function that preserves Python-list and language fields as lists.
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
rendered-message and language fields as plain Python lists, and delegates
every other key to PyTorch's ``default_collate``.
"""
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
preserved = {
key: [sample[key] for sample in batch if key in sample]
for key in _PYTHON_LIST_KEYS
if any(key in sample for sample in batch)
}
tensorizable = [
{
key: value
for key, value in sample.items()
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
}
for sample in batch
]
collated = default_collate(tensorizable)
collated.update(preserved)
return collated
+32
View File
@@ -0,0 +1,32 @@
#!/usr/bin/env python
from pathlib import Path
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
def test_message_recipe_validates_unknown_binding():
with pytest.raises(ValueError, match="unknown binding"):
TrainingRecipe(
messages=[
MessageTurn(role="user", content="${missing}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
def test_canonical_recipe_loads():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
assert recipe.blend is not None
assert set(recipe.blend) == {
"memory_update",
"user_interjection_response",
"high_level_subtask",
"low_level_execution",
"ask_vqa_top",
"ask_vqa_wrist",
}
assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96)
+152
View File
@@ -0,0 +1,152 @@
#!/usr/bin/env python
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.io_utils import write_info
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
PERSISTENT_STYLES,
STYLE_REGISTRY,
VIEW_DEPENDENT_STYLES,
column_for_style,
is_view_dependent_style,
language_events_arrow_type,
language_feature_info,
language_persistent_arrow_type,
validate_camera_field,
)
from lerobot.datasets.utils import DEFAULT_DATA_PATH
def test_language_arrow_schema_has_expected_fields():
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == [
"role",
"content",
"style",
"timestamp",
"camera",
"tool_calls",
]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
assert column_for_style("plan") == LANGUAGE_PERSISTENT
assert column_for_style("memory") == LANGUAGE_PERSISTENT
assert column_for_style("motion") == LANGUAGE_PERSISTENT
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
assert column_for_style("interjection") == LANGUAGE_EVENTS
assert column_for_style("vqa") == LANGUAGE_EVENTS
assert column_for_style("trace") == LANGUAGE_EVENTS
assert column_for_style(None) == LANGUAGE_EVENTS
def test_view_dependent_styles():
# motion lives in PERSISTENT_STYLES and is described in robot-frame
# (joint / Cartesian) terms, so it is NOT view-dependent. Only vqa
# (event) and trace (event, pixel-trajectory) carry a camera tag.
assert {"vqa", "trace"} == VIEW_DEPENDENT_STYLES
assert is_view_dependent_style("vqa")
assert is_view_dependent_style("trace")
assert not is_view_dependent_style("motion")
assert not is_view_dependent_style("subtask")
assert not is_view_dependent_style("plan")
assert not is_view_dependent_style("interjection")
assert not is_view_dependent_style(None)
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
validate_camera_field("vqa", "observation.images.top")
validate_camera_field("trace", "observation.images.front")
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("vqa", None)
with pytest.raises(ValueError, match="view-dependent"):
validate_camera_field("trace", "")
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
validate_camera_field("subtask", None)
validate_camera_field("plan", None)
validate_camera_field("memory", None)
validate_camera_field("motion", None)
validate_camera_field("interjection", None)
validate_camera_field(None, None)
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("subtask", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("motion", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field("interjection", "observation.images.top")
with pytest.raises(ValueError, match="must have camera=None"):
validate_camera_field(None, "observation.images.top")
def test_unknown_style_rejected():
with pytest.raises(ValueError, match="Unknown language style"):
column_for_style("surprise")
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
root = tmp_path / "language_dataset"
dataset = empty_lerobot_dataset_factory(
root=root,
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
use_videos=False,
)
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
dataset.save_episode()
dataset.finalize()
persistent = [
{
"role": "assistant",
"content": "reach for the cup",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
]
event = {
"role": "user",
"content": "what is visible?",
"style": "vqa",
"camera": "observation.images.top",
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
df = pd.read_parquet(data_path)
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
df[LANGUAGE_EVENTS] = [[event], []]
df.to_parquet(data_path)
info = dataset.meta.info
info["features"].update(language_feature_info())
write_info(info, root)
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
first = reloaded[0]
second = reloaded[1]
assert first[LANGUAGE_PERSISTENT] == persistent
assert first[LANGUAGE_EVENTS] == [event]
assert second[LANGUAGE_PERSISTENT] == persistent
assert second[LANGUAGE_EVENTS] == []
+388
View File
@@ -0,0 +1,388 @@
#!/usr/bin/env python
from pathlib import Path
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"timestamp": timestamp,
"camera": camera,
"tool_calls": tool_calls,
}
def event_row(role, content, style, tool_calls=None, camera=None):
return {
"role": role,
"content": content,
"style": style,
"camera": camera,
"tool_calls": tool_calls,
}
PERSISTENT = [
persistent_row("assistant", "plan 0", "plan", 0.0),
persistent_row("assistant", "memory 0", "memory", 0.0),
persistent_row("assistant", "subtask 0", "subtask", 0.0),
persistent_row("assistant", "memory 1", "memory", 1.0),
persistent_row("assistant", "subtask 1", "subtask", 1.0),
]
EVENTS_AT_1 = [
event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
]
EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"),
event_row(
"assistant",
None,
None,
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
),
]
# Same emission tick, two cameras: triggers per-camera disambiguation in
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
# (vqa, user) + (vqa, assistant) pair per camera.
EVENTS_AT_3_TWO_CAMERAS = [
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
]
def test_resolver_temporal_semantics():
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
assert (
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
== '{"count": 2}'
)
def test_persistent_relative_resolvers_reject_event_styles():
with pytest.raises(ValueError, match="event-only"):
active_at(1.0, persistent=PERSISTENT, style="vqa")
with pytest.raises(ValueError, match="event-only"):
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
def test_nth_prev_and_next():
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
def test_substitution_if_present_multimodal_and_tool_calls():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "${task}: ${interjection}"},
],
stream="high_level",
if_present="interjection",
),
MessageTurn(
role="assistant",
content="${plan}",
stream="high_level",
target=True,
tool_calls_from="speech",
),
],
bindings={"plan": "active_at(t, style=plan)"},
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=EVENTS_AT_2,
t=2.0,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
assert rendered["messages"][1]["content"] == "plan 0"
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
assert rendered["message_streams"] == ["high_level", "high_level"]
assert rendered["target_message_indices"] == [1]
def test_exact_event_miss_returns_none_when_target_skips():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
]
)
assert (
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
)
def test_deterministic_blend_sampling():
recipe = TrainingRecipe(
blend={
"a": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
],
),
"b": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
],
),
}
)
first = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
second = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
assert first == second
def test_emitted_at_filters_vqa_by_camera():
top = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.top",
)
wrist = emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
camera="observation.images.wrist",
)
assert top["content"] == '{"count": 3}'
assert wrist["content"] == '{"count": 1}'
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
with pytest.raises(ValueError, match="Ambiguous resolver"):
emitted_at(
3.0,
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
style="vqa",
role="assistant",
)
def test_per_camera_blend_renders_both_views():
recipe = TrainingRecipe(
blend={
"top": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
),
},
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.top"},
{"type": "text", "text": "${vqa_query}"},
],
stream="high_level",
if_present="vqa_query",
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
"wrist": TrainingRecipe(
weight=1.0,
bindings={
"vqa_query": (
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
),
"vqa": (
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
),
},
messages=[
MessageTurn(
role="user",
content=[
{"type": "image", "feature": "observation.images.wrist"},
{"type": "text", "text": "${vqa_query}"},
],
stream="high_level",
if_present="vqa_query",
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
}
)
rendered_top = render_sample(
recipe=recipe.blend["top"],
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
t=3.0,
sample_idx=0,
)
rendered_wrist = render_sample(
recipe=recipe.blend["wrist"],
persistent=PERSISTENT,
events=EVENTS_AT_3_TWO_CAMERAS,
t=3.0,
sample_idx=0,
)
assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top"
assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?"
assert rendered_top["messages"][1]["content"] == '{"count": 3}'
assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist"
assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?"
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
rephrasings = [
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
# No explicit task override → resolver consults persistent rows.
seen: set[str] = set()
for sample_idx in range(64):
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=sample_idx,
dataset_ctx={"task": "canonical kitchen task"},
)
seen.add(rendered["messages"][0]["content"])
# Every rephrasing should be reachable across enough samples.
assert seen == {r["content"] for r in rephrasings}
# Same sample_idx → same pick (determinism).
a = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
dataset_ctx={"task": "canonical"},
)
b = render_sample(
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
dataset_ctx={"task": "canonical"},
)
assert a["messages"][0]["content"] == b["messages"][0]["content"]
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT, # no task_aug rows
events=[],
t=0.0,
sample_idx=0,
dataset_ctx={"task": "clean the kitchen"},
)
assert rendered["messages"][0]["content"] == "clean the kitchen"
def test_resolve_task_explicit_override_beats_rephrasings():
rephrasings = [
persistent_row("user", "rephrased one", "task_aug", 0.0),
persistent_row("user", "rephrased two", "task_aug", 0.0),
]
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
]
)
rendered = render_sample(
recipe=recipe,
persistent=rephrasings,
events=[],
t=0.0,
sample_idx=0,
task="explicit override wins",
dataset_ctx={"task": "canonical"},
)
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_canonical_recipe_can_render_low_level_branch():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
rendered = render_sample(
recipe=low_level,
persistent=PERSISTENT,
events=[],
t=0.5,
sample_idx=0,
task="clean kitchen",
)
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
assert rendered["message_streams"][-1] == "low_level"
assert rendered["target_message_indices"] == [1]
-193
View File
@@ -1,193 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pytest
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask
@@ -0,0 +1,56 @@
#!/usr/bin/env python
import torch
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
from lerobot.processor.converters import create_transition
from lerobot.processor.render_messages_processor import RenderMessagesStep
from lerobot.types import TransitionKey
def test_render_messages_step_noops_without_language_columns():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(complementary_data={"task": "do it"})
assert RenderMessagesStep(recipe)(transition) == transition
def test_render_messages_step_renders_and_drops_raw_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
]
)
transition = create_transition(
complementary_data={
"task": "do it",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [
{
"role": "assistant",
"content": "reach carefully",
"style": "subtask",
"timestamp": 0.0,
"camera": None,
"tool_calls": None,
}
],
"language_events": [],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert "language_persistent" not in data
assert "language_events" not in data
assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1]
+36
View File
@@ -0,0 +1,36 @@
#!/usr/bin/env python
import torch
from lerobot.utils.collate import lerobot_collate_fn
def test_lerobot_collate_preserves_messages_and_drops_raw_language():
batch = [
{
"index": torch.tensor(0),
"messages": [{"role": "assistant", "content": "a"}],
"message_streams": ["low_level"],
"target_message_indices": [0],
"language_persistent": [{"content": "raw"}],
"language_events": [],
},
{
"index": torch.tensor(1),
"messages": [{"role": "assistant", "content": "b"}],
"message_streams": ["low_level"],
"target_message_indices": [0],
"language_persistent": [{"content": "raw"}],
"language_events": [],
},
]
out = lerobot_collate_fn(batch)
assert out["index"].tolist() == [0, 1]
assert out["messages"][0][0]["content"] == "a"
assert out["messages"][1][0]["content"] == "b"
assert out["message_streams"] == [["low_level"], ["low_level"]]
assert out["target_message_indices"] == [[0], [0]]
assert "language_persistent" not in out
assert "language_events" not in out