diff --git a/examples/tester.py b/examples/tester.py new file mode 100644 index 000000000..132109f5c --- /dev/null +++ b/examples/tester.py @@ -0,0 +1,66 @@ +from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset + +REPO_A = "lerobot/pusht" +REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id + +feature_keys_mapping = { + REPO_A: { # pusht (1 camera, 2-dim) + "action": "actions", + "observation.state": "obs_state", + "observation.image": "obs_image.cam_high", + }, + REPO_B: { # dual arm (3 cameras, 14-dim) + "action": "actions", + "observation.state": "obs_state", + "observation.images.cam_high": "obs_image.cam_high", + "observation.images.cam_left_wrist": "obs_image.cam_left_wrist", + "observation.images.cam_right_wrist": "obs_image.cam_right_wrist", + }, +} + +from torchvision.transforms.v2 import Compose, ToImage, Resize +image_tf = Compose([ + ToImage(), # converts to tensor if needed + Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640) +]) + +from torch.utils.data import DataLoader + +dataset = MultiLeRobotDataset( + repo_ids=[REPO_A, REPO_B], + image_transforms=image_tf, # ensures same HxW + feature_keys_mapping=feature_keys_mapping, + train_on_all_features=True, # keep union of cameras; zero-fill missing + # optional: override if you want fixed maxima; else inferred: + # max_action_dim=14, + # max_state_dim=14, + max_action_dim=14, + max_state_dim=14, + max_image_dim=224, + ignore_keys=[ + "next.*", # drop reward/done/success + "index", + "timestamp", + "videos/*", # drop all video metadata + "observation.effort", # 👈 drop effort everywhere + ], +) +breakpoint() +loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True) +for _ in range(100): + batch = next(iter(loader)) + +breakpoint() +# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14) +assert batch["actions"].shape[-1] == 14 +assert batch["obs_state"].shape[-1] == 14 +assert batch["actions_padding_mask"].shape[-1] == 14 +assert batch["obs_state_padding_mask"].shape[-1] == 14 + +# cameras: all canonical keys exist; pusht will have wrists zero-filled +for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]: + assert cam in batch + assert f"{cam}_is_pad" in batch + # images should all be 3x224x224 (or your transform’s size) + img = batch[cam] + assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader diff --git a/examples/tester.sh b/examples/tester.sh new file mode 100644 index 000000000..6e6e95db0 --- /dev/null +++ b/examples/tester.sh @@ -0,0 +1,16 @@ +# storage / caches +RAID=/raid/jade +export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers +export HF_HOME=$RAID/.cache/huggingface +export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets +export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot +export WANDB_CACHE_DIR=$RAID/.cache/wandb +export TMPDIR=$RAID/.cache/tmp +mkdir -p $TMPDIR +export WANDB_MODE=offline +# export HF_DATASETS_OFFLINE=1 +# export HF_HUB_OFFLINE=1 +export TOKENIZERS_PARALLELISM=false +export MUJOCO_GL=egl + +python examples/tester.py \ No newline at end of file diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index bfe7b18b4..b13e5564f 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -174,3 +174,79 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np aggregated_stats[key] = aggregate_feature_stats(stats_with_key) return aggregated_stats + +import numpy as np + +def aggregate_stats_multi( + stats_list: list[dict[str, dict]], + max_action_dim: int | None = None, + max_state_dim: int | None = None, +) -> dict[str, dict[str, np.ndarray]]: + """Aggregate stats from multiple compute_stats outputs into a single set of stats. + + Supports heterogeneous robots by padding action/state stats to the max dim. + The final stats will have the union of all data keys from each of the stats dicts. + + - new_min = elementwise min across datasets + - new_max = elementwise max across datasets + - new_mean = weighted mean (by count) + - new_std = recomputed from total variance + """ + + data_keys = {key for stats in stats_list for key in stats} + aggregated_stats = {key: {} for key in data_keys} + + def _pad(arr: np.ndarray, target: int) -> np.ndarray: + if arr.ndim == 0: # scalar + return arr + if target is None or target <= 0 or arr.shape[-1] == target: + return arr + pad_width = [(0, 0)] * arr.ndim + pad_width[-1] = (0, target - arr.shape[-1]) + return np.pad(arr, pad_width, mode="constant") + + for key in data_keys: + stats_with_key = [stats[key] for stats in stats_list if key in stats] + + # decide if this key should be padded + target_dim = None + if "action" in key and max_action_dim: + target_dim = max_action_dim + elif "state" in key and max_state_dim: + target_dim = max_state_dim + + padded = [] + counts = [] + for s in stats_with_key: + mean = _pad(np.array(s["mean"]), target_dim) + std = _pad(np.array(s["std"]), target_dim) + min_ = _pad(np.array(s["min"]), target_dim) + max_ = _pad(np.array(s["max"]), target_dim) + count = s.get("count", 1) + + padded.append(dict(mean=mean, std=std, min=min_, max=max_, count=count)) + counts.append(count) + + counts = np.array(counts, dtype=np.float64) + total_count = counts.sum() + + means = np.stack([p["mean"] for p in padded]) + stds = np.stack([p["std"] for p in padded]) + mins = np.stack([p["min"] for p in padded]) + maxs = np.stack([p["max"] for p in padded]) + + # weighted mean (broadcast weights properly) + new_mean = np.average(means, axis=0, weights=counts) + new_var = np.average(stds**2 + (means - new_mean)**2, axis=0, weights=counts) + + new_std = np.sqrt(new_var) + + aggregated_stats[key] = { + "min": mins.min(axis=0), + "max": maxs.max(axis=0), + "mean": new_mean, + "std": new_std, + "count": int(total_count), + } + + return aggregated_stats diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 4ac7a841c..7f2983f9b 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -31,6 +31,7 @@ import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError +from collections import defaultdict from lerobot.constants import HF_LEROBOT_HOME from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.datasets.image_writer import AsyncImageWriter, write_image @@ -81,7 +82,12 @@ from lerobot.datasets.video_utils import ( ) CODEBASE_VERSION = "v3.0" - +OBS_IMAGE = "observation.image" +OBS_IMAGE_2 = "observation.image_2" +OBS_IMAGE_3 = "observation.image_3" +OBS_STATE = "observation.state" +OBS_ENV_STATE = "observation.env_state" +ACTION = "action" class LeRobotDatasetMetadata: def __init__( @@ -1322,13 +1328,139 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() return obj +ROBOT_TYPE_KEYS_MAPPING = { + "lerobot/stanford_hydra_dataset": "static_single_arm", + "lerobot/iamlab_cmu_pickup_insert": "static_single_arm", + "lerobot/berkeley_fanuc_manipulation": "static_single_arm", + "lerobot/toto": "static_single_arm", + "lerobot/roboturk": "static_single_arm", + "lerobot/jaco_play": "static_single_arm", + "lerobot/taco_play": "static_single_arm_7statedim", +} +class MultiLeRobotDatasetMeta: + def __init__( + self, + datasets: list[LeRobotDataset], + repo_ids: list[str], + keys_to_max_dim: dict[str, int], + train_on_all_features: bool = False, + ): + self.repo_ids = repo_ids + self.keys_to_max_dim = keys_to_max_dim + self.train_on_all_features = train_on_all_features + self.robot_types = [ds.meta.info["robot_type"] for ds in datasets] + # assign robot_type if missing + for ds in datasets: + ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"]) + ds.robot_type = ds.meta.info["robot_type"] + + # step 1: compute disabled features + self.disabled_features = set() + if not self.train_on_all_features: + intersection = set(datasets[0].features) + for ds in datasets: + intersection.intersection_update(ds.features) + if not intersection: + raise RuntimeError("No common features across datasets.") + for repo_id, ds in zip(repo_ids, datasets, strict=False): + extra = set(ds.features) - intersection + logging.warning(f"Disabling {extra} for repo {repo_id}") + self.disabled_features.update(extra) + + # step 2: build union_features excluding disabled + self.union_features = {} + for ds in datasets: + for k, v in ds.features.items(): + if k not in self.disabled_features: + self.union_features[k] = v + + # step 3: reshape feature schema + self.features = reshape_features_to_max_dim( + self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim + ) + + # step 4: aggregate stats + self.stats = aggregate_stats_per_robot_type(datasets) + for robot_type_, stats_ in self.stats.items(): + for feat_key, feat_stats in stats_.items(): + if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]: + for k, v in feat_stats.items(): + pad_value = 0 if k in ["min", "mean"] else 1 + self.stats[robot_type_][feat_key][k] = pad_tensor( + v, + max_size=self.keys_to_max_dim.get(feat_key, -1), + pad_dim=-1, + pad_value=pad_value, + ) + + # step 5: episodes & tasks + self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)} + self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)} + self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)} + + +class MultiLeRobotDatasetCleaner: + def __init__( + self, + datasets: list[LeRobotDataset], + repo_ids: list[str], + sampling_weights: list[float], + datasets_repo_ids: list[str], + min_fps: int = 1, + max_fps: int = 100, + ): + self.original_datasets = datasets + self.original_repo_ids = repo_ids + self.original_weights = sampling_weights + self.original_datasets_repo_ids = datasets_repo_ids + + # step 1: remove datasets with invalid fps + + # step 2: keep datasets with same features per robot type + consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type( + datasets + ) + + self.cleaned_datasets = consistent_datasets + self.keep_mask = keep_mask + self.cleaned_weights = [sampling_weights[i] for i in range(len(datasets)) if keep_mask[i]] + self.cleaned_repo_ids = [repo_ids[i] for i in range(len(datasets)) if keep_mask[i]] + self.cleaned_datasets_repo_ids = [ + datasets_repo_ids[i] for i in range(len(datasets)) if keep_mask[i] + ] + + self.cumulative_sizes = np.array( + [0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0)) + ) + self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32) + +# --- at the top of the file (same imports as before) --- +from collections import defaultdict +from typing import Callable +import copy +import numpy as np +import torch +import datasets +from pathlib import Path + +# If you already have these in your codebase, reuse them +try: + from lerobot.common.constants import ( + ACTION, OBS_ENV_STATE, OBS_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3 + ) +except Exception: + # Fallbacks if constants are already strings elsewhere + ACTION = "action" + OBS_ENV_STATE = "observation.env_state" + OBS_STATE = "observation.state" + OBS_IMAGE = "observation.image" + OBS_IMAGE_2 = "observation.image_2" + OBS_IMAGE_3 = "observation.image_3" + +IGNORED_KEYS = ["observation.effort"] class MultiLeRobotDataset(torch.utils.data.Dataset): - """A dataset consisting of multiple underlying `LeRobotDataset`s. - - The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API - structure of `LeRobotDataset`. - """ + # ... keep your existing docstring ... def __init__( self, @@ -1336,99 +1468,253 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): root: str | Path | None = None, episodes: dict | None = None, image_transforms: Callable | None = None, - delta_timestamps: dict[str, list[float]] | None = None, + delta_timestamps: dict[list[float]] | None = None, tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None, + # --- NEW: simple add-ons --- + sampling_weights: list[float] | None = None, + feature_keys_mapping: dict[str, dict[str, str]] | None = None, + max_action_dim: int | None = None, + max_state_dim: int | None = None, + max_num_images: int | None = None, + max_image_dim: int | None = None, + train_on_all_features: bool = False, + min_fps: int = 1, + max_fps: int = 100, + ignore_keys: list[str] | None = None, # exact or glob patterns ): super().__init__() self.repo_ids = repo_ids self.root = Path(root) if root else HF_LEROBOT_HOME self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) - # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which - # are handled by this class. - self._datasets = [ - LeRobotDataset( - repo_id, - root=self.root / repo_id, - episodes=episodes[repo_id] if episodes else None, - image_transforms=image_transforms, - delta_timestamps=delta_timestamps, - tolerance_s=self.tolerances_s[repo_id], - download_videos=download_videos, - video_backend=video_backend, - ) - for repo_id in repo_ids - ] - # Disable any data keys that are not common across all of the datasets. Note: we may relax this - # restriction in future iterations of this class. For now, this is necessary at least for being able - # to use PyTorch's default DataLoader collate function. - self.disabled_features = set() - intersection_features = set(self._datasets[0].features) - for ds in self._datasets: - intersection_features.intersection_update(ds.features) - if len(intersection_features) == 0: - raise RuntimeError( - "Multiple datasets were provided but they had no keys common to all of them. " - "The multi-dataset functionality currently only keeps common keys." - ) - for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): - extra_keys = set(ds.features).difference(intersection_features) - logging.warning( - f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " - "other datasets." - ) - self.disabled_features.update(extra_keys) + # --- NEW: store mapping and simple knobs --- + self.feature_keys_mapping: dict[str, dict[str, str]] = feature_keys_mapping or {} + self.train_on_all_features = train_on_all_features + self.max_action_dim = max_action_dim + self.max_state_dim = max_state_dim + self.max_image_dim = max_image_dim + self.max_num_images = max_num_images # (optional, we don’t enforce count, we enforce names) + self._ignore_patterns = list(ignore_keys or []) + # Build underlying single datasets + _datasets = [] + datasets_repo_ids = [] + self.sampling_weights = [] + sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids) + assert len(sampling_weights) == len(repo_ids), ( + "The number of sampling weights must match the number of datasets. " + f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets." + ) + for i, repo_id in enumerate(repo_ids): + try: + _datasets.append( + LeRobotDataset( + repo_id, + root=self.root / repo_id, + episodes=episodes.get(repo_id, None) if episodes else None, + image_transforms=image_transforms, # transforms applied inside single ds + delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None, + tolerance_s=self.tolerances_s[repo_id], + download_videos=download_videos, + video_backend=video_backend, + ) + ) + datasets_repo_ids.append(repo_id) + self.sampling_weights.append(float(sampling_weights[i])) + except Exception as e: + print(f"Failed to load dataset: {repo_id} due to Exception: {e}") + + print( + f"Finish loading {len(_datasets)} datasets, with sampling weights: " + f"{self.sampling_weights} corresponding to: {datasets_repo_ids}" + ) + + # Bookkeeping for mapping & canonical image inventory self.image_transforms = image_transforms - self.delta_timestamps = delta_timestamps - # TODO(rcadene, aliberts): We should not perform this aggregation for datasets - # with multiple robots of different ranges. Instead we should have one normalization - # per robot. - self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + self.delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None + self._datasets = _datasets + self.datasets_repo_ids = datasets_repo_ids + + # --- NEW: compute “canonical image keys” (targets across all mappings) --- + self._canonical_image_keys: set[str] = set() + self._source_keys_per_repo: dict[str, set[str]] = {} + self._target_keys_per_repo: dict[str, set[str]] = {} + for rid, mapping in self.feature_keys_mapping.items(): + src_keys = set(mapping.keys()) + tgt_keys = set(mapping.values()) + self._source_keys_per_repo[rid] = src_keys + self._target_keys_per_repo[rid] = tgt_keys + # union of target names (we will ensure these exist at __getitem__) + self._canonical_image_keys |= { + k for k in tgt_keys if self._is_image_key_like(k) + } + + # If user didn’t give any mapping, fall back to native keys (no-ops) + if not self._canonical_image_keys and self.train_on_all_features: + # discover all image-like keys from raw features + for ds in self._datasets: + for k, v in ds.hf_features.items(): + if isinstance(v, (datasets.Image, VideoFrame)): + self._canonical_image_keys.add(k) + + # Cleaner: keep fps & consistent feature sets per robot type (unchanged) + cleaner = MultiLeRobotDatasetCleaner( + datasets=self._datasets, + repo_ids=repo_ids, + sampling_weights=self.sampling_weights, + datasets_repo_ids=self.datasets_repo_ids, + min_fps=min_fps, + max_fps=max_fps, + ) + self._datasets = cleaner.cleaned_datasets + self.sampling_weights = cleaner.cleaned_weights + self.repo_ids = cleaner.cleaned_repo_ids + self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids + self.cumulative_sizes = cleaner.cumulative_sizes + + # Meta (unchanged): we give it dim maxima; it will reshape/pad vectors + self.meta = MultiLeRobotDatasetMeta( + datasets=self._datasets, + repo_ids=self.repo_ids, + keys_to_max_dim={ + ACTION: self.max_action_dim if self.max_action_dim is not None else -1, + OBS_ENV_STATE: self.max_state_dim if self.max_state_dim is not None else -1, + OBS_STATE: self.max_state_dim if self.max_state_dim is not None else -1, + OBS_IMAGE: self.max_image_dim if self.max_image_dim is not None else -1, + OBS_IMAGE_2: self.max_image_dim if self.max_image_dim is not None else -1, + OBS_IMAGE_3: self.max_image_dim if self.max_image_dim is not None else -1, + }, + train_on_all_features=train_on_all_features, + ) + + # --- NEW: track dropped (source) keys so collate won’t expect them + # Anything that we *rename away* should be considered disabled, + # otherwise downstream may expect them to exist. + self._dropped_keys = set() + for rid, mapping in self.feature_keys_mapping.items(): + self._dropped_keys |= set(mapping.keys()) + + # Merge with meta’s disabled features + self.disabled_features = set(self.meta.disabled_features) | self._dropped_keys + + self.stats = self.meta.stats + + # --- NEW: cache an example image shape per canonical key (lazy, filled on first use) + self._cached_img_shape: dict[str, torch.Size] = {} + + # ---------------------- NEW small helpers ---------------------- + + def _is_image_key_like(self, key: str) -> bool: + # A loose heuristic: rely on name OR on features later + return ("image" in key) or ("cam_" in key) or ("images." in key) + + def _should_ignore(self, key: str) -> bool: + # exact or glob-style match + for pat in self._ignore_patterns: + if key == pat or fnmatch.fnmatch(key, pat): + return True + return False + def _apply_feature_mapping(self, item: dict, repo_id: str) -> dict: + """ + Rename features according to feature_keys_mapping[repo_id]. + - Moves tensor/image under target key. + - Drops source key if moved. + - Adds *_is_pad=False for image targets we fill/keep. + """ + mapping = self.feature_keys_mapping.get(repo_id, {}) or {} + if not mapping: + return item + + for src, tgt in mapping.items(): + if src in item: + # Move value + item[tgt] = item[src] + # Drop the source to avoid duplication + del item[src] + return item + + def _ensure_union_image_keys(self, item: dict) -> dict: + """ + Ensure that every canonical image key exists. + When missing, create a zero tensor matching (B,C,H,W) or (C,H,W) of an available image. + Also add boolean mask at f"{key}_is_pad". + """ + if not self.train_on_all_features or not self._canonical_image_keys: + return item + + # find any existing image tensor in item to copy shape/dtype + exemplar = None + for k in list(item.keys()): + v = item[k] + if torch.is_tensor(v) and v.ndim in (3, 4, 5): # (C,H,W) or (B,C,H,W) or (B,T,C,H,W) + exemplar = v + break + + # fallback to a safe 3x224x224 if nothing found + def _fallback_image(): + return torch.zeros(3, 224, 224, dtype=torch.uint8) + + for key in self._canonical_image_keys: + if key not in item: + img = torch.zeros_like(exemplar) if exemplar is not None else _fallback_image() + item[key] = img + item[f"{key}_is_pad"] = torch.tensor(True, dtype=torch.bool) + else: + # Add a mask saying it’s *not* padded + if f"{key}_is_pad" not in item: + item[f"{key}_is_pad"] = torch.tensor(False, dtype=torch.bool) + return item + + # ---------------------- existing API below (mostly unchanged) ---------------------- @property def repo_id_to_index(self): - """Return a mapping from dataset repo_id to a dataset index automatically created by this class. - - This index is incorporated as a data key in the dictionary returned by `__getitem__`. - """ return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} @property def repo_index_to_id(self): - """Return the inverse mapping if repo_id_to_index.""" return {v: k for k, v in self.repo_id_to_index} @property def fps(self) -> int: - """Frames per second used during data collection. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ return self._datasets[0].meta.info["fps"] @property def video(self) -> bool: - """Returns True if this dataset loads video frames from mp4 files. - - Returns False if it only loads images from png files. - - NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. - """ return self._datasets[0].meta.info.get("video", False) @property def features(self) -> datasets.Features: - features = {} + """ + Extend native HF features with any *target* keys introduced by mapping. + We copy the source spec for targets that didn’t exist in any raw dataset. + """ + features: dict[str, datasets.features.Feature] = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + for k, v in dataset.hf_features.items(): + if k not in self.disabled_features: + features[k] = v + + # Add mapped target image specs if not present yet + for rid, mapping in self.feature_keys_mapping.items(): + ds = None + # find the dataset object to read feature spec for source + for _ds, _rid in zip(self._datasets, self.repo_ids, strict=False): + if _rid == rid: + ds = _ds + break + if ds is None: + continue + for src, tgt in mapping.items(): + if tgt not in features and src in ds.hf_features: + features[tgt] = ds.hf_features[src] + return features @property def camera_keys(self) -> list[str]: - """Keys to access image and video stream from cameras.""" keys = [] for key, feats in self.features.items(): if isinstance(feats, (datasets.Image, VideoFrame)): @@ -1437,12 +1723,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): @property def video_frame_keys(self) -> list[str]: - """Keys to access video frames that requires to be decoded into images. - - Note: It is empty if the dataset contains images only, - or equal to `self.cameras` if the dataset contains videos only, - or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. - """ video_frame_keys = [] for key, feats in self.features.items(): if isinstance(feats, VideoFrame): @@ -1451,21 +1731,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): @property def num_frames(self) -> int: - """Number of samples/frames.""" return sum(d.num_frames for d in self._datasets) @property def num_episodes(self) -> int: - """Number of episodes.""" return sum(d.num_episodes for d in self._datasets) @property def tolerance_s(self) -> float: - """Tolerance in seconds used to discard loaded frames when their timestamps - are not close enough from the requested frames. It is only used when `delta_timestamps` - is provided or when loading video frames from mp4 files. - """ - # 1e-4 to account for possible numerical error return 1 / self.fps - 1e-4 def __len__(self): @@ -1474,22 +1747,83 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if idx >= len(self): raise IndexError(f"Index {idx} out of bounds.") - # Determine which dataset to get an item from based on the index. - start_idx = 0 - dataset_idx = 0 - for dataset in self._datasets: - if idx >= start_idx + dataset.num_frames: - start_idx += dataset.num_frames - dataset_idx += 1 - continue - break - else: - raise AssertionError("We expect the loop to break out as long as the index is within bounds.") - item = self._datasets[dataset_idx][idx - start_idx] + dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1 + local_idx = (idx - self.cumulative_sizes[dataset_idx]).item() + item = self._datasets[dataset_idx][local_idx] + + # Identify which repo this sample came from + repo_id = self.datasets_repo_ids[dataset_idx] + + # --- NEW: apply mapping and ensure union of image keys --- + item = self._apply_feature_mapping(item, repo_id) + item = self._ensure_union_image_keys(item) + + # annotate dataset index for downstream item["dataset_index"] = torch.tensor(dataset_idx) + + # Pad vector features to max dims using meta (unchanged) + item = create_padded_features(item, self.meta.features) + + # Drop any disabled (including original source keys we remapped away) for data_key in self.disabled_features: if data_key in item: del item[data_key] + for k in IGNORED_KEYS: + if k in item: + item.pop(k) + # Convert any datasets.Image still present to tensor + if self.image_transforms is not None: + for cam in [k for k in item.keys() if self._is_image_key_like(k)]: + val = item[cam] + if not torch.is_tensor(val): + item[cam] = self.image_transforms(val) + # 🔑 Pad actions if too short + if "actions" in item and self.max_action_dim is not None: + act = item["actions"] + if act.shape[-1] < self.max_action_dim: + pad_len = self.max_action_dim - act.shape[-1] + item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1) + item["actions_padding_mask"] = torch.cat( + [torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], + dim=-1, + ) + + # pad obs_state if too short + if "obs_state" in item and self.max_state_dim is not None: + st = item["obs_state"] + if st.shape[-1] < self.max_state_dim: + pad_len = self.max_state_dim - st.shape[-1] + item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1) + item["obs_state_padding_mask"] = torch.cat( + [torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], + dim=-1, + ) + # actions + if "actions" in item and self.max_action_dim is not None: + act = item["actions"] + if act.shape[-1] < self.max_action_dim: + pad_len = self.max_action_dim - act.shape[-1] + item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1) + mask = torch.cat( + [torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], + dim=-1, + ) + else: + mask = torch.zeros(self.max_action_dim, dtype=torch.bool) # 👈 all False if no padding + item["actions_padding_mask"] = mask + # obs state + if "obs_state" in item and self.max_state_dim is not None: + st = item["obs_state"] + if st.shape[-1] < self.max_state_dim: + pad_len = self.max_state_dim - st.shape[-1] + item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1) + mask = torch.cat( + [torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)], + dim=-1, + ) + else: + mask = torch.zeros(self.max_state_dim, dtype=torch.bool) # 👈 always add mask + item["obs_state_padding_mask"] = mask return item @@ -1506,3 +1840,149 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): f" Transformations: {self.image_transforms},\n" f")" ) + +def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list: + """ + Filters datasets to only keep those with consistent feature shapes per robot type. + + Args: + ls_datasets (List): List of datasets, each with a `meta.info['robot_type']` + and `meta.episodes_stats` dictionary. + + Returns: + List: Filtered list of datasets with consistent feature shapes. + """ + robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets} + datasets_to_remove = set() + + for robot_type in robot_types: + # Collect all stats dicts for this robot type + stats_list = [ + ep_stats + for ds in ls_datasets + if ds.meta.info["robot_type"] == robot_type + for ep_stats in episode_stats_values(ds.meta) + ] + if not stats_list: + continue + + # Determine the most common shape for each key + all_keys = {key for stats in stats_list for key in stats} + for ds in ls_datasets: + if ds.meta.info["robot_type"] != robot_type: + continue + for key in all_keys: + shape_counter = defaultdict(int) + + for stats in stats_list: + value = stats.get(key) + if ( + value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray)) + ): # FIXME(mshukor): check all stats; min, mean, max + shape_counter[value["mean"].shape] += 1 + if not shape_counter: + continue + + # Identify the most frequent shape + main_shape = max(shape_counter, key=shape_counter.get) + # Flag datasets that don't match the main shape + # for ds in ls_datasets: + first_ep_stats = next(iter(episode_stats_values(ds.meta)), None) + if not first_ep_stats: + continue + value = first_ep_stats.get(key) + if ( + value + and "mean" in value + and isinstance(value["mean"], (torch.Tensor, np.ndarray)) + and value["mean"].shape != main_shape + ): + datasets_to_remove.add(ds) + break + + # Filter out inconsistent datasets + datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets] + filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove] + print( + f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}" + ) + return filtered_datasets, datasets_maks + + +def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]: + """Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type. + + The final stats will have the union of all data keys from each of the datasets. + + The final stats will have the union of all data keys from each of the datasets. For instance: + - new_max = max(max_dataset_0, max_dataset_1, ...) + - new_min = min(min_dataset_0, min_dataset_1, ...) + - new_mean = (mean of all data) + - new_std = (std of all data) + """ + + robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets} + stats = {robot_type: {} for robot_type in robot_types} + for robot_type in robot_types: + robot_type_datasets = [] + for ds in ls_datasets: + if ds.meta.info["robot_type"] == robot_type: + robot_type_datasets.extend(list(episode_stats_values(ds.meta))) + # robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type] + stat = aggregate_stats(robot_type_datasets) + stats[robot_type] = stat + return stats + +def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict: + """Reshape features to have a maximum dimension of `max_dim`.""" + reshaped_features = {} + for key in features: + if key in keys_to_max_dim and keys_to_max_dim[key] is not None: + reshaped_features[key] = features[key] + shape = list(features[key]["shape"]) + if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images + shape[-3] = keys_to_max_dim[key] + shape[-2] = keys_to_max_dim[key] + else: + shape[reshape_dim] = keys_to_max_dim[key] + reshaped_features[key]["shape"] = tuple(shape) + else: + reshaped_features[key] = features[key] + return reshaped_features + +def create_padded_features(item: dict, features: dict = {}): + for key, ft in features.items(): + if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack + continue + shape = ft["shape"] + if len(shape) == 3: # images to torch format (C, H, W) + shape = (shape[2], shape[0], shape[1]) + if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele) + shape = [] + if key not in item: + dtype = str_to_torch_dtype(ft["dtype"]) + item[key] = torch.zeros(shape, dtype=dtype) + item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64) + if "image" in key: # FIXME(mshukor): support other observations + item[f"{key}_is_pad"] = torch.BoolTensor([False]) + else: + item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64) + return item + +def str_to_torch_dtype(dtype_str): + """Convert a dtype string to a torch dtype.""" + mapping = { + "float32": torch.float32, + "int64": torch.int64, + "int16": torch.int16, + "bool": torch.bool, + "video": torch.float32, # Assuming video is stored as uint8 images + } + return mapping.get(dtype_str, torch.float32) # Default to float32 + +def episode_stats_values(meta): + episodes = meta.episodes.to_pandas().to_dict(orient="records") + return [ + {k: v for k, v in ep.items() if isinstance(v, dict) and "mean" in v} + for ep in episodes + ]