Compare commits

...

4 Commits

Author SHA1 Message Date
pepijn c5925399a9 style: remove decorative comment separator in transforms.py
Made-with: Cursor
2026-03-13 04:43:31 +00:00
pepijn f478ae5bfa docs: add Multi-Dataset Training guide
Covers feature mapping, auto-padding, per-dataset transforms,
weighted sampling, stats aggregation, and full config examples
for training across RoboCasa, LIBERO-plus, and RoboMME datasets.

Made-with: Cursor
2026-03-13 04:37:01 +00:00
pepijn b4d40d0228 feat: add MultiLeRobotDataset with weighted sampling and RoboMME env integration
Multi-dataset training support:
- NewMultiLeRobotDataset with per-dataset feature mapping, auto-padding,
  per-dataset transform pipelines, and weighted sampling
- MultiDatasetMeta shim compatible with EpisodeAwareSampler and make_policy
- WeightedEpisodeAwareSampler for proportional cross-dataset sampling
- SubDatasetConfig / MultiDatasetConfig in training configs
- DatasetTransformPipeline with built-in PadAction, PadState, ResizeImages
- Factory and training script wired up for multi-dataset path

RoboMME environment integration:
- RoboMMEEnv config and Gymnasium wrapper (robomme.py)
- robomme optional dependency in pyproject.toml

Made-with: Cursor
2026-03-13 04:31:35 +00:00
pepijn db5c26f07d feat(envs): add LIBERO-plus integration for evaluation benchmarks
Add LiberoPlusEnv config (subclass of LiberoEnv), register libero_plus
env type in factory, add import fallbacks for LIBERO-plus package
structure, and add libero_plus optional dependency group in pyproject.toml.

