Compare commits

...

5 Commits

Author SHA1 Message Date
pepijn c5925399a9 style: remove decorative comment separator in transforms.py
Made-with: Cursor
2026-03-13 04:43:31 +00:00
pepijn f478ae5bfa docs: add Multi-Dataset Training guide
Covers feature mapping, auto-padding, per-dataset transforms,
weighted sampling, stats aggregation, and full config examples
for training across RoboCasa, LIBERO-plus, and RoboMME datasets.

Made-with: Cursor
2026-03-13 04:37:01 +00:00
pepijn b4d40d0228 feat: add MultiLeRobotDataset with weighted sampling and RoboMME env integration
Multi-dataset training support:
- NewMultiLeRobotDataset with per-dataset feature mapping, auto-padding,
  per-dataset transform pipelines, and weighted sampling
- MultiDatasetMeta shim compatible with EpisodeAwareSampler and make_policy
- WeightedEpisodeAwareSampler for proportional cross-dataset sampling
- SubDatasetConfig / MultiDatasetConfig in training configs
- DatasetTransformPipeline with built-in PadAction, PadState, ResizeImages
- Factory and training script wired up for multi-dataset path

RoboMME environment integration:
- RoboMMEEnv config and Gymnasium wrapper (robomme.py)
- robomme optional dependency in pyproject.toml

Made-with: Cursor
2026-03-13 04:31:35 +00:00
pepijn db5c26f07d feat(envs): add LIBERO-plus integration for evaluation benchmarks
Add LiberoPlusEnv config (subclass of LiberoEnv), register libero_plus
env type in factory, add import fallbacks for LIBERO-plus package
structure, and add libero_plus optional dependency group in pyproject.toml.

Made-with: Cursor
2026-03-12 04:31:09 +00:00
Pepijn 8904768db4 feat(envs): add RoboCasa composite-task benchmark integration
Integrates 5 selected RoboCasa kitchen tasks (3 short + 2 long) as a
LeRobot benchmark environment, following the same pattern as Libero.

Selected tasks:
  Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
  Long:  PrepareCoffee, RestockPantry

Changes:
- envs/robocasa.py: RoboCasaEnv wrapper with flat 12D Box action space,
  3-camera pixel obs, and 16D proprioceptive state
