mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-29 23:49:43 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b06790da0 | |||
| b43dc39ba4 | |||
| 2b71221194 | |||
| 8833d735a1 |
@@ -31,8 +31,8 @@
|
|||||||
title: Porting Large Datasets
|
title: Porting Large Datasets
|
||||||
- local: using_dataset_tools
|
- local: using_dataset_tools
|
||||||
title: Using the Dataset Tools
|
title: Using the Dataset Tools
|
||||||
- local: dataset_subtask
|
- local: language_and_recipes
|
||||||
title: Using Subtasks in the Dataset
|
title: Language Columns and Recipes
|
||||||
- local: streaming_video_encoding
|
- local: streaming_video_encoding
|
||||||
title: Streaming Video Encoding
|
title: Streaming Video Encoding
|
||||||
title: "Datasets"
|
title: "Datasets"
|
||||||
|
|||||||
@@ -1,277 +0,0 @@
|
|||||||
# Using Subtasks in LeRobot Datasets
|
|
||||||
|
|
||||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
|
||||||
|
|
||||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
|
||||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
|
||||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
|
||||||
|
|
||||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
|
||||||
|
|
||||||
## What are Subtasks?
|
|
||||||
|
|
||||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
|
||||||
|
|
||||||
1. "Approach the apple"
|
|
||||||
2. "Grasp the apple"
|
|
||||||
3. "Lift the apple"
|
|
||||||
4. "Move to basket"
|
|
||||||
5. "Release the apple"
|
|
||||||
|
|
||||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
|
||||||
|
|
||||||
<img
|
|
||||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
|
||||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
|
||||||
width="80%"
|
|
||||||
/>
|
|
||||||
|
|
||||||
<p>
|
|
||||||
<em>Figure: Overview of subtask annotation.</em>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
|
||||||
|
|
||||||
## Dataset Structure
|
|
||||||
|
|
||||||
Subtask information is stored in the dataset metadata:
|
|
||||||
|
|
||||||
```
|
|
||||||
my-dataset/
|
|
||||||
├── data/
|
|
||||||
│ └── ...
|
|
||||||
├── meta/
|
|
||||||
│ ├── info.json
|
|
||||||
│ ├── stats.json
|
|
||||||
│ ├── tasks.parquet
|
|
||||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
|
||||||
│ └── episodes/
|
|
||||||
│ └── ...
|
|
||||||
└── videos/
|
|
||||||
└── ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### Subtasks Parquet File
|
|
||||||
|
|
||||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
|
||||||
|
|
||||||
| subtask_index | subtask (index column) |
|
|
||||||
| ------------- | ---------------------- |
|
|
||||||
| 0 | "Approach the apple" |
|
|
||||||
| 1 | "Grasp the apple" |
|
|
||||||
| 2 | "Lift the apple" |
|
|
||||||
| ... | ... |
|
|
||||||
|
|
||||||
### Frame-Level Annotations
|
|
||||||
|
|
||||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Example frame data in the parquet file
|
|
||||||
{
|
|
||||||
"index": 42,
|
|
||||||
"timestamp": 1.4,
|
|
||||||
"episode_index": 0,
|
|
||||||
"task_index": 0,
|
|
||||||
"subtask_index": 2, # References "Lift the apple"
|
|
||||||
"observation.state": [...],
|
|
||||||
"action": [...],
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Annotating Datasets with Subtasks
|
|
||||||
|
|
||||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
|
||||||
|
|
||||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
|
||||||
|
|
||||||
After completing your annotation:
|
|
||||||
|
|
||||||
1. Click "Push to Hub" to upload your annotated dataset
|
|
||||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
|
||||||
|
|
||||||
## Loading Datasets with Subtasks
|
|
||||||
|
|
||||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
|
||||||
|
|
||||||
# Load a dataset with subtask annotations
|
|
||||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
|
||||||
|
|
||||||
# Access a sample
|
|
||||||
sample = dataset[100]
|
|
||||||
|
|
||||||
# The sample includes both task and subtask information
|
|
||||||
print(sample["task"]) # "Collect the fruit"
|
|
||||||
print(sample["subtask"]) # "Grasp the apple"
|
|
||||||
print(sample["task_index"]) # tensor(0)
|
|
||||||
print(sample["subtask_index"]) # tensor(2)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Checking for Subtask Support
|
|
||||||
|
|
||||||
You can check if a dataset has subtask annotations:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Check if subtasks are available
|
|
||||||
has_subtasks = (
|
|
||||||
"subtask_index" in dataset.features
|
|
||||||
and dataset.meta.subtasks is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
if has_subtasks:
|
|
||||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
|
||||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
|
||||||
```
|
|
||||||
|
|
||||||
## Using Subtasks for Training
|
|
||||||
|
|
||||||
### With the Tokenizer Processor
|
|
||||||
|
|
||||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lerobot.processor import TokenizerProcessorStep
|
|
||||||
|
|
||||||
# Create a tokenizer processor step
|
|
||||||
tokenizer_processor = TokenizerProcessorStep(
|
|
||||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
|
||||||
padding="max_length",
|
|
||||||
max_length=64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The processor will automatically tokenize subtasks if present in the batch
|
|
||||||
# and add them to the observation under:
|
|
||||||
# - "observation.subtask.tokens"
|
|
||||||
# - "observation.subtask.attention_mask"
|
|
||||||
```
|
|
||||||
|
|
||||||
When subtasks are available in the batch, the tokenizer processor adds:
|
|
||||||
|
|
||||||
- `observation.subtask.tokens`: Tokenized subtask text
|
|
||||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
|
||||||
|
|
||||||
### DataLoader with Subtasks
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
|
||||||
|
|
||||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=16,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch in dataloader:
|
|
||||||
# Access subtask information in the batch
|
|
||||||
subtasks = batch["subtask"] # List of subtask strings
|
|
||||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
|
||||||
|
|
||||||
# Use for training hierarchical policies or reward models
|
|
||||||
print(f"Batch subtasks: {set(subtasks)}")
|
|
||||||
```
|
|
||||||
|
|
||||||
## Example Datasets with Subtask Annotations
|
|
||||||
|
|
||||||
Try loading a dataset with subtask annotations:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
|
||||||
|
|
||||||
# Example dataset with subtask annotations
|
|
||||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
|
||||||
|
|
||||||
# Explore the subtasks
|
|
||||||
print("Available subtasks:")
|
|
||||||
for subtask_name in dataset.meta.subtasks.index:
|
|
||||||
print(f" - {subtask_name}")
|
|
||||||
|
|
||||||
# Get subtask distribution
|
|
||||||
subtask_counts = {}
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
sample = dataset[i]
|
|
||||||
subtask = sample["subtask"]
|
|
||||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
|
||||||
|
|
||||||
print("\nSubtask distribution:")
|
|
||||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
|
||||||
print(f" {subtask}: {count} frames")
|
|
||||||
```
|
|
||||||
|
|
||||||
## Use Cases
|
|
||||||
|
|
||||||
### 1. Hierarchical Policy Training
|
|
||||||
|
|
||||||
Train policies that predict both actions and current subtask:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class HierarchicalPolicy(nn.Module):
|
|
||||||
def __init__(self, num_subtasks):
|
|
||||||
super().__init__()
|
|
||||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
|
||||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
|
||||||
|
|
||||||
def forward(self, observations):
|
|
||||||
features = self.encoder(observations)
|
|
||||||
actions = self.action_head(features)
|
|
||||||
subtask_logits = self.subtask_head(features)
|
|
||||||
return actions, subtask_logits
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Stage-Aware Reward Modeling (SARM)
|
|
||||||
|
|
||||||
Build reward models that understand task progression:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# SARM predicts:
|
|
||||||
# - Stage: Which subtask is being executed (discrete)
|
|
||||||
# - Progress: How far along the subtask (continuous 0-1)
|
|
||||||
|
|
||||||
class SARMRewardModel(nn.Module):
|
|
||||||
def forward(self, observations):
|
|
||||||
features = self.encoder(observations)
|
|
||||||
stage_logits = self.stage_classifier(features)
|
|
||||||
progress = self.progress_regressor(features)
|
|
||||||
return stage_logits, progress
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Progress Visualization
|
|
||||||
|
|
||||||
Monitor robot execution by tracking subtask progression:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def visualize_execution(model, observations):
|
|
||||||
for t, obs in enumerate(observations):
|
|
||||||
action, subtask_logits = model(obs)
|
|
||||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
|
||||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Reference
|
|
||||||
|
|
||||||
### LeRobotDataset Properties
|
|
||||||
|
|
||||||
| Property | Type | Description |
|
|
||||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
|
||||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
|
||||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
|
||||||
|
|
||||||
### Sample Keys
|
|
||||||
|
|
||||||
When subtasks are available, each sample includes:
|
|
||||||
|
|
||||||
| Key | Type | Description |
|
|
||||||
| --------------- | -------------- | ------------------------------------ |
|
|
||||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
|
||||||
| `subtask` | `str` | Natural language subtask description |
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
|
||||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
|
||||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
# Language columns and recipes
|
||||||
|
|
||||||
|
LeRobot stores reusable language annotations directly next to frame data in `data/chunk-*/file-*.parquet`.
|
||||||
|
The two optional columns are:
|
||||||
|
|
||||||
|
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||||
|
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||||
|
|
||||||
|
Both columns share the same row shape:
|
||||||
|
|
||||||
|
```text
|
||||||
|
role: string
|
||||||
|
content: string | null
|
||||||
|
style: string | null
|
||||||
|
timestamp: float64
|
||||||
|
tool_calls: list[Json] | null
|
||||||
|
```
|
||||||
|
|
||||||
|
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The language stack has three layers:
|
||||||
|
|
||||||
|
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||||
|
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||||
|
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||||
|
|
||||||
|
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||||
|
|
||||||
|
## Temporal semantics
|
||||||
|
|
||||||
|
Persistent styles are active after emission until replaced:
|
||||||
|
|
||||||
|
- `active_at(t, style=subtask)`
|
||||||
|
- `nth_prev(style=memory, offset=1)`
|
||||||
|
- `nth_next(style=subtask, offset=1)`
|
||||||
|
|
||||||
|
Event styles only exist on their exact timestamp:
|
||||||
|
|
||||||
|
- `emitted_at(t, style=interjection)`
|
||||||
|
- `emitted_at(t, style=vqa, role=user)`
|
||||||
|
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||||
|
|
||||||
|
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||||
|
|
||||||
|
## Recipe anatomy
|
||||||
|
|
||||||
|
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
messages:
|
||||||
|
- { role: user, content: "${task}", stream: high_level }
|
||||||
|
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||||
|
```
|
||||||
|
|
||||||
|
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||||
|
|
||||||
|
```python
|
||||||
|
sample["messages"]
|
||||||
|
sample["message_streams"]
|
||||||
|
sample["target_message_indices"]
|
||||||
|
```
|
||||||
|
|
||||||
|
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone.
|
||||||
|
|
||||||
|
## Blends
|
||||||
|
|
||||||
|
Blend recipes select one weighted sub-recipe deterministically from the sample index.
|
||||||
|
The canonical `recipes/pi05_hirobot.yaml` combines memory updates, interjection responses, high-level subtask prediction, low-level execution, and VQA.
|
||||||
|
|
||||||
|
## Graceful absence
|
||||||
|
|
||||||
|
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||||
|
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||||
+1
-1
@@ -95,7 +95,7 @@ dependencies = [
|
|||||||
|
|
||||||
# ── Feature-scoped extras ──────────────────────────────────
|
# ── Feature-scoped extras ──────────────────────────────────
|
||||||
dataset = [
|
dataset = [
|
||||||
"datasets>=4.0.0,<5.0.0",
|
"datasets>=4.7.0,<5.0.0",
|
||||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"lerobot[av-dep]",
|
"lerobot[av-dep]",
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
|||||||
|
|
||||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||||
from .policies import PreTrainedConfig
|
from .policies import PreTrainedConfig
|
||||||
|
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||||
from .types import (
|
from .types import (
|
||||||
FeatureType,
|
FeatureType,
|
||||||
NormalizationMode,
|
NormalizationMode,
|
||||||
@@ -41,7 +42,10 @@ __all__ = [
|
|||||||
# Config classes
|
# Config classes
|
||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"EvalConfig",
|
"EvalConfig",
|
||||||
|
"MessageTurn",
|
||||||
"PeftConfig",
|
"PeftConfig",
|
||||||
"PreTrainedConfig",
|
"PreTrainedConfig",
|
||||||
|
"TrainingRecipe",
|
||||||
"WandBConfig",
|
"WandBConfig",
|
||||||
|
"load_recipe",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal, get_args
|
||||||
|
|
||||||
|
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||||
|
MessageStream = Literal["high_level", "low_level"]
|
||||||
|
|
||||||
|
DEFAULT_BINDINGS = {
|
||||||
|
"subtask": "active_at(t, style=subtask)",
|
||||||
|
"memory": "active_at(t, style=memory)",
|
||||||
|
"plan": "active_at(t, style=plan)",
|
||||||
|
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||||
|
"interjection": "emitted_at(t, style=interjection)",
|
||||||
|
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||||
|
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||||
|
}
|
||||||
|
|
||||||
|
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||||
|
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||||
|
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageTurn:
|
||||||
|
"""A single chat-style turn in a recipe template.
|
||||||
|
|
||||||
|
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||||
|
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||||
|
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||||
|
training target, and ``if_present`` skips the turn when the named binding
|
||||||
|
resolves to ``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: MessageRole
|
||||||
|
content: str | list[dict[str, Any]] | None = None
|
||||||
|
stream: MessageStream | None = None
|
||||||
|
target: bool = False
|
||||||
|
if_present: str | None = None
|
||||||
|
tool_calls_from: str | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate role, stream, and content after dataclass construction."""
|
||||||
|
if self.role not in _VALID_ROLES:
|
||||||
|
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||||
|
if self.stream is not None and self.stream not in _VALID_STREAMS:
|
||||||
|
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||||
|
if self.content is None and self.tool_calls_from is None:
|
||||||
|
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||||
|
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||||
|
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||||
|
if isinstance(self.content, list):
|
||||||
|
for block in self.content:
|
||||||
|
if not isinstance(block, dict) or "type" not in block:
|
||||||
|
raise ValueError(
|
||||||
|
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||||
|
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingRecipe:
|
||||||
|
"""A recipe describing how to render training samples from language rows.
|
||||||
|
|
||||||
|
A recipe is either a *message recipe* (``messages`` plus optional
|
||||||
|
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||||
|
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: list[MessageTurn] | None = None
|
||||||
|
bindings: dict[str, str] | None = None
|
||||||
|
blend: dict[str, TrainingRecipe] | None = None
|
||||||
|
weight: float | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||||
|
if self.messages is not None and self.blend is not None:
|
||||||
|
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||||
|
if self.messages is None and self.blend is None:
|
||||||
|
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||||
|
|
||||||
|
if self.messages is not None:
|
||||||
|
self._validate_message_recipe()
|
||||||
|
if self.blend is not None:
|
||||||
|
self._validate_blend_recipe()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||||
|
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||||
|
data = dict(data)
|
||||||
|
if data.get("messages") is not None:
|
||||||
|
data["messages"] = [
|
||||||
|
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||||
|
for turn in data["messages"]
|
||||||
|
]
|
||||||
|
if data.get("blend") is not None:
|
||||||
|
data["blend"] = {
|
||||||
|
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||||
|
for name, recipe in data["blend"].items()
|
||||||
|
}
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||||
|
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||||
|
import yaml # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
with open(path) as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||||
|
return cls.from_dict(data)
|
||||||
|
|
||||||
|
def _validate_message_recipe(self) -> None:
|
||||||
|
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||||
|
assert self.messages is not None
|
||||||
|
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||||
|
|
||||||
|
for turn in self.messages:
|
||||||
|
missing = self._referenced_bindings(turn) - known_bindings
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||||
|
|
||||||
|
if not any(turn.target for turn in self.messages):
|
||||||
|
raise ValueError("Message recipes must contain at least one target turn.")
|
||||||
|
|
||||||
|
def _validate_blend_recipe(self) -> None:
|
||||||
|
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||||
|
assert self.blend is not None
|
||||||
|
if not self.blend:
|
||||||
|
raise ValueError("Blend recipes must contain at least one component.")
|
||||||
|
|
||||||
|
for name, recipe in self.blend.items():
|
||||||
|
if recipe.blend is not None:
|
||||||
|
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||||
|
if recipe.messages is None:
|
||||||
|
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||||
|
if recipe.weight is None:
|
||||||
|
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||||
|
if recipe.weight <= 0:
|
||||||
|
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||||
|
|
||||||
|
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||||
|
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||||
|
names: set[str] = set()
|
||||||
|
if turn.if_present is not None:
|
||||||
|
names.add(turn.if_present)
|
||||||
|
if turn.tool_calls_from is not None:
|
||||||
|
names.add(turn.tool_calls_from)
|
||||||
|
names.update(_placeholders_in_content(turn.content))
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||||
|
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||||
|
if content is None:
|
||||||
|
return set()
|
||||||
|
if isinstance(content, str):
|
||||||
|
return set(_PLACEHOLDER_RE.findall(content))
|
||||||
|
|
||||||
|
names: set[str] = set()
|
||||||
|
for block in content:
|
||||||
|
for value in block.values():
|
||||||
|
if isinstance(value, str):
|
||||||
|
names.update(_PLACEHOLDER_RE.findall(value))
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||||
|
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||||
|
return TrainingRecipe.from_yaml(path)
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
blend:
|
||||||
|
|
||||||
|
memory_update:
|
||||||
|
weight: 0.10
|
||||||
|
bindings:
|
||||||
|
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||||
|
current_memory: "emitted_at(t, style=memory)"
|
||||||
|
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${task}", stream: high_level}
|
||||||
|
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||||
|
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||||
|
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||||
|
|
||||||
|
user_interjection_response:
|
||||||
|
weight: 0.16
|
||||||
|
bindings:
|
||||||
|
prior_plan: "nth_prev(style=plan, offset=1)"
|
||||||
|
current_plan: "emitted_at(t, style=plan)"
|
||||||
|
interjection: "emitted_at(t, style=interjection)"
|
||||||
|
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${task}", stream: high_level}
|
||||||
|
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
||||||
|
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||||
|
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
||||||
|
|
||||||
|
high_level_subtask:
|
||||||
|
weight: 0.15
|
||||||
|
bindings:
|
||||||
|
next_subtask: "nth_next(style=subtask, offset=1)"
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||||
|
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
|
||||||
|
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
|
||||||
|
|
||||||
|
low_level_execution:
|
||||||
|
weight: 0.35
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||||
|
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
|
||||||
|
|
||||||
|
ask_vqa:
|
||||||
|
weight: 0.20
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${vqa_query}", stream: high_level, if_present: vqa_query}
|
||||||
|
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||||
@@ -37,6 +37,14 @@ from .dataset_tools import (
|
|||||||
from .factory import make_dataset, resolve_delta_timestamps
|
from .factory import make_dataset, resolve_delta_timestamps
|
||||||
from .image_writer import safe_stop_image_writer
|
from .image_writer import safe_stop_image_writer
|
||||||
from .io_utils import load_episodes, write_stats
|
from .io_utils import load_episodes, write_stats
|
||||||
|
from .language import (
|
||||||
|
EVENT_ONLY_STYLES,
|
||||||
|
LANGUAGE_EVENTS,
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
PERSISTENT_STYLES,
|
||||||
|
STYLE_REGISTRY,
|
||||||
|
column_for_style,
|
||||||
|
)
|
||||||
from .lerobot_dataset import LeRobotDataset
|
from .lerobot_dataset import LeRobotDataset
|
||||||
from .multi_dataset import MultiLeRobotDataset
|
from .multi_dataset import MultiLeRobotDataset
|
||||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
@@ -53,10 +61,15 @@ __all__ = [
|
|||||||
"CODEBASE_VERSION",
|
"CODEBASE_VERSION",
|
||||||
"DEFAULT_EPISODES_PATH",
|
"DEFAULT_EPISODES_PATH",
|
||||||
"DEFAULT_QUANTILES",
|
"DEFAULT_QUANTILES",
|
||||||
|
"EVENT_ONLY_STYLES",
|
||||||
"EpisodeAwareSampler",
|
"EpisodeAwareSampler",
|
||||||
|
"LANGUAGE_EVENTS",
|
||||||
|
"LANGUAGE_PERSISTENT",
|
||||||
"LeRobotDataset",
|
"LeRobotDataset",
|
||||||
"LeRobotDatasetMetadata",
|
"LeRobotDatasetMetadata",
|
||||||
"MultiLeRobotDataset",
|
"MultiLeRobotDataset",
|
||||||
|
"PERSISTENT_STYLES",
|
||||||
|
"STYLE_REGISTRY",
|
||||||
"StreamingLeRobotDataset",
|
"StreamingLeRobotDataset",
|
||||||
"VideoEncodingManager",
|
"VideoEncodingManager",
|
||||||
"add_features",
|
"add_features",
|
||||||
@@ -66,6 +79,7 @@ __all__ = [
|
|||||||
"convert_image_to_video_dataset",
|
"convert_image_to_video_dataset",
|
||||||
"create_initial_features",
|
"create_initial_features",
|
||||||
"create_lerobot_dataset_card",
|
"create_lerobot_dataset_card",
|
||||||
|
"column_for_style",
|
||||||
"delete_episodes",
|
"delete_episodes",
|
||||||
"get_feature_stats",
|
"get_feature_stats",
|
||||||
"load_episodes",
|
"load_episodes",
|
||||||
|
|||||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
|||||||
|
|
||||||
ep_stats = {}
|
ep_stats = {}
|
||||||
for key, data in episode_data.items():
|
for key, data in episode_data.items():
|
||||||
if features[key]["dtype"] == "string":
|
if features[key]["dtype"] in {"string", "language"}:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if features[key]["dtype"] in ["image", "video"]:
|
if features[key]["dtype"] in ["image", "video"]:
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ from .io_utils import (
|
|||||||
load_episodes,
|
load_episodes,
|
||||||
load_info,
|
load_info,
|
||||||
load_stats,
|
load_stats,
|
||||||
load_subtasks,
|
|
||||||
load_tasks,
|
load_tasks,
|
||||||
write_info,
|
write_info,
|
||||||
write_json,
|
write_json,
|
||||||
@@ -177,7 +176,6 @@ class LeRobotDatasetMetadata:
|
|||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||||
self.tasks = load_tasks(self.root)
|
self.tasks = load_tasks(self.root)
|
||||||
self.subtasks = load_subtasks(self.root)
|
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
self.stats = load_stats(self.root)
|
self.stats = load_stats(self.root)
|
||||||
|
|
||||||
@@ -635,7 +633,6 @@ class LeRobotDatasetMetadata:
|
|||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
|
|
||||||
obj.tasks = None
|
obj.tasks = None
|
||||||
obj.subtasks = None
|
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
obj.stats = None
|
obj.stats = None
|
||||||
obj.info = create_empty_dataset_info(
|
obj.info = create_empty_dataset_info(
|
||||||
|
|||||||
@@ -295,9 +295,4 @@ class DatasetReader:
|
|||||||
task_idx = item["task_index"].item()
|
task_idx = item["task_index"].item()
|
||||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||||
|
|
||||||
# add subtask information if available
|
|
||||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
|
||||||
subtask_idx = item["subtask_index"].item()
|
|
||||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|||||||
@@ -22,6 +22,12 @@ from PIL import Image as PILImage
|
|||||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||||
|
|
||||||
|
from .language import (
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
is_language_column,
|
||||||
|
language_events_column_feature,
|
||||||
|
language_persistent_column_feature,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
@@ -45,7 +51,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
|||||||
"""
|
"""
|
||||||
hf_features = {}
|
hf_features = {}
|
||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "video":
|
if is_language_column(key):
|
||||||
|
hf_features[key] = (
|
||||||
|
language_persistent_column_feature()
|
||||||
|
if key == LANGUAGE_PERSISTENT
|
||||||
|
else language_events_column_feature()
|
||||||
|
)
|
||||||
|
elif ft["dtype"] == "video":
|
||||||
continue
|
continue
|
||||||
elif ft["dtype"] == "image":
|
elif ft["dtype"] == "image":
|
||||||
hf_features[key] = datasets.Image()
|
hf_features[key] = datasets.Image()
|
||||||
@@ -242,6 +254,8 @@ def validate_feature_dtype_and_shape(
|
|||||||
return validate_feature_image_or_video(name, expected_shape, value)
|
return validate_feature_image_or_video(name, expected_shape, value)
|
||||||
elif expected_dtype == "string":
|
elif expected_dtype == "string":
|
||||||
return validate_feature_string(name, value)
|
return validate_feature_string(name, value)
|
||||||
|
elif expected_dtype == "language":
|
||||||
|
return ""
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_SUBTASKS_PATH,
|
|
||||||
DEFAULT_TASKS_PATH,
|
DEFAULT_TASKS_PATH,
|
||||||
EPISODES_DIR,
|
EPISODES_DIR,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
@@ -189,14 +188,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
|||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
|
||||||
"""Load subtasks from subtasks.parquet if it exists."""
|
|
||||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
|
||||||
if subtasks_path.exists():
|
|
||||||
return pd.read_parquet(subtasks_path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||||
This function writes episode-level metadata to a single parquet file.
|
This function writes episode-level metadata to a single parquet file.
|
||||||
@@ -268,11 +259,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
|||||||
dict: The batch with items converted to torch tensors.
|
dict: The batch with items converted to torch tensors.
|
||||||
"""
|
"""
|
||||||
for key in items_dict:
|
for key in items_dict:
|
||||||
|
if key in {"language_persistent", "language_events"}:
|
||||||
|
continue
|
||||||
first_item = items_dict[key][0]
|
first_item = items_dict[key][0]
|
||||||
if isinstance(first_item, PILImage.Image):
|
if isinstance(first_item, PILImage.Image):
|
||||||
to_tensor = transforms.ToTensor()
|
to_tensor = transforms.ToTensor()
|
||||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||||
elif first_item is None:
|
elif first_item is None or isinstance(first_item, dict):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||||
@@ -308,7 +301,11 @@ def item_to_torch(item: dict) -> dict:
|
|||||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||||
"""
|
"""
|
||||||
for key, val in item.items():
|
for key, val in item.items():
|
||||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
if isinstance(val, (np.ndarray | list)) and key not in [
|
||||||
|
"task",
|
||||||
|
"language_persistent",
|
||||||
|
"language_events",
|
||||||
|
]:
|
||||||
# Convert numpy arrays and lists to torch tensors
|
# Convert numpy arrays and lists to torch tensors
|
||||||
item[key] = torch.tensor(val)
|
item[key] = torch.tensor(val)
|
||||||
return item
|
return item
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
LANGUAGE_PERSISTENT = "language_persistent"
|
||||||
|
LANGUAGE_EVENTS = "language_events"
|
||||||
|
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||||
|
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
|
||||||
|
EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls")
|
||||||
|
|
||||||
|
CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"}
|
||||||
|
EXTENDED_STYLES = set()
|
||||||
|
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||||
|
|
||||||
|
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"}
|
||||||
|
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||||
|
|
||||||
|
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||||
|
|
||||||
|
|
||||||
|
def _json_arrow_type() -> pa.DataType:
|
||||||
|
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||||
|
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||||
|
|
||||||
|
|
||||||
|
def _json_feature() -> object:
|
||||||
|
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||||
|
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||||
|
|
||||||
|
|
||||||
|
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||||
|
"""Return the Arrow struct type for a single persistent language row.
|
||||||
|
|
||||||
|
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||||
|
that became active at a specific moment and remains active until superseded.
|
||||||
|
"""
|
||||||
|
return pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("role", pa.string(), nullable=False),
|
||||||
|
pa.field("content", pa.string(), nullable=True),
|
||||||
|
pa.field("style", pa.string(), nullable=True),
|
||||||
|
pa.field("timestamp", pa.float64(), nullable=False),
|
||||||
|
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def language_event_row_arrow_type() -> pa.StructType:
|
||||||
|
"""Return the Arrow struct type for a single event language row.
|
||||||
|
|
||||||
|
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||||
|
row whose frame timestamp is the event's firing time.
|
||||||
|
"""
|
||||||
|
return pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("role", pa.string(), nullable=False),
|
||||||
|
pa.field("content", pa.string(), nullable=True),
|
||||||
|
pa.field("style", pa.string(), nullable=True),
|
||||||
|
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def language_persistent_arrow_type() -> pa.ListType:
|
||||||
|
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||||
|
return pa.list_(language_persistent_row_arrow_type())
|
||||||
|
|
||||||
|
|
||||||
|
def language_events_arrow_type() -> pa.ListType:
|
||||||
|
"""Return the Arrow list type for the ``language_events`` column."""
|
||||||
|
return pa.list_(language_event_row_arrow_type())
|
||||||
|
|
||||||
|
|
||||||
|
def language_persistent_row_feature() -> dict[str, object]:
|
||||||
|
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||||
|
return {
|
||||||
|
"role": datasets.Value("string"),
|
||||||
|
"content": datasets.Value("string"),
|
||||||
|
"style": datasets.Value("string"),
|
||||||
|
"timestamp": datasets.Value("float64"),
|
||||||
|
"tool_calls": datasets.List(_json_feature()),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def language_event_row_feature() -> dict[str, object]:
|
||||||
|
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||||
|
return {
|
||||||
|
"role": datasets.Value("string"),
|
||||||
|
"content": datasets.Value("string"),
|
||||||
|
"style": datasets.Value("string"),
|
||||||
|
"tool_calls": datasets.List(_json_feature()),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def language_persistent_column_feature() -> datasets.List:
|
||||||
|
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||||
|
return datasets.List(language_persistent_row_feature())
|
||||||
|
|
||||||
|
|
||||||
|
def language_events_column_feature() -> datasets.List:
|
||||||
|
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||||
|
return datasets.List(language_event_row_feature())
|
||||||
|
|
||||||
|
|
||||||
|
def language_feature_info() -> dict[str, dict]:
|
||||||
|
"""Return the ``info["features"]`` entries for both language columns."""
|
||||||
|
return {
|
||||||
|
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||||
|
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_language_column(key: str) -> bool:
|
||||||
|
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||||
|
return key in LANGUAGE_COLUMNS
|
||||||
|
|
||||||
|
|
||||||
|
def column_for_style(style: str | None) -> LanguageColumn:
|
||||||
|
"""Map a language style to the column where rows of that style are stored.
|
||||||
|
|
||||||
|
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||||
|
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||||
|
to :data:`LANGUAGE_EVENTS`.
|
||||||
|
"""
|
||||||
|
if style is None:
|
||||||
|
return LANGUAGE_EVENTS
|
||||||
|
if style in PERSISTENT_STYLES:
|
||||||
|
return LANGUAGE_PERSISTENT
|
||||||
|
if style in EVENT_ONLY_STYLES:
|
||||||
|
return LANGUAGE_EVENTS
|
||||||
|
raise ValueError(f"Unknown language style: {style!r}")
|
||||||
@@ -0,0 +1,511 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import hashlib
|
||||||
|
import re
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
||||||
|
|
||||||
|
from .language import (
|
||||||
|
EVENT_ONLY_STYLES,
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
PERSISTENT_STYLES,
|
||||||
|
column_for_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
LanguageRow = dict[str, Any]
|
||||||
|
RenderedMessages = dict[str, list[Any]]
|
||||||
|
|
||||||
|
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||||
|
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||||
|
|
||||||
|
|
||||||
|
def active_at(
|
||||||
|
t: float,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
events: Sequence[LanguageRow] | None = None,
|
||||||
|
style: str | None = None,
|
||||||
|
role: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||||
|
|
||||||
|
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||||
|
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``
|
||||||
|
selector. ``events`` is accepted for resolver-signature uniformity but is
|
||||||
|
not consulted: only persistent styles are valid here.
|
||||||
|
"""
|
||||||
|
_validate_persistent_resolver("active_at", style)
|
||||||
|
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
|
||||||
|
matches = [row for row in matches if _timestamp(row) <= t]
|
||||||
|
return _select_latest(matches, style=style, role=role, tool_name=tool_name)
|
||||||
|
|
||||||
|
|
||||||
|
def emitted_at(
|
||||||
|
t: float,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
events: Sequence[LanguageRow],
|
||||||
|
style: str | None = None,
|
||||||
|
role: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||||
|
|
||||||
|
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||||
|
equals ``t``. For event styles, the ``events`` list is assumed to come from
|
||||||
|
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
|
||||||
|
so all matching event rows are considered emitted at ``t``.
|
||||||
|
"""
|
||||||
|
column = column_for_style(style)
|
||||||
|
if column == LANGUAGE_PERSISTENT:
|
||||||
|
matches = [
|
||||||
|
row
|
||||||
|
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
|
||||||
|
if _timestamp(row) == t
|
||||||
|
]
|
||||||
|
return _select_one(
|
||||||
|
matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key
|
||||||
|
)
|
||||||
|
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name)
|
||||||
|
return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key)
|
||||||
|
|
||||||
|
|
||||||
|
def nth_prev(
|
||||||
|
t: float,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
events: Sequence[LanguageRow] | None = None,
|
||||||
|
style: str | None = None,
|
||||||
|
offset: int = 1,
|
||||||
|
role: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||||
|
|
||||||
|
Walks back through chronologically sorted persistent rows of ``style``
|
||||||
|
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
|
||||||
|
positions before the row active at ``t``. Only valid for persistent styles.
|
||||||
|
"""
|
||||||
|
return _nth_relative(
|
||||||
|
t,
|
||||||
|
persistent=persistent,
|
||||||
|
style=style,
|
||||||
|
offset=-offset,
|
||||||
|
role=role,
|
||||||
|
tool_name=tool_name,
|
||||||
|
resolver_name="nth_prev",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def nth_next(
|
||||||
|
t: float,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
events: Sequence[LanguageRow] | None = None,
|
||||||
|
style: str | None = None,
|
||||||
|
offset: int = 1,
|
||||||
|
role: str | None = None,
|
||||||
|
tool_name: str | None = None,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||||
|
|
||||||
|
Walks forward through chronologically sorted persistent rows of ``style``
|
||||||
|
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
|
||||||
|
positions after the row active at ``t``. Only valid for persistent styles.
|
||||||
|
"""
|
||||||
|
return _nth_relative(
|
||||||
|
t,
|
||||||
|
persistent=persistent,
|
||||||
|
style=style,
|
||||||
|
offset=offset,
|
||||||
|
role=role,
|
||||||
|
tool_name=tool_name,
|
||||||
|
resolver_name="nth_next",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_sample(
|
||||||
|
*,
|
||||||
|
recipe: TrainingRecipe,
|
||||||
|
persistent: Sequence[LanguageRow] | None,
|
||||||
|
events: Sequence[LanguageRow] | None,
|
||||||
|
t: float,
|
||||||
|
sample_idx: int,
|
||||||
|
task: str | None = None,
|
||||||
|
dataset_ctx: Any | None = None,
|
||||||
|
) -> RenderedMessages | None:
|
||||||
|
"""Render the chat-style messages for a single dataset sample.
|
||||||
|
|
||||||
|
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||||
|
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||||
|
Returns ``None`` if the resolved sample contains no target message.
|
||||||
|
"""
|
||||||
|
persistent_rows = _normalize_rows(persistent or [])
|
||||||
|
event_rows = _normalize_rows(events or [])
|
||||||
|
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||||
|
bindings = _resolve_bindings(
|
||||||
|
selected_recipe,
|
||||||
|
persistent=persistent_rows,
|
||||||
|
events=event_rows,
|
||||||
|
t=t,
|
||||||
|
task=task,
|
||||||
|
dataset_ctx=dataset_ctx,
|
||||||
|
)
|
||||||
|
return _render_message_recipe(selected_recipe, bindings)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||||
|
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||||
|
if recipe.blend is None:
|
||||||
|
return recipe
|
||||||
|
|
||||||
|
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||||
|
if total_weight <= 0:
|
||||||
|
raise ValueError("Blend weights must sum to a positive value.")
|
||||||
|
|
||||||
|
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||||
|
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||||
|
cumulative = 0.0
|
||||||
|
last_component: TrainingRecipe | None = None
|
||||||
|
for component in recipe.blend.values():
|
||||||
|
last_component = component
|
||||||
|
cumulative += component.weight or 0.0
|
||||||
|
if draw < cumulative:
|
||||||
|
return component
|
||||||
|
assert last_component is not None
|
||||||
|
return last_component
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_bindings(
|
||||||
|
recipe: TrainingRecipe,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
events: Sequence[LanguageRow],
|
||||||
|
t: float,
|
||||||
|
task: str | None,
|
||||||
|
dataset_ctx: Any | None,
|
||||||
|
) -> dict[str, LanguageRow | str | None]:
|
||||||
|
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||||
|
bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)}
|
||||||
|
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||||
|
for name, spec in specs.items():
|
||||||
|
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||||
|
return bindings
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None:
|
||||||
|
"""Return ``task`` if set, otherwise look it up on ``dataset_ctx``."""
|
||||||
|
if task is not None:
|
||||||
|
return task
|
||||||
|
if dataset_ctx is None:
|
||||||
|
return None
|
||||||
|
if isinstance(dataset_ctx, dict):
|
||||||
|
return dataset_ctx.get("task")
|
||||||
|
return getattr(dataset_ctx, "task", None)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_spec(
|
||||||
|
spec: str,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
events: Sequence[LanguageRow],
|
||||||
|
t: float,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||||
|
match = _RESOLVER_RE.match(spec.strip())
|
||||||
|
if match is None:
|
||||||
|
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||||
|
name = match.group("name")
|
||||||
|
kwargs = _parse_resolver_args(match.group("args"))
|
||||||
|
kwargs.pop("t_arg", None)
|
||||||
|
|
||||||
|
resolvers = {
|
||||||
|
"active_at": active_at,
|
||||||
|
"emitted_at": emitted_at,
|
||||||
|
"nth_prev": nth_prev,
|
||||||
|
"nth_next": nth_next,
|
||||||
|
}
|
||||||
|
if name not in resolvers:
|
||||||
|
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||||
|
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||||
|
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if not args.strip():
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||||
|
for part in parts:
|
||||||
|
if part == "t":
|
||||||
|
kwargs["t_arg"] = True
|
||||||
|
continue
|
||||||
|
if "=" not in part:
|
||||||
|
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||||
|
key, value = (item.strip() for item in part.split("=", 1))
|
||||||
|
if key == "offset":
|
||||||
|
kwargs[key] = int(value)
|
||||||
|
else:
|
||||||
|
kwargs[key] = value.strip("\"'")
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _render_message_recipe(
|
||||||
|
recipe: TrainingRecipe,
|
||||||
|
bindings: dict[str, LanguageRow | str | None],
|
||||||
|
) -> RenderedMessages | None:
|
||||||
|
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||||
|
assert recipe.messages is not None
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
streams: list[str | None] = []
|
||||||
|
target_indices: list[int] = []
|
||||||
|
|
||||||
|
for turn in recipe.messages:
|
||||||
|
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
message = {"role": turn.role}
|
||||||
|
if turn.content is not None:
|
||||||
|
message["content"] = _render_content(turn.content, bindings)
|
||||||
|
|
||||||
|
if turn.tool_calls_from is not None:
|
||||||
|
row = bindings.get(turn.tool_calls_from)
|
||||||
|
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||||
|
if tool_calls:
|
||||||
|
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||||
|
|
||||||
|
message_idx = len(messages)
|
||||||
|
messages.append(message)
|
||||||
|
streams.append(turn.stream)
|
||||||
|
if turn.target:
|
||||||
|
target_indices.append(message_idx)
|
||||||
|
|
||||||
|
if not target_indices:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rendered = {
|
||||||
|
"messages": messages,
|
||||||
|
"message_streams": streams,
|
||||||
|
"target_message_indices": target_indices,
|
||||||
|
}
|
||||||
|
_validate_rendered(rendered)
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
|
||||||
|
def _render_content(
|
||||||
|
content: str | list[dict[str, Any]],
|
||||||
|
bindings: dict[str, LanguageRow | str | None],
|
||||||
|
) -> str | list[dict[str, Any]]:
|
||||||
|
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return _substitute(content, bindings)
|
||||||
|
|
||||||
|
rendered_blocks = []
|
||||||
|
for block in content:
|
||||||
|
rendered_block = copy.deepcopy(block)
|
||||||
|
for key, value in rendered_block.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
rendered_block[key] = _substitute(value, bindings)
|
||||||
|
rendered_blocks.append(rendered_block)
|
||||||
|
return rendered_blocks
|
||||||
|
|
||||||
|
|
||||||
|
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||||
|
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||||
|
|
||||||
|
def replace(match: re.Match[str]) -> str:
|
||||||
|
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||||
|
name = match.group(1)
|
||||||
|
if name not in bindings:
|
||||||
|
raise ValueError(f"Unknown template binding: {name!r}")
|
||||||
|
value = bindings[name]
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
content = value.get("content")
|
||||||
|
return "" if content is None else str(content)
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
return _PLACEHOLDER_RE.sub(replace, template)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||||
|
"""Sanity-check the rendered output for stream/target alignment."""
|
||||||
|
messages = rendered["messages"]
|
||||||
|
streams = rendered["message_streams"]
|
||||||
|
target_indices = rendered["target_message_indices"]
|
||||||
|
|
||||||
|
if len(streams) != len(messages):
|
||||||
|
raise ValueError("message_streams must be aligned with messages.")
|
||||||
|
if not target_indices:
|
||||||
|
raise ValueError("Rendered samples must contain at least one target message.")
|
||||||
|
for idx in target_indices:
|
||||||
|
if idx < 0 or idx >= len(messages):
|
||||||
|
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||||
|
for idx, stream in enumerate(streams):
|
||||||
|
if stream is None:
|
||||||
|
raise ValueError(f"Rendered message {idx} has no stream.")
|
||||||
|
|
||||||
|
|
||||||
|
def _nth_relative(
|
||||||
|
t: float,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow],
|
||||||
|
style: str | None,
|
||||||
|
offset: int,
|
||||||
|
role: str | None,
|
||||||
|
tool_name: str | None,
|
||||||
|
resolver_name: str,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||||
|
_validate_persistent_resolver(resolver_name, style)
|
||||||
|
if abs(offset) < 1:
|
||||||
|
raise ValueError(f"{resolver_name} offset must be non-zero.")
|
||||||
|
|
||||||
|
rows = sorted(
|
||||||
|
_matching_rows(persistent, style=style, role=role, tool_name=tool_name),
|
||||||
|
key=_persistent_sort_key,
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
anchor_idx = None
|
||||||
|
for idx, row in enumerate(rows):
|
||||||
|
if _timestamp(row) <= t:
|
||||||
|
anchor_idx = idx
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||||
|
|
||||||
|
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||||
|
return None
|
||||||
|
return rows[target_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
|
||||||
|
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||||
|
if style is None:
|
||||||
|
raise ValueError(f"{resolver_name} requires a persistent style.")
|
||||||
|
if style in EVENT_ONLY_STYLES:
|
||||||
|
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
|
||||||
|
if style not in PERSISTENT_STYLES:
|
||||||
|
column_for_style(style)
|
||||||
|
|
||||||
|
|
||||||
|
def _matching_rows(
|
||||||
|
rows: Sequence[LanguageRow],
|
||||||
|
*,
|
||||||
|
style: str | None,
|
||||||
|
role: str | None,
|
||||||
|
tool_name: str | None,
|
||||||
|
) -> list[LanguageRow]:
|
||||||
|
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name`` selectors."""
|
||||||
|
return [
|
||||||
|
row
|
||||||
|
for row in rows
|
||||||
|
if (style is None or row.get("style") == style)
|
||||||
|
and (role is None or row.get("role") == role)
|
||||||
|
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _select_latest(
|
||||||
|
rows: Sequence[LanguageRow],
|
||||||
|
*,
|
||||||
|
style: str | None,
|
||||||
|
role: str | None,
|
||||||
|
tool_name: str | None,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
rows = sorted(rows, key=_persistent_sort_key)
|
||||||
|
latest_ts = _timestamp(rows[-1])
|
||||||
|
return _select_one(
|
||||||
|
[row for row in rows if _timestamp(row) == latest_ts],
|
||||||
|
style=style,
|
||||||
|
role=role,
|
||||||
|
tool_name=tool_name,
|
||||||
|
sort_key=_persistent_sort_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_one(
|
||||||
|
rows: Sequence[LanguageRow],
|
||||||
|
*,
|
||||||
|
style: str | None,
|
||||||
|
role: str | None,
|
||||||
|
tool_name: str | None,
|
||||||
|
sort_key: Any,
|
||||||
|
) -> LanguageRow | None:
|
||||||
|
"""Return the single matching row, or raise if the selectors are ambiguous."""
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
if len(rows) > 1 and role is None and tool_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate."
|
||||||
|
)
|
||||||
|
return sorted(rows, key=sort_key)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||||
|
"""Sort key for persistent rows: ``(timestamp, style, role)``."""
|
||||||
|
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
|
||||||
|
|
||||||
|
|
||||||
|
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
|
||||||
|
"""Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame)."""
|
||||||
|
return (row.get("style") or "", row.get("role") or "")
|
||||||
|
|
||||||
|
|
||||||
|
def _timestamp(row: LanguageRow) -> float:
|
||||||
|
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||||
|
value = row["timestamp"]
|
||||||
|
return float(value.item() if hasattr(value, "item") else value)
|
||||||
|
|
||||||
|
|
||||||
|
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||||
|
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||||
|
for tool_call in row.get("tool_calls") or []:
|
||||||
|
if isinstance(tool_call, str):
|
||||||
|
continue
|
||||||
|
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||||
|
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||||
|
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||||
|
normalized = []
|
||||||
|
for row in rows:
|
||||||
|
if row is None:
|
||||||
|
continue
|
||||||
|
if hasattr(row, "as_py"):
|
||||||
|
row = row.as_py()
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||||
|
normalized.append(dict(row))
|
||||||
|
return normalized
|
||||||
@@ -83,7 +83,6 @@ VIDEO_DIR = "videos"
|
|||||||
|
|
||||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
|
||||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ from .relative_action_processor import (
|
|||||||
to_relative_actions,
|
to_relative_actions,
|
||||||
)
|
)
|
||||||
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
||||||
|
from .render_messages_processor import RenderMessagesStep
|
||||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -128,6 +129,7 @@ __all__ = [
|
|||||||
"make_default_robot_observation_processor",
|
"make_default_robot_observation_processor",
|
||||||
"AbsoluteActionsProcessorStep",
|
"AbsoluteActionsProcessorStep",
|
||||||
"RelativeActionsProcessorStep",
|
"RelativeActionsProcessorStep",
|
||||||
|
"RenderMessagesStep",
|
||||||
"MapDeltaActionToRobotActionStep",
|
"MapDeltaActionToRobotActionStep",
|
||||||
"MapTensorToDeltaActionDictStep",
|
"MapTensorToDeltaActionDictStep",
|
||||||
"NewLineTaskProcessorStep",
|
"NewLineTaskProcessorStep",
|
||||||
|
|||||||
@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
|||||||
task_index_value = complementary_data["task_index"]
|
task_index_value = complementary_data["task_index"]
|
||||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||||
|
|
||||||
|
complementary_data.pop("language_persistent", None)
|
||||||
|
complementary_data.pop("language_events", None)
|
||||||
|
|
||||||
|
if "messages" in complementary_data:
|
||||||
|
messages = complementary_data["messages"]
|
||||||
|
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||||
|
complementary_data["messages"] = [messages]
|
||||||
|
|
||||||
|
if "message_streams" in complementary_data:
|
||||||
|
streams = complementary_data["message_streams"]
|
||||||
|
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
|
||||||
|
complementary_data["message_streams"] = [streams]
|
||||||
|
|
||||||
|
if "target_message_indices" in complementary_data:
|
||||||
|
indices = complementary_data["target_message_indices"]
|
||||||
|
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
|
||||||
|
complementary_data["target_message_indices"] = [indices]
|
||||||
return complementary_data
|
return complementary_data
|
||||||
|
|
||||||
def transform_features(
|
def transform_features(
|
||||||
|
|||||||
@@ -167,12 +167,35 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
|||||||
"""
|
"""
|
||||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
|
||||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||||
|
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
|
||||||
|
language_persistent_key = (
|
||||||
|
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
|
||||||
|
)
|
||||||
|
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
|
||||||
|
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
|
||||||
|
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
|
||||||
|
target_message_indices_key = (
|
||||||
|
{"target_message_indices": batch["target_message_indices"]}
|
||||||
|
if "target_message_indices" in batch
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
return {
|
||||||
|
**pad_keys,
|
||||||
|
**task_key,
|
||||||
|
**index_key,
|
||||||
|
**task_index_key,
|
||||||
|
**episode_index_key,
|
||||||
|
**timestamp_key,
|
||||||
|
**language_persistent_key,
|
||||||
|
**language_events_key,
|
||||||
|
**messages_key,
|
||||||
|
**message_streams_key,
|
||||||
|
**target_message_indices_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
|
|||||||
@@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.configs.recipe import TrainingRecipe
|
||||||
|
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||||
|
from lerobot.datasets.language_render import render_sample
|
||||||
|
from lerobot.types import EnvTransition, TransitionKey
|
||||||
|
|
||||||
|
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||||
|
class RenderMessagesStep(ProcessorStep):
|
||||||
|
"""Processor step that turns raw language columns into rendered chat messages.
|
||||||
|
|
||||||
|
Reads ``language_persistent`` and ``language_events`` from the transition's
|
||||||
|
complementary data, renders them through ``recipe`` at the sample timestamp,
|
||||||
|
and replaces the raw columns with the resulting ``messages`` /
|
||||||
|
``message_streams`` / ``target_message_indices`` keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
recipe: TrainingRecipe
|
||||||
|
dataset_ctx: Any | None = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||||
|
"""Render messages for a single transition; return ``None`` to drop it."""
|
||||||
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||||
|
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||||
|
|
||||||
|
if not persistent and not events:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
timestamp = complementary_data.get("timestamp")
|
||||||
|
if timestamp is None:
|
||||||
|
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||||
|
|
||||||
|
sample_idx = complementary_data.get("index", 0)
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=self.recipe,
|
||||||
|
persistent=persistent,
|
||||||
|
events=events,
|
||||||
|
t=_scalar(timestamp),
|
||||||
|
sample_idx=int(_scalar(sample_idx)),
|
||||||
|
task=complementary_data.get("task"),
|
||||||
|
dataset_ctx=self.dataset_ctx,
|
||||||
|
)
|
||||||
|
if rendered is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
new_complementary_data = dict(complementary_data)
|
||||||
|
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||||
|
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||||
|
new_complementary_data.update(rendered)
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def _scalar(value: Any) -> float | int:
|
||||||
|
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||||
|
if hasattr(value, "item"):
|
||||||
|
return value.item()
|
||||||
|
if isinstance(value, list) and len(value) == 1:
|
||||||
|
return _scalar(value[0])
|
||||||
|
return value
|
||||||
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
|||||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
|
from lerobot.utils.collate import lerobot_collate_fn
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
@@ -386,6 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type == "cuda",
|
pin_memory=device.type == "cuda",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
|
collate_fn=lerobot_collate_fn,
|
||||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from torch.utils.data._utils.collate import default_collate
|
||||||
|
|
||||||
|
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
||||||
|
|
||||||
|
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
|
||||||
|
|
||||||
|
|
||||||
|
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||||
|
"""Collate function that preserves Python-list and language fields as lists.
|
||||||
|
|
||||||
|
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
|
||||||
|
rendered-message and language fields as plain Python lists, and delegates
|
||||||
|
every other key to PyTorch's ``default_collate``.
|
||||||
|
"""
|
||||||
|
batch = [sample for sample in batch if sample is not None]
|
||||||
|
if not batch:
|
||||||
|
return None
|
||||||
|
|
||||||
|
preserved = {
|
||||||
|
key: [sample[key] for sample in batch if key in sample]
|
||||||
|
for key in _PYTHON_LIST_KEYS
|
||||||
|
if any(key in sample for sample in batch)
|
||||||
|
}
|
||||||
|
tensorizable = [
|
||||||
|
{
|
||||||
|
key: value
|
||||||
|
for key, value in sample.items()
|
||||||
|
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
|
||||||
|
}
|
||||||
|
for sample in batch
|
||||||
|
]
|
||||||
|
collated = default_collate(tensorizable)
|
||||||
|
collated.update(preserved)
|
||||||
|
return collated
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_recipe_validates_unknown_binding():
|
||||||
|
with pytest.raises(ValueError, match="unknown binding"):
|
||||||
|
TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_canonical_recipe_loads():
|
||||||
|
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||||
|
|
||||||
|
assert recipe.blend is not None
|
||||||
|
assert set(recipe.blend) == {
|
||||||
|
"memory_update",
|
||||||
|
"user_interjection_response",
|
||||||
|
"high_level_subtask",
|
||||||
|
"low_level_execution",
|
||||||
|
"ask_vqa",
|
||||||
|
}
|
||||||
|
assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96)
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.datasets import LeRobotDataset
|
||||||
|
from lerobot.datasets.io_utils import write_info
|
||||||
|
from lerobot.datasets.language import (
|
||||||
|
EVENT_ONLY_STYLES,
|
||||||
|
LANGUAGE_EVENTS,
|
||||||
|
LANGUAGE_PERSISTENT,
|
||||||
|
PERSISTENT_STYLES,
|
||||||
|
STYLE_REGISTRY,
|
||||||
|
column_for_style,
|
||||||
|
language_events_arrow_type,
|
||||||
|
language_feature_info,
|
||||||
|
language_persistent_arrow_type,
|
||||||
|
)
|
||||||
|
from lerobot.datasets.utils import DEFAULT_DATA_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def test_language_arrow_schema_has_expected_fields():
|
||||||
|
persistent_row_type = language_persistent_arrow_type().value_type
|
||||||
|
event_row_type = language_events_arrow_type().value_type
|
||||||
|
|
||||||
|
assert isinstance(persistent_row_type, pa.StructType)
|
||||||
|
assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
|
||||||
|
|
||||||
|
assert isinstance(event_row_type, pa.StructType)
|
||||||
|
assert event_row_type.names == ["role", "content", "style", "tool_calls"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_style_registry_routes_columns():
|
||||||
|
assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES
|
||||||
|
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
||||||
|
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
||||||
|
|
||||||
|
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
|
||||||
|
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
||||||
|
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
||||||
|
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
||||||
|
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
||||||
|
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
||||||
|
assert column_for_style("trace") == LANGUAGE_EVENTS
|
||||||
|
assert column_for_style(None) == LANGUAGE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_style_rejected():
|
||||||
|
with pytest.raises(ValueError, match="Unknown language style"):
|
||||||
|
column_for_style("surprise")
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
root = tmp_path / "language_dataset"
|
||||||
|
dataset = empty_lerobot_dataset_factory(
|
||||||
|
root=root,
|
||||||
|
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
|
||||||
|
use_videos=False,
|
||||||
|
)
|
||||||
|
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
|
||||||
|
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
|
||||||
|
dataset.save_episode()
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
persistent = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "reach for the cup",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
event = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is visible?",
|
||||||
|
"style": "vqa",
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||||
|
df = pd.read_parquet(data_path)
|
||||||
|
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
|
||||||
|
df[LANGUAGE_EVENTS] = [[event], []]
|
||||||
|
df.to_parquet(data_path)
|
||||||
|
|
||||||
|
info = dataset.meta.info
|
||||||
|
info["features"].update(language_feature_info())
|
||||||
|
write_info(info, root)
|
||||||
|
|
||||||
|
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
|
||||||
|
|
||||||
|
first = reloaded[0]
|
||||||
|
second = reloaded[1]
|
||||||
|
assert first[LANGUAGE_PERSISTENT] == persistent
|
||||||
|
assert first[LANGUAGE_EVENTS] == [event]
|
||||||
|
assert second[LANGUAGE_PERSISTENT] == persistent
|
||||||
|
assert second[LANGUAGE_EVENTS] == []
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||||
|
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
|
||||||
|
|
||||||
|
|
||||||
|
def persistent_row(role, content, style, timestamp, tool_calls=None):
|
||||||
|
return {
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
"style": style,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def event_row(role, content, style, tool_calls=None):
|
||||||
|
return {
|
||||||
|
"role": role,
|
||||||
|
"content": content,
|
||||||
|
"style": style,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PERSISTENT = [
|
||||||
|
persistent_row("assistant", "plan 0", "plan", 0.0),
|
||||||
|
persistent_row("assistant", "memory 0", "memory", 0.0),
|
||||||
|
persistent_row("assistant", "subtask 0", "subtask", 0.0),
|
||||||
|
persistent_row("assistant", "memory 1", "memory", 1.0),
|
||||||
|
persistent_row("assistant", "subtask 1", "subtask", 1.0),
|
||||||
|
]
|
||||||
|
EVENTS_AT_1 = [
|
||||||
|
event_row("user", "what is visible?", "vqa"),
|
||||||
|
event_row("assistant", '{"count": 2}', "vqa"),
|
||||||
|
]
|
||||||
|
EVENTS_AT_2 = [
|
||||||
|
event_row("user", "skip wiping", "interjection"),
|
||||||
|
event_row(
|
||||||
|
"assistant",
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_temporal_semantics():
|
||||||
|
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
|
||||||
|
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
|
||||||
|
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
|
||||||
|
assert (
|
||||||
|
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
|
||||||
|
== '{"count": 2}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_persistent_relative_resolvers_reject_event_styles():
|
||||||
|
with pytest.raises(ValueError, match="event-only"):
|
||||||
|
active_at(1.0, persistent=PERSISTENT, style="vqa")
|
||||||
|
with pytest.raises(ValueError, match="event-only"):
|
||||||
|
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
|
||||||
|
|
||||||
|
|
||||||
|
def test_nth_prev_and_next():
|
||||||
|
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
|
||||||
|
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_substitution_if_present_multimodal_and_tool_calls():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
{"type": "image", "feature": "observation.images.top"},
|
||||||
|
{"type": "text", "text": "${task}: ${interjection}"},
|
||||||
|
],
|
||||||
|
stream="high_level",
|
||||||
|
if_present="interjection",
|
||||||
|
),
|
||||||
|
MessageTurn(
|
||||||
|
role="assistant",
|
||||||
|
content="${plan}",
|
||||||
|
stream="high_level",
|
||||||
|
target=True,
|
||||||
|
tool_calls_from="speech",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
bindings={"plan": "active_at(t, style=plan)"},
|
||||||
|
)
|
||||||
|
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=PERSISTENT,
|
||||||
|
events=EVENTS_AT_2,
|
||||||
|
t=2.0,
|
||||||
|
sample_idx=0,
|
||||||
|
task="clean kitchen",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
|
||||||
|
assert rendered["messages"][1]["content"] == "plan 0"
|
||||||
|
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
|
||||||
|
assert rendered["message_streams"] == ["high_level", "high_level"]
|
||||||
|
assert rendered["target_message_indices"] == [1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_exact_event_miss_returns_none_when_target_skips():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
|
||||||
|
MessageTurn(
|
||||||
|
role="assistant",
|
||||||
|
content="${vqa}",
|
||||||
|
stream="high_level",
|
||||||
|
target=True,
|
||||||
|
if_present="vqa",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_blend_sampling():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
blend={
|
||||||
|
"a": TrainingRecipe(
|
||||||
|
weight=1.0,
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"b": TrainingRecipe(
|
||||||
|
weight=1.0,
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
first = render_sample(
|
||||||
|
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||||
|
)
|
||||||
|
second = render_sample(
|
||||||
|
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||||
|
)
|
||||||
|
assert first == second
|
||||||
|
|
||||||
|
|
||||||
|
def test_canonical_recipe_can_render_low_level_branch():
|
||||||
|
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||||
|
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||||
|
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=low_level,
|
||||||
|
persistent=PERSISTENT,
|
||||||
|
events=[],
|
||||||
|
t=0.5,
|
||||||
|
sample_idx=0,
|
||||||
|
task="clean kitchen",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
|
||||||
|
assert rendered["message_streams"][-1] == "low_level"
|
||||||
|
assert rendered["target_message_indices"] == [1]
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Tests for subtask functionality in LeRobotDataset.
|
|
||||||
|
|
||||||
These tests verify that:
|
|
||||||
- Subtask information is correctly loaded from datasets that have subtask data
|
|
||||||
- The __getitem__ method correctly adds subtask strings to returned items
|
|
||||||
- Subtask handling gracefully handles missing data
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
|
||||||
|
|
||||||
import pandas as pd # noqa: E402
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
|
|
||||||
|
|
||||||
class TestSubtaskDataset:
|
|
||||||
"""Tests for subtask handling in LeRobotDataset."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def subtask_dataset(self):
|
|
||||||
"""Load the test subtask dataset from the hub."""
|
|
||||||
# Use lerobot/pusht-subtask dataset with episode 1
|
|
||||||
return LeRobotDataset(
|
|
||||||
repo_id="lerobot/pusht-subtask",
|
|
||||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_subtask_dataset_loads(self, subtask_dataset):
|
|
||||||
"""Test that the subtask dataset loads successfully."""
|
|
||||||
assert subtask_dataset is not None
|
|
||||||
assert len(subtask_dataset) > 0
|
|
||||||
|
|
||||||
def test_subtask_metadata_loaded(self, subtask_dataset):
|
|
||||||
"""Test that subtask metadata is loaded when present in dataset."""
|
|
||||||
# The dataset should have subtasks metadata loaded
|
|
||||||
assert subtask_dataset.meta.subtasks is not None
|
|
||||||
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
|
|
||||||
|
|
||||||
def test_subtask_index_in_features(self, subtask_dataset):
|
|
||||||
"""Test that subtask_index is a feature when dataset has subtasks."""
|
|
||||||
assert "subtask_index" in subtask_dataset.features
|
|
||||||
|
|
||||||
def test_getitem_returns_subtask_string(self, subtask_dataset):
|
|
||||||
"""Test that __getitem__ correctly adds subtask string to returned item."""
|
|
||||||
item = subtask_dataset[0]
|
|
||||||
|
|
||||||
# Subtask should be present in the returned item
|
|
||||||
assert "subtask" in item
|
|
||||||
assert isinstance(item["subtask"], str)
|
|
||||||
assert len(item["subtask"]) > 0 # Should not be empty
|
|
||||||
|
|
||||||
def test_getitem_has_subtask_index(self, subtask_dataset):
|
|
||||||
"""Test that __getitem__ includes subtask_index."""
|
|
||||||
item = subtask_dataset[0]
|
|
||||||
|
|
||||||
assert "subtask_index" in item
|
|
||||||
assert isinstance(item["subtask_index"], torch.Tensor)
|
|
||||||
|
|
||||||
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
|
|
||||||
"""Test that subtask_index correctly maps to a subtask in metadata."""
|
|
||||||
item = subtask_dataset[0]
|
|
||||||
|
|
||||||
subtask_idx = item["subtask_index"].item()
|
|
||||||
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
|
|
||||||
|
|
||||||
assert item["subtask"] == subtask_from_metadata
|
|
||||||
|
|
||||||
def test_all_items_have_subtask(self, subtask_dataset):
|
|
||||||
"""Test that all items in the dataset have subtask information."""
|
|
||||||
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
|
|
||||||
item = subtask_dataset[i]
|
|
||||||
assert "subtask" in item
|
|
||||||
assert isinstance(item["subtask"], str)
|
|
||||||
|
|
||||||
def test_task_and_subtask_coexist(self, subtask_dataset):
|
|
||||||
"""Test that both task and subtask are present in returned items."""
|
|
||||||
item = subtask_dataset[0]
|
|
||||||
|
|
||||||
# Both task and subtask should be present
|
|
||||||
assert "task" in item
|
|
||||||
assert "subtask" in item
|
|
||||||
assert isinstance(item["task"], str)
|
|
||||||
assert isinstance(item["subtask"], str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSubtaskDatasetMissing:
|
|
||||||
"""Tests for graceful handling when subtask data is missing."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
|
|
||||||
"""Create a dataset without subtask information."""
|
|
||||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
|
|
||||||
|
|
||||||
# Add some frames and save
|
|
||||||
for _ in range(5):
|
|
||||||
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
|
|
||||||
dataset.save_episode()
|
|
||||||
dataset.finalize()
|
|
||||||
|
|
||||||
# Reload the dataset
|
|
||||||
return LeRobotDataset(dataset.repo_id, root=dataset.root)
|
|
||||||
|
|
||||||
def test_no_subtask_in_features(self, dataset_without_subtasks):
|
|
||||||
"""Test that subtask_index is not in features when not provided."""
|
|
||||||
assert "subtask_index" not in dataset_without_subtasks.features
|
|
||||||
|
|
||||||
def test_getitem_without_subtask(self, dataset_without_subtasks):
|
|
||||||
"""Test that __getitem__ works when subtask is not present."""
|
|
||||||
item = dataset_without_subtasks[0]
|
|
||||||
|
|
||||||
# Item should still be retrievable
|
|
||||||
assert item is not None
|
|
||||||
assert "state" in item
|
|
||||||
assert "task" in item
|
|
||||||
|
|
||||||
# Subtask should NOT be present
|
|
||||||
assert "subtask" not in item
|
|
||||||
|
|
||||||
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
|
|
||||||
"""Test that subtasks metadata is None when not present."""
|
|
||||||
assert dataset_without_subtasks.meta.subtasks is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestSubtaskEdgeCases:
|
|
||||||
"""Edge case tests for subtask handling."""
|
|
||||||
|
|
||||||
def test_subtask_with_multiple_episodes(self):
|
|
||||||
"""Test subtask handling with multiple episodes if available."""
|
|
||||||
try:
|
|
||||||
dataset = LeRobotDataset(
|
|
||||||
repo_id="lerobot/pusht-subtask",
|
|
||||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pytest.skip("Could not load test-subtask dataset")
|
|
||||||
|
|
||||||
# Check first and last items have valid subtasks
|
|
||||||
first_item = dataset[0]
|
|
||||||
last_item = dataset[len(dataset) - 1]
|
|
||||||
|
|
||||||
assert "subtask" in first_item
|
|
||||||
assert "subtask" in last_item
|
|
||||||
assert isinstance(first_item["subtask"], str)
|
|
||||||
assert isinstance(last_item["subtask"], str)
|
|
||||||
|
|
||||||
def test_subtask_index_consistency(self):
|
|
||||||
"""Test that same subtask_index returns same subtask string."""
|
|
||||||
try:
|
|
||||||
dataset = LeRobotDataset(
|
|
||||||
repo_id="lerobot/pusht-subtask",
|
|
||||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pytest.skip("Could not load test-subtask dataset")
|
|
||||||
|
|
||||||
if len(dataset) < 2:
|
|
||||||
pytest.skip("Dataset too small for this test")
|
|
||||||
|
|
||||||
# Collect subtask_index to subtask mappings
|
|
||||||
subtask_map = {}
|
|
||||||
for i in range(min(len(dataset), 10)):
|
|
||||||
item = dataset[i]
|
|
||||||
idx = item["subtask_index"].item()
|
|
||||||
subtask = item["subtask"]
|
|
||||||
|
|
||||||
if idx in subtask_map:
|
|
||||||
# Same index should always return same subtask
|
|
||||||
assert subtask_map[idx] == subtask, (
|
|
||||||
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
subtask_map[idx] = subtask
|
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||||
|
from lerobot.processor.converters import create_transition
|
||||||
|
from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||||
|
from lerobot.types import TransitionKey
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_messages_step_noops_without_language_columns():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
transition = create_transition(complementary_data={"task": "do it"})
|
||||||
|
|
||||||
|
assert RenderMessagesStep(recipe)(transition) == transition
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_messages_step_renders_and_drops_raw_language():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
transition = create_transition(
|
||||||
|
complementary_data={
|
||||||
|
"task": "do it",
|
||||||
|
"timestamp": torch.tensor(0.0),
|
||||||
|
"index": torch.tensor(7),
|
||||||
|
"language_persistent": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "reach carefully",
|
||||||
|
"style": "subtask",
|
||||||
|
"timestamp": 0.0,
|
||||||
|
"tool_calls": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"language_events": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
out = RenderMessagesStep(recipe)(transition)
|
||||||
|
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
|
assert "language_persistent" not in data
|
||||||
|
assert "language_events" not in data
|
||||||
|
assert data["messages"][-1]["content"] == "reach carefully"
|
||||||
|
assert data["message_streams"] == ["high_level", "low_level"]
|
||||||
|
assert data["target_message_indices"] == [1]
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.utils.collate import lerobot_collate_fn
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_collate_preserves_messages_and_drops_raw_language():
|
||||||
|
batch = [
|
||||||
|
{
|
||||||
|
"index": torch.tensor(0),
|
||||||
|
"messages": [{"role": "assistant", "content": "a"}],
|
||||||
|
"message_streams": ["low_level"],
|
||||||
|
"target_message_indices": [0],
|
||||||
|
"language_persistent": [{"content": "raw"}],
|
||||||
|
"language_events": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": torch.tensor(1),
|
||||||
|
"messages": [{"role": "assistant", "content": "b"}],
|
||||||
|
"message_streams": ["low_level"],
|
||||||
|
"target_message_indices": [0],
|
||||||
|
"language_persistent": [{"content": "raw"}],
|
||||||
|
"language_events": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
out = lerobot_collate_fn(batch)
|
||||||
|
|
||||||
|
assert out["index"].tolist() == [0, 1]
|
||||||
|
assert out["messages"][0][0]["content"] == "a"
|
||||||
|
assert out["messages"][1][0]["content"] == "b"
|
||||||
|
assert out["message_streams"] == [["low_level"], ["low_level"]]
|
||||||
|
assert out["target_message_indices"] == [[0], [0]]
|
||||||
|
assert "language_persistent" not in out
|
||||||
|
assert "language_events" not in out
|
||||||
Reference in New Issue
Block a user