Made-with: Cursor
2026-03-12 04:31:09 +00:00
14 changed files with 1157 additions and 71 deletions
+2
View File
@@ -31,6 +31,8 @@
title: Using Subtasks in the Dataset
- local: streaming_video_encoding
title: Streaming Video Encoding
- local: multi_dataset_training
title: Multi-Dataset Training
title: "Datasets"
- sections:
- local: act
+232
View File
@@ -0,0 +1,232 @@
# Multi-Dataset Training
This guide covers how to train a single policy on multiple heterogeneous datasets using `MultiLeRobotDataset`.
## Overview
Real-world robot learning datasets come from different environments, robots, and camera setups. A RoboCasa dataset might have three cameras named `robot0_agentview_left`, `robot0_agentview_right`, and `robot0_eye_in_hand`, while a LIBERO dataset uses `observation.images.front` and `observation.images.wrist`, and a RoboMME dataset uses bare `image` and `wrist_image` keys. State and action dimensions also differ.
`MultiLeRobotDataset` lets you train on all of them jointly by:
- **Mapping** each dataset's feature keys into a shared namespace
- **Padding** features that a dataset doesn't have with zeros
- **Weighting** how often each dataset is sampled
- **Transforming** samples per-dataset (e.g. padding actions to a common dimension)
- **Aggregating** statistics across all sub-datasets for normalization
## Configuration
Multi-dataset training is configured via `MultiDatasetConfig` in a YAML config file. Instead of a single `dataset.repo_id`, you provide a `datasets` list where each entry is a `SubDatasetConfig`.
### SubDatasetConfig fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `repo_id` | `str` | required | HuggingFace repo ID or local dataset name |
| `root` | `str \| None` | `None` | Local root directory for the dataset |
| `episodes` | `list[int] \| None` | `None` | Subset of episode indices to use |
| `revision` | `str \| None` | `None` | Dataset version / revision |
| `video_backend` | `str` | auto | Video decoding backend (`pyav`, `torchcodec`, etc.) |
| `weight` | `float` | `1.0` | Relative sampling weight for this dataset |
| `feature_map` | `dict[str, str]` | `{}` | Maps dataset keys to unified policy keys |
| `transforms` | `list` | `None` | Per-dataset transform steps (applied per sample) |
### Example: Three-dataset config
```yaml
dataset:
type: multi
use_imagenet_stats: true
datasets:
# RoboCasa: 3 cameras, state(16), action(12)
- repo_id: pepijn223/robocasa_PrepareCoffee
root: /data/robocasa_PrepareCoffee
weight: 1.0
feature_map:
observation.images.robot0_agentview_left: observation.images.front_left
observation.images.robot0_agentview_right: observation.images.front_right
observation.images.robot0_eye_in_hand: observation.images.wrist
# LIBERO-plus: 2 cameras, state(8), action(7)
- repo_id: pepijn223/libero_plus_lerobot
root: /data/libero_plus_lerobot
weight: 0.5
feature_map:
observation.images.front: observation.images.front_left
observation.images.wrist: observation.images.wrist
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
# RoboMME: 2 cameras (non-standard keys), state(8), action(8)
- repo_id: pepijn223/robomme_data_lerobot
root: /data/robomme_data_lerobot
weight: 0.3
feature_map:
image: observation.images.front_left
wrist_image: observation.images.wrist
state: observation.state
actions: action
transforms:
- type: pad_action
kwargs: {target_dim: 12}
- type: pad_state
kwargs: {target_dim: 16}
```
## Feature Mapping
The `feature_map` dictionary renames dataset-local keys into a shared namespace. Keys not listed pass through unchanged. In the example above, all three datasets end up with the same camera key names (`observation.images.front_left`, `observation.images.wrist`) even though they use different conventions internally.
After mapping, the **union** of all features across datasets defines the unified schema. If a feature exists in some datasets but not others, it is automatically zero-padded for datasets that lack it, and a boolean `{key}_is_pad` flag is added to the sample so the policy can optionally mask padded features.
## Automatic Padding
When a sub-dataset doesn't have a feature that exists in the unified schema:
- **Images/videos**: padded with a black frame (zeros) matching the expected resolution
- **Float tensors** (state, action): padded with zeros
- **Integer/bool tensors**: padded with zeros / False
A companion `{key}_is_pad = True` tensor is added so the model can distinguish real data from padding.
## Per-Dataset Transforms
Each sub-dataset can have its own `transforms` pipeline that runs after feature renaming but before cross-dataset padding. This is useful for making shapes compatible before PyTorch's collate function stacks the batch.
### Built-in transforms
| Name | Description | Parameters |
|------|-------------|------------|
| `pad_action` | Zero-pad `action` to a target dimension | `target_dim: int` |
| `pad_state` | Zero-pad `observation.state` to a target dimension | `target_dim: int` |
| `resize_images` | Resize all `observation.images.*` tensors | `height: int`, `width: int` |
### Custom transforms
You can register your own transforms in `lerobot/datasets/transforms.py`:
```python
from lerobot.datasets.transforms import DatasetTransformStep, register_dataset_transform
@register_dataset_transform("my_transform")
class MyTransform(DatasetTransformStep):
def __init__(self, some_param: int):
self.some_param = some_param
def __call__(self, sample: dict) -> dict:
# Modify sample in-place or return a new dict
sample["action"] = sample["action"] * self.some_param
return sample
```
Then reference it in the config:
```yaml
transforms:
- type: my_transform
kwargs: {some_param: 2}
```
## Weighted Sampling
The `weight` field on each sub-dataset controls how often it is sampled during training. Weights are relative and automatically normalized to probabilities. For example, with weights `[1.0, 0.5, 0.3]`, the first dataset is sampled roughly 56% of the time, the second 28%, and the third 16%.
This uses `WeightedEpisodeAwareSampler`, which respects episode boundaries (so `drop_n_last_frames` and similar policy settings work correctly) while sampling across datasets proportionally.
## Stats Aggregation
Normalization statistics (mean, std, min, max, quantiles) are automatically aggregated across all sub-datasets using the mapped feature keys. The aggregation uses a weighted parallel variance algorithm so that datasets with more frames contribute proportionally to the global statistics.
The aggregated stats are used by the standard LeRobot preprocessor for normalization during training.
## Training
Launch training the same way as single-dataset training. The factory and training script automatically detect `MultiDatasetConfig` and set up the weighted sampler:
```bash
python -m lerobot.scripts.lerobot_train \
--config_path path/to/multi_dataset_config.yaml
```
## Architecture
The data flow during training with `MultiLeRobotDataset`:
```
┌─────────────────────────────────────────────────────────┐
│ MultiLeRobotDataset.__getitem__(global_idx) │
│ │
│ 1. Map global_idx → (dataset_idx, local_idx) │
│ 2. Fetch sample from sub-dataset │
│ 3. Rename keys via feature_map │
│ 4. Apply per-dataset transforms (pad_action, etc.) │
│ 5. Zero-pad missing features + add _is_pad flags │
│ 6. Add dataset_index tag │
└─────────────────────┬───────────────────────────────────┘
┌────────────▼────────────┐
│ PyTorch DataLoader │
│ (collates into batch) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ LeRobot Preprocessor │
│ (normalize, tokenize) │
└────────────┬────────────┘
┌────────────▼────────────┐
│ Policy forward + loss │
└─────────────────────────┘
```
## API Reference
### `NewMultiLeRobotDataset`
```python
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
dataset = NewMultiLeRobotDataset(
configs=[...], # list[SubDatasetConfig]
image_transforms=None, # optional image augmentation
delta_timestamps=None, # optional temporal neighbors
tolerance_s=1e-4, # timestamp tolerance
)
dataset.num_frames # total frames across all sub-datasets
dataset.num_episodes # total episodes
dataset.meta # MultiDatasetMeta (stats, features, episodes)
dataset.dataset_weights # list of per-dataset weights
dataset.features # unified feature dict (union of all mapped features)
dataset.camera_keys # unified camera key list
```
### `WeightedEpisodeAwareSampler`
```python
from lerobot.datasets.sampler import WeightedEpisodeAwareSampler
sampler = WeightedEpisodeAwareSampler(
dataset_from_indices=dataset.meta.episodes["dataset_from_index"],
dataset_to_indices=dataset.meta.episodes["dataset_to_index"],
dataset_membership=dataset.meta.episodes["dataset_source"],
dataset_weights=dataset.dataset_weights,
shuffle=True,
)
```
### `DatasetTransformPipeline`
```python
from lerobot.datasets.transforms import DatasetTransformPipeline, DatasetTransformStepConfig
pipeline = DatasetTransformPipeline([
DatasetTransformStepConfig(type="pad_action", kwargs={"target_dim": 12}),
DatasetTransformStepConfig(type="pad_state", kwargs={"target_dim": 16}),
])
sample = pipeline(sample) # modifies the sample dict
```
+8
View File
@@ -175,6 +175,14 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero_plus = [
"lerobot[transformers-dep]",
"libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'",
"lerobot[scipy-dep]",
]
robomme = [
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
]
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
+27 -6
View File
@@ -16,18 +16,13 @@
from dataclasses import dataclass, field
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.transforms import DatasetTransformStepConfig, ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
@dataclass
class DatasetConfig:
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
@@ -37,6 +32,32 @@ class DatasetConfig:
streaming: bool = False
@dataclass
class SubDatasetConfig:
"""Configuration for a single dataset within a MultiDatasetConfig."""
repo_id: str
root: str | None = None
episodes: list[int] | None = None
revision: str | None = None
video_backend: str = field(default_factory=get_safe_default_codec)
weight: float = 1.0
# Maps dataset-local feature keys to unified policy keys.
# Keys not listed pass through unchanged.
feature_map: dict[str, str] = field(default_factory=dict)
# Per-dataset transforms applied after feature renaming, before cross-dataset padding.
transforms: list[DatasetTransformStepConfig] | None = None
@dataclass
class MultiDatasetConfig:
"""Configuration for training on multiple datasets jointly."""
datasets: list[SubDatasetConfig] = field(default_factory=list)
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
use_imagenet_stats: bool = True
@dataclass
class WandBConfig:
enable: bool = False
+6 -6
View File
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from lerobot.configs.default import DatasetConfig, EvalConfig, MultiDatasetConfig, PeftConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
@@ -35,7 +35,7 @@ TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
dataset: DatasetConfig | MultiDatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
@@ -129,8 +129,9 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if isinstance(self.dataset, MultiDatasetConfig):
if len(self.dataset.datasets) < 1:
raise ValueError("MultiDatasetConfig.datasets must contain at least one sub-dataset.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
@@ -143,8 +144,7 @@ class TrainPipelineConfig(HubMixin):
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
if self.use_rabc and not self.rabc_progress_path and isinstance(self.dataset, DatasetConfig):
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
+67 -51
View File
@@ -18,13 +18,14 @@ from pprint import pformat
import torch
from lerobot.configs.default import DatasetConfig, MultiDatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
@@ -68,66 +69,81 @@ def resolve_delta_timestamps(
return delta_timestamps
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset:
"""Handles the logic of setting up delta timestamps and image transforms before creating a dataset.
Args:
cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig.
Raises:
NotImplementedError: The MultiLeRobotDataset is currently deactivated.
def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | NewMultiLeRobotDataset:
"""Create a single or multi-dataset depending on the config type.
Returns:
LeRobotDataset | MultiLeRobotDataset
LeRobotDataset | NewMultiLeRobotDataset
"""
if isinstance(cfg.dataset, MultiDatasetConfig):
return _make_multi_dataset(cfg)
return _make_single_dataset(cfg)
def _make_single_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset:
ds_cfg: DatasetConfig = cfg.dataset # type: ignore[assignment]
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
ImageTransforms(ds_cfg.image_transforms) if ds_cfg.image_transforms.enable else None
)
ds_meta = LeRobotDatasetMetadata(ds_cfg.repo_id, root=ds_cfg.root, revision=ds_cfg.revision)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
if not cfg.dataset.streaming:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
tolerance_s=cfg.tolerance_s,
)
else:
dataset = StreamingLeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(
cfg.dataset.repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
if not ds_cfg.streaming:
dataset = LeRobotDataset(
ds_cfg.repo_id,
root=ds_cfg.root,
episodes=ds_cfg.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
revision=ds_cfg.revision,
video_backend=ds_cfg.video_backend,
tolerance_s=cfg.tolerance_s,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
else:
dataset = StreamingLeRobotDataset(
ds_cfg.repo_id,
root=ds_cfg.root,
episodes=ds_cfg.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=ds_cfg.revision,
max_num_shards=cfg.num_workers,
tolerance_s=cfg.tolerance_s,
)
if cfg.dataset.use_imagenet_stats:
if ds_cfg.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
for stats_type, stats_val in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
return dataset
def _make_multi_dataset(cfg: TrainPipelineConfig) -> NewMultiLeRobotDataset:
multi_cfg: MultiDatasetConfig = cfg.dataset # type: ignore[assignment]
image_transforms = (
ImageTransforms(multi_cfg.image_transforms) if multi_cfg.image_transforms.enable else None
)
dataset = NewMultiLeRobotDataset(
configs=multi_cfg.datasets,
image_transforms=image_transforms,
tolerance_s=cfg.tolerance_s,
)
logging.info(
"MultiLeRobotDataset created with %d sub-datasets:\n%s",
len(multi_cfg.datasets),
pformat(
{i: c.repo_id for i, c in enumerate(multi_cfg.datasets)},
indent=2,
),
)
if multi_cfg.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats_val in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32)
return dataset
+364
View File
@@ -0,0 +1,364 @@
"""MultiLeRobotDataset: joint training over heterogeneous LeRobot datasets.
Supports:
- Per-dataset feature mapping (rename keys to a unified namespace)
- Automatic zero-padding for features missing in some datasets
- Per-dataset transform pipelines
- Weighted sampling via dataset weights
- Aggregated stats across all sub-datasets
- A ``meta`` shim compatible with EpisodeAwareSampler and make_policy
"""
from __future__ import annotations
import logging
from collections.abc import Callable
import numpy as np
import torch
import torch.utils.data
from lerobot.configs.default import SubDatasetConfig
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.transforms import DatasetTransformPipeline
class MultiDatasetMeta:
"""Lightweight metadata shim that exposes the same interface as ``LeRobotDatasetMetadata``.
Built by aggregating the metadata of multiple sub-datasets after their
feature keys have been mapped to a unified namespace.
"""
def __init__(
self,
datasets: list[LeRobotDataset],
feature_maps: list[dict[str, str]],
):
self._datasets = datasets
self._feature_maps = feature_maps
self._unified_features = self._build_unified_features()
self._episodes = self._build_episodes()
self._stats = self._build_stats()
# ------------------------------------------------------------------
# Feature union
# ------------------------------------------------------------------
def _build_unified_features(self) -> dict[str, dict]:
"""Build feature dict as the *union* of all mapped feature keys."""
unified: dict[str, dict] = {}
for ds, fmap in zip(self._datasets, self._feature_maps):
for original_key, feat_info in ds.meta.features.items():
mapped_key = fmap.get(original_key, original_key)
if mapped_key not in unified:
unified[mapped_key] = dict(feat_info)
else:
existing_shape = tuple(unified[mapped_key]["shape"])
new_shape = tuple(feat_info["shape"])
if existing_shape != new_shape and unified[mapped_key]["dtype"] == feat_info["dtype"]:
logging.warning(
"Feature '%s' has shape %s in one dataset but %s in another. "
"The larger shape will be used (padding applied automatically).",
mapped_key,
existing_shape,
new_shape,
)
if np.prod(new_shape) > np.prod(existing_shape):
unified[mapped_key] = dict(feat_info)
return unified
# ------------------------------------------------------------------
# Episode metadata (global flat indexing)
# ------------------------------------------------------------------
def _build_episodes(self) -> dict[str, list]:
"""Concatenate episode boundaries across sub-datasets with frame offsets.
Produces the same column structure as ``load_episodes()`` so that
``EpisodeAwareSampler`` and ``WeightedEpisodeAwareSampler`` can consume it.
"""
from_indices: list[int] = []
to_indices: list[int] = []
dataset_source: list[int] = []
frame_offset = 0
for ds_idx, ds in enumerate(self._datasets):
eps = ds.meta.episodes
for ep in eps:
from_indices.append(ep["dataset_from_index"] + frame_offset)
to_indices.append(ep["dataset_to_index"] + frame_offset)
dataset_source.append(ds_idx)
frame_offset += ds.num_frames
return {
"dataset_from_index": from_indices,
"dataset_to_index": to_indices,
"dataset_source": dataset_source,
}
# ------------------------------------------------------------------
# Stats aggregation
# ------------------------------------------------------------------
def _build_stats(self) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats across sub-datasets using mapped feature keys."""
mapped_stats_list: list[dict[str, dict]] = []
for ds, fmap in zip(self._datasets, self._feature_maps):
reverse_map = {v: k for k, v in fmap.items()}
mapped: dict[str, dict] = {}
for unified_key in self._unified_features:
original_key = reverse_map.get(unified_key, unified_key)
if original_key in ds.meta.stats:
mapped[unified_key] = ds.meta.stats[original_key]
mapped_stats_list.append(mapped)
return aggregate_stats(mapped_stats_list)
# ------------------------------------------------------------------
# Properties matching LeRobotDatasetMetadata API
# ------------------------------------------------------------------
@property
def features(self) -> dict[str, dict]:
return self._unified_features
@property
def image_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
return [k for k, f in self._unified_features.items() if f["dtype"] in ("video", "image")]
@property
def names(self) -> dict[str, list | dict]:
return {k: f["names"] for k, f in self._unified_features.items()}
@property
def shapes(self) -> dict[str, tuple]:
return {k: tuple(f["shape"]) for k, f in self._unified_features.items()}
@property
def fps(self) -> int:
fps_values = {ds.meta.fps for ds in self._datasets}
if len(fps_values) > 1:
logging.warning("Sub-datasets have different FPS values: %s. Using the first.", fps_values)
return self._datasets[0].meta.fps
@property
def stats(self) -> dict[str, dict[str, np.ndarray]]:
return self._stats
@stats.setter
def stats(self, value: dict):
self._stats = value
@property
def episodes(self) -> dict[str, list]:
return self._episodes
@property
def total_episodes(self) -> int:
return sum(ds.meta.total_episodes for ds in self._datasets)
@property
def total_frames(self) -> int:
return sum(ds.meta.total_frames for ds in self._datasets)
@property
def total_tasks(self) -> int:
return sum(ds.meta.total_tasks for ds in self._datasets)
@property
def info(self) -> dict:
return {
"fps": self.fps,
"features": self._unified_features,
"total_episodes": self.total_episodes,
"total_frames": self.total_frames,
"total_tasks": self.total_tasks,
"codebase_version": "v3.0",
}
class NewMultiLeRobotDataset(torch.utils.data.Dataset):
"""Dataset that wraps multiple ``LeRobotDataset`` instances with feature mapping and padding.
Each sub-dataset can have different feature names and shapes. A per-dataset
``feature_map`` renames keys into a shared namespace. Features that a given
sub-dataset does not provide are zero-padded so every ``__getitem__`` returns
the full unified feature set.
"""
def __init__(
self,
configs: list[SubDatasetConfig],
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerance_s: float = 1e-4,
):
super().__init__()
self._configs = configs
self.image_transforms = image_transforms
self._datasets: list[LeRobotDataset] = []
self._feature_maps: list[dict[str, str]] = []
self._transform_pipelines: list[DatasetTransformPipeline | None] = []
self._weights: list[float] = []
for cfg in configs:
ds = LeRobotDataset(
repo_id=cfg.repo_id,
root=cfg.root,
episodes=cfg.episodes,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=tolerance_s,
revision=cfg.revision,
video_backend=cfg.video_backend,
)
self._datasets.append(ds)
self._feature_maps.append(cfg.feature_map or {})
self._transform_pipelines.append(
DatasetTransformPipeline(cfg.transforms) if cfg.transforms else None
)
self._weights.append(cfg.weight)
self._meta = MultiDatasetMeta(self._datasets, self._feature_maps)
# Pre-compute cumulative frame counts for fast index mapping.
self._cumulative_frames: list[int] = []
total = 0
for ds in self._datasets:
total += ds.num_frames
self._cumulative_frames.append(total)
# Build reverse maps (unified_key -> original_key) per dataset for padding.
self._reverse_maps: list[dict[str, str]] = []
for fmap in self._feature_maps:
self._reverse_maps.append({v: k for k, v in fmap.items()})
logging.info(
"MultiLeRobotDataset: %d sub-datasets, %d total frames, %d total episodes, "
"%d unified features",
len(self._datasets),
self.num_frames,
self.num_episodes,
len(self._meta.features),
)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
@property
def meta(self) -> MultiDatasetMeta:
return self._meta
@property
def dataset_weights(self) -> list[float]:
return self._weights
@property
def num_frames(self) -> int:
return self._cumulative_frames[-1] if self._cumulative_frames else 0
@property
def num_episodes(self) -> int:
return sum(ds.num_episodes for ds in self._datasets)
@property
def episodes(self) -> list[int] | None:
return None
@property
def fps(self) -> int:
return self._meta.fps
@property
def features(self) -> dict[str, dict]:
return self._meta.features
@property
def camera_keys(self) -> list[str]:
return self._meta.camera_keys
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
def _locate(self, idx: int) -> tuple[int, int]:
"""Map a global frame index to (dataset_index, local_index)."""
for ds_idx, cum in enumerate(self._cumulative_frames):
if idx < cum:
local = idx - (self._cumulative_frames[ds_idx - 1] if ds_idx > 0 else 0)
return ds_idx, local
raise IndexError(f"Index {idx} out of range (total {self.num_frames})")
def __len__(self) -> int:
return self.num_frames
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
ds_idx, local_idx = self._locate(idx)
item = self._datasets[ds_idx][local_idx]
# 1. Rename keys according to feature_map.
fmap = self._feature_maps[ds_idx]
if fmap:
renamed: dict[str, torch.Tensor] = {}
for key, value in item.items():
renamed[fmap.get(key, key)] = value
item = renamed
# 2. Apply per-dataset transform pipeline.
pipeline = self._transform_pipelines[ds_idx]
if pipeline is not None:
item = pipeline(item)
# 3. Pad missing features with zeros.
reverse_map = self._reverse_maps[ds_idx]
ds_features = self._datasets[ds_idx].meta.features
for unified_key, feat_info in self._meta.features.items():
if unified_key in item:
continue
original_key = reverse_map.get(unified_key, unified_key)
if original_key in ds_features:
continue
shape = tuple(feat_info["shape"])
dtype = feat_info["dtype"]
if dtype in ("video", "image"):
# Camera tensors are (C, H, W) after transforms.
c, h, w = (shape[2], shape[0], shape[1]) if len(shape) == 3 else (3, shape[0], shape[1])
item[unified_key] = torch.zeros(c, h, w, dtype=torch.float32)
elif dtype in ("float32", "float64"):
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
elif dtype in ("int32", "int64"):
item[unified_key] = torch.zeros(shape, dtype=torch.int64)
elif dtype == "bool":
item[unified_key] = torch.zeros(shape, dtype=torch.bool)
else:
item[unified_key] = torch.zeros(shape, dtype=torch.float32)
item[f"{unified_key}_is_pad"] = torch.tensor(True)
# 4. Tag which dataset this sample came from.
item["dataset_index"] = torch.tensor(ds_idx)
return item
def __repr__(self) -> str:
repo_ids = [c.repo_id for c in self._configs]
return (
f"NewMultiLeRobotDataset(\n"
f" repo_ids={repo_ids},\n"
f" num_frames={self.num_frames},\n"
f" num_episodes={self.num_episodes},\n"
f" unified_features={list(self._meta.features.keys())},\n"
f" weights={self._weights},\n"
f")"
)
+77
View File
@@ -59,3 +59,80 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class WeightedEpisodeAwareSampler:
"""Sampler that draws frames from multiple datasets according to per-dataset weights.
Each iteration first selects a sub-dataset proportionally to its weight, then
uniformly samples a frame from that sub-dataset's valid index set. Episode
boundary information is respected so that dropped frames are excluded.
Args:
dataset_from_indices: Start index for each episode (global, flat).
dataset_to_indices: End index (exclusive) for each episode (global, flat).
dataset_membership: Which sub-dataset each episode belongs to (integer id).
dataset_weights: Relative sampling weight per sub-dataset.
episode_indices_to_use: If given, only episodes in this set are used.
drop_n_first_frames: Frames to skip at the start of each episode.
drop_n_last_frames: Frames to skip at the end of each episode.
shuffle: Whether to shuffle within each epoch.
num_samples: How many samples per epoch. Defaults to total valid frames.
generator: Optional torch.Generator for reproducibility.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
dataset_membership: list[int],
dataset_weights: list[float],
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
num_samples: int | None = None,
generator: torch.Generator | None = None,
):
n_datasets = max(dataset_membership) + 1 if dataset_membership else 0
self._per_dataset_indices: list[list[int]] = [[] for _ in range(n_datasets)]
episodes_to_use = set(episode_indices_to_use) if episode_indices_to_use is not None else None
for ep_idx, (start, end, ds_id) in enumerate(
zip(dataset_from_indices, dataset_to_indices, dataset_membership, strict=True)
):
if episodes_to_use is not None and ep_idx not in episodes_to_use:
continue
frame_range = range(start + drop_n_first_frames, end - drop_n_last_frames)
self._per_dataset_indices[ds_id].extend(frame_range)
# Normalise weights (only over datasets that actually have frames).
raw_weights = list(dataset_weights[:n_datasets])
self._weights = torch.zeros(n_datasets)
for i, w in enumerate(raw_weights):
if len(self._per_dataset_indices[i]) > 0:
self._weights[i] = w
total_w = self._weights.sum()
if total_w > 0:
self._weights /= total_w
self._total_frames = sum(len(idx) for idx in self._per_dataset_indices)
self._num_samples = num_samples if num_samples is not None else self._total_frames
self.shuffle = shuffle
self._generator = generator
def __iter__(self) -> Iterator[int]:
if not self.shuffle:
for ds_indices in self._per_dataset_indices:
yield from ds_indices
return
for _ in range(self._num_samples):
ds_id = int(torch.multinomial(self._weights, 1, generator=self._generator).item())
indices = self._per_dataset_indices[ds_id]
local_idx = int(torch.randint(len(indices), (1,), generator=self._generator).item())
yield indices[local_idx]
def __len__(self) -> int:
return self._num_samples
+113
View File
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.nn.functional as F_nn
from torchvision.transforms import v2
from torchvision.transforms.v2 import (
Transform,
@@ -258,3 +260,114 @@ class ImageTransforms(Transform):
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)
# Per-dataset transform pipeline (used by MultiLeRobotDataset)
@dataclass
class DatasetTransformStepConfig:
"""Config for a single per-dataset transform step."""
type: str
kwargs: dict[str, Any] = field(default_factory=dict)
_DATASET_TRANSFORM_REGISTRY: dict[str, type["DatasetTransformStep"]] = {}
def register_dataset_transform(name: str):
"""Decorator to register a DatasetTransformStep by name."""
def decorator(cls: type["DatasetTransformStep"]) -> type["DatasetTransformStep"]:
_DATASET_TRANSFORM_REGISTRY[name] = cls
return cls
return decorator
class DatasetTransformStep:
"""Base class for a single per-dataset transform applied to a sample dict."""
def __call__(self, sample: dict) -> dict:
raise NotImplementedError
@register_dataset_transform("pad_action")
class PadAction(DatasetTransformStep):
"""Zero-pad the ``action`` tensor to *target_dim* along the last axis."""
def __init__(self, target_dim: int):
self.target_dim = target_dim
def __call__(self, sample: dict) -> dict:
action = sample.get("action")
if action is None:
return sample
current = action.shape[-1]
if current < self.target_dim:
sample["action"] = F_nn.pad(action, (0, self.target_dim - current))
return sample
@register_dataset_transform("pad_state")
class PadState(DatasetTransformStep):
"""Zero-pad ``observation.state`` to *target_dim* along the last axis."""
def __init__(self, target_dim: int):
self.target_dim = target_dim
def __call__(self, sample: dict) -> dict:
state = sample.get("observation.state")
if state is None:
return sample
current = state.shape[-1]
if current < self.target_dim:
sample["observation.state"] = F_nn.pad(state, (0, self.target_dim - current))
return sample
@register_dataset_transform("resize_images")
class ResizeImages(DatasetTransformStep):
"""Resize all image/video camera tensors to (height, width)."""
def __init__(self, height: int, width: int):
self.size = (height, width)
def __call__(self, sample: dict) -> dict:
for key in list(sample.keys()):
if not key.startswith("observation.images."):
continue
img = sample[key]
if not isinstance(img, torch.Tensor) or img.ndim < 3:
continue
sample[key] = F.resize(img, self.size, antialias=True)
return sample
class DatasetTransformPipeline:
"""Sequential pipeline of DatasetTransformStep instances."""
def __init__(self, configs: list[DatasetTransformStepConfig] | None = None):
self.steps: list[DatasetTransformStep] = []
if configs:
for cfg in configs:
self.steps.append(self._build(cfg))
@staticmethod
def _build(cfg: DatasetTransformStepConfig) -> DatasetTransformStep:
cls = _DATASET_TRANSFORM_REGISTRY.get(cfg.type)
if cls is None:
raise ValueError(
f"Unknown dataset transform '{cfg.type}'. "
f"Available: {list(_DATASET_TRANSFORM_REGISTRY)}"
)
return cls(**cfg.kwargs)
def __call__(self, sample: dict) -> dict:
for step in self.steps:
sample = step(sample)
return sample
def __repr__(self) -> str:
return f"DatasetTransformPipeline(steps={self.steps})"
+53
View File
@@ -346,6 +346,19 @@ class LiberoEnv(EnvConfig):
return kwargs
@EnvConfig.register_subclass("libero_plus")
@dataclass
class LiberoPlusEnv(LiberoEnv):
"""Alias config for LIBERO-plus benchmarks.
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
config reuses the existing LIBERO env implementation while making intent explicit
in experiment configs (`env.type=libero_plus`).
"""
task: str = "libero_spatial"
@EnvConfig.register_subclass("robocasa")
@dataclass
class RoboCasaEnv(EnvConfig):
@@ -392,6 +405,46 @@ class RoboCasaEnv(EnvConfig):
return {"split": self.split}
@EnvConfig.register_subclass("robomme")
@dataclass
class RoboMMEEnv(EnvConfig):
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
Uses BenchmarkEnvBuilder from the robomme package.
"""
task: str = "PickXtimes"
fps: int = 10
episode_length: int = 300
action_space: str = "joint_angle"
dataset_split: str = "test"
task_ids: list[int] | None = None
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
"front_rgb": f"{OBS_IMAGES}.front",
"wrist_rgb": f"{OBS_IMAGES}.wrist",
OBS_STATE: OBS_STATE,
}
)
@property
def gym_kwargs(self) -> dict:
return {
"action_space": self.action_space,
"dataset": self.dataset_split,
}
@EnvConfig.register_subclass("metaworld")
@dataclass
class MetaworldEnv(EnvConfig):
+29 -2
View File
@@ -20,7 +20,17 @@ import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv, RoboCasaEnv
from lerobot.envs.configs import (
AlohaEnv,
EnvConfig,
HubEnvConfig,
IsaaclabArenaEnv,
LiberoEnv,
LiberoPlusEnv,
PushtEnv,
RoboCasaEnv,
RoboMMEEnv,
)
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
@@ -35,8 +45,12 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "libero":
return LiberoEnv(**kwargs)
elif env_type == "libero_plus":
return LiberoPlusEnv(**kwargs)
elif env_type == "robocasa":
return RoboCasaEnv(**kwargs)
elif env_type == "robomme":
return RoboMMEEnv(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -72,7 +86,7 @@ def make_env_pre_post_processors(
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
@@ -201,6 +215,19 @@ def make_env(
env_cls=env_cls,
)
elif "robomme" in cfg.type:
from lerobot.envs.robomme import create_robomme_envs
return create_robomme_envs(
task=cfg.task,
n_envs=n_envs,
action_space_type=cfg.action_space,
dataset=cfg.dataset_split,
episode_length=cfg.episode_length,
task_ids=cfg.task_ids,
env_cls=env_cls,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
+8 -2
View File
@@ -26,8 +26,14 @@ import gymnasium as gym
import numpy as np
import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
try:
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
except ImportError:
# LIBERO-plus may be installed from source with an extra nested package level.
from libero.libero.libero import benchmark, get_libero_path
from libero.libero.libero.envs import OffScreenRenderEnv
from lerobot.processor import RobotObservation
+154
View File
@@ -0,0 +1,154 @@
"""RoboMME environment wrapper for LeRobot evaluation.
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
``VectorEnv`` suitable for ``lerobot_eval``.
RoboMME tasks:
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
"""
from __future__ import annotations
from typing import Any
import gymnasium as gym
import numpy as np
from gymnasium import spaces
ROBOMME_TASKS = [
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
]
class RoboMMEGymEnv(gym.Env):
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
metadata = {"render_modes": ["rgb_array"]}
def __init__(
self,
task: str = "PickXtimes",
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_idx: int = 0,
max_steps: int = 300,
):
super().__init__()
from robomme.env_record_wrapper import BenchmarkEnvBuilder
self._task = task
self._action_space_type = action_space_type
self._dataset = dataset
self._episode_idx = episode_idx
self._max_steps = max_steps
self._builder = BenchmarkEnvBuilder(
env_id=task,
dataset=dataset,
action_space=action_space_type,
gui_render=False,
max_steps=max_steps,
)
self._env = None
action_dim = 8 if action_space_type == "joint_angle" else 7
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
self.observation_space = spaces.Dict({
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
})
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self._env = self._builder.make_env_for_episode(
episode_idx=self._episode_idx, max_steps=self._max_steps,
)
obs, info = self._env.reset()
return self._convert_obs(obs), self._convert_info(info)
def step(self, action):
obs, reward, terminated, truncated, info = self._env.step(action)
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
status = info.get("status", "ongoing")
is_success = status == "success"
conv_info = self._convert_info(info)
conv_info["is_success"] = is_success
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
def _convert_obs(self, obs: dict) -> dict:
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
state = np.concatenate([joint, gripper])
return {
"front_rgb": front_rgb,
"wrist_rgb": wrist_rgb,
"state": state,
}
def _convert_info(self, info: dict) -> dict:
return {
"status": info.get("status", "ongoing"),
"task_goal": info.get("task_goal", ""),
}
def create_robomme_envs(
task: str,
n_envs: int = 1,
action_space_type: str = "joint_angle",
dataset: str = "test",
episode_length: int = 300,
task_ids: list[int] | None = None,
env_cls=None,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Create vectorized RoboMME environments for evaluation.
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
"""
if env_cls is None:
env_cls = gym.vector.SyncVectorEnv
if task_ids is None:
task_ids = [0]
suite_name = "robomme"
envs_by_task = {}
for task_id in task_ids:
def _make_one(ep_idx=task_id):
return RoboMMEGymEnv(
task=task,
action_space_type=action_space_type,
dataset=dataset,
episode_idx=ep_idx,
max_steps=episode_length,
)
vec = env_cls(
[_make_one for _ in range(n_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
envs_by_task[task_id] = vec
return {suite_name: envs_by_task}
+17 -4
View File
@@ -29,7 +29,8 @@ from tqdm import tqdm
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import close_envs
@@ -343,13 +344,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0)
if isinstance(dataset, NewMultiLeRobotDataset):
shuffle = False
sampler = WeightedEpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
dataset_membership=dataset.meta.episodes["dataset_source"],
dataset_weights=dataset.dataset_weights,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
elif drop_n_last > 0:
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
drop_n_last_frames=drop_n_last,
shuffle=True,
)
else:
@@ -360,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle and not cfg.dataset.streaming,
shuffle=shuffle and not getattr(cfg.dataset, "streaming", False),
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,