diff --git a/pyproject.toml b/pyproject.toml index 585af6f4b..b46868cd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,6 +180,9 @@ libero_plus = [ "libero @ git+https://github.com/sylvestf/LIBERO-plus.git@main ; sys_platform == 'linux'", "lerobot[scipy-dep]", ] +robomme = [ + "robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'", +] metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"] # All diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index dcb0cbd54..30eb8eb64 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -16,18 +16,13 @@ from dataclasses import dataclass, field -from lerobot.datasets.transforms import ImageTransformsConfig +from lerobot.datasets.transforms import DatasetTransformStepConfig, ImageTransformsConfig from lerobot.datasets.video_utils import get_safe_default_codec @dataclass class DatasetConfig: - # You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data - # keys common between the datasets are kept. Each dataset gets and additional transform that inserts the - # "dataset_index" into the returned item. The index mapping is made according to the order in which the - # datasets are provided. repo_id: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. root: str | None = None episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) @@ -37,6 +32,32 @@ class DatasetConfig: streaming: bool = False +@dataclass +class SubDatasetConfig: + """Configuration for a single dataset within a MultiDatasetConfig.""" + + repo_id: str + root: str | None = None + episodes: list[int] | None = None + revision: str | None = None + video_backend: str = field(default_factory=get_safe_default_codec) + weight: float = 1.0 + # Maps dataset-local feature keys to unified policy keys. + # Keys not listed pass through unchanged. + feature_map: dict[str, str] = field(default_factory=dict) + # Per-dataset transforms applied after feature renaming, before cross-dataset padding. + transforms: list[DatasetTransformStepConfig] | None = None + + +@dataclass +class MultiDatasetConfig: + """Configuration for training on multiple datasets jointly.""" + + datasets: list[SubDatasetConfig] = field(default_factory=list) + image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) + use_imagenet_stats: bool = True + + @dataclass class WandBConfig: enable: bool = False diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 9d20afc68..fa1607f87 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot import envs from lerobot.configs import parser -from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig +from lerobot.configs.default import DatasetConfig, EvalConfig, MultiDatasetConfig, PeftConfig, WandBConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.optim import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig @@ -35,7 +35,7 @@ TRAIN_CONFIG_NAME = "train_config.json" @dataclass class TrainPipelineConfig(HubMixin): - dataset: DatasetConfig + dataset: DatasetConfig | MultiDatasetConfig env: envs.EnvConfig | None = None policy: PreTrainedConfig | None = None # Set `dir` to where you would like to save all of the run outputs. If you run another training session @@ -129,8 +129,9 @@ class TrainPipelineConfig(HubMixin): train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" self.output_dir = Path("outputs/train") / train_dir - if isinstance(self.dataset.repo_id, list): - raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") + if isinstance(self.dataset, MultiDatasetConfig): + if len(self.dataset.datasets) < 1: + raise ValueError("MultiDatasetConfig.datasets must contain at least one sub-dataset.") if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") @@ -143,8 +144,7 @@ class TrainPipelineConfig(HubMixin): "'policy.repo_id' argument missing. Please specify it to push the model to the hub." ) - if self.use_rabc and not self.rabc_progress_path: - # Auto-detect from dataset path + if self.use_rabc and not self.rabc_progress_path and isinstance(self.dataset, DatasetConfig): repo_id = self.dataset.repo_id if self.dataset.root: self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet") diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 31e939809..8cc2ff186 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -18,13 +18,14 @@ from pprint import pformat import torch +from lerobot.configs.default import DatasetConfig, MultiDatasetConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.lerobot_dataset import ( LeRobotDataset, LeRobotDatasetMetadata, - MultiLeRobotDataset, ) +from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD @@ -68,66 +69,81 @@ def resolve_delta_timestamps( return delta_timestamps -def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset: - """Handles the logic of setting up delta timestamps and image transforms before creating a dataset. - - Args: - cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig. - - Raises: - NotImplementedError: The MultiLeRobotDataset is currently deactivated. +def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | NewMultiLeRobotDataset: + """Create a single or multi-dataset depending on the config type. Returns: - LeRobotDataset | MultiLeRobotDataset + LeRobotDataset | NewMultiLeRobotDataset """ + if isinstance(cfg.dataset, MultiDatasetConfig): + return _make_multi_dataset(cfg) + + return _make_single_dataset(cfg) + + +def _make_single_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset: + ds_cfg: DatasetConfig = cfg.dataset # type: ignore[assignment] image_transforms = ( - ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None + ImageTransforms(ds_cfg.image_transforms) if ds_cfg.image_transforms.enable else None ) + ds_meta = LeRobotDatasetMetadata(ds_cfg.repo_id, root=ds_cfg.root, revision=ds_cfg.revision) + delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) - if isinstance(cfg.dataset.repo_id, str): - ds_meta = LeRobotDatasetMetadata( - cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision - ) - delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) - if not cfg.dataset.streaming: - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - episodes=cfg.dataset.episodes, - delta_timestamps=delta_timestamps, - image_transforms=image_transforms, - revision=cfg.dataset.revision, - video_backend=cfg.dataset.video_backend, - tolerance_s=cfg.tolerance_s, - ) - else: - dataset = StreamingLeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - episodes=cfg.dataset.episodes, - delta_timestamps=delta_timestamps, - image_transforms=image_transforms, - revision=cfg.dataset.revision, - max_num_shards=cfg.num_workers, - tolerance_s=cfg.tolerance_s, - ) - else: - raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") - dataset = MultiLeRobotDataset( - cfg.dataset.repo_id, - # TODO(aliberts): add proper support for multi dataset - # delta_timestamps=delta_timestamps, + if not ds_cfg.streaming: + dataset = LeRobotDataset( + ds_cfg.repo_id, + root=ds_cfg.root, + episodes=ds_cfg.episodes, + delta_timestamps=delta_timestamps, image_transforms=image_transforms, - video_backend=cfg.dataset.video_backend, + revision=ds_cfg.revision, + video_backend=ds_cfg.video_backend, + tolerance_s=cfg.tolerance_s, ) - logging.info( - "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " - f"{pformat(dataset.repo_id_to_index, indent=2)}" + else: + dataset = StreamingLeRobotDataset( + ds_cfg.repo_id, + root=ds_cfg.root, + episodes=ds_cfg.episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=ds_cfg.revision, + max_num_shards=cfg.num_workers, + tolerance_s=cfg.tolerance_s, ) - if cfg.dataset.use_imagenet_stats: + if ds_cfg.use_imagenet_stats: for key in dataset.meta.camera_keys: - for stats_type, stats in IMAGENET_STATS.items(): - dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + for stats_type, stats_val in IMAGENET_STATS.items(): + dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32) + + return dataset + + +def _make_multi_dataset(cfg: TrainPipelineConfig) -> NewMultiLeRobotDataset: + multi_cfg: MultiDatasetConfig = cfg.dataset # type: ignore[assignment] + image_transforms = ( + ImageTransforms(multi_cfg.image_transforms) if multi_cfg.image_transforms.enable else None + ) + + dataset = NewMultiLeRobotDataset( + configs=multi_cfg.datasets, + image_transforms=image_transforms, + tolerance_s=cfg.tolerance_s, + ) + + logging.info( + "MultiLeRobotDataset created with %d sub-datasets:\n%s", + len(multi_cfg.datasets), + pformat( + {i: c.repo_id for i, c in enumerate(multi_cfg.datasets)}, + indent=2, + ), + ) + + if multi_cfg.use_imagenet_stats: + for key in dataset.meta.camera_keys: + for stats_type, stats_val in IMAGENET_STATS.items(): + dataset.meta.stats[key][stats_type] = torch.tensor(stats_val, dtype=torch.float32) return dataset diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py new file mode 100644 index 000000000..31e4f8285 --- /dev/null +++ b/src/lerobot/datasets/multi_dataset.py @@ -0,0 +1,364 @@ +"""MultiLeRobotDataset: joint training over heterogeneous LeRobot datasets. + +Supports: +- Per-dataset feature mapping (rename keys to a unified namespace) +- Automatic zero-padding for features missing in some datasets +- Per-dataset transform pipelines +- Weighted sampling via dataset weights +- Aggregated stats across all sub-datasets +- A ``meta`` shim compatible with EpisodeAwareSampler and make_policy +""" + +from __future__ import annotations + +import logging +from collections.abc import Callable + +import numpy as np +import torch +import torch.utils.data + +from lerobot.configs.default import SubDatasetConfig +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.transforms import DatasetTransformPipeline + + +class MultiDatasetMeta: + """Lightweight metadata shim that exposes the same interface as ``LeRobotDatasetMetadata``. + + Built by aggregating the metadata of multiple sub-datasets after their + feature keys have been mapped to a unified namespace. + """ + + def __init__( + self, + datasets: list[LeRobotDataset], + feature_maps: list[dict[str, str]], + ): + self._datasets = datasets + self._feature_maps = feature_maps + + self._unified_features = self._build_unified_features() + self._episodes = self._build_episodes() + self._stats = self._build_stats() + + # ------------------------------------------------------------------ + # Feature union + # ------------------------------------------------------------------ + + def _build_unified_features(self) -> dict[str, dict]: + """Build feature dict as the *union* of all mapped feature keys.""" + unified: dict[str, dict] = {} + for ds, fmap in zip(self._datasets, self._feature_maps): + for original_key, feat_info in ds.meta.features.items(): + mapped_key = fmap.get(original_key, original_key) + if mapped_key not in unified: + unified[mapped_key] = dict(feat_info) + else: + existing_shape = tuple(unified[mapped_key]["shape"]) + new_shape = tuple(feat_info["shape"]) + if existing_shape != new_shape and unified[mapped_key]["dtype"] == feat_info["dtype"]: + logging.warning( + "Feature '%s' has shape %s in one dataset but %s in another. " + "The larger shape will be used (padding applied automatically).", + mapped_key, + existing_shape, + new_shape, + ) + if np.prod(new_shape) > np.prod(existing_shape): + unified[mapped_key] = dict(feat_info) + return unified + + # ------------------------------------------------------------------ + # Episode metadata (global flat indexing) + # ------------------------------------------------------------------ + + def _build_episodes(self) -> dict[str, list]: + """Concatenate episode boundaries across sub-datasets with frame offsets. + + Produces the same column structure as ``load_episodes()`` so that + ``EpisodeAwareSampler`` and ``WeightedEpisodeAwareSampler`` can consume it. + """ + from_indices: list[int] = [] + to_indices: list[int] = [] + dataset_source: list[int] = [] + + frame_offset = 0 + for ds_idx, ds in enumerate(self._datasets): + eps = ds.meta.episodes + for ep in eps: + from_indices.append(ep["dataset_from_index"] + frame_offset) + to_indices.append(ep["dataset_to_index"] + frame_offset) + dataset_source.append(ds_idx) + frame_offset += ds.num_frames + + return { + "dataset_from_index": from_indices, + "dataset_to_index": to_indices, + "dataset_source": dataset_source, + } + + # ------------------------------------------------------------------ + # Stats aggregation + # ------------------------------------------------------------------ + + def _build_stats(self) -> dict[str, dict[str, np.ndarray]]: + """Aggregate stats across sub-datasets using mapped feature keys.""" + mapped_stats_list: list[dict[str, dict]] = [] + for ds, fmap in zip(self._datasets, self._feature_maps): + reverse_map = {v: k for k, v in fmap.items()} + mapped: dict[str, dict] = {} + for unified_key in self._unified_features: + original_key = reverse_map.get(unified_key, unified_key) + if original_key in ds.meta.stats: + mapped[unified_key] = ds.meta.stats[original_key] + mapped_stats_list.append(mapped) + + return aggregate_stats(mapped_stats_list) + + # ------------------------------------------------------------------ + # Properties matching LeRobotDatasetMetadata API + # ------------------------------------------------------------------ + + @property + def features(self) -> dict[str, dict]: + return self._unified_features + + @property + def image_keys(self) -> list[str]: + return [k for k, f in self._unified_features.items() if f["dtype"] == "image"] + + @property + def video_keys(self) -> list[str]: + return [k for k, f in self._unified_features.items() if f["dtype"] == "video"] + + @property + def camera_keys(self) -> list[str]: + return [k for k, f in self._unified_features.items() if f["dtype"] in ("video", "image")] + + @property + def names(self) -> dict[str, list | dict]: + return {k: f["names"] for k, f in self._unified_features.items()} + + @property + def shapes(self) -> dict[str, tuple]: + return {k: tuple(f["shape"]) for k, f in self._unified_features.items()} + + @property + def fps(self) -> int: + fps_values = {ds.meta.fps for ds in self._datasets} + if len(fps_values) > 1: + logging.warning("Sub-datasets have different FPS values: %s. Using the first.", fps_values) + return self._datasets[0].meta.fps + + @property + def stats(self) -> dict[str, dict[str, np.ndarray]]: + return self._stats + + @stats.setter + def stats(self, value: dict): + self._stats = value + + @property + def episodes(self) -> dict[str, list]: + return self._episodes + + @property + def total_episodes(self) -> int: + return sum(ds.meta.total_episodes for ds in self._datasets) + + @property + def total_frames(self) -> int: + return sum(ds.meta.total_frames for ds in self._datasets) + + @property + def total_tasks(self) -> int: + return sum(ds.meta.total_tasks for ds in self._datasets) + + @property + def info(self) -> dict: + return { + "fps": self.fps, + "features": self._unified_features, + "total_episodes": self.total_episodes, + "total_frames": self.total_frames, + "total_tasks": self.total_tasks, + "codebase_version": "v3.0", + } + + +class NewMultiLeRobotDataset(torch.utils.data.Dataset): + """Dataset that wraps multiple ``LeRobotDataset`` instances with feature mapping and padding. + + Each sub-dataset can have different feature names and shapes. A per-dataset + ``feature_map`` renames keys into a shared namespace. Features that a given + sub-dataset does not provide are zero-padded so every ``__getitem__`` returns + the full unified feature set. + """ + + def __init__( + self, + configs: list[SubDatasetConfig], + image_transforms: Callable | None = None, + delta_timestamps: dict[str, list[float]] | None = None, + tolerance_s: float = 1e-4, + ): + super().__init__() + self._configs = configs + self.image_transforms = image_transforms + + self._datasets: list[LeRobotDataset] = [] + self._feature_maps: list[dict[str, str]] = [] + self._transform_pipelines: list[DatasetTransformPipeline | None] = [] + self._weights: list[float] = [] + + for cfg in configs: + ds = LeRobotDataset( + repo_id=cfg.repo_id, + root=cfg.root, + episodes=cfg.episodes, + image_transforms=image_transforms, + delta_timestamps=delta_timestamps, + tolerance_s=tolerance_s, + revision=cfg.revision, + video_backend=cfg.video_backend, + ) + self._datasets.append(ds) + self._feature_maps.append(cfg.feature_map or {}) + self._transform_pipelines.append( + DatasetTransformPipeline(cfg.transforms) if cfg.transforms else None + ) + self._weights.append(cfg.weight) + + self._meta = MultiDatasetMeta(self._datasets, self._feature_maps) + + # Pre-compute cumulative frame counts for fast index mapping. + self._cumulative_frames: list[int] = [] + total = 0 + for ds in self._datasets: + total += ds.num_frames + self._cumulative_frames.append(total) + + # Build reverse maps (unified_key -> original_key) per dataset for padding. + self._reverse_maps: list[dict[str, str]] = [] + for fmap in self._feature_maps: + self._reverse_maps.append({v: k for k, v in fmap.items()}) + + logging.info( + "MultiLeRobotDataset: %d sub-datasets, %d total frames, %d total episodes, " + "%d unified features", + len(self._datasets), + self.num_frames, + self.num_episodes, + len(self._meta.features), + ) + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + @property + def meta(self) -> MultiDatasetMeta: + return self._meta + + @property + def dataset_weights(self) -> list[float]: + return self._weights + + @property + def num_frames(self) -> int: + return self._cumulative_frames[-1] if self._cumulative_frames else 0 + + @property + def num_episodes(self) -> int: + return sum(ds.num_episodes for ds in self._datasets) + + @property + def episodes(self) -> list[int] | None: + return None + + @property + def fps(self) -> int: + return self._meta.fps + + @property + def features(self) -> dict[str, dict]: + return self._meta.features + + @property + def camera_keys(self) -> list[str]: + return self._meta.camera_keys + + # ------------------------------------------------------------------ + # Indexing + # ------------------------------------------------------------------ + + def _locate(self, idx: int) -> tuple[int, int]: + """Map a global frame index to (dataset_index, local_index).""" + for ds_idx, cum in enumerate(self._cumulative_frames): + if idx < cum: + local = idx - (self._cumulative_frames[ds_idx - 1] if ds_idx > 0 else 0) + return ds_idx, local + raise IndexError(f"Index {idx} out of range (total {self.num_frames})") + + def __len__(self) -> int: + return self.num_frames + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + ds_idx, local_idx = self._locate(idx) + item = self._datasets[ds_idx][local_idx] + + # 1. Rename keys according to feature_map. + fmap = self._feature_maps[ds_idx] + if fmap: + renamed: dict[str, torch.Tensor] = {} + for key, value in item.items(): + renamed[fmap.get(key, key)] = value + item = renamed + + # 2. Apply per-dataset transform pipeline. + pipeline = self._transform_pipelines[ds_idx] + if pipeline is not None: + item = pipeline(item) + + # 3. Pad missing features with zeros. + reverse_map = self._reverse_maps[ds_idx] + ds_features = self._datasets[ds_idx].meta.features + for unified_key, feat_info in self._meta.features.items(): + if unified_key in item: + continue + original_key = reverse_map.get(unified_key, unified_key) + if original_key in ds_features: + continue + shape = tuple(feat_info["shape"]) + dtype = feat_info["dtype"] + if dtype in ("video", "image"): + # Camera tensors are (C, H, W) after transforms. + c, h, w = (shape[2], shape[0], shape[1]) if len(shape) == 3 else (3, shape[0], shape[1]) + item[unified_key] = torch.zeros(c, h, w, dtype=torch.float32) + elif dtype in ("float32", "float64"): + item[unified_key] = torch.zeros(shape, dtype=torch.float32) + elif dtype in ("int32", "int64"): + item[unified_key] = torch.zeros(shape, dtype=torch.int64) + elif dtype == "bool": + item[unified_key] = torch.zeros(shape, dtype=torch.bool) + else: + item[unified_key] = torch.zeros(shape, dtype=torch.float32) + item[f"{unified_key}_is_pad"] = torch.tensor(True) + + # 4. Tag which dataset this sample came from. + item["dataset_index"] = torch.tensor(ds_idx) + return item + + def __repr__(self) -> str: + repo_ids = [c.repo_id for c in self._configs] + return ( + f"NewMultiLeRobotDataset(\n" + f" repo_ids={repo_ids},\n" + f" num_frames={self.num_frames},\n" + f" num_episodes={self.num_episodes},\n" + f" unified_features={list(self._meta.features.keys())},\n" + f" weights={self._weights},\n" + f")" + ) diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index d0bb20c27..9d5365077 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -59,3 +59,80 @@ class EpisodeAwareSampler: def __len__(self) -> int: return len(self.indices) + + +class WeightedEpisodeAwareSampler: + """Sampler that draws frames from multiple datasets according to per-dataset weights. + + Each iteration first selects a sub-dataset proportionally to its weight, then + uniformly samples a frame from that sub-dataset's valid index set. Episode + boundary information is respected so that dropped frames are excluded. + + Args: + dataset_from_indices: Start index for each episode (global, flat). + dataset_to_indices: End index (exclusive) for each episode (global, flat). + dataset_membership: Which sub-dataset each episode belongs to (integer id). + dataset_weights: Relative sampling weight per sub-dataset. + episode_indices_to_use: If given, only episodes in this set are used. + drop_n_first_frames: Frames to skip at the start of each episode. + drop_n_last_frames: Frames to skip at the end of each episode. + shuffle: Whether to shuffle within each epoch. + num_samples: How many samples per epoch. Defaults to total valid frames. + generator: Optional torch.Generator for reproducibility. + """ + + def __init__( + self, + dataset_from_indices: list[int], + dataset_to_indices: list[int], + dataset_membership: list[int], + dataset_weights: list[float], + episode_indices_to_use: list | None = None, + drop_n_first_frames: int = 0, + drop_n_last_frames: int = 0, + shuffle: bool = False, + num_samples: int | None = None, + generator: torch.Generator | None = None, + ): + n_datasets = max(dataset_membership) + 1 if dataset_membership else 0 + self._per_dataset_indices: list[list[int]] = [[] for _ in range(n_datasets)] + + episodes_to_use = set(episode_indices_to_use) if episode_indices_to_use is not None else None + + for ep_idx, (start, end, ds_id) in enumerate( + zip(dataset_from_indices, dataset_to_indices, dataset_membership, strict=True) + ): + if episodes_to_use is not None and ep_idx not in episodes_to_use: + continue + frame_range = range(start + drop_n_first_frames, end - drop_n_last_frames) + self._per_dataset_indices[ds_id].extend(frame_range) + + # Normalise weights (only over datasets that actually have frames). + raw_weights = list(dataset_weights[:n_datasets]) + self._weights = torch.zeros(n_datasets) + for i, w in enumerate(raw_weights): + if len(self._per_dataset_indices[i]) > 0: + self._weights[i] = w + total_w = self._weights.sum() + if total_w > 0: + self._weights /= total_w + + self._total_frames = sum(len(idx) for idx in self._per_dataset_indices) + self._num_samples = num_samples if num_samples is not None else self._total_frames + self.shuffle = shuffle + self._generator = generator + + def __iter__(self) -> Iterator[int]: + if not self.shuffle: + for ds_indices in self._per_dataset_indices: + yield from ds_indices + return + + for _ in range(self._num_samples): + ds_id = int(torch.multinomial(self._weights, 1, generator=self._generator).item()) + indices = self._per_dataset_indices[ds_id] + local_idx = int(torch.randint(len(indices), (1,), generator=self._generator).item()) + yield indices[local_idx] + + def __len__(self) -> int: + return self._num_samples diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index 5240619cb..9e52dd243 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import logging from collections.abc import Callable, Sequence from dataclasses import dataclass, field from typing import Any import torch +import torch.nn.functional as F_nn from torchvision.transforms import v2 from torchvision.transforms.v2 import ( Transform, @@ -258,3 +260,116 @@ class ImageTransforms(Transform): def forward(self, *inputs: Any) -> Any: return self.tf(*inputs) + + +# --------------------------------------------------------------------------- +# Per-dataset transform pipeline (used by MultiLeRobotDataset) +# --------------------------------------------------------------------------- + + +@dataclass +class DatasetTransformStepConfig: + """Config for a single per-dataset transform step.""" + + type: str + kwargs: dict[str, Any] = field(default_factory=dict) + + +_DATASET_TRANSFORM_REGISTRY: dict[str, type["DatasetTransformStep"]] = {} + + +def register_dataset_transform(name: str): + """Decorator to register a DatasetTransformStep by name.""" + + def decorator(cls: type["DatasetTransformStep"]) -> type["DatasetTransformStep"]: + _DATASET_TRANSFORM_REGISTRY[name] = cls + return cls + + return decorator + + +class DatasetTransformStep: + """Base class for a single per-dataset transform applied to a sample dict.""" + + def __call__(self, sample: dict) -> dict: + raise NotImplementedError + + +@register_dataset_transform("pad_action") +class PadAction(DatasetTransformStep): + """Zero-pad the ``action`` tensor to *target_dim* along the last axis.""" + + def __init__(self, target_dim: int): + self.target_dim = target_dim + + def __call__(self, sample: dict) -> dict: + action = sample.get("action") + if action is None: + return sample + current = action.shape[-1] + if current < self.target_dim: + sample["action"] = F_nn.pad(action, (0, self.target_dim - current)) + return sample + + +@register_dataset_transform("pad_state") +class PadState(DatasetTransformStep): + """Zero-pad ``observation.state`` to *target_dim* along the last axis.""" + + def __init__(self, target_dim: int): + self.target_dim = target_dim + + def __call__(self, sample: dict) -> dict: + state = sample.get("observation.state") + if state is None: + return sample + current = state.shape[-1] + if current < self.target_dim: + sample["observation.state"] = F_nn.pad(state, (0, self.target_dim - current)) + return sample + + +@register_dataset_transform("resize_images") +class ResizeImages(DatasetTransformStep): + """Resize all image/video camera tensors to (height, width).""" + + def __init__(self, height: int, width: int): + self.size = (height, width) + + def __call__(self, sample: dict) -> dict: + for key in list(sample.keys()): + if not key.startswith("observation.images."): + continue + img = sample[key] + if not isinstance(img, torch.Tensor) or img.ndim < 3: + continue + sample[key] = F.resize(img, self.size, antialias=True) + return sample + + +class DatasetTransformPipeline: + """Sequential pipeline of DatasetTransformStep instances.""" + + def __init__(self, configs: list[DatasetTransformStepConfig] | None = None): + self.steps: list[DatasetTransformStep] = [] + if configs: + for cfg in configs: + self.steps.append(self._build(cfg)) + + @staticmethod + def _build(cfg: DatasetTransformStepConfig) -> DatasetTransformStep: + cls = _DATASET_TRANSFORM_REGISTRY.get(cfg.type) + if cls is None: + raise ValueError( + f"Unknown dataset transform '{cfg.type}'. " + f"Available: {list(_DATASET_TRANSFORM_REGISTRY)}" + ) + return cls(**cfg.kwargs) + + def __call__(self, sample: dict) -> dict: + for step in self.steps: + sample = step(sample) + return sample + + def __repr__(self) -> str: + return f"DatasetTransformPipeline(steps={self.steps})" diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 2a73dd272..bdc16bc73 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -405,6 +405,46 @@ class RoboCasaEnv(EnvConfig): return {"split": self.split} +@EnvConfig.register_subclass("robomme") +@dataclass +class RoboMMEEnv(EnvConfig): + """RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN). + + 16 tasks across 4 suites: Counting, Permanence, Reference, Imitation. + Uses BenchmarkEnvBuilder from the robomme package. + """ + + task: str = "PickXtimes" + fps: int = 10 + episode_length: int = 300 + action_space: str = "joint_angle" + dataset_split: str = "test" + task_ids: list[int] | None = None + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)), + "front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)), + "wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + ACTION: ACTION, + "front_rgb": f"{OBS_IMAGES}.front", + "wrist_rgb": f"{OBS_IMAGES}.wrist", + OBS_STATE: OBS_STATE, + } + ) + + @property + def gym_kwargs(self) -> dict: + return { + "action_space": self.action_space, + "dataset": self.dataset_split, + } + + @EnvConfig.register_subclass("metaworld") @dataclass class MetaworldEnv(EnvConfig): diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 365f74088..2810e4025 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -29,6 +29,7 @@ from lerobot.envs.configs import ( LiberoPlusEnv, PushtEnv, RoboCasaEnv, + RoboMMEEnv, ) from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result from lerobot.policies.xvla.configuration_xvla import XVLAConfig @@ -48,6 +49,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return LiberoPlusEnv(**kwargs) elif env_type == "robocasa": return RoboCasaEnv(**kwargs) + elif env_type == "robomme": + return RoboMMEEnv(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") @@ -212,6 +215,19 @@ def make_env( env_cls=env_cls, ) + elif "robomme" in cfg.type: + from lerobot.envs.robomme import create_robomme_envs + + return create_robomme_envs( + task=cfg.task, + n_envs=n_envs, + action_space_type=cfg.action_space, + dataset=cfg.dataset_split, + episode_length=cfg.episode_length, + task_ids=cfg.task_ids, + env_cls=env_cls, + ) + elif "metaworld" in cfg.type: from lerobot.envs.metaworld import create_metaworld_envs diff --git a/src/lerobot/envs/robomme.py b/src/lerobot/envs/robomme.py new file mode 100644 index 000000000..e963edf6e --- /dev/null +++ b/src/lerobot/envs/robomme.py @@ -0,0 +1,154 @@ +"""RoboMME environment wrapper for LeRobot evaluation. + +Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible +``VectorEnv`` suitable for ``lerobot_eval``. + +RoboMME tasks: + Counting: BinFill, PickXtimes, SwingXtimes, StopCube + Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap + Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder + Imitation: MoveCube, InsertPeg, PatternLock, RouteStick + +Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark) +""" + +from __future__ import annotations + +from typing import Any + +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +ROBOMME_TASKS = [ + "BinFill", "PickXtimes", "SwingXtimes", "StopCube", + "VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap", + "PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder", + "MoveCube", "InsertPeg", "PatternLock", "RouteStick", +] + + +class RoboMMEGymEnv(gym.Env): + """Thin Gymnasium wrapper around a single RoboMME episode env.""" + + metadata = {"render_modes": ["rgb_array"]} + + def __init__( + self, + task: str = "PickXtimes", + action_space_type: str = "joint_angle", + dataset: str = "test", + episode_idx: int = 0, + max_steps: int = 300, + ): + super().__init__() + from robomme.env_record_wrapper import BenchmarkEnvBuilder + + self._task = task + self._action_space_type = action_space_type + self._dataset = dataset + self._episode_idx = episode_idx + self._max_steps = max_steps + + self._builder = BenchmarkEnvBuilder( + env_id=task, + dataset=dataset, + action_space=action_space_type, + gui_render=False, + max_steps=max_steps, + ) + self._env = None + + action_dim = 8 if action_space_type == "joint_angle" else 7 + self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32) + self.observation_space = spaces.Dict({ + "front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8), + "wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8), + "state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32), + }) + + def reset(self, *, seed=None, options=None): + super().reset(seed=seed) + self._env = self._builder.make_env_for_episode( + episode_idx=self._episode_idx, max_steps=self._max_steps, + ) + obs, info = self._env.reset() + return self._convert_obs(obs), self._convert_info(info) + + def step(self, action): + obs, reward, terminated, truncated, info = self._env.step(action) + + terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated) + truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated) + + status = info.get("status", "ongoing") + is_success = status == "success" + conv_info = self._convert_info(info) + conv_info["is_success"] = is_success + + return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info + + def _convert_obs(self, obs: dict) -> dict: + front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"] + wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"] + joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"] + gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"] + + front_rgb = np.asarray(front_rgb, dtype=np.uint8) + wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8) + joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7] + gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1] + state = np.concatenate([joint, gripper]) + + return { + "front_rgb": front_rgb, + "wrist_rgb": wrist_rgb, + "state": state, + } + + def _convert_info(self, info: dict) -> dict: + return { + "status": info.get("status", "ongoing"), + "task_goal": info.get("task_goal", ""), + } + + +def create_robomme_envs( + task: str, + n_envs: int = 1, + action_space_type: str = "joint_angle", + dataset: str = "test", + episode_length: int = 300, + task_ids: list[int] | None = None, + env_cls=None, +) -> dict[str, dict[int, gym.vector.VectorEnv]]: + """Create vectorized RoboMME environments for evaluation. + + Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format. + """ + if env_cls is None: + env_cls = gym.vector.SyncVectorEnv + + if task_ids is None: + task_ids = [0] + + suite_name = "robomme" + envs_by_task = {} + + for task_id in task_ids: + def _make_one(ep_idx=task_id): + return RoboMMEGymEnv( + task=task, + action_space_type=action_space_type, + dataset=dataset, + episode_idx=ep_idx, + max_steps=episode_length, + ) + + vec = env_cls( + [_make_one for _ in range(n_envs)], + autoreset_mode=gym.vector.AutoresetMode.SAME_STEP, + ) + envs_by_task[task_id] = vec + + return {suite_name: envs_by_task} diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 1fed3bee4..913b87264 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -29,7 +29,8 @@ from tqdm import tqdm from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset -from lerobot.datasets.sampler import EpisodeAwareSampler +from lerobot.datasets.multi_dataset import NewMultiLeRobotDataset +from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env, make_env_pre_post_processors from lerobot.envs.utils import close_envs @@ -343,13 +344,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training - if hasattr(cfg.policy, "drop_n_last_frames"): + drop_n_last = getattr(cfg.policy, "drop_n_last_frames", 0) + + if isinstance(dataset, NewMultiLeRobotDataset): + shuffle = False + sampler = WeightedEpisodeAwareSampler( + dataset.meta.episodes["dataset_from_index"], + dataset.meta.episodes["dataset_to_index"], + dataset_membership=dataset.meta.episodes["dataset_source"], + dataset_weights=dataset.dataset_weights, + drop_n_last_frames=drop_n_last, + shuffle=True, + ) + elif drop_n_last > 0: shuffle = False sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, - drop_n_last_frames=cfg.policy.drop_n_last_frames, + drop_n_last_frames=drop_n_last, shuffle=True, ) else: @@ -360,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): dataset, num_workers=cfg.num_workers, batch_size=cfg.batch_size, - shuffle=shuffle and not cfg.dataset.streaming, + shuffle=shuffle and not getattr(cfg.dataset, "streaming", False), sampler=sampler, pin_memory=device.type == "cuda", drop_last=False,