- envs/configs.py: RoboCasaEnv config with features_map
- envs/factory.py: wire robocasa into make_env + make_env_pre_post_processors
- processor/env_processor.py: RoboCasaProcessorStep for obs key remapping
- tests/test_robocasa_env.py: full test suite (auto-skips if assets missing)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-09 17:08:32 +01:00
17 changed files with 1711 additions and 72 deletions
+2
View File
@@ -31,6 +31,8 @@
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
- local: multi_dataset_training
title: Multi-Dataset Training
title: "Datasets"
- sections:
- local: act
+232
View File
@@ -0,0 +1,232 @@
# Multi-Dataset Training
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
## Overview
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
`MultiLeRobotDataset` lets you train on all of them jointly by:
- **Mapping** each dataset's feature keys into a shared namespace
- **Padding** features that a dataset doesn't have with zeros
- **Weighting** how often each dataset is sampled
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
- **Aggregating** statistics across all sub-datasets for normalization
## Configuration
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
### SubDatasetConfig fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
| `root` | `str \| None` | `None` | Local root directory for the dataset |
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
| `revision` | `str \| None` | `None` | Dataset version / revision |
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
### Example: Three-dataset config
```yaml
dataset:
type: multi
use_imagenet_stats: true
datasets:
# RoboCasa: 3 cameras, state(16), action(12)
- repo_id: pepijn223/robocasa_PrepareCoffee
root: /data/robocasa_PrepareCoffee
weight: 1.0
feature_map:
observation.images.robot0_agentview_left: observation.images.front_left
observation.images.robot0_agentview_right: observation.images.front_right
observation.images.robot0_eye_in_hand: observation.images.wrist
# LIBERO-plus: 2 cameras, state(8), action(7)
- repo_id: pepijn223/libero_plus_lerobot
root: /data/libero_plus_lerobot
weight: 0.5
feature_map:
observation.images.front: observation.images.front_left
observation.images.wrist: observation.images.wrist
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
- repo_id: pepijn223/robomme_data_lerobot
root: /data/robomme_data_lerobot
weight: 0.3
feature_map:
image: observation.images.front_left
wrist_image: observation.images.wrist
state: observation.state
actions: action
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
```
## Feature Mapping
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
## Automatic Padding
When a sub-dataset doesn't have a feature that exists in the unified schema:
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
- **Float tensors** (state, action): padded with zeros
- **Integer/bool tensors**: padded with zeros / False
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
## Per-Dataset Transforms
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
### Built-in transforms
| Name | Description | Parameters |
|------|-------------|------------|
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
### Custom transforms
You can register your own transforms in `lerobot/datasets/transforms.py`:
```python
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
@register_dataset_transform("my_transform")
class MyTransform(DatasetTransformStep):
def __init__(self, some_param: int):
self.some_param = some_param
def __call__(self, sample: dict) -> dict:
# Modify sample in-place or return a new dict
sample["action"] = sample["action"] * self.some_param
return sample
```
Then reference it in the config:
```yaml
transforms:
- type: my_transform
kwargs: {some_param: 2}
```
## Weighted Sampling
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
## Stats Aggregation
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
## Training
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
```bash
python -m lerobot.scripts.lerobot_train \
--config_path path/to/multi_dataset_config.yaml
```
## Architecture
The data flow during training with `MultiLeRobotDataset`:
```
┌─────────────────────────────────────────────────────────┐
│ MultiLeRobotDataset.__getitem__(global_idx) │
│ │
│ 1. Map global_idx → (dataset_idx, local_idx) │
│ 2. Fetch sample from sub-dataset │
│ 3. Rename keys via feature_map │
│ 4. Apply per-dataset transforms (pad_action, etc.) │
│ 5. Zero-pad missing features + add _is_pad flags │
│ 6. Add dataset_index tag │
└─────────────────────┬───────────────────────────────────┘
┌────────────▼────────────┐
│ PyTorch DataLoader │
│ (collates into batch) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ LeRobot Preprocessor │
│ (normalize, tokenize) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ Policy forward + loss │
└─────────────────────────┘
```
## API Reference
### `NewMultiLeRobotDataset`
```python
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
dataset = NewMultiLeRobotDataset(
configs=[...], # list[SubDatasetConfig]
image_transforms=None, # optional image augmentation
delta_timestamps=None, # optional temporal neighbors
tolerance_s=1e-4, # timestamp tolerance
)
dataset.num_frames # total frames across all sub-datasets
dataset.num_episodes # total episodes
dataset.meta # MultiDatasetMeta (stats, features, episodes)
dataset.dataset_weights # list of per-dataset weights
dataset.features # unified feature dict (union of all mapped features)
dataset.camera_keys # unified camera key list
```
### `WeightedEpisodeAwareSampler`
```python
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
sampler = WeightedEpisodeAwareSampler(
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
dataset_membership=dataset.meta.episodes["dataset_source"],
dataset_weights=dataset.dataset_weights,
shuffle=True,
)
```
### `DatasetTransformPipeline`
```python
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
pipeline = DatasetTransformPipeline([
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
])
sample = pipeline(sample) # modifies the sample dict
```
+8
View File
@@ -175,6 +175,14 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero_plus = [
"lerobot[transformers-dep]",
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
"lerobot[scipy-dep]",
]
robomme = [
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
+27 -6
View File
@@ -16,18 +16,13 @@
from dataclasses import dataclass, field
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.transforms import DatasetTransformStepConfig, ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
@dataclass
class DatasetConfig:
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
@@ -37,6 +32,32 @@ class DatasetConfig:
streaming: bool = False
@dataclass
class SubDatasetConfig:
"""Configuration for a single dataset within a MultiDatasetConfig."""
repo_id: str
root: str | None = None
episodes: list[int] | None = None
revision: str | None = None
video_backend: str = field(default_factory=get_safe_default_codec)
weight: float = 1.0
# Maps dataset-local feature keys to unified policy keys.
# Keys not listed pass through unchanged.
feature_map: dict[str, str] = field(default_factory=dict)
# Per-dataset transforms applied after feature renaming, before cross-dataset padding.
transforms: list[DatasetTransformStepConfig] | None = None
@dataclass
class MultiDatasetConfig:
"""Configuration for training on multiple datasets jointly."""
datasets: list[SubDatasetConfig] = field(default_factory=list)
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
use_imagenet_stats: bool = True
@dataclass
class WandBConfig:
enable: bool = False
+6 -6
View File
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from lerobot.configs.default import DatasetConfig, EvalConfig, MultiDatasetConfig, PeftConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@@ -35,7 +35,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
dataset: DatasetConfig | MultiDatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
@@ -129,8 +129,9 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if isinstance(self.dataset, MultiDatasetConfig):
if len(self.dataset.datasets) < 1:
raise ValueError("MultiDatasetConfig.datasets must contain at least one sub-dataset.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
@@ -143,8 +144,7 @@ class TrainPipelineConfig(HubMixin):
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
if self.use_rabc and not self.rabc_progress_path and isinstance(self.dataset, DatasetConfig):
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
+67 -51
View File
@@ -18,13 +18,14 @@ from pprint import pformat
import torch
from lerobot.configs.default import DatasetConfig, MultiDatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
@@ -68,66 +69,81 @@ def resolve_delta_timestamps(
return delta_timestamps
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
Args:
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
Raises:
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | NewMultiLeRobotDataset:
"""Create a single or multi-dataset depending on the config type.
Returns:
LeRobotDataset | MultiLeRobotDataset
LeRobotDataset | NewMultiLeRobotDataset
"""
if isinstance(cfg.dataset, MultiDatasetConfig):
return _make_multi_dataset(cfg)
return _make_single_dataset(cfg)
def _make_single_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset:
ds_cfg: DatasetConfig = cfg.dataset # type: ignore[assignment]
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
ImageTransforms(ds_cfg.image_transforms) if ds_cfg.image_transforms.enable else None
)
ds_meta = LeRobotDatasetMetadata(ds_cfg.repo_id, root=ds_cfg.root, revision=ds_cfg.revision)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
if not cfg.dataset.streaming:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
tolerance_s=cfg.tolerance_s,
)
else:
dataset = StreamingLeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
if not ds_cfg.streaming:
dataset = LeRobotDataset(
ds_cfg.repo_id,
root=ds_cfg.root,
episodes=ds_cfg.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
revision=ds_cfg.revision,
video_backend=ds_cfg.video_backend,
tolerance_s=cfg.tolerance_s,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
else:
dataset = StreamingLeRobotDataset(
ds_cfg.repo_id,
root=ds_cfg.root,
episodes=ds_cfg.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=ds_cfg.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
)
if cfg.dataset.use_imagenet_stats:
if ds_cfg.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
for stats_type, stats_val in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
return dataset
def _make_multi_dataset(cfg: TrainPipelineConfig) -> NewMultiLeRobotDataset:
multi_cfg: MultiDatasetConfig = cfg.dataset # type: ignore[assignment]
image_transforms = (
ImageTransforms(multi_cfg.image_transforms) if multi_cfg.image_transforms.enable else None
)
dataset = NewMultiLeRobotDataset(
configs=multi_cfg.datasets,
image_transforms=image_transforms,
tolerance_s=cfg.tolerance_s,
)
logging.info(
"MultiLeRobotDataset created with %d sub-datasets:\n%s",
len(multi_cfg.datasets),
pformat(
{i: c.repo_id for i, c in enumerate(multi_cfg.datasets)},
indent=2,
),
)
if multi_cfg.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats_val in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
return dataset
+364
View File
@@ -0,0 +1,364 @@
"""MultiLeRobotDataset: joint training over heterogeneous LeRobot datasets.
Supports:
- Per-dataset feature mapping (rename keys to a unified namespace)
- Automatic zero-padding for features missing in some datasets
- Per-dataset transform pipelines
- Weighted sampling via dataset weights
- Aggregated stats across all sub-datasets
- A ``meta`` shim compatible with EpisodeAwareSampler and make_policy
"""
from __future__ import annotations
import logging
from collections.abc import Callable
import numpy as np
import torch
import torch.utils.data
from lerobot.configs.default import SubDatasetConfig
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.transforms import DatasetTransformPipeline
class MultiDatasetMeta:
"""Lightweight metadata shim that exposes the same interface as ``LeRobotDatasetMetadata``.
Built by aggregating the metadata of multiple sub-datasets after their
feature keys have been mapped to a unified namespace.
"""
def __init__(
self,
datasets: list[LeRobotDataset],
feature_maps: list[dict[str, str]],
):
self._datasets = datasets
self._feature_maps = feature_maps
self._unified_features = self._build_unified_features()
self._episodes = self._build_episodes()
self._stats = self._build_stats()
# ------------------------------------------------------------------
# Feature union
# ------------------------------------------------------------------
def _build_unified_features(self) -> dict[str, dict]:
"""Build feature dict as the *union* of all mapped feature keys."""
unified: dict[str, dict] = {}
for ds, fmap in zip(self._datasets, self._feature_maps):
for original_key, feat_info in ds.meta.features.items():
mapped_key = fmap.get(original_key, original_key)
if mapped_key not in unified:
unified[mapped_key] = dict(feat_info)
else:
existing_shape = tuple(unified[mapped_key]["shape"])
new_shape = tuple(feat_info["shape"])
if existing_shape != new_shape and unified[mapped_key]["dtype"] == feat_info["dtype"]:
logging.warning(
"Feature '%s' has shape %s in one dataset but %s in another. "
"The larger shape will be used (padding applied automatically).",
mapped_key,
existing_shape,
new_shape,
)
if np.prod(new_shape) > np.prod(existing_shape):
unified[mapped_key] = dict(feat_info)
return unified
# ------------------------------------------------------------------
# Episode metadata (global flat indexing)
# ------------------------------------------------------------------
def _build_episodes(self) -> dict[str, list]:
"""Concatenate episode boundaries across sub-datasets with frame offsets.
Produces the same column structure as ``load_episodes()`` so that
``EpisodeAwareSampler`` and ``WeightedEpisodeAwareSampler`` can consume it.
"""
from_indices: list[int] = []
to_indices: list[int] = []
dataset_source: list[int] = []
frame_offset = 0
for ds_idx, ds in enumerate(self._datasets):
eps = ds.meta.episodes
for ep in eps:
from_indices.append(ep["dataset_from_index"] + frame_offset)
to_indices.append(ep["dataset_to_index"] + frame_offset)
dataset_source.append(ds_idx)
frame_offset += ds.num_frames
return {
"dataset_from_index": from_indices,
"dataset_to_index": to_indices,
"dataset_source": dataset_source,
}
# ------------------------------------------------------------------
# Stats aggregation
# ------------------------------------------------------------------
def _build_stats(self) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats across sub-datasets using mapped feature keys."""
mapped_stats_list: list[dict[str, dict]] = []
for ds, fmap in zip(self._datasets, self._feature_maps):
reverse_map = {v: k for k, v in fmap.items()}
mapped: dict[str, dict] = {}
for unified_key in self._unified_features:
original_key = reverse_map.get(unified_key, unified_key)
if original_key in ds.meta.stats:
mapped[unified_key] = ds.meta.stats[original_key]
mapped_stats_list.append(mapped)
return aggregate_stats(mapped_stats_list)
# ------------------------------------------------------------------
# Properties matching LeRobotDatasetMetadata API
# ------------------------------------------------------------------
@property
def features(self) -> dict[str, dict]:
return self._unified_features
@property
def image_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] in ("video", "image")]
@property
def names(self) -> dict[str, list | dict]:
return {k: f["names"] for k, f in self._unified_features.items()}
@property
def shapes(self) -> dict[str, tuple]:
return {k: tuple(f["shape"]) for k, f in self._unified_features.items()}
@property
def fps(self) -> int:
fps_values = {ds.meta.fps for ds in self._datasets}
if len(fps_values) > 1:
logging.warning("Sub-datasets have different FPS values: %s. Using the first.", fps_values)
return self._datasets[0].meta.fps
@property
def stats(self) -> dict[str, dict[str, np.ndarray]]:
return self._stats
@stats.setter
def stats(self, value: dict):
self._stats = value
@property
def episodes(self) -> dict[str, list]:
return self._episodes
@property
def total_episodes(self) -> int:
return sum(ds.meta.total_episodes for ds in self._datasets)
@property
def total_frames(self) -> int:
return sum(ds.meta.total_frames for ds in self._datasets)
@property
def total_tasks(self) -> int:
return sum(ds.meta.total_tasks for ds in self._datasets)
@property
def info(self) -> dict:
return {
"fps": self.fps,
"features": self._unified_features,
"total_episodes": self.total_episodes,
"total_frames": self.total_frames,
"total_tasks": self.total_tasks,
"codebase_version": "v3.0",
}
class NewMultiLeRobotDataset(torch.utils.data.Dataset):
"""Dataset that wraps multiple ``LeRobotDataset`` instances with feature mapping and padding.
Each sub-dataset can have different feature names and shapes. A per-dataset
``feature_map`` renames keys into a shared namespace. Features that a given
sub-dataset does not provide are zero-padded so every ``__getitem__`` returns
the full unified feature set.
"""
def __init__(
self,
configs: list[SubDatasetConfig],
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
):
super().__init__()
self._configs = configs
self.image_transforms = image_transforms
self._datasets: list[LeRobotDataset] = []
self._feature_maps: list[dict[str, str]] = []
self._transform_pipelines: list[DatasetTransformPipeline | None] = []
self._weights: list[float] = []
for cfg in configs:
ds = LeRobotDataset(
repo_id=cfg.repo_id,
root=cfg.root,
episodes=cfg.episodes,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=tolerance_s,
revision=cfg.revision,
video_backend=cfg.video_backend,
)
self._datasets.append(ds)
self._feature_maps.append(cfg.feature_map or {})
self._transform_pipelines.append(
DatasetTransformPipeline(cfg.transforms) if cfg.transforms else None
)
self._weights.append(cfg.weight)
self._meta = MultiDatasetMeta(self._datasets, self._feature_maps)
# Pre-compute cumulative frame counts for fast index mapping.
self._cumulative_frames: list[int] = []
total = 0
for ds in self._datasets:
total += ds.num_frames
self._cumulative_frames.append(total)
# Build reverse maps (unified_key -> original_key) per dataset for padding.
self._reverse_maps: list[dict[str, str]] = []
for fmap in self._feature_maps:
self._reverse_maps.append({v: k for k, v in fmap.items()})
logging.info(
"MultiLeRobotDataset: %d sub-datasets, %d total frames, %d total episodes, "
"%d unified features",
len(self._datasets),
self.num_frames,
self.num_episodes,
len(self._meta.features),
)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
@property
def meta(self) -> MultiDatasetMeta:
return self._meta
@property
def dataset_weights(self) -> list[float]:
return self._weights
@property
def num_frames(self) -> int:
return self._cumulative_frames[-1] if self._cumulative_frames else 0
@property
def num_episodes(self) -> int:
return sum(ds.num_episodes for ds in self._datasets)
@property
def episodes(self) -> list[int] | None:
return None
@property
def fps(self) -> int:
return self._meta.fps
@property
def features(self) -> dict[str, dict]:
return self._meta.features
@property
def camera_keys(self) -> list[str]:
return self._meta.camera_keys
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
def _locate(self, idx: int) -> tuple[int, int]:
"""Map a global frame index to (dataset_index, local_index)."""
for ds_idx, cum in enumerate(self._cumulative_frames):
if idx < cum:
local = idx - (self._cumulative_frames[ds_idx - 1] if ds_idx > 0 else 0)
return ds_idx, local
raise IndexError(f"Index {idx} out of range (total {self.num_frames})")
def __len__(self) -> int:
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
ds_idx, local_idx = self._locate(idx)
item = self._datasets[ds_idx][local_idx]
# 1. Rename keys according to feature_map.
fmap = self._feature_maps[ds_idx]
if fmap:
renamed: dict[str, torch.Tensor] = {}
for key, value in item.items():
renamed[fmap.get(key, key)] = value
item = renamed
# 2. Apply per-dataset transform pipeline.
pipeline = self._transform_pipelines[ds_idx]
if pipeline is not None:
item = pipeline(item)
# 3. Pad missing features with zeros.
reverse_map = self._reverse_maps[ds_idx]
ds_features = self._datasets[ds_idx].meta.features
for unified_key, feat_info in self._meta.features.items():
if unified_key in item:
continue
original_key = reverse_map.get(unified_key, unified_key)
if original_key in ds_features:
continue
shape = tuple(feat_info["shape"])
dtype = feat_info["dtype"]
if dtype in ("video", "image"):
# Camera tensors are (C, H, W) after transforms.
c, h, w = (shape[2], shape[0], shape[1]) if len(shape) == 3 else (3, shape[0], shape[1])
item[unified_key] = torch.zeros(c, h, w, dtype=torch.float32)
elif dtype in ("float32", "float64"):
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
elif dtype in ("int32", "int64"):
item[unified_key] = torch.zeros(shape, dtype=torch.int64)
elif dtype == "bool":
item[unified_key] = torch.zeros(shape, dtype=torch.bool)
else:
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
item[f"{unified_key}_is_pad"] = torch.tensor(True)
# 4. Tag which dataset this sample came from.
item["dataset_index"] = torch.tensor(ds_idx)
return item
def __repr__(self) -> str:
repo_ids = [c.repo_id for c in self._configs]
return (
f"NewMultiLeRobotDataset(\n"
f" repo_ids={repo_ids},\n"
f" num_frames={self.num_frames},\n"
f" num_episodes={self.num_episodes},\n"
f" unified_features={list(self._meta.features.keys())},\n"
f" weights={self._weights},\n"
f")"
)
+77
View File
@@ -59,3 +59,80 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class WeightedEpisodeAwareSampler:
"""Sampler that draws frames from multiple datasets according to per-dataset weights.
Each iteration first selects a sub-dataset proportionally to its weight, then
uniformly samples a frame from that sub-dataset's valid index set. Episode
boundary information is respected so that dropped frames are excluded.
Args:
dataset_from_indices: Start index for each episode (global, flat).
dataset_to_indices: End index (exclusive) for each episode (global, flat).
dataset_membership: Which sub-dataset each episode belongs to (integer id).
dataset_weights: Relative sampling weight per sub-dataset.
episode_indices_to_use: If given, only episodes in this set are used.
drop_n_first_frames: Frames to skip at the start of each episode.
drop_n_last_frames: Frames to skip at the end of each episode.
shuffle: Whether to shuffle within each epoch.
num_samples: How many samples per epoch. Defaults to total valid frames.
generator: Optional torch.Generator for reproducibility.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
dataset_membership: list[int],
dataset_weights: list[float],
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
num_samples: int | None = None,
generator: torch.Generator | None = None,
):
n_datasets = max(dataset_membership) + 1 if dataset_membership else 0
self._per_dataset_indices: list[list[int]] = [[] for _ in range(n_datasets)]
episodes_to_use = set(episode_indices_to_use) if episode_indices_to_use is not None else None
for ep_idx, (start, end, ds_id) in enumerate(
zip(dataset_from_indices, dataset_to_indices, dataset_membership, strict=True)
):
if episodes_to_use is not None and ep_idx not in episodes_to_use:
continue
frame_range = range(start + drop_n_first_frames, end - drop_n_last_frames)
self._per_dataset_indices[ds_id].extend(frame_range)
# Normalise weights (only over datasets that actually have frames).
raw_weights = list(dataset_weights[:n_datasets])
self._weights = torch.zeros(n_datasets)
for i, w in enumerate(raw_weights):
if len(self._per_dataset_indices[i]) > 0:
self._weights[i] = w
total_w = self._weights.sum()
if total_w > 0:
self._weights /= total_w
self._total_frames = sum(len(idx) for idx in self._per_dataset_indices)
self._num_samples = num_samples if num_samples is not None else self._total_frames
self.shuffle = shuffle
self._generator = generator
def __iter__(self) -> Iterator[int]:
if not self.shuffle:
for ds_indices in self._per_dataset_indices:
yield from ds_indices
return
for _ in range(self._num_samples):
ds_id = int(torch.multinomial(self._weights, 1, generator=self._generator).item())
indices = self._per_dataset_indices[ds_id]
local_idx = int(torch.randint(len(indices), (1,), generator=self._generator).item())
yield indices[local_idx]
def __len__(self) -> int:
return self._num_samples
+113
View File
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.nn.functional as F_nn
from torchvision.transforms import v2
from torchvision.transforms.v2 import (
Transform,
@@ -258,3 +260,114 @@ class ImageTransforms(Transform):
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)
# Per-dataset transform pipeline (used by MultiLeRobotDataset)
@dataclass
class DatasetTransformStepConfig:
"""Config for a single per-dataset transform step."""
type: str
kwargs: dict[str, Any] = field(default_factory=dict)
_DATASET_TRANSFORM_REGISTRY: dict[str, type["DatasetTransformStep"]] = {}
def register_dataset_transform(name: str):
"""Decorator to register a DatasetTransformStep by name."""
def decorator(cls: type["DatasetTransformStep"]) -> type["DatasetTransformStep"]:
_DATASET_TRANSFORM_REGISTRY[name] = cls
return cls
return decorator
class DatasetTransformStep:
"""Base class for a single per-dataset transform applied to a sample dict."""
def __call__(self, sample: dict) -> dict:
raise NotImplementedError
@register_dataset_transform("pad_action")
class PadAction(DatasetTransformStep):
"""Zero-pad the ``action`` tensor to *target_dim* along the last axis."""
def __init__(self, target_dim: int):
self.target_dim = target_dim
def __call__(self, sample: dict) -> dict:
action = sample.get("action")
if action is None:
return sample
current = action.shape[-1]
if current < self.target_dim:
sample["action"] = F_nn.pad(action, (0, self.target_dim - current))
return sample
@register_dataset_transform("pad_state")
class PadState(DatasetTransformStep):
"""Zero-pad ``observation.state`` to *target_dim* along the last axis."""
def __init__(self, target_dim: int):
self.target_dim = target_dim
def __call__(self, sample: dict) -> dict:
state = sample.get("observation.state")
if state is None:
return sample
current = state.shape[-1]
if current < self.target_dim:
sample["observation.state"] = F_nn.pad(state, (0, self.target_dim - current))
return sample
@register_dataset_transform("resize_images")
class ResizeImages(DatasetTransformStep):
"""Resize all image/video camera tensors to (height, width)."""
def __init__(self, height: int, width: int):
self.size = (height, width)
def __call__(self, sample: dict) -> dict:
for key in list(sample.keys()):
if not key.startswith("observation.images."):
continue
img = sample[key]
if not isinstance(img, torch.Tensor) or img.ndim < 3:
continue
sample[key] = F.resize(img, self.size, antialias=True)
return sample
class DatasetTransformPipeline:
"""Sequential pipeline of DatasetTransformStep instances."""
def __init__(self, configs: list[DatasetTransformStepConfig] | None = None):
self.steps: list[DatasetTransformStep] = []
if configs:
for cfg in configs:
self.steps.append(self._build(cfg))
@staticmethod
def _build(cfg: DatasetTransformStepConfig) -> DatasetTransformStep:
cls = _DATASET_TRANSFORM_REGISTRY.get(cfg.type)
if cls is None:
raise ValueError(
f"Unknown dataset transform '{cfg.type}'. "
f"Available: {list(_DATASET_TRANSFORM_REGISTRY)}"
)
return cls(**cfg.kwargs)
def __call__(self, sample: dict) -> dict:
for step in self.steps:
sample = step(sample)
return sample
def __repr__(self) -> str:
return f"DatasetTransformPipeline(steps={self.steps})"
+99
View File
@@ -346,6 +346,105 @@ class LiberoEnv(EnvConfig):
return kwargs
@EnvConfig.register_subclass("libero_plus")
@dataclass
class LiberoPlusEnv(LiberoEnv):
"""Alias config for LIBERO-plus benchmarks.
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
config reuses the existing LIBERO env implementation while making intent explicit
in experiment configs (`env.type=libero_plus`).
"""
task: str = "libero_spatial"
@EnvConfig.register_subclass("robocasa")
@dataclass
class RoboCasaEnv(EnvConfig):
"""RoboCasa kitchen composite-task environments.
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
action space and a structured pixel + state observation dict.
Selected benchmark tasks (3 short + 2 long):
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
Long: PrepareCoffee, RestockPantry
"""
task: str = "PickPlaceCounterToCabinet"
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
fps: int = 20
episode_length: int = 500
image_size: int = 128
split: str = "target" # "pretrain" or "target"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"agentview_left": f"{OBS_IMAGES}.agentview_left",
"agentview_right": f"{OBS_IMAGES}.agentview_right",
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
"robot_state": OBS_STATE,
}
)
def __post_init__(self):
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
self.features[cam] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
)
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
@property
def gym_kwargs(self) -> dict:
return {"split": self.split}
@EnvConfig.register_subclass("robomme")
@dataclass
class RoboMMEEnv(EnvConfig):
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
Uses BenchmarkEnvBuilder from the robomme package.
"""
task: str = "PickXtimes"
fps: int = 10
episode_length: int = 300
action_space: str = "joint_angle"
dataset_split: str = "test"
task_ids: list[int] | None = None
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"front_rgb": f"{OBS_IMAGES}.front",
"wrist_rgb": f"{OBS_IMAGES}.wrist",
OBS_STATE: OBS_STATE,
}
)
@property
def gym_kwargs(self) -> dict:
return {
"action_space": self.action_space,
"dataset": self.dataset_split,
}
@EnvConfig.register_subclass("metaworld")
@dataclass
class MetaworldEnv(EnvConfig):
+50 -3
View File
@@ -20,11 +20,21 @@ import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
from lerobot.envs.configs import (
AlohaEnv,
EnvConfig,
HubEnvConfig,
IsaaclabArenaEnv,
LiberoEnv,
LiberoPlusEnv,
PushtEnv,
RoboCasaEnv,
RoboMMEEnv,
)
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -35,6 +45,12 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
elif env_type == "libero_plus":
return LiberoPlusEnv(**kwargs)
elif env_type == "robocasa":
return RoboCasaEnv(**kwargs)
elif env_type == "robomme":
return RoboMMEEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -70,9 +86,13 @@ def make_env_pre_post_processors(
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
preprocessor_steps.append(RoboCasaProcessorStep())
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
# Parse comma-separated keys (handle None for state-based policies)
@@ -181,6 +201,33 @@ def make_env(
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "robocasa" in cfg.type:
from lerobot.envs.robocasa import create_robocasa_envs
tasks = cfg.tasks if cfg.tasks else [cfg.task]
return create_robocasa_envs(
tasks=tasks,
n_envs=n_envs,
image_size=cfg.image_size,
split=cfg.split,
episode_length=cfg.episode_length,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
elif "robomme" in cfg.type:
from lerobot.envs.robomme import create_robomme_envs
return create_robomme_envs(
task=cfg.task,
n_envs=n_envs,
action_space_type=cfg.action_space,
dataset=cfg.dataset_split,
episode_length=cfg.episode_length,
task_ids=cfg.task_ids,
env_cls=env_cls,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
+8 -2
View File
@@ -26,8 +26,14 @@ import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
try:
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
except ImportError:
# LIBERO-plus may be installed from source with an extra nested package level.
from libero.libero.libero import benchmark, get_libero_path
from libero.libero.libero.envs import OffScreenRenderEnv
from lerobot.processor import RobotObservation
+273
View File
@@ -0,0 +1,273 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# Action layout (flat 12D, normalized to [-1, 1]):
# [0:3] end_effector_position (delta x, y, z)
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
# [6:7] gripper_close (open=-1, close=+1)
# [7:11] base_motion (x, y, theta, torso_height)
# [11:12] control_mode (arm=-1, base=+1)
ACTION_DIM = 12
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
# Proprioceptive state layout (flat 16D):
# [0:2] gripper_qpos
# [2:5] base_position
# [5:9] base_rotation (quaternion)
# [9:12] end_effector_position_relative
# [12:16] end_effector_rotation_relative (quaternion)
STATE_DIM = 16
# Obs dict keys from RoboCasaGymEnv.get_observation()
_CAM_KEYS = (
"video.robot0_agentview_left",
"video.robot0_agentview_right",
"video.robot0_eye_in_hand",
)
_STATE_KEYS_ORDERED = (
"state.gripper_qpos", # (2,)
"state.base_position", # (3,)
"state.base_rotation", # (4,)
"state.end_effector_position_relative", # (3,)
"state.end_effector_rotation_relative", # (4,)
)
# Mapping from video.* key → short image name used in features_map
CAM_KEY_TO_NAME = {
"video.robot0_agentview_left": "agentview_left",
"video.robot0_agentview_right": "agentview_right",
"video.robot0_eye_in_hand": "eye_in_hand",
}
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
return {
"action.end_effector_position": flat[0:3],
"action.end_effector_rotation": flat[3:6],
"action.gripper_close": flat[6:7],
"action.base_motion": flat[7:11],
"action.control_mode": flat[11:12],
}
class RoboCasaEnv(gym.Env):
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
and a structured observation dict compatible with LeRobot policies.
Observations returned by step/reset:
{
"pixels": {
"agentview_left": (H, W, 3) uint8,
"agentview_right": (H, W, 3) uint8,
"eye_in_hand": (H, W, 3) uint8,
},
"robot_state": (16,) float32,
}
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
"""
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
def __init__(
self,
task: str,
split: str = "target",
image_size: int = 128,
render_mode: str = "rgb_array",
episode_length: int = 500,
**gym_kwargs: Any,
):
super().__init__()
# Lazy import — robocasa is optional
import robocasa.environments # noqa: F401 — registers all gym envs
self.task = task
self.render_mode = render_mode
self.image_size = image_size
self._max_episode_steps = episode_length
self._step_count = 0
self._env = gym.make(
f"robocasa/{task}",
split=split,
camera_widths=image_size,
camera_heights=image_size,
**gym_kwargs,
)
# Flat 12D Box action space
self.action_space = spaces.Box(
low=ACTION_LOW,
high=ACTION_HIGH,
shape=(ACTION_DIM,),
dtype=np.float32,
)
images = {
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
for name in CAM_KEY_TO_NAME.values()
}
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"robot_state": spaces.Box(
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
),
}
)
def _format_obs(self, raw_obs: dict) -> dict:
pixels = {
CAM_KEY_TO_NAME[k]: raw_obs[k]
for k in _CAM_KEYS
if k in raw_obs
}
state_parts = [
np.asarray(raw_obs[k], dtype=np.float32)
for k in _STATE_KEYS_ORDERED
if k in raw_obs
]
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
return {"pixels": pixels, "robot_state": robot_state}
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
super().reset(seed=seed)
self._step_count = 0
raw_obs, info = self._env.reset(seed=seed)
info.setdefault("is_success", False)
info["task"] = self.task
return self._format_obs(raw_obs), info
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
raise ValueError(
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
)
action_dict = _flat_to_action_dict(action)
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
self._step_count += 1
is_success = bool(info.get("success", False))
terminated = terminated or is_success
if self._step_count >= self._max_episode_steps:
truncated = True
info.update({"task": self.task, "is_success": is_success})
obs = self._format_obs(raw_obs)
if terminated or truncated:
info["final_info"] = {"task": self.task, "is_success": is_success}
return obs, reward, terminated, truncated, info
def render(self) -> np.ndarray | None:
if self.render_mode == "rgb_array":
return self._env.render()
return None
def close(self) -> None:
self._env.close()
def _make_env_fns(
*,
task: str,
n_envs: int,
image_size: int,
split: str,
episode_length: int,
gym_kwargs: dict[str, Any],
) -> list[Callable[[], RoboCasaEnv]]:
"""Build n_envs factory callables for a single task."""
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
return RoboCasaEnv(
task=task,
split=split,
image_size=image_size,
episode_length=episode_length,
**gym_kwargs,
)
return [partial(_make, i) for i in range(n_envs)]
def create_robocasa_envs(
tasks: str | Sequence[str],
n_envs: int,
image_size: int = 128,
split: str = "target",
episode_length: int = 500,
gym_kwargs: dict[str, Any] | None = None,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
) -> dict[str, dict[int, Any]]:
"""Create vectorized RoboCasa environments.
Args:
tasks: A single task name or list of task names (without "robocasa/" prefix).
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
n_envs: Number of parallel envs per task.
image_size: Square image resolution for all cameras.
split: RoboCasa dataset split — "pretrain" or "target".
episode_length: Max steps per episode before truncation.
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
Returns:
dict[task_name][task_id=0] -> vec_env
"""
if env_cls is None or not callable(env_cls):
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
if not isinstance(n_envs, int) or n_envs <= 0:
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
if isinstance(tasks, str):
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
else:
task_list = [str(t).strip() for t in tasks if str(t).strip()]
if not task_list:
raise ValueError("`tasks` must contain at least one task name.")
gym_kwargs = dict(gym_kwargs or {})
out: dict[str, dict[int, Any]] = defaultdict(dict)
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
for task in task_list:
fns = _make_env_fns(
task=task,
n_envs=n_envs,
image_size=image_size,
split=split,
episode_length=episode_length,
gym_kwargs=gym_kwargs,
)
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
print(f" Built vec env | task={task} | n_envs={n_envs}")
return {suite: dict(task_map) for suite, task_map in out.items()}
+154
View File
@@ -0,0 +1,154 @@
"""RoboMME environment wrapper for LeRobot evaluation.
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
``VectorEnv`` suitable for ``lerobot_eval``.
RoboMME tasks:
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
"""
from __future__ import annotations
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
ROBOMME_TASKS = [
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
]
class RoboMMEGymEnv(gym.Env):
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
metadata = {"render_modes": ["rgb_array"]}
def __init__(
self,
task: str = "PickXtimes",
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_idx: int = 0,
max_steps: int = 300,
):
super().__init__()
from robomme.env_record_wrapper import BenchmarkEnvBuilder
self._task = task
self._action_space_type = action_space_type
self._dataset = dataset
self._episode_idx = episode_idx
self._max_steps = max_steps
self._builder = BenchmarkEnvBuilder(
env_id=task,
dataset=dataset,
action_space=action_space_type,
gui_render=False,
max_steps=max_steps,
)
self._env = None
action_dim = 8 if action_space_type == "joint_angle" else 7
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
self.observation_space = spaces.Dict({
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
})
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self._env = self._builder.make_env_for_episode(
episode_idx=self._episode_idx, max_steps=self._max_steps,
)
obs, info = self._env.reset()
return self._convert_obs(obs), self._convert_info(info)
def step(self, action):
obs, reward, terminated, truncated, info = self._env.step(action)
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
status = info.get("status", "ongoing")
is_success = status == "success"
conv_info = self._convert_info(info)
conv_info["is_success"] = is_success
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
def _convert_obs(self, obs: dict) -> dict:
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
state = np.concatenate([joint, gripper])
return {
"front_rgb": front_rgb,
"wrist_rgb": wrist_rgb,
"state": state,
}
def _convert_info(self, info: dict) -> dict:
return {
"status": info.get("status", "ongoing"),
"task_goal": info.get("task_goal", ""),
}
def create_robomme_envs(
task: str,
n_envs: int = 1,
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_length: int = 300,
task_ids: list[int] | None = None,
env_cls=None,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Create vectorized RoboMME environments for evaluation.
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
"""
if env_cls is None:
env_cls = gym.vector.SyncVectorEnv
if task_ids is None:
task_ids = [0]
suite_name = "robomme"
envs_by_task = {}
for task_id in task_ids:
def _make_one(ep_idx=task_id):
return RoboMMEGymEnv(
task=task,
action_space_type=action_space_type,
dataset=dataset,
episode_idx=ep_idx,
max_steps=episode_length,
)
vec = env_cls(
[_make_one for _ in range(n_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
envs_by_task[task_id] = vec
return {suite_name: envs_by_task}
+38
View File
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
return result
@dataclass
@ProcessorStepRegistry.register(name="robocasa_processor")
class RoboCasaProcessorStep(ObservationProcessorStep):
"""
Processes RoboCasa observations into LeRobot format.
The RoboCasaEnv wrapper returns:
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
- ``observation.robot_state``: (B, 16) float32 proprioception
This step remaps them to:
- ``observation.images.<cam_name>`` (unchanged tensor)
- ``observation.state`` (robot_state renamed)
"""
def _process_observation(self, observation: dict) -> dict:
processed = {}
obs_prefix = OBS_PREFIX # "observation."
for key, value in observation.items():
if key.startswith(f"{OBS_IMAGES}."):
# Already in the right place; pass through
processed[key] = value
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
# Rename robot_state → observation.state
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
return processed
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def observation(self, observation: dict) -> dict:
return self._process_observation(observation)
@dataclass
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
+17 -4
View File
@@ -29,7 +29,8 @@ from tqdm import tqdm
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import close_envs
@@ -343,13 +344,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
if isinstance(dataset, NewMultiLeRobotDataset):
shuffle = False
sampler = WeightedEpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
dataset_membership=dataset.meta.episodes["dataset_source"],
dataset_weights=dataset.dataset_weights,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
elif drop_n_last > 0:
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
else:
@@ -360,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle and not cfg.dataset.streaming,
shuffle=shuffle and not getattr(cfg.dataset, "streaming", False),
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
+176
View File
@@ -0,0 +1,176 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RoboCasa LeRobot integration.
Requires: robocasa installed + kitchen assets downloaded.
Tests are skipped automatically if robocasa is not available.
"""
from __future__ import annotations
import numpy as np
import pytest
# Skip entire module if robocasa is not installed or assets are missing
robocasa = pytest.importorskip("robocasa", reason="robocasa not installed")
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM, CAM_KEY_TO_NAME, RoboCasaEnv, create_robocasa_envs
# The 5 benchmark tasks (3 short + 2 long)
BENCHMARK_TASKS = [
"PickPlaceCounterToCabinet", # short
"PrepareToast", # short
"CoffeeSetupMug", # short
"PrepareCoffee", # long
"RestockPantry", # long
]
SHORT_TASKS = BENCHMARK_TASKS[:3]
LONG_TASKS = BENCHMARK_TASKS[3:]
IMAGE_SIZE = 64 # small for fast tests
@pytest.fixture(scope="module")
def single_env():
"""Shared env instance for lightweight tests."""
env = RoboCasaEnv(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
yield env
env.close()
class TestRoboCasaEnvSpaces:
def test_action_space_is_flat_box(self, single_env):
import gymnasium as gym
assert isinstance(single_env.action_space, gym.spaces.Box)
assert single_env.action_space.shape == (ACTION_DIM,)
assert single_env.action_space.dtype == np.float32
def test_action_bounds(self, single_env):
assert np.all(single_env.action_space.low == -1.0)
assert np.all(single_env.action_space.high == 1.0)
def test_observation_space_has_pixels_and_state(self, single_env):
import gymnasium as gym
assert isinstance(single_env.observation_space, gym.spaces.Dict)
assert "pixels" in single_env.observation_space.spaces
assert "robot_state" in single_env.observation_space.spaces
def test_observation_space_cameras(self, single_env):
pixels_space = single_env.observation_space["pixels"]
expected_cams = set(CAM_KEY_TO_NAME.values())
assert set(pixels_space.spaces.keys()) == expected_cams
def test_state_dim(self, single_env):
state_space = single_env.observation_space["robot_state"]
assert state_space.shape == (STATE_DIM,)
class TestRoboCasaEnvReset:
def test_reset_returns_obs_and_info(self, single_env):
obs, info = single_env.reset()
assert isinstance(obs, dict)
assert isinstance(info, dict)
def test_reset_obs_has_pixels(self, single_env):
obs, _ = single_env.reset()
assert "pixels" in obs
for cam_name in CAM_KEY_TO_NAME.values():
assert cam_name in obs["pixels"], f"Missing camera: {cam_name}"
def test_reset_obs_image_shape(self, single_env):
obs, _ = single_env.reset()
for cam_name, img in obs["pixels"].items():
assert img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3), f"Bad shape for {cam_name}: {img.shape}"
assert img.dtype == np.uint8
def test_reset_obs_state_shape(self, single_env):
obs, _ = single_env.reset()
assert obs["robot_state"].shape == (STATE_DIM,)
assert obs["robot_state"].dtype == np.float32
def test_reset_info_has_task(self, single_env):
_, info = single_env.reset()
assert "task" in info
assert info["task"] == "PickPlaceCounterToCabinet"
class TestRoboCasaEnvStep:
def test_step_10_random_actions(self, single_env):
single_env.reset()
for _ in range(10):
action = single_env.action_space.sample()
obs, reward, terminated, truncated, info = single_env.step(action)
assert obs["robot_state"].shape == (STATE_DIM,)
assert isinstance(reward, float)
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
def test_step_bad_action_raises(self, single_env):
single_env.reset()
with pytest.raises(ValueError, match="Expected 1-D action"):
single_env.step(np.zeros((2, ACTION_DIM)))
def test_step_info_has_is_success(self, single_env):
single_env.reset()
_, _, _, _, info = single_env.step(single_env.action_space.sample())
assert "is_success" in info
class TestRoboCasaConfig:
def test_robocasa_env_config(self):
from lerobot.envs.configs import RoboCasaEnv as RoboCasaEnvConfig
from lerobot.configs.types import FeatureType
cfg = RoboCasaEnvConfig(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
assert cfg.type == "robocasa"
# action feature
assert "action" in cfg.features
assert cfg.features["action"].shape == (ACTION_DIM,)
# camera features
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
assert cam in cfg.features
assert cfg.features[cam].type == FeatureType.VISUAL
assert cfg.features[cam].shape == (IMAGE_SIZE, IMAGE_SIZE, 3)
# state feature
assert "robot_state" in cfg.features
assert cfg.features["robot_state"].shape == (STATE_DIM,)
def test_make_env_config_robocasa(self):
from lerobot.envs.factory import make_env_config
cfg = make_env_config("robocasa", task="PickPlaceCounterToCabinet")
assert cfg.type == "robocasa"
class TestRoboCasaProcessorStep:
def test_processor_remaps_keys(self):
import torch
from lerobot.processor.env_processor import RoboCasaProcessorStep
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
step = RoboCasaProcessorStep()
B = 2
obs = {
f"{OBS_IMAGES}.agentview_left": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
f"{OBS_IMAGES}.agentview_right": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
f"{OBS_IMAGES}.eye_in_hand": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
f"observation.robot_state": torch.zeros(B, STATE_DIM),
}
out = step._process_observation(obs)
assert OBS_STATE in out
assert out[OBS_STATE].dtype == torch.float32
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
assert f"{OBS_IMAGES}.{cam}" in out