mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c5925399a9 | |||
| f478ae5bfa | |||
| b4d40d0228 | |||
| db5c26f07d | |||
| 8904768db4 |
@@ -31,6 +31,8 @@
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
- local: multi_dataset_training
|
||||
title: Multi-Dataset Training
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
# Multi-Dataset Training
|
||||
|
||||
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
|
||||
|
||||
## Overview
|
||||
|
||||
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
|
||||
|
||||
`MultiLeRobotDataset` lets you train on all of them jointly by:
|
||||
|
||||
- **Mapping** each dataset's feature keys into a shared namespace
|
||||
- **Padding** features that a dataset doesn't have with zeros
|
||||
- **Weighting** how often each dataset is sampled
|
||||
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
|
||||
- **Aggregating** statistics across all sub-datasets for normalization
|
||||
|
||||
## Configuration
|
||||
|
||||
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
|
||||
|
||||
### SubDatasetConfig fields
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
|
||||
| `root` | `str \| None` | `None` | Local root directory for the dataset |
|
||||
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
|
||||
| `revision` | `str \| None` | `None` | Dataset version / revision |
|
||||
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
|
||||
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
|
||||
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
|
||||
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
|
||||
|
||||
### Example: Three-dataset config
|
||||
|
||||
```yaml
|
||||
dataset:
|
||||
type: multi
|
||||
use_imagenet_stats: true
|
||||
datasets:
|
||||
# RoboCasa: 3 cameras, state(16), action(12)
|
||||
- repo_id: pepijn223/robocasa_PrepareCoffee
|
||||
root: /data/robocasa_PrepareCoffee
|
||||
weight: 1.0
|
||||
feature_map:
|
||||
observation.images.robot0_agentview_left: observation.images.front_left
|
||||
observation.images.robot0_agentview_right: observation.images.front_right
|
||||
observation.images.robot0_eye_in_hand: observation.images.wrist
|
||||
|
||||
# LIBERO-plus: 2 cameras, state(8), action(7)
|
||||
- repo_id: pepijn223/libero_plus_lerobot
|
||||
root: /data/libero_plus_lerobot
|
||||
weight: 0.5
|
||||
feature_map:
|
||||
observation.images.front: observation.images.front_left
|
||||
observation.images.wrist: observation.images.wrist
|
||||
transforms:
|
||||
- type: pad_action
|
||||
kwargs: {target_dim: 12}
|
||||
- type: pad_state
|
||||
kwargs: {target_dim: 16}
|
||||
|
||||
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
|
||||
- repo_id: pepijn223/robomme_data_lerobot
|
||||
root: /data/robomme_data_lerobot
|
||||
weight: 0.3
|
||||
feature_map:
|
||||
image: observation.images.front_left
|
||||
wrist_image: observation.images.wrist
|
||||
state: observation.state
|
||||
actions: action
|
||||
transforms:
|
||||
- type: pad_action
|
||||
kwargs: {target_dim: 12}
|
||||
- type: pad_state
|
||||
kwargs: {target_dim: 16}
|
||||
```
|
||||
|
||||
## Feature Mapping
|
||||
|
||||
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
|
||||
|
||||
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
|
||||
|
||||
## Automatic Padding
|
||||
|
||||
When a sub-dataset doesn't have a feature that exists in the unified schema:
|
||||
|
||||
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
|
||||
- **Float tensors** (state, action): padded with zeros
|
||||
- **Integer/bool tensors**: padded with zeros / False
|
||||
|
||||
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
|
||||
|
||||
## Per-Dataset Transforms
|
||||
|
||||
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
|
||||
|
||||
### Built-in transforms
|
||||
|
||||
| Name | Description | Parameters |
|
||||
|------|-------------|------------|
|
||||
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
|
||||
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
|
||||
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
|
||||
|
||||
### Custom transforms
|
||||
|
||||
You can register your own transforms in `lerobot/datasets/transforms.py`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
|
||||
|
||||
@register_dataset_transform("my_transform")
|
||||
class MyTransform(DatasetTransformStep):
|
||||
def __init__(self, some_param: int):
|
||||
self.some_param = some_param
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
# Modify sample in-place or return a new dict
|
||||
sample["action"] = sample["action"] * self.some_param
|
||||
return sample
|
||||
```
|
||||
|
||||
Then reference it in the config:
|
||||
|
||||
```yaml
|
||||
transforms:
|
||||
- type: my_transform
|
||||
kwargs: {some_param: 2}
|
||||
```
|
||||
|
||||
## Weighted Sampling
|
||||
|
||||
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
|
||||
|
||||
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
|
||||
|
||||
## Stats Aggregation
|
||||
|
||||
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
|
||||
|
||||
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
|
||||
|
||||
## Training
|
||||
|
||||
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--config_path path/to/multi_dataset_config.yaml
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The data flow during training with `MultiLeRobotDataset`:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ MultiLeRobotDataset.__getitem__(global_idx) │
|
||||
│ │
|
||||
│ 1. Map global_idx → (dataset_idx, local_idx) │
|
||||
│ 2. Fetch sample from sub-dataset │
|
||||
│ 3. Rename keys via feature_map │
|
||||
│ 4. Apply per-dataset transforms (pad_action, etc.) │
|
||||
│ 5. Zero-pad missing features + add _is_pad flags │
|
||||
│ 6. Add dataset_index tag │
|
||||
└─────────────────────┬───────────────────────────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ PyTorch DataLoader │
|
||||
│ (collates into batch) │
|
||||
└────────────┬────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ LeRobot Preprocessor │
|
||||
│ (normalize, tokenize) │
|
||||
└────────────┬────────────┘
|
||||
│
|
||||
┌────────────▼────────────┐
|
||||
│ Policy forward + loss │
|
||||
└─────────────────────────┘
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### `NewMultiLeRobotDataset`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||
|
||||
dataset = NewMultiLeRobotDataset(
|
||||
configs=[...], # list[SubDatasetConfig]
|
||||
image_transforms=None, # optional image augmentation
|
||||
delta_timestamps=None, # optional temporal neighbors
|
||||
tolerance_s=1e-4, # timestamp tolerance
|
||||
)
|
||||
|
||||
dataset.num_frames # total frames across all sub-datasets
|
||||
dataset.num_episodes # total episodes
|
||||
dataset.meta # MultiDatasetMeta (stats, features, episodes)
|
||||
dataset.dataset_weights # list of per-dataset weights
|
||||
dataset.features # unified feature dict (union of all mapped features)
|
||||
dataset.camera_keys # unified camera key list
|
||||
```
|
||||
|
||||
### `WeightedEpisodeAwareSampler`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
|
||||
|
||||
sampler = WeightedEpisodeAwareSampler(
|
||||
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
|
||||
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
|
||||
dataset_membership=dataset.meta.episodes["dataset_source"],
|
||||
dataset_weights=dataset.dataset_weights,
|
||||
shuffle=True,
|
||||
)
|
||||
```
|
||||
|
||||
### `DatasetTransformPipeline`
|
||||
|
||||
```python
|
||||
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
|
||||
|
||||
pipeline = DatasetTransformPipeline([
|
||||
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
|
||||
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
|
||||
])
|
||||
|
||||
sample = pipeline(sample) # modifies the sample dict
|
||||
```
|
||||
@@ -175,6 +175,14 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero_plus = [
|
||||
"lerobot[transformers-dep]",
|
||||
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
robomme = [
|
||||
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
||||
]
|
||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# All
|
||||
|
||||
@@ -16,18 +16,13 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.datasets.transforms import ImageTransformsConfig
|
||||
from lerobot.datasets.transforms import DatasetTransformStepConfig, ImageTransformsConfig
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
@@ -37,6 +32,32 @@ class DatasetConfig:
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubDatasetConfig:
|
||||
"""Configuration for a single dataset within a MultiDatasetConfig."""
|
||||
|
||||
repo_id: str
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
revision: str | None = None
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
weight: float = 1.0
|
||||
# Maps dataset-local feature keys to unified policy keys.
|
||||
# Keys not listed pass through unchanged.
|
||||
feature_map: dict[str, str] = field(default_factory=dict)
|
||||
# Per-dataset transforms applied after feature renaming, before cross-dataset padding.
|
||||
transforms: list[DatasetTransformStepConfig] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDatasetConfig:
|
||||
"""Configuration for training on multiple datasets jointly."""
|
||||
|
||||
datasets: list[SubDatasetConfig] = field(default_factory=list)
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
use_imagenet_stats: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
enable: bool = False
|
||||
|
||||
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, MultiDatasetConfig, PeftConfig, WandBConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -35,7 +35,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
dataset: DatasetConfig | MultiDatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
@@ -129,8 +129,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
if isinstance(self.dataset, MultiDatasetConfig):
|
||||
if len(self.dataset.datasets) < 1:
|
||||
raise ValueError("MultiDatasetConfig.datasets must contain at least one sub-dataset.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
@@ -143,8 +144,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
if self.use_rabc and not self.rabc_progress_path and isinstance(self.dataset, DatasetConfig):
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
|
||||
@@ -18,13 +18,14 @@ from pprint import pformat
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.default import DatasetConfig, MultiDatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
@@ -68,66 +69,81 @@ def resolve_delta_timestamps(
|
||||
return delta_timestamps
|
||||
|
||||
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
|
||||
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | NewMultiLeRobotDataset:
|
||||
"""Create a single or multi-dataset depending on the config type.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset | MultiLeRobotDataset
|
||||
LeRobotDataset | NewMultiLeRobotDataset
|
||||
"""
|
||||
if isinstance(cfg.dataset, MultiDatasetConfig):
|
||||
return _make_multi_dataset(cfg)
|
||||
|
||||
return _make_single_dataset(cfg)
|
||||
|
||||
|
||||
def _make_single_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset:
|
||||
ds_cfg: DatasetConfig = cfg.dataset # type: ignore[assignment]
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
ImageTransforms(ds_cfg.image_transforms) if ds_cfg.image_transforms.enable else None
|
||||
)
|
||||
ds_meta = LeRobotDatasetMetadata(ds_cfg.repo_id, root=ds_cfg.root, revision=ds_cfg.revision)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
if not ds_cfg.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
ds_cfg.repo_id,
|
||||
root=ds_cfg.root,
|
||||
episodes=ds_cfg.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
revision=ds_cfg.revision,
|
||||
video_backend=ds_cfg.video_backend,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
ds_cfg.repo_id,
|
||||
root=ds_cfg.root,
|
||||
episodes=ds_cfg.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=ds_cfg.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
if ds_cfg.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
for stats_type, stats_val in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _make_multi_dataset(cfg: TrainPipelineConfig) -> NewMultiLeRobotDataset:
|
||||
multi_cfg: MultiDatasetConfig = cfg.dataset # type: ignore[assignment]
|
||||
image_transforms = (
|
||||
ImageTransforms(multi_cfg.image_transforms) if multi_cfg.image_transforms.enable else None
|
||||
)
|
||||
|
||||
dataset = NewMultiLeRobotDataset(
|
||||
configs=multi_cfg.datasets,
|
||||
image_transforms=image_transforms,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"MultiLeRobotDataset created with %d sub-datasets:\n%s",
|
||||
len(multi_cfg.datasets),
|
||||
pformat(
|
||||
{i: c.repo_id for i, c in enumerate(multi_cfg.datasets)},
|
||||
indent=2,
|
||||
),
|
||||
)
|
||||
|
||||
if multi_cfg.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
for stats_type, stats_val in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -0,0 +1,364 @@
|
||||
"""MultiLeRobotDataset: joint training over heterogeneous LeRobot datasets.
|
||||
|
||||
Supports:
|
||||
- Per-dataset feature mapping (rename keys to a unified namespace)
|
||||
- Automatic zero-padding for features missing in some datasets
|
||||
- Per-dataset transform pipelines
|
||||
- Weighted sampling via dataset weights
|
||||
- Aggregated stats across all sub-datasets
|
||||
- A ``meta`` shim compatible with EpisodeAwareSampler and make_policy
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
from lerobot.configs.default import SubDatasetConfig
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.transforms import DatasetTransformPipeline
|
||||
|
||||
|
||||
class MultiDatasetMeta:
|
||||
"""Lightweight metadata shim that exposes the same interface as ``LeRobotDatasetMetadata``.
|
||||
|
||||
Built by aggregating the metadata of multiple sub-datasets after their
|
||||
feature keys have been mapped to a unified namespace.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
feature_maps: list[dict[str, str]],
|
||||
):
|
||||
self._datasets = datasets
|
||||
self._feature_maps = feature_maps
|
||||
|
||||
self._unified_features = self._build_unified_features()
|
||||
self._episodes = self._build_episodes()
|
||||
self._stats = self._build_stats()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Feature union
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_unified_features(self) -> dict[str, dict]:
|
||||
"""Build feature dict as the *union* of all mapped feature keys."""
|
||||
unified: dict[str, dict] = {}
|
||||
for ds, fmap in zip(self._datasets, self._feature_maps):
|
||||
for original_key, feat_info in ds.meta.features.items():
|
||||
mapped_key = fmap.get(original_key, original_key)
|
||||
if mapped_key not in unified:
|
||||
unified[mapped_key] = dict(feat_info)
|
||||
else:
|
||||
existing_shape = tuple(unified[mapped_key]["shape"])
|
||||
new_shape = tuple(feat_info["shape"])
|
||||
if existing_shape != new_shape and unified[mapped_key]["dtype"] == feat_info["dtype"]:
|
||||
logging.warning(
|
||||
"Feature '%s' has shape %s in one dataset but %s in another. "
|
||||
"The larger shape will be used (padding applied automatically).",
|
||||
mapped_key,
|
||||
existing_shape,
|
||||
new_shape,
|
||||
)
|
||||
if np.prod(new_shape) > np.prod(existing_shape):
|
||||
unified[mapped_key] = dict(feat_info)
|
||||
return unified
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Episode metadata (global flat indexing)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_episodes(self) -> dict[str, list]:
|
||||
"""Concatenate episode boundaries across sub-datasets with frame offsets.
|
||||
|
||||
Produces the same column structure as ``load_episodes()`` so that
|
||||
``EpisodeAwareSampler`` and ``WeightedEpisodeAwareSampler`` can consume it.
|
||||
"""
|
||||
from_indices: list[int] = []
|
||||
to_indices: list[int] = []
|
||||
dataset_source: list[int] = []
|
||||
|
||||
frame_offset = 0
|
||||
for ds_idx, ds in enumerate(self._datasets):
|
||||
eps = ds.meta.episodes
|
||||
for ep in eps:
|
||||
from_indices.append(ep["dataset_from_index"] + frame_offset)
|
||||
to_indices.append(ep["dataset_to_index"] + frame_offset)
|
||||
dataset_source.append(ds_idx)
|
||||
frame_offset += ds.num_frames
|
||||
|
||||
return {
|
||||
"dataset_from_index": from_indices,
|
||||
"dataset_to_index": to_indices,
|
||||
"dataset_source": dataset_source,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stats aggregation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_stats(self) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats across sub-datasets using mapped feature keys."""
|
||||
mapped_stats_list: list[dict[str, dict]] = []
|
||||
for ds, fmap in zip(self._datasets, self._feature_maps):
|
||||
reverse_map = {v: k for k, v in fmap.items()}
|
||||
mapped: dict[str, dict] = {}
|
||||
for unified_key in self._unified_features:
|
||||
original_key = reverse_map.get(unified_key, unified_key)
|
||||
if original_key in ds.meta.stats:
|
||||
mapped[unified_key] = ds.meta.stats[original_key]
|
||||
mapped_stats_list.append(mapped)
|
||||
|
||||
return aggregate_stats(mapped_stats_list)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Properties matching LeRobotDatasetMetadata API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
return self._unified_features
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
return [k for k, f in self._unified_features.items() if f["dtype"] == "image"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
return [k for k, f in self._unified_features.items() if f["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return [k for k, f in self._unified_features.items() if f["dtype"] in ("video", "image")]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
return {k: f["names"] for k, f in self._unified_features.items()}
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict[str, tuple]:
|
||||
return {k: tuple(f["shape"]) for k, f in self._unified_features.items()}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
fps_values = {ds.meta.fps for ds in self._datasets}
|
||||
if len(fps_values) > 1:
|
||||
logging.warning("Sub-datasets have different FPS values: %s. Using the first.", fps_values)
|
||||
return self._datasets[0].meta.fps
|
||||
|
||||
@property
|
||||
def stats(self) -> dict[str, dict[str, np.ndarray]]:
|
||||
return self._stats
|
||||
|
||||
@stats.setter
|
||||
def stats(self, value: dict):
|
||||
self._stats = value
|
||||
|
||||
@property
|
||||
def episodes(self) -> dict[str, list]:
|
||||
return self._episodes
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
return sum(ds.meta.total_episodes for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
return sum(ds.meta.total_frames for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
return sum(ds.meta.total_tasks for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def info(self) -> dict:
|
||||
return {
|
||||
"fps": self.fps,
|
||||
"features": self._unified_features,
|
||||
"total_episodes": self.total_episodes,
|
||||
"total_frames": self.total_frames,
|
||||
"total_tasks": self.total_tasks,
|
||||
"codebase_version": "v3.0",
|
||||
}
|
||||
|
||||
|
||||
class NewMultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Dataset that wraps multiple ``LeRobotDataset`` instances with feature mapping and padding.
|
||||
|
||||
Each sub-dataset can have different feature names and shapes. A per-dataset
|
||||
``feature_map`` renames keys into a shared namespace. Features that a given
|
||||
sub-dataset does not provide are zero-padded so every ``__getitem__`` returns
|
||||
the full unified feature set.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configs: list[SubDatasetConfig],
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self._configs = configs
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
self._datasets: list[LeRobotDataset] = []
|
||||
self._feature_maps: list[dict[str, str]] = []
|
||||
self._transform_pipelines: list[DatasetTransformPipeline | None] = []
|
||||
self._weights: list[float] = []
|
||||
|
||||
for cfg in configs:
|
||||
ds = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
root=cfg.root,
|
||||
episodes=cfg.episodes,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=tolerance_s,
|
||||
revision=cfg.revision,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
self._datasets.append(ds)
|
||||
self._feature_maps.append(cfg.feature_map or {})
|
||||
self._transform_pipelines.append(
|
||||
DatasetTransformPipeline(cfg.transforms) if cfg.transforms else None
|
||||
)
|
||||
self._weights.append(cfg.weight)
|
||||
|
||||
self._meta = MultiDatasetMeta(self._datasets, self._feature_maps)
|
||||
|
||||
# Pre-compute cumulative frame counts for fast index mapping.
|
||||
self._cumulative_frames: list[int] = []
|
||||
total = 0
|
||||
for ds in self._datasets:
|
||||
total += ds.num_frames
|
||||
self._cumulative_frames.append(total)
|
||||
|
||||
# Build reverse maps (unified_key -> original_key) per dataset for padding.
|
||||
self._reverse_maps: list[dict[str, str]] = []
|
||||
for fmap in self._feature_maps:
|
||||
self._reverse_maps.append({v: k for k, v in fmap.items()})
|
||||
|
||||
logging.info(
|
||||
"MultiLeRobotDataset: %d sub-datasets, %d total frames, %d total episodes, "
|
||||
"%d unified features",
|
||||
len(self._datasets),
|
||||
self.num_frames,
|
||||
self.num_episodes,
|
||||
len(self._meta.features),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def meta(self) -> MultiDatasetMeta:
|
||||
return self._meta
|
||||
|
||||
@property
|
||||
def dataset_weights(self) -> list[float]:
|
||||
return self._weights
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return self._cumulative_frames[-1] if self._cumulative_frames else 0
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return sum(ds.num_episodes for ds in self._datasets)
|
||||
|
||||
@property
|
||||
def episodes(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._meta.fps
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
return self._meta.features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return self._meta.camera_keys
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _locate(self, idx: int) -> tuple[int, int]:
|
||||
"""Map a global frame index to (dataset_index, local_index)."""
|
||||
for ds_idx, cum in enumerate(self._cumulative_frames):
|
||||
if idx < cum:
|
||||
local = idx - (self._cumulative_frames[ds_idx - 1] if ds_idx > 0 else 0)
|
||||
return ds_idx, local
|
||||
raise IndexError(f"Index {idx} out of range (total {self.num_frames})")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_frames
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
ds_idx, local_idx = self._locate(idx)
|
||||
item = self._datasets[ds_idx][local_idx]
|
||||
|
||||
# 1. Rename keys according to feature_map.
|
||||
fmap = self._feature_maps[ds_idx]
|
||||
if fmap:
|
||||
renamed: dict[str, torch.Tensor] = {}
|
||||
for key, value in item.items():
|
||||
renamed[fmap.get(key, key)] = value
|
||||
item = renamed
|
||||
|
||||
# 2. Apply per-dataset transform pipeline.
|
||||
pipeline = self._transform_pipelines[ds_idx]
|
||||
if pipeline is not None:
|
||||
item = pipeline(item)
|
||||
|
||||
# 3. Pad missing features with zeros.
|
||||
reverse_map = self._reverse_maps[ds_idx]
|
||||
ds_features = self._datasets[ds_idx].meta.features
|
||||
for unified_key, feat_info in self._meta.features.items():
|
||||
if unified_key in item:
|
||||
continue
|
||||
original_key = reverse_map.get(unified_key, unified_key)
|
||||
if original_key in ds_features:
|
||||
continue
|
||||
shape = tuple(feat_info["shape"])
|
||||
dtype = feat_info["dtype"]
|
||||
if dtype in ("video", "image"):
|
||||
# Camera tensors are (C, H, W) after transforms.
|
||||
c, h, w = (shape[2], shape[0], shape[1]) if len(shape) == 3 else (3, shape[0], shape[1])
|
||||
item[unified_key] = torch.zeros(c, h, w, dtype=torch.float32)
|
||||
elif dtype in ("float32", "float64"):
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
|
||||
elif dtype in ("int32", "int64"):
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.int64)
|
||||
elif dtype == "bool":
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.bool)
|
||||
else:
|
||||
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
|
||||
item[f"{unified_key}_is_pad"] = torch.tensor(True)
|
||||
|
||||
# 4. Tag which dataset this sample came from.
|
||||
item["dataset_index"] = torch.tensor(ds_idx)
|
||||
return item
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repo_ids = [c.repo_id for c in self._configs]
|
||||
return (
|
||||
f"NewMultiLeRobotDataset(\n"
|
||||
f" repo_ids={repo_ids},\n"
|
||||
f" num_frames={self.num_frames},\n"
|
||||
f" num_episodes={self.num_episodes},\n"
|
||||
f" unified_features={list(self._meta.features.keys())},\n"
|
||||
f" weights={self._weights},\n"
|
||||
f")"
|
||||
)
|
||||
@@ -59,3 +59,80 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedEpisodeAwareSampler:
|
||||
"""Sampler that draws frames from multiple datasets according to per-dataset weights.
|
||||
|
||||
Each iteration first selects a sub-dataset proportionally to its weight, then
|
||||
uniformly samples a frame from that sub-dataset's valid index set. Episode
|
||||
boundary information is respected so that dropped frames are excluded.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: Start index for each episode (global, flat).
|
||||
dataset_to_indices: End index (exclusive) for each episode (global, flat).
|
||||
dataset_membership: Which sub-dataset each episode belongs to (integer id).
|
||||
dataset_weights: Relative sampling weight per sub-dataset.
|
||||
episode_indices_to_use: If given, only episodes in this set are used.
|
||||
drop_n_first_frames: Frames to skip at the start of each episode.
|
||||
drop_n_last_frames: Frames to skip at the end of each episode.
|
||||
shuffle: Whether to shuffle within each epoch.
|
||||
num_samples: How many samples per epoch. Defaults to total valid frames.
|
||||
generator: Optional torch.Generator for reproducibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
dataset_membership: list[int],
|
||||
dataset_weights: list[float],
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
num_samples: int | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
n_datasets = max(dataset_membership) + 1 if dataset_membership else 0
|
||||
self._per_dataset_indices: list[list[int]] = [[] for _ in range(n_datasets)]
|
||||
|
||||
episodes_to_use = set(episode_indices_to_use) if episode_indices_to_use is not None else None
|
||||
|
||||
for ep_idx, (start, end, ds_id) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, dataset_membership, strict=True)
|
||||
):
|
||||
if episodes_to_use is not None and ep_idx not in episodes_to_use:
|
||||
continue
|
||||
frame_range = range(start + drop_n_first_frames, end - drop_n_last_frames)
|
||||
self._per_dataset_indices[ds_id].extend(frame_range)
|
||||
|
||||
# Normalise weights (only over datasets that actually have frames).
|
||||
raw_weights = list(dataset_weights[:n_datasets])
|
||||
self._weights = torch.zeros(n_datasets)
|
||||
for i, w in enumerate(raw_weights):
|
||||
if len(self._per_dataset_indices[i]) > 0:
|
||||
self._weights[i] = w
|
||||
total_w = self._weights.sum()
|
||||
if total_w > 0:
|
||||
self._weights /= total_w
|
||||
|
||||
self._total_frames = sum(len(idx) for idx in self._per_dataset_indices)
|
||||
self._num_samples = num_samples if num_samples is not None else self._total_frames
|
||||
self.shuffle = shuffle
|
||||
self._generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if not self.shuffle:
|
||||
for ds_indices in self._per_dataset_indices:
|
||||
yield from ds_indices
|
||||
return
|
||||
|
||||
for _ in range(self._num_samples):
|
||||
ds_id = int(torch.multinomial(self._weights, 1, generator=self._generator).item())
|
||||
indices = self._per_dataset_indices[ds_id]
|
||||
local_idx = int(torch.randint(len(indices), (1,), generator=self._generator).item())
|
||||
yield indices[local_idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_samples
|
||||
|
||||
@@ -14,11 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F_nn
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import (
|
||||
Transform,
|
||||
@@ -258,3 +260,114 @@ class ImageTransforms(Transform):
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
return self.tf(*inputs)
|
||||
|
||||
|
||||
|
||||
# Per-dataset transform pipeline (used by MultiLeRobotDataset)
|
||||
|
||||
@dataclass
|
||||
class DatasetTransformStepConfig:
|
||||
"""Config for a single per-dataset transform step."""
|
||||
|
||||
type: str
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
_DATASET_TRANSFORM_REGISTRY: dict[str, type["DatasetTransformStep"]] = {}
|
||||
|
||||
|
||||
def register_dataset_transform(name: str):
|
||||
"""Decorator to register a DatasetTransformStep by name."""
|
||||
|
||||
def decorator(cls: type["DatasetTransformStep"]) -> type["DatasetTransformStep"]:
|
||||
_DATASET_TRANSFORM_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class DatasetTransformStep:
|
||||
"""Base class for a single per-dataset transform applied to a sample dict."""
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@register_dataset_transform("pad_action")
|
||||
class PadAction(DatasetTransformStep):
|
||||
"""Zero-pad the ``action`` tensor to *target_dim* along the last axis."""
|
||||
|
||||
def __init__(self, target_dim: int):
|
||||
self.target_dim = target_dim
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
action = sample.get("action")
|
||||
if action is None:
|
||||
return sample
|
||||
current = action.shape[-1]
|
||||
if current < self.target_dim:
|
||||
sample["action"] = F_nn.pad(action, (0, self.target_dim - current))
|
||||
return sample
|
||||
|
||||
|
||||
@register_dataset_transform("pad_state")
|
||||
class PadState(DatasetTransformStep):
|
||||
"""Zero-pad ``observation.state`` to *target_dim* along the last axis."""
|
||||
|
||||
def __init__(self, target_dim: int):
|
||||
self.target_dim = target_dim
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
state = sample.get("observation.state")
|
||||
if state is None:
|
||||
return sample
|
||||
current = state.shape[-1]
|
||||
if current < self.target_dim:
|
||||
sample["observation.state"] = F_nn.pad(state, (0, self.target_dim - current))
|
||||
return sample
|
||||
|
||||
|
||||
@register_dataset_transform("resize_images")
|
||||
class ResizeImages(DatasetTransformStep):
|
||||
"""Resize all image/video camera tensors to (height, width)."""
|
||||
|
||||
def __init__(self, height: int, width: int):
|
||||
self.size = (height, width)
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
for key in list(sample.keys()):
|
||||
if not key.startswith("observation.images."):
|
||||
continue
|
||||
img = sample[key]
|
||||
if not isinstance(img, torch.Tensor) or img.ndim < 3:
|
||||
continue
|
||||
sample[key] = F.resize(img, self.size, antialias=True)
|
||||
return sample
|
||||
|
||||
|
||||
class DatasetTransformPipeline:
|
||||
"""Sequential pipeline of DatasetTransformStep instances."""
|
||||
|
||||
def __init__(self, configs: list[DatasetTransformStepConfig] | None = None):
|
||||
self.steps: list[DatasetTransformStep] = []
|
||||
if configs:
|
||||
for cfg in configs:
|
||||
self.steps.append(self._build(cfg))
|
||||
|
||||
@staticmethod
|
||||
def _build(cfg: DatasetTransformStepConfig) -> DatasetTransformStep:
|
||||
cls = _DATASET_TRANSFORM_REGISTRY.get(cfg.type)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown dataset transform '{cfg.type}'. "
|
||||
f"Available: {list(_DATASET_TRANSFORM_REGISTRY)}"
|
||||
)
|
||||
return cls(**cfg.kwargs)
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
for step in self.steps:
|
||||
sample = step(sample)
|
||||
return sample
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DatasetTransformPipeline(steps={self.steps})"
|
||||
|
||||
@@ -346,6 +346,105 @@ class LiberoEnv(EnvConfig):
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero_plus")
|
||||
@dataclass
|
||||
class LiberoPlusEnv(LiberoEnv):
|
||||
"""Alias config for LIBERO-plus benchmarks.
|
||||
|
||||
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
|
||||
config reuses the existing LIBERO env implementation while making intent explicit
|
||||
in experiment configs (`env.type=libero_plus`).
|
||||
"""
|
||||
|
||||
task: str = "libero_spatial"
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robocasa")
|
||||
@dataclass
|
||||
class RoboCasaEnv(EnvConfig):
|
||||
"""RoboCasa kitchen composite-task environments.
|
||||
|
||||
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
|
||||
action space and a structured pixel + state observation dict.
|
||||
|
||||
Selected benchmark tasks (3 short + 2 long):
|
||||
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
|
||||
Long: PrepareCoffee, RestockPantry
|
||||
"""
|
||||
|
||||
task: str = "PickPlaceCounterToCabinet"
|
||||
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
|
||||
fps: int = 20
|
||||
episode_length: int = 500
|
||||
image_size: int = 128
|
||||
split: str = "target" # "pretrain" or "target"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agentview_left": f"{OBS_IMAGES}.agentview_left",
|
||||
"agentview_right": f"{OBS_IMAGES}.agentview_right",
|
||||
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
|
||||
"robot_state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
self.features[cam] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
|
||||
)
|
||||
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"split": self.split}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robomme")
|
||||
@dataclass
|
||||
class RoboMMEEnv(EnvConfig):
|
||||
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
|
||||
|
||||
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
|
||||
Uses BenchmarkEnvBuilder from the robomme package.
|
||||
"""
|
||||
|
||||
task: str = "PickXtimes"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
action_space: str = "joint_angle"
|
||||
dataset_split: str = "test"
|
||||
task_ids: list[int] | None = None
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
|
||||
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"front_rgb": f"{OBS_IMAGES}.front",
|
||||
"wrist_rgb": f"{OBS_IMAGES}.wrist",
|
||||
OBS_STATE: OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"action_space": self.action_space,
|
||||
"dataset": self.dataset_split,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
class MetaworldEnv(EnvConfig):
|
||||
|
||||
@@ -20,11 +20,21 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import (
|
||||
AlohaEnv,
|
||||
EnvConfig,
|
||||
HubEnvConfig,
|
||||
IsaaclabArenaEnv,
|
||||
LiberoEnv,
|
||||
LiberoPlusEnv,
|
||||
PushtEnv,
|
||||
RoboCasaEnv,
|
||||
RoboMMEEnv,
|
||||
)
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
@@ -35,6 +45,12 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
elif env_type == "libero_plus":
|
||||
return LiberoPlusEnv(**kwargs)
|
||||
elif env_type == "robocasa":
|
||||
return RoboCasaEnv(**kwargs)
|
||||
elif env_type == "robomme":
|
||||
return RoboMMEEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -70,9 +86,13 @@ def make_env_pre_post_processors(
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
|
||||
preprocessor_steps.append(RoboCasaProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
@@ -181,6 +201,33 @@ def make_env(
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "robocasa" in cfg.type:
|
||||
from lerobot.envs.robocasa import create_robocasa_envs
|
||||
|
||||
tasks = cfg.tasks if cfg.tasks else [cfg.task]
|
||||
return create_robocasa_envs(
|
||||
tasks=tasks,
|
||||
n_envs=n_envs,
|
||||
image_size=cfg.image_size,
|
||||
split=cfg.split,
|
||||
episode_length=cfg.episode_length,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "robomme" in cfg.type:
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
return create_robomme_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
action_space_type=cfg.action_space,
|
||||
dataset=cfg.dataset_split,
|
||||
episode_length=cfg.episode_length,
|
||||
task_ids=cfg.task_ids,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
|
||||
@@ -26,8 +26,14 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
try:
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
except ImportError:
|
||||
# LIBERO-plus may be installed from source with an extra nested package level.
|
||||
from libero.libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
# Action layout (flat 12D, normalized to [-1, 1]):
|
||||
# [0:3] end_effector_position (delta x, y, z)
|
||||
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
|
||||
# [6:7] gripper_close (open=-1, close=+1)
|
||||
# [7:11] base_motion (x, y, theta, torso_height)
|
||||
# [11:12] control_mode (arm=-1, base=+1)
|
||||
ACTION_DIM = 12
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
# Proprioceptive state layout (flat 16D):
|
||||
# [0:2] gripper_qpos
|
||||
# [2:5] base_position
|
||||
# [5:9] base_rotation (quaternion)
|
||||
# [9:12] end_effector_position_relative
|
||||
# [12:16] end_effector_rotation_relative (quaternion)
|
||||
STATE_DIM = 16
|
||||
|
||||
# Obs dict keys from RoboCasaGymEnv.get_observation()
|
||||
_CAM_KEYS = (
|
||||
"video.robot0_agentview_left",
|
||||
"video.robot0_agentview_right",
|
||||
"video.robot0_eye_in_hand",
|
||||
)
|
||||
_STATE_KEYS_ORDERED = (
|
||||
"state.gripper_qpos", # (2,)
|
||||
"state.base_position", # (3,)
|
||||
"state.base_rotation", # (4,)
|
||||
"state.end_effector_position_relative", # (3,)
|
||||
"state.end_effector_rotation_relative", # (4,)
|
||||
)
|
||||
|
||||
# Mapping from video.* key → short image name used in features_map
|
||||
CAM_KEY_TO_NAME = {
|
||||
"video.robot0_agentview_left": "agentview_left",
|
||||
"video.robot0_agentview_right": "agentview_right",
|
||||
"video.robot0_eye_in_hand": "eye_in_hand",
|
||||
}
|
||||
|
||||
|
||||
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
|
||||
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
|
||||
return {
|
||||
"action.end_effector_position": flat[0:3],
|
||||
"action.end_effector_rotation": flat[3:6],
|
||||
"action.gripper_close": flat[6:7],
|
||||
"action.base_motion": flat[7:11],
|
||||
"action.control_mode": flat[11:12],
|
||||
}
|
||||
|
||||
|
||||
class RoboCasaEnv(gym.Env):
|
||||
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
|
||||
and a structured observation dict compatible with LeRobot policies.
|
||||
|
||||
Observations returned by step/reset:
|
||||
{
|
||||
"pixels": {
|
||||
"agentview_left": (H, W, 3) uint8,
|
||||
"agentview_right": (H, W, 3) uint8,
|
||||
"eye_in_hand": (H, W, 3) uint8,
|
||||
},
|
||||
"robot_state": (16,) float32,
|
||||
}
|
||||
|
||||
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
split: str = "target",
|
||||
image_size: int = 128,
|
||||
render_mode: str = "rgb_array",
|
||||
episode_length: int = 500,
|
||||
**gym_kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
# Lazy import — robocasa is optional
|
||||
import robocasa.environments # noqa: F401 — registers all gym envs
|
||||
|
||||
self.task = task
|
||||
self.render_mode = render_mode
|
||||
self.image_size = image_size
|
||||
self._max_episode_steps = episode_length
|
||||
self._step_count = 0
|
||||
|
||||
self._env = gym.make(
|
||||
f"robocasa/{task}",
|
||||
split=split,
|
||||
camera_widths=image_size,
|
||||
camera_heights=image_size,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
# Flat 12D Box action space
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW,
|
||||
high=ACTION_HIGH,
|
||||
shape=(ACTION_DIM,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
images = {
|
||||
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
|
||||
for name in CAM_KEY_TO_NAME.values()
|
||||
}
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"robot_state": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _format_obs(self, raw_obs: dict) -> dict:
|
||||
pixels = {
|
||||
CAM_KEY_TO_NAME[k]: raw_obs[k]
|
||||
for k in _CAM_KEYS
|
||||
if k in raw_obs
|
||||
}
|
||||
state_parts = [
|
||||
np.asarray(raw_obs[k], dtype=np.float32)
|
||||
for k in _STATE_KEYS_ORDERED
|
||||
if k in raw_obs
|
||||
]
|
||||
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
|
||||
return {"pixels": pixels, "robot_state": robot_state}
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
|
||||
super().reset(seed=seed)
|
||||
self._step_count = 0
|
||||
raw_obs, info = self._env.reset(seed=seed)
|
||||
info.setdefault("is_success", False)
|
||||
info["task"] = self.task
|
||||
return self._format_obs(raw_obs), info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(
|
||||
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
|
||||
)
|
||||
action_dict = _flat_to_action_dict(action)
|
||||
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
|
||||
self._step_count += 1
|
||||
|
||||
is_success = bool(info.get("success", False))
|
||||
terminated = terminated or is_success
|
||||
if self._step_count >= self._max_episode_steps:
|
||||
truncated = True
|
||||
|
||||
info.update({"task": self.task, "is_success": is_success})
|
||||
obs = self._format_obs(raw_obs)
|
||||
|
||||
if terminated or truncated:
|
||||
info["final_info"] = {"task": self.task, "is_success": is_success}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray | None:
|
||||
if self.render_mode == "rgb_array":
|
||||
return self._env.render()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self._env.close()
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
image_size: int,
|
||||
split: str,
|
||||
episode_length: int,
|
||||
gym_kwargs: dict[str, Any],
|
||||
) -> list[Callable[[], RoboCasaEnv]]:
|
||||
"""Build n_envs factory callables for a single task."""
|
||||
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
|
||||
return RoboCasaEnv(
|
||||
task=task,
|
||||
split=split,
|
||||
image_size=image_size,
|
||||
episode_length=episode_length,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
return [partial(_make, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robocasa_envs(
|
||||
tasks: str | Sequence[str],
|
||||
n_envs: int,
|
||||
image_size: int = 128,
|
||||
split: str = "target",
|
||||
episode_length: int = 500,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboCasa environments.
|
||||
|
||||
Args:
|
||||
tasks: A single task name or list of task names (without "robocasa/" prefix).
|
||||
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
|
||||
n_envs: Number of parallel envs per task.
|
||||
image_size: Square image resolution for all cameras.
|
||||
split: RoboCasa dataset split — "pretrain" or "target".
|
||||
episode_length: Max steps per episode before truncation.
|
||||
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
|
||||
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
|
||||
|
||||
Returns:
|
||||
dict[task_name][task_id=0] -> vec_env
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
|
||||
else:
|
||||
task_list = [str(t).strip() for t in tasks if str(t).strip()]
|
||||
if not task_list:
|
||||
raise ValueError("`tasks` must contain at least one task name.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
|
||||
for task in task_list:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
image_size=image_size,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
gym_kwargs=gym_kwargs,
|
||||
)
|
||||
out["robocasa"][len(out["robocasa"])] = env_cls(fns)
|
||||
print(f" Built vec env | task={task} | n_envs={n_envs}")
|
||||
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
@@ -0,0 +1,154 @@
|
||||
"""RoboMME environment wrapper for LeRobot evaluation.
|
||||
|
||||
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
|
||||
``VectorEnv`` suitable for ``lerobot_eval``.
|
||||
|
||||
RoboMME tasks:
|
||||
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
|
||||
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
|
||||
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
|
||||
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
|
||||
|
||||
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
ROBOMME_TASKS = [
|
||||
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
|
||||
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
|
||||
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
|
||||
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
|
||||
]
|
||||
|
||||
|
||||
class RoboMMEGymEnv(gym.Env):
|
||||
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str = "PickXtimes",
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_idx: int = 0,
|
||||
max_steps: int = 300,
|
||||
):
|
||||
super().__init__()
|
||||
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
||||
|
||||
self._task = task
|
||||
self._action_space_type = action_space_type
|
||||
self._dataset = dataset
|
||||
self._episode_idx = episode_idx
|
||||
self._max_steps = max_steps
|
||||
|
||||
self._builder = BenchmarkEnvBuilder(
|
||||
env_id=task,
|
||||
dataset=dataset,
|
||||
action_space=action_space_type,
|
||||
gui_render=False,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
self._env = None
|
||||
|
||||
action_dim = 8 if action_space_type == "joint_angle" else 7
|
||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
|
||||
self.observation_space = spaces.Dict({
|
||||
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
|
||||
})
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
self._env = self._builder.make_env_for_episode(
|
||||
episode_idx=self._episode_idx, max_steps=self._max_steps,
|
||||
)
|
||||
obs, info = self._env.reset()
|
||||
return self._convert_obs(obs), self._convert_info(info)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self._env.step(action)
|
||||
|
||||
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
|
||||
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
|
||||
|
||||
status = info.get("status", "ongoing")
|
||||
is_success = status == "success"
|
||||
conv_info = self._convert_info(info)
|
||||
conv_info["is_success"] = is_success
|
||||
|
||||
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
|
||||
|
||||
def _convert_obs(self, obs: dict) -> dict:
|
||||
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
|
||||
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
|
||||
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
|
||||
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
|
||||
|
||||
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
|
||||
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
|
||||
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
|
||||
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
|
||||
state = np.concatenate([joint, gripper])
|
||||
|
||||
return {
|
||||
"front_rgb": front_rgb,
|
||||
"wrist_rgb": wrist_rgb,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
def _convert_info(self, info: dict) -> dict:
|
||||
return {
|
||||
"status": info.get("status", "ongoing"),
|
||||
"task_goal": info.get("task_goal", ""),
|
||||
}
|
||||
|
||||
|
||||
def create_robomme_envs(
|
||||
task: str,
|
||||
n_envs: int = 1,
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_length: int = 300,
|
||||
task_ids: list[int] | None = None,
|
||||
env_cls=None,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Create vectorized RoboMME environments for evaluation.
|
||||
|
||||
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
|
||||
"""
|
||||
if env_cls is None:
|
||||
env_cls = gym.vector.SyncVectorEnv
|
||||
|
||||
if task_ids is None:
|
||||
task_ids = [0]
|
||||
|
||||
suite_name = "robomme"
|
||||
envs_by_task = {}
|
||||
|
||||
for task_id in task_ids:
|
||||
def _make_one(ep_idx=task_id):
|
||||
return RoboMMEGymEnv(
|
||||
task=task,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_idx=ep_idx,
|
||||
max_steps=episode_length,
|
||||
)
|
||||
|
||||
vec = env_cls(
|
||||
[_make_one for _ in range(n_envs)],
|
||||
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
|
||||
)
|
||||
envs_by_task[task_id] = vec
|
||||
|
||||
return {suite_name: envs_by_task}
|
||||
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robocasa_processor")
|
||||
class RoboCasaProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes RoboCasa observations into LeRobot format.
|
||||
|
||||
The RoboCasaEnv wrapper returns:
|
||||
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
|
||||
- ``observation.robot_state``: (B, 16) float32 proprioception
|
||||
|
||||
This step remaps them to:
|
||||
- ``observation.images.<cam_name>`` (unchanged tensor)
|
||||
- ``observation.state`` (robot_state renamed)
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation: dict) -> dict:
|
||||
processed = {}
|
||||
obs_prefix = OBS_PREFIX # "observation."
|
||||
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
# Already in the right place; pass through
|
||||
processed[key] = value
|
||||
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
|
||||
# Rename robot_state → observation.state
|
||||
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
|
||||
|
||||
return processed
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -29,7 +29,8 @@ from tqdm import tqdm
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||
from lerobot.envs.utils import close_envs
|
||||
@@ -343,13 +344,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
|
||||
|
||||
if isinstance(dataset, NewMultiLeRobotDataset):
|
||||
shuffle = False
|
||||
sampler = WeightedEpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
dataset_membership=dataset.meta.episodes["dataset_source"],
|
||||
dataset_weights=dataset.dataset_weights,
|
||||
drop_n_last_frames=drop_n_last,
|
||||
shuffle=True,
|
||||
)
|
||||
elif drop_n_last > 0:
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
drop_n_last_frames=drop_n_last,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
@@ -360,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle and not cfg.dataset.streaming,
|
||||
shuffle=shuffle and not getattr(cfg.dataset, "streaming", False),
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for RoboCasa LeRobot integration.
|
||||
|
||||
Requires: robocasa installed + kitchen assets downloaded.
|
||||
Tests are skipped automatically if robocasa is not available.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# Skip entire module if robocasa is not installed or assets are missing
|
||||
robocasa = pytest.importorskip("robocasa", reason="robocasa not installed")
|
||||
|
||||
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM, CAM_KEY_TO_NAME, RoboCasaEnv, create_robocasa_envs
|
||||
|
||||
# The 5 benchmark tasks (3 short + 2 long)
|
||||
BENCHMARK_TASKS = [
|
||||
"PickPlaceCounterToCabinet", # short
|
||||
"PrepareToast", # short
|
||||
"CoffeeSetupMug", # short
|
||||
"PrepareCoffee", # long
|
||||
"RestockPantry", # long
|
||||
]
|
||||
SHORT_TASKS = BENCHMARK_TASKS[:3]
|
||||
LONG_TASKS = BENCHMARK_TASKS[3:]
|
||||
|
||||
IMAGE_SIZE = 64 # small for fast tests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def single_env():
|
||||
"""Shared env instance for lightweight tests."""
|
||||
env = RoboCasaEnv(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||
yield env
|
||||
env.close()
|
||||
|
||||
|
||||
class TestRoboCasaEnvSpaces:
|
||||
def test_action_space_is_flat_box(self, single_env):
|
||||
import gymnasium as gym
|
||||
|
||||
assert isinstance(single_env.action_space, gym.spaces.Box)
|
||||
assert single_env.action_space.shape == (ACTION_DIM,)
|
||||
assert single_env.action_space.dtype == np.float32
|
||||
|
||||
def test_action_bounds(self, single_env):
|
||||
assert np.all(single_env.action_space.low == -1.0)
|
||||
assert np.all(single_env.action_space.high == 1.0)
|
||||
|
||||
def test_observation_space_has_pixels_and_state(self, single_env):
|
||||
import gymnasium as gym
|
||||
|
||||
assert isinstance(single_env.observation_space, gym.spaces.Dict)
|
||||
assert "pixels" in single_env.observation_space.spaces
|
||||
assert "robot_state" in single_env.observation_space.spaces
|
||||
|
||||
def test_observation_space_cameras(self, single_env):
|
||||
pixels_space = single_env.observation_space["pixels"]
|
||||
expected_cams = set(CAM_KEY_TO_NAME.values())
|
||||
assert set(pixels_space.spaces.keys()) == expected_cams
|
||||
|
||||
def test_state_dim(self, single_env):
|
||||
state_space = single_env.observation_space["robot_state"]
|
||||
assert state_space.shape == (STATE_DIM,)
|
||||
|
||||
|
||||
class TestRoboCasaEnvReset:
|
||||
def test_reset_returns_obs_and_info(self, single_env):
|
||||
obs, info = single_env.reset()
|
||||
assert isinstance(obs, dict)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
def test_reset_obs_has_pixels(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
assert "pixels" in obs
|
||||
for cam_name in CAM_KEY_TO_NAME.values():
|
||||
assert cam_name in obs["pixels"], f"Missing camera: {cam_name}"
|
||||
|
||||
def test_reset_obs_image_shape(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
for cam_name, img in obs["pixels"].items():
|
||||
assert img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3), f"Bad shape for {cam_name}: {img.shape}"
|
||||
assert img.dtype == np.uint8
|
||||
|
||||
def test_reset_obs_state_shape(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||
assert obs["robot_state"].dtype == np.float32
|
||||
|
||||
def test_reset_info_has_task(self, single_env):
|
||||
_, info = single_env.reset()
|
||||
assert "task" in info
|
||||
assert info["task"] == "PickPlaceCounterToCabinet"
|
||||
|
||||
|
||||
class TestRoboCasaEnvStep:
|
||||
def test_step_10_random_actions(self, single_env):
|
||||
single_env.reset()
|
||||
for _ in range(10):
|
||||
action = single_env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = single_env.step(action)
|
||||
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(terminated, bool)
|
||||
assert isinstance(truncated, bool)
|
||||
|
||||
def test_step_bad_action_raises(self, single_env):
|
||||
single_env.reset()
|
||||
with pytest.raises(ValueError, match="Expected 1-D action"):
|
||||
single_env.step(np.zeros((2, ACTION_DIM)))
|
||||
|
||||
def test_step_info_has_is_success(self, single_env):
|
||||
single_env.reset()
|
||||
_, _, _, _, info = single_env.step(single_env.action_space.sample())
|
||||
assert "is_success" in info
|
||||
|
||||
|
||||
class TestRoboCasaConfig:
|
||||
def test_robocasa_env_config(self):
|
||||
from lerobot.envs.configs import RoboCasaEnv as RoboCasaEnvConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
cfg = RoboCasaEnvConfig(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||
assert cfg.type == "robocasa"
|
||||
# action feature
|
||||
assert "action" in cfg.features
|
||||
assert cfg.features["action"].shape == (ACTION_DIM,)
|
||||
# camera features
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
assert cam in cfg.features
|
||||
assert cfg.features[cam].type == FeatureType.VISUAL
|
||||
assert cfg.features[cam].shape == (IMAGE_SIZE, IMAGE_SIZE, 3)
|
||||
# state feature
|
||||
assert "robot_state" in cfg.features
|
||||
assert cfg.features["robot_state"].shape == (STATE_DIM,)
|
||||
|
||||
def test_make_env_config_robocasa(self):
|
||||
from lerobot.envs.factory import make_env_config
|
||||
cfg = make_env_config("robocasa", task="PickPlaceCounterToCabinet")
|
||||
assert cfg.type == "robocasa"
|
||||
|
||||
|
||||
class TestRoboCasaProcessorStep:
|
||||
def test_processor_remaps_keys(self):
|
||||
import torch
|
||||
from lerobot.processor.env_processor import RoboCasaProcessorStep
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
step = RoboCasaProcessorStep()
|
||||
B = 2
|
||||
obs = {
|
||||
f"{OBS_IMAGES}.agentview_left": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"{OBS_IMAGES}.agentview_right": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"{OBS_IMAGES}.eye_in_hand": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"observation.robot_state": torch.zeros(B, STATE_DIM),
|
||||
}
|
||||
out = step._process_observation(obs)
|
||||
assert OBS_STATE in out
|
||||
assert out[OBS_STATE].dtype == torch.float32
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
assert f"{OBS_IMAGES}.{cam}" in out
|
||||
Reference in New Issue
Block a user