# Multi-Dataset Training This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`. ## Overview Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ. `MultiLeRobotDataset` lets you train on all of them jointly by: - **Mapping** each dataset's feature keys into a shared namespace - **Padding** features that a dataset doesn't have with zeros - **Weighting** how often each dataset is sampled - **Transforming** samples per-dataset (e.g. padding actions to a common dimension) - **Aggregating** statistics across all sub-datasets for normalization ## Configuration Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`. ### SubDatasetConfig fields | Field | Type | Default | Description | |-------|------|---------|-------------| | `repo_id` | `str` | required | HuggingFace repo ID or local dataset name | | `root` | `str \| None` | `None` | Local root directory for the dataset | | `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use | | `revision` | `str \| None` | `None` | Dataset version / revision | | `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) | | `weight` | `float` | `1.0` | Relative sampling weight for this dataset | | `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys | | `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) | ### Example: Three-dataset config ```yaml dataset: type: multi use_imagenet_stats: true datasets: # RoboCasa: 3 cameras, state(16), action(12) - repo_id: pepijn223/robocasa_PrepareCoffee root: /data/robocasa_PrepareCoffee weight: 1.0 feature_map: observation.images.robot0_agentview_left: observation.images.front_left observation.images.robot0_agentview_right: observation.images.front_right observation.images.robot0_eye_in_hand: observation.images.wrist # LIBERO-plus: 2 cameras, state(8), action(7) - repo_id: pepijn223/libero_plus_lerobot root: /data/libero_plus_lerobot weight: 0.5 feature_map: observation.images.front: observation.images.front_left observation.images.wrist: observation.images.wrist transforms: - type: pad_action kwargs: {target_dim: 12} - type: pad_state kwargs: {target_dim: 16} # RoboMME: 2 cameras (non-standard keys), state(8), action(8) - repo_id: pepijn223/robomme_data_lerobot root: /data/robomme_data_lerobot weight: 0.3 feature_map: image: observation.images.front_left wrist_image: observation.images.wrist state: observation.state actions: action transforms: - type: pad_action kwargs: {target_dim: 12} - type: pad_state kwargs: {target_dim: 16} ``` ## Feature Mapping The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally. After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features. ## Automatic Padding When a sub-dataset doesn't have a feature that exists in the unified schema: - **Images/videos**: padded with a black frame (zeros) matching the expected resolution - **Float tensors** (state, action): padded with zeros - **Integer/bool tensors**: padded with zeros / False A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding. ## Per-Dataset Transforms Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch. ### Built-in transforms | Name | Description | Parameters | |------|-------------|------------| | `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` | | `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` | | `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` | ### Custom transforms You can register your own transforms in `lerobot/datasets/transforms.py`: ```python from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform @register_dataset_transform("my_transform") class MyTransform(DatasetTransformStep): def __init__(self, some_param: int): self.some_param = some_param def __call__(self, sample: dict) -> dict: # Modify sample in-place or return a new dict sample["action"] = sample["action"] * self.some_param return sample ``` Then reference it in the config: ```yaml transforms: - type: my_transform kwargs: {some_param: 2} ``` ## Weighted Sampling The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%. This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally. ## Stats Aggregation Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics. The aggregated stats are used by the standard LeRobot preprocessor for normalization during training. ## Training Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler: ```bash python -m lerobot.scripts.lerobot_train \ --config_path path/to/multi_dataset_config.yaml ``` ## Architecture The data flow during training with `MultiLeRobotDataset`: ``` ┌─────────────────────────────────────────────────────────┐ │ MultiLeRobotDataset.__getitem__(global_idx) │ │ │ │ 1. Map global_idx → (dataset_idx, local_idx) │ │ 2. Fetch sample from sub-dataset │ │ 3. Rename keys via feature_map │ │ 4. Apply per-dataset transforms (pad_action, etc.) │ │ 5. Zero-pad missing features + add _is_pad flags │ │ 6. Add dataset_index tag │ └─────────────────────┬───────────────────────────────────┘ │ ┌────────────▼────────────┐ │ PyTorch DataLoader │ │ (collates into batch) │ └────────────┬────────────┘ │ ┌────────────▼────────────┐ │ LeRobot Preprocessor │ │ (normalize, tokenize) │ └────────────┬────────────┘ │ ┌────────────▼────────────┐ │ Policy forward + loss │ └─────────────────────────┘ ``` ## API Reference ### `NewMultiLeRobotDataset` ```python from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset dataset = NewMultiLeRobotDataset( configs=[...], # list[SubDatasetConfig] image_transforms=None, # optional image augmentation delta_timestamps=None, # optional temporal neighbors tolerance_s=1e-4, # timestamp tolerance ) dataset.num_frames # total frames across all sub-datasets dataset.num_episodes # total episodes dataset.meta # MultiDatasetMeta (stats, features, episodes) dataset.dataset_weights # list of per-dataset weights dataset.features # unified feature dict (union of all mapped features) dataset.camera_keys # unified camera key list ``` ### `WeightedEpisodeAwareSampler` ```python from lerobot.datasets.sampler import WeightedEpisodeAwareSampler sampler = WeightedEpisodeAwareSampler( dataset_from_indices=dataset.meta.episodes["dataset_from_index"], dataset_to_indices=dataset.meta.episodes["dataset_to_index"], dataset_membership=dataset.meta.episodes["dataset_source"], dataset_weights=dataset.dataset_weights, shuffle=True, ) ``` ### `DatasetTransformPipeline` ```python from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig pipeline = DatasetTransformPipeline([ DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}), DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}), ]) sample = pipeline(sample) # modifies the sample dict ```