mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
docs: add Multi-Dataset Training guide
Covers feature mapping, auto-padding, per-dataset transforms, weighted sampling, stats aggregation, and full config examples for training across RoboCasa, LIBERO-plus, and RoboMME datasets. Made-with: Cursor
This commit is contained in:
@@ -31,6 +31,8 @@
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
- local: multi_dataset_training
|
||||
title: Multi-Dataset Training
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
# Multi-Dataset Training
|
||||
|
||||
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
|
||||
|
||||
`MultiLeRobotDataset` lets you train on all of them jointly by:
|
||||
|
||||
- **Mapping** each dataset's feature keys into a shared namespace
|
||||
- **Padding** features that a dataset doesn't have with zeros
|
||||
- **Weighting** how often each dataset is sampled
|
||||
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
|
||||
- **Aggregating** statistics across all sub-datasets for normalization
|
||||
|
||||
## Configuration
|
||||
|
||||
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
|
||||
|
||||
### SubDatasetConfig fields
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
|
||||
| `root` | `str \| None` | `None` | Local root directory for the dataset |
|
||||
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
|
||||
| `revision` | `str \| None` | `None` | Dataset version / revision |
|
||||
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
|
||||
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
|
||||
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
|
||||
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
|
||||
|
||||
### Example: Three-dataset config
|
||||
|
||||
```yaml
|
||||
dataset:
|
||||
type: multi
|
||||
use_imagenet_stats: true
|
||||
datasets:
|
||||
# RoboCasa: 3 cameras, state(16), action(12)
|
||||
- repo_id: pepijn223/robocasa_PrepareCoffee
|
||||
root: /data/robocasa_PrepareCoffee
|
||||
weight: 1.0
|
||||
feature_map:
|
||||
observation.images.robot0_agentview_left: observation.images.front_left
|
||||
observation.images.robot0_agentview_right: observation.images.front_right
|
||||
observation.images.robot0_eye_in_hand: observation.images.wrist
|
||||
|
||||
# LIBERO-plus: 2 cameras, state(8), action(7)
|
||||
- repo_id: pepijn223/libero_plus_lerobot
|
||||
root: /data/libero_plus_lerobot
|
||||
weight: 0.5
|
||||
feature_map:
|
||||
observation.images.front: observation.images.front_left
|
||||
observation.images.wrist: observation.images.wrist
|
||||
transforms:
|
||||
- type: pad_action
|
||||
kwargs: {target_dim: 12}
|
||||
- type: pad_state
|
||||
kwargs: {target_dim: 16}
|
||||
|
||||
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
|
||||
- repo_id: pepijn223/robomme_data_lerobot
|
||||
root: /data/robomme_data_lerobot
|
||||
weight: 0.3
|
||||
feature_map:
|
||||
image: observation.images.front_left
|
||||
wrist_image: observation.images.wrist
|
||||
state: observation.state
|
||||
actions: action
|
||||
transforms:
|
||||
- type: pad_action
|
||||
kwargs: {target_dim: 12}
|
||||
- type: pad_state
|
||||
kwargs: {target_dim: 16}
|
||||
```
|
||||
|
||||
## Feature Mapping
|
||||
|
||||
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
|
||||
|
||||
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
|
||||
|
||||
## Automatic Padding
|
||||
|
||||
When a sub-dataset doesn't have a feature that exists in the unified schema:
|
||||
|
||||
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
|
||||
- **Float tensors** (state, action): padded with zeros
|
||||
- **Integer/bool tensors**: padded with zeros / False
|
||||
|
||||
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
|
||||
|
||||
## Per-Dataset Transforms
|
||||
|
||||
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
|
||||
|
||||
### Built-in transforms
|
||||
|
||||
| Name | Description | Parameters |
|
||||
|------|-------------|------------|
|
||||
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
|
||||
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
|
||||
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
|
||||
|
||||
### Custom transforms
|
||||
|
||||
You can register your own transforms in `lerobot/datasets/transforms.py`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
|
||||
|
||||
@register_dataset_transform("my_transform")
|
||||
class MyTransform(DatasetTransformStep):
|
||||
def __init__(self, some_param: int):
|
||||
self.some_param = some_param
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
# Modify sample in-place or return a new dict
|
||||
sample["action"] = sample["action"] * self.some_param
|
||||
return sample
|
||||
```
|
||||
|
||||
Then reference it in the config:
|
||||
|
||||
```yaml
|
||||
transforms:
|
||||
- type: my_transform
|
||||
kwargs: {some_param: 2}
|
||||
```
|
||||
|
||||
## Weighted Sampling
|
||||
|
||||
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
|
||||
|
||||
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
|
||||
|
||||
## Stats Aggregation
|
||||
|
||||
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
|
||||
|
||||
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
|
||||
|
||||
## Training
|
||||
|
||||
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--config_path path/to/multi_dataset_config.yaml
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The data flow during training with `MultiLeRobotDataset`:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ MultiLeRobotDataset.__getitem__(global_idx) │
|
||||
│ │
|
||||
│ 1. Map global_idx → (dataset_idx, local_idx) │
|
||||
│ 2. Fetch sample from sub-dataset │
|
||||
│ 3. Rename keys via feature_map │
|
||||
│ 4. Apply per-dataset transforms (pad_action, etc.) │
|
||||
│ 5. Zero-pad missing features + add _is_pad flags │
|
||||
│ 6. Add dataset_index tag │
|
||||
└─────────────────────┬───────────────────────────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ PyTorch DataLoader │
|
||||
│ (collates into batch) │
|
||||
└────────────┬────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ LeRobot Preprocessor │
|
||||
│ (normalize, tokenize) │
|
||||
└────────────┬────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ Policy forward + loss │
|
||||
└─────────────────────────┘
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### `NewMultiLeRobotDataset`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||
|
||||
dataset = NewMultiLeRobotDataset(
|
||||
configs=[...], # list[SubDatasetConfig]
|
||||
image_transforms=None, # optional image augmentation
|
||||
delta_timestamps=None, # optional temporal neighbors
|
||||
tolerance_s=1e-4, # timestamp tolerance
|
||||
)
|
||||
|
||||
dataset.num_frames # total frames across all sub-datasets
|
||||
dataset.num_episodes # total episodes
|
||||
dataset.meta # MultiDatasetMeta (stats, features, episodes)
|
||||
dataset.dataset_weights # list of per-dataset weights
|
||||
dataset.features # unified feature dict (union of all mapped features)
|
||||
dataset.camera_keys # unified camera key list
|
||||
```
|
||||
|
||||
### `WeightedEpisodeAwareSampler`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
|
||||
|
||||
sampler = WeightedEpisodeAwareSampler(
|
||||
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
|
||||
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
|
||||
dataset_membership=dataset.meta.episodes["dataset_source"],
|
||||
dataset_weights=dataset.dataset_weights,
|
||||
shuffle=True,
|
||||
)
|
||||
```
|
||||
|
||||
### `DatasetTransformPipeline`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
|
||||
|
||||
pipeline = DatasetTransformPipeline([
|
||||
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
|
||||
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
|
||||
])
|
||||
|
||||
sample = pipeline(sample) # modifies the sample dict
|
||||
```
|
||||
Reference in New Issue
Block a user