docs: add Multi-Dataset Training guide

Covers feature mapping, auto-padding, per-dataset transforms,
weighted sampling, stats aggregation, and full config examples
for training across RoboCasa, LIBERO-plus, and RoboMME datasets.

Made-with: Cursor
This commit is contained in:
pepijn
2026-03-13 04:37:01 +00:00
parent b4d40d0228
commit f478ae5bfa
2 changed files with 234 additions and 0 deletions
+2
View File
@@ -31,6 +31,8 @@
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
- local: multi_dataset_training
title: Multi-Dataset Training
title: "Datasets"
- sections:
- local: act
+232
View File
@@ -0,0 +1,232 @@
# Multi-Dataset Training
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
## Overview
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
`MultiLeRobotDataset` lets you train on all of them jointly by:
- **Mapping** each dataset's feature keys into a shared namespace
- **Padding** features that a dataset doesn't have with zeros
- **Weighting** how often each dataset is sampled
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
- **Aggregating** statistics across all sub-datasets for normalization
## Configuration
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
### SubDatasetConfig fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
| `root` | `str \| None` | `None` | Local root directory for the dataset |
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
| `revision` | `str \| None` | `None` | Dataset version / revision |
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
### Example: Three-dataset config
```yaml
dataset:
type: multi
use_imagenet_stats: true
datasets:
# RoboCasa: 3 cameras, state(16), action(12)
- repo_id: pepijn223/robocasa_PrepareCoffee
root: /data/robocasa_PrepareCoffee
weight: 1.0
feature_map:
observation.images.robot0_agentview_left: observation.images.front_left
observation.images.robot0_agentview_right: observation.images.front_right
observation.images.robot0_eye_in_hand: observation.images.wrist
# LIBERO-plus: 2 cameras, state(8), action(7)
- repo_id: pepijn223/libero_plus_lerobot
root: /data/libero_plus_lerobot
weight: 0.5
feature_map:
observation.images.front: observation.images.front_left
observation.images.wrist: observation.images.wrist
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
- repo_id: pepijn223/robomme_data_lerobot
root: /data/robomme_data_lerobot
weight: 0.3
feature_map:
image: observation.images.front_left
wrist_image: observation.images.wrist
state: observation.state
actions: action
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
```
## Feature Mapping
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
## Automatic Padding
When a sub-dataset doesn't have a feature that exists in the unified schema:
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
- **Float tensors** (state, action): padded with zeros
- **Integer/bool tensors**: padded with zeros / False
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
## Per-Dataset Transforms
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
### Built-in transforms
| Name | Description | Parameters |
|------|-------------|------------|
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
### Custom transforms
You can register your own transforms in `lerobot/datasets/transforms.py`:
```python
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
@register_dataset_transform("my_transform")
class MyTransform(DatasetTransformStep):
def __init__(self, some_param: int):
self.some_param = some_param
def __call__(self, sample: dict) -> dict:
# Modify sample in-place or return a new dict
sample["action"] = sample["action"] * self.some_param
return sample
```
Then reference it in the config:
```yaml
transforms:
- type: my_transform
kwargs: {some_param: 2}
```
## Weighted Sampling
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
## Stats Aggregation
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
## Training
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
```bash
python -m lerobot.scripts.lerobot_train \
--config_path path/to/multi_dataset_config.yaml
```
## Architecture
The data flow during training with `MultiLeRobotDataset`:
```
┌─────────────────────────────────────────────────────────┐
│ MultiLeRobotDataset.__getitem__(global_idx) │
│ │
│ 1. Map global_idx → (dataset_idx, local_idx) │
│ 2. Fetch sample from sub-dataset │
│ 3. Rename keys via feature_map │
│ 4. Apply per-dataset transforms (pad_action, etc.) │
│ 5. Zero-pad missing features + add _is_pad flags │
│ 6. Add dataset_index tag │
└─────────────────────┬───────────────────────────────────┘
┌────────────▼────────────┐
│ PyTorch DataLoader │
│ (collates into batch) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ LeRobot Preprocessor │
│ (normalize, tokenize) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ Policy forward + loss │
└─────────────────────────┘
```
## API Reference
### `NewMultiLeRobotDataset`
```python
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
dataset = NewMultiLeRobotDataset(
configs=[...], # list[SubDatasetConfig]
image_transforms=None, # optional image augmentation
delta_timestamps=None, # optional temporal neighbors
tolerance_s=1e-4, # timestamp tolerance
)
dataset.num_frames # total frames across all sub-datasets
dataset.num_episodes # total episodes
dataset.meta # MultiDatasetMeta (stats, features, episodes)
dataset.dataset_weights # list of per-dataset weights
dataset.features # unified feature dict (union of all mapped features)
dataset.camera_keys # unified camera key list
```
### `WeightedEpisodeAwareSampler`
```python
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
sampler = WeightedEpisodeAwareSampler(
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
dataset_membership=dataset.meta.episodes["dataset_source"],
dataset_weights=dataset.dataset_weights,
shuffle=True,
)
```
### `DatasetTransformPipeline`
```python
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
pipeline = DatasetTransformPipeline([
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
])
sample = pipeline(sample) # modifies the sample dict
```