From f478ae5bfaa6c6d1b913a8c8fbea30c7a57982e0 Mon Sep 17 00:00:00 2001 From: pepijn Date: Fri, 13 Mar 2026 04:37:01 +0000 Subject: [PATCH] docs: add Multi-Dataset Training guide Covers feature mapping, auto-padding, per-dataset transforms, weighted sampling, stats aggregation, and full config examples for training across RoboCasa, LIBERO-plus, and RoboMME datasets. Made-with: Cursor --- docs/source/_toctree.yml | 2 + docs/source/multi_dataset_training.mdx | 232 +++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 docs/source/multi_dataset_training.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1055975d7..2fbd89ff2 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -31,6 +31,8 @@ title: Using Subtasks in the Dataset - local: streaming_video_encoding title: Streaming Video Encoding + - local: multi_dataset_training + title: Multi-Dataset Training title: "Datasets" - sections: - local: act diff --git a/docs/source/multi_dataset_training.mdx b/docs/source/multi_dataset_training.mdx new file mode 100644 index 000000000..42abf34b8 --- /dev/null +++ b/docs/source/multi_dataset_training.mdx @@ -0,0 +1,232 @@ +# Multi-Dataset Training + +This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`. + +## Overview + +Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ. + +`MultiLeRobotDataset` lets you train on all of them jointly by: + +- **Mapping** each dataset's feature keys into a shared namespace +- **Padding** features that a dataset doesn't have with zeros +- **Weighting** how often each dataset is sampled +- **Transforming** samples per-dataset (e.g. padding actions to a common dimension) +- **Aggregating** statistics across all sub-datasets for normalization + +## Configuration + +Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`. + +### SubDatasetConfig fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name | +| `root` | `str \| None` | `None` | Local root directory for the dataset | +| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use | +| `revision` | `str \| None` | `None` | Dataset version / revision | +| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) | +| `weight` | `float` | `1.0` | Relative sampling weight for this dataset | +| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys | +| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) | + +### Example: Three-dataset config + +```yaml +dataset: + type: multi + use_imagenet_stats: true + datasets: + # RoboCasa: 3 cameras, state(16), action(12) + - repo_id: pepijn223/robocasa_PrepareCoffee + root: /data/robocasa_PrepareCoffee + weight: 1.0 + feature_map: + observation.images.robot0_agentview_left: observation.images.front_left + observation.images.robot0_agentview_right: observation.images.front_right + observation.images.robot0_eye_in_hand: observation.images.wrist + + # LIBERO-plus: 2 cameras, state(8), action(7) + - repo_id: pepijn223/libero_plus_lerobot + root: /data/libero_plus_lerobot + weight: 0.5 + feature_map: + observation.images.front: observation.images.front_left + observation.images.wrist: observation.images.wrist + transforms: + - type: pad_action + kwargs: {target_dim: 12} + - type: pad_state + kwargs: {target_dim: 16} + + # RoboMME: 2 cameras (non-standard keys), state(8), action(8) + - repo_id: pepijn223/robomme_data_lerobot + root: /data/robomme_data_lerobot + weight: 0.3 + feature_map: + image: observation.images.front_left + wrist_image: observation.images.wrist + state: observation.state + actions: action + transforms: + - type: pad_action + kwargs: {target_dim: 12} + - type: pad_state + kwargs: {target_dim: 16} +``` + +## Feature Mapping + +The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally. + +After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features. + +## Automatic Padding + +When a sub-dataset doesn't have a feature that exists in the unified schema: + +- **Images/videos**: padded with a black frame (zeros) matching the expected resolution +- **Float tensors** (state, action): padded with zeros +- **Integer/bool tensors**: padded with zeros / False + +A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding. + +## Per-Dataset Transforms + +Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch. + +### Built-in transforms + +| Name | Description | Parameters | +|------|-------------|------------| +| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` | +| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` | +| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` | + +### Custom transforms + +You can register your own transforms in `lerobot/datasets/transforms.py`: + +```python +from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform + +@register_dataset_transform("my_transform") +class MyTransform(DatasetTransformStep): + def __init__(self, some_param: int): + self.some_param = some_param + + def __call__(self, sample: dict) -> dict: + # Modify sample in-place or return a new dict + sample["action"] = sample["action"] * self.some_param + return sample +``` + +Then reference it in the config: + +```yaml +transforms: + - type: my_transform + kwargs: {some_param: 2} +``` + +## Weighted Sampling + +The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%. + +This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally. + +## Stats Aggregation + +Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics. + +The aggregated stats are used by the standard LeRobot preprocessor for normalization during training. + +## Training + +Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler: + +```bash +python -m lerobot.scripts.lerobot_train \ + --config_path path/to/multi_dataset_config.yaml +``` + +## Architecture + +The data flow during training with `MultiLeRobotDataset`: + +``` +┌─────────────────────────────────────────────────────────┐ +│ MultiLeRobotDataset.__getitem__(global_idx) │ +│ │ +│ 1. Map global_idx → (dataset_idx, local_idx) │ +│ 2. Fetch sample from sub-dataset │ +│ 3. Rename keys via feature_map │ +│ 4. Apply per-dataset transforms (pad_action, etc.) │ +│ 5. Zero-pad missing features + add _is_pad flags │ +│ 6. Add dataset_index tag │ +└─────────────────────┬───────────────────────────────────┘ + │ + ┌────────────▼────────────┐ + │ PyTorch DataLoader │ + │ (collates into batch) │ + └────────────┬────────────┘ + │ + ┌────────────▼────────────┐ + │ LeRobot Preprocessor │ + │ (normalize, tokenize) │ + └────────────┬────────────┘ + │ + ┌────────────▼────────────┐ + │ Policy forward + loss │ + └─────────────────────────┘ +``` + +## API Reference + +### `NewMultiLeRobotDataset` + +```python +from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset + +dataset = NewMultiLeRobotDataset( + configs=[...], # list[SubDatasetConfig] + image_transforms=None, # optional image augmentation + delta_timestamps=None, # optional temporal neighbors + tolerance_s=1e-4, # timestamp tolerance +) + +dataset.num_frames # total frames across all sub-datasets +dataset.num_episodes # total episodes +dataset.meta # MultiDatasetMeta (stats, features, episodes) +dataset.dataset_weights # list of per-dataset weights +dataset.features # unified feature dict (union of all mapped features) +dataset.camera_keys # unified camera key list +``` + +### `WeightedEpisodeAwareSampler` + +```python +from lerobot.datasets.sampler import WeightedEpisodeAwareSampler + +sampler = WeightedEpisodeAwareSampler( + dataset_from_indices=dataset.meta.episodes["dataset_from_index"], + dataset_to_indices=dataset.meta.episodes["dataset_to_index"], + dataset_membership=dataset.meta.episodes["dataset_source"], + dataset_weights=dataset.dataset_weights, + shuffle=True, +) +``` + +### `DatasetTransformPipeline` + +```python +from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig + +pipeline = DatasetTransformPipeline([ + DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}), + DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}), +]) + +sample = pipeline(sample) # modifies the sample dict +```