Compare commits

...

2 Commits

Author SHA1 Message Date
Jade Choghari 2e9b6d4b88 Merge branch 'main' into feat/add-multidataset-training 2025-09-23 18:17:09 +02:00
Jade Choghari 0a3851e2a3 add first commit 2025-09-18 14:12:54 +02:00
4 changed files with 727 additions and 89 deletions
+66
View File
@@ -0,0 +1,66 @@
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset
REPO_A = "lerobot/pusht"
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id
feature_keys_mapping = {
REPO_A: { # pusht (1 camera, 2-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.image": "obs_image.cam_high",
},
REPO_B: { # dual arm (3 cameras, 14-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.images.cam_high": "obs_image.cam_high",
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
},
}
from torchvision.transforms.v2 import Compose, ToImage, Resize
image_tf = Compose([
ToImage(), # converts to tensor if needed
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
])
from torch.utils.data import DataLoader
dataset = MultiLeRobotDataset(
repo_ids=[REPO_A, REPO_B],
image_transforms=image_tf, # ensures same HxW
feature_keys_mapping=feature_keys_mapping,
train_on_all_features=True, # keep union of cameras; zero-fill missing
# optional: override if you want fixed maxima; else inferred:
# max_action_dim=14,
# max_state_dim=14,
max_action_dim=14,
max_state_dim=14,
max_image_dim=224,
ignore_keys=[
"next.*", # drop reward/done/success
"index",
"timestamp",
"videos/*", # drop all video metadata
"observation.effort", # 👈 drop effort everywhere
],
)
breakpoint()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
for _ in range(100):
batch = next(iter(loader))
breakpoint()
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
assert batch["actions"].shape[-1] == 14
assert batch["obs_state"].shape[-1] == 14
assert batch["actions_padding_mask"].shape[-1] == 14
assert batch["obs_state_padding_mask"].shape[-1] == 14
# cameras: all canonical keys exist; pusht will have wrists zero-filled
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
assert cam in batch
assert f"{cam}_is_pad" in batch
# images should all be 3x224x224 (or your transforms size)
img = batch[cam]
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader
+16
View File
@@ -0,0 +1,16 @@
# storage / caches
RAID=/raid/jade
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
export HF_HOME=$RAID/.cache/huggingface
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
export WANDB_CACHE_DIR=$RAID/.cache/wandb
export TMPDIR=$RAID/.cache/tmp
mkdir -p $TMPDIR
export WANDB_MODE=offline
# export HF_DATASETS_OFFLINE=1
# export HF_HUB_OFFLINE=1
export TOKENIZERS_PARALLELISM=false
export MUJOCO_GL=egl
python examples/tester.py
+76
View File
@@ -174,3 +174,79 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
return aggregated_stats
import numpy as np
def aggregate_stats_multi(
stats_list: list[dict[str, dict]],
max_action_dim: int | None = None,
max_state_dim: int | None = None,
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
Supports heterogeneous robots by padding action/state stats to the max dim.
The final stats will have the union of all data keys from each of the stats dicts.
- new_min = elementwise min across datasets
- new_max = elementwise max across datasets
- new_mean = weighted mean (by count)
- new_std = recomputed from total variance
"""
data_keys = {key for stats in stats_list for key in stats}
aggregated_stats = {key: {} for key in data_keys}
def _pad(arr: np.ndarray, target: int) -> np.ndarray:
if arr.ndim == 0: # scalar
return arr
if target is None or target <= 0 or arr.shape[-1] == target:
return arr
pad_width = [(0, 0)] * arr.ndim
pad_width[-1] = (0, target - arr.shape[-1])
return np.pad(arr, pad_width, mode="constant")
for key in data_keys:
stats_with_key = [stats[key] for stats in stats_list if key in stats]
# decide if this key should be padded
target_dim = None
if "action" in key and max_action_dim:
target_dim = max_action_dim
elif "state" in key and max_state_dim:
target_dim = max_state_dim
padded = []
counts = []
for s in stats_with_key:
mean = _pad(np.array(s["mean"]), target_dim)
std = _pad(np.array(s["std"]), target_dim)
min_ = _pad(np.array(s["min"]), target_dim)
max_ = _pad(np.array(s["max"]), target_dim)
count = s.get("count", 1)
padded.append(dict(mean=mean, std=std, min=min_, max=max_, count=count))
counts.append(count)
counts = np.array(counts, dtype=np.float64)
total_count = counts.sum()
means = np.stack([p["mean"] for p in padded])
stds = np.stack([p["std"] for p in padded])
mins = np.stack([p["min"] for p in padded])
maxs = np.stack([p["max"] for p in padded])
# weighted mean (broadcast weights properly)
new_mean = np.average(means, axis=0, weights=counts)
new_var = np.average(stds**2 + (means - new_mean)**2, axis=0, weights=counts)
new_std = np.sqrt(new_var)
aggregated_stats[key] = {
"min": mins.min(axis=0),
"max": maxs.max(axis=0),
"mean": new_mean,
"std": new_std,
"count": int(total_count),
}
return aggregated_stats
+569 -89
View File
@@ -31,6 +31,7 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from collections import defaultdict
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
@@ -81,7 +82,12 @@ from lerobot.datasets.video_utils import (
)
CODEBASE_VERSION = "v3.0"
OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image_2"
OBS_IMAGE_3 = "observation.image_3"
OBS_STATE = "observation.state"
OBS_ENV_STATE = "observation.env_state"
ACTION = "action"
class LeRobotDatasetMetadata:
def __init__(
@@ -1322,13 +1328,139 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj
ROBOT_TYPE_KEYS_MAPPING = {
"lerobot/stanford_hydra_dataset": "static_single_arm",
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
"lerobot/toto": "static_single_arm",
"lerobot/roboturk": "static_single_arm",
"lerobot/jaco_play": "static_single_arm",
"lerobot/taco_play": "static_single_arm_7statedim",
}
class MultiLeRobotDatasetMeta:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
keys_to_max_dim: dict[str, int],
train_on_all_features: bool = False,
):
self.repo_ids = repo_ids
self.keys_to_max_dim = keys_to_max_dim
self.train_on_all_features = train_on_all_features
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
# assign robot_type if missing
for ds in datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# step 1: compute disabled features
self.disabled_features = set()
if not self.train_on_all_features:
intersection = set(datasets[0].features)
for ds in datasets:
intersection.intersection_update(ds.features)
if not intersection:
raise RuntimeError("No common features across datasets.")
for repo_id, ds in zip(repo_ids, datasets, strict=False):
extra = set(ds.features) - intersection
logging.warning(f"Disabling {extra} for repo {repo_id}")
self.disabled_features.update(extra)
# step 2: build union_features excluding disabled
self.union_features = {}
for ds in datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
self.union_features[k] = v
# step 3: reshape feature schema
self.features = reshape_features_to_max_dim(
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
# step 4: aggregate stats
self.stats = aggregate_stats_per_robot_type(datasets)
for robot_type_, stats_ in self.stats.items():
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
pad_value = 0 if k in ["min", "mean"] else 1
self.stats[robot_type_][feat_key][k] = pad_tensor(
v,
max_size=self.keys_to_max_dim.get(feat_key, -1),
pad_dim=-1,
pad_value=pad_value,
)
# step 5: episodes & tasks
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
class MultiLeRobotDatasetCleaner:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
sampling_weights: list[float],
datasets_repo_ids: list[str],
min_fps: int = 1,
max_fps: int = 100,
):
self.original_datasets = datasets
self.original_repo_ids = repo_ids
self.original_weights = sampling_weights
self.original_datasets_repo_ids = datasets_repo_ids
# step 1: remove datasets with invalid fps
# step 2: keep datasets with same features per robot type
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
datasets
)
self.cleaned_datasets = consistent_datasets
self.keep_mask = keep_mask
self.cleaned_weights = [sampling_weights[i] for i in range(len(datasets)) if keep_mask[i]]
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(datasets)) if keep_mask[i]]
self.cleaned_datasets_repo_ids = [
datasets_repo_ids[i] for i in range(len(datasets)) if keep_mask[i]
]
self.cumulative_sizes = np.array(
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
)
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
# --- at the top of the file (same imports as before) ---
from collections import defaultdict
from typing import Callable
import copy
import numpy as np
import torch
import datasets
from pathlib import Path
# If you already have these in your codebase, reuse them
try:
from lerobot.common.constants import (
ACTION, OBS_ENV_STATE, OBS_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3
)
except Exception:
# Fallbacks if constants are already strings elsewhere
ACTION = "action"
OBS_ENV_STATE = "observation.env_state"
OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image_2"
OBS_IMAGE_3 = "observation.image_3"
IGNORED_KEYS = ["observation.effort"]
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
# ... keep your existing docstring ...
def __init__(
self,
@@ -1336,99 +1468,253 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
# --- NEW: simple add-ons ---
sampling_weights: list[float] | None = None,
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
max_action_dim: int | None = None,
max_state_dim: int | None = None,
max_num_images: int | None = None,
max_image_dim: int | None = None,
train_on_all_features: bool = False,
min_fps: int = 1,
max_fps: int = 100,
ignore_keys: list[str] | None = None, # exact or glob patterns
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
# --- NEW: store mapping and simple knobs ---
self.feature_keys_mapping: dict[str, dict[str, str]] = feature_keys_mapping or {}
self.train_on_all_features = train_on_all_features
self.max_action_dim = max_action_dim
self.max_state_dim = max_state_dim
self.max_image_dim = max_image_dim
self.max_num_images = max_num_images # (optional, we dont enforce count, we enforce names)
self._ignore_patterns = list(ignore_keys or [])
# Build underlying single datasets
_datasets = []
datasets_repo_ids = []
self.sampling_weights = []
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
assert len(sampling_weights) == len(repo_ids), (
"The number of sampling weights must match the number of datasets. "
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
)
for i, repo_id in enumerate(repo_ids):
try:
_datasets.append(
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes.get(repo_id, None) if episodes else None,
image_transforms=image_transforms, # transforms applied inside single ds
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
)
datasets_repo_ids.append(repo_id)
self.sampling_weights.append(float(sampling_weights[i]))
except Exception as e:
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
print(
f"Finish loading {len(_datasets)} datasets, with sampling weights: "
f"{self.sampling_weights} corresponding to: {datasets_repo_ids}"
)
# Bookkeeping for mapping & canonical image inventory
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
self.delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None
self._datasets = _datasets
self.datasets_repo_ids = datasets_repo_ids
# --- NEW: compute “canonical image keys” (targets across all mappings) ---
self._canonical_image_keys: set[str] = set()
self._source_keys_per_repo: dict[str, set[str]] = {}
self._target_keys_per_repo: dict[str, set[str]] = {}
for rid, mapping in self.feature_keys_mapping.items():
src_keys = set(mapping.keys())
tgt_keys = set(mapping.values())
self._source_keys_per_repo[rid] = src_keys
self._target_keys_per_repo[rid] = tgt_keys
# union of target names (we will ensure these exist at __getitem__)
self._canonical_image_keys |= {
k for k in tgt_keys if self._is_image_key_like(k)
}
# If user didnt give any mapping, fall back to native keys (no-ops)
if not self._canonical_image_keys and self.train_on_all_features:
# discover all image-like keys from raw features
for ds in self._datasets:
for k, v in ds.hf_features.items():
if isinstance(v, (datasets.Image, VideoFrame)):
self._canonical_image_keys.add(k)
# Cleaner: keep fps & consistent feature sets per robot type (unchanged)
cleaner = MultiLeRobotDatasetCleaner(
datasets=self._datasets,
repo_ids=repo_ids,
sampling_weights=self.sampling_weights,
datasets_repo_ids=self.datasets_repo_ids,
min_fps=min_fps,
max_fps=max_fps,
)
self._datasets = cleaner.cleaned_datasets
self.sampling_weights = cleaner.cleaned_weights
self.repo_ids = cleaner.cleaned_repo_ids
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
self.cumulative_sizes = cleaner.cumulative_sizes
# Meta (unchanged): we give it dim maxima; it will reshape/pad vectors
self.meta = MultiLeRobotDatasetMeta(
datasets=self._datasets,
repo_ids=self.repo_ids,
keys_to_max_dim={
ACTION: self.max_action_dim if self.max_action_dim is not None else -1,
OBS_ENV_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
OBS_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
OBS_IMAGE: self.max_image_dim if self.max_image_dim is not None else -1,
OBS_IMAGE_2: self.max_image_dim if self.max_image_dim is not None else -1,
OBS_IMAGE_3: self.max_image_dim if self.max_image_dim is not None else -1,
},
train_on_all_features=train_on_all_features,
)
# --- NEW: track dropped (source) keys so collate wont expect them
# Anything that we *rename away* should be considered disabled,
# otherwise downstream may expect them to exist.
self._dropped_keys = set()
for rid, mapping in self.feature_keys_mapping.items():
self._dropped_keys |= set(mapping.keys())
# Merge with metas disabled features
self.disabled_features = set(self.meta.disabled_features) | self._dropped_keys
self.stats = self.meta.stats
# --- NEW: cache an example image shape per canonical key (lazy, filled on first use)
self._cached_img_shape: dict[str, torch.Size] = {}
# ---------------------- NEW small helpers ----------------------
def _is_image_key_like(self, key: str) -> bool:
# A loose heuristic: rely on name OR on features later
return ("image" in key) or ("cam_" in key) or ("images." in key)
def _should_ignore(self, key: str) -> bool:
# exact or glob-style match
for pat in self._ignore_patterns:
if key == pat or fnmatch.fnmatch(key, pat):
return True
return False
def _apply_feature_mapping(self, item: dict, repo_id: str) -> dict:
"""
Rename features according to feature_keys_mapping[repo_id].
- Moves tensor/image under target key.
- Drops source key if moved.
- Adds *_is_pad=False for image targets we fill/keep.
"""
mapping = self.feature_keys_mapping.get(repo_id, {}) or {}
if not mapping:
return item
for src, tgt in mapping.items():
if src in item:
# Move value
item[tgt] = item[src]
# Drop the source to avoid duplication
del item[src]
return item
def _ensure_union_image_keys(self, item: dict) -> dict:
"""
Ensure that every canonical image key exists.
When missing, create a zero tensor matching (B,C,H,W) or (C,H,W) of an available image.
Also add boolean mask at f"{key}_is_pad".
"""
if not self.train_on_all_features or not self._canonical_image_keys:
return item
# find any existing image tensor in item to copy shape/dtype
exemplar = None
for k in list(item.keys()):
v = item[k]
if torch.is_tensor(v) and v.ndim in (3, 4, 5): # (C,H,W) or (B,C,H,W) or (B,T,C,H,W)
exemplar = v
break
# fallback to a safe 3x224x224 if nothing found
def _fallback_image():
return torch.zeros(3, 224, 224, dtype=torch.uint8)
for key in self._canonical_image_keys:
if key not in item:
img = torch.zeros_like(exemplar) if exemplar is not None else _fallback_image()
item[key] = img
item[f"{key}_is_pad"] = torch.tensor(True, dtype=torch.bool)
else:
# Add a mask saying its *not* padded
if f"{key}_is_pad" not in item:
item[f"{key}_is_pad"] = torch.tensor(False, dtype=torch.bool)
return item
# ---------------------- existing API below (mostly unchanged) ----------------------
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
"""
Extend native HF features with any *target* keys introduced by mapping.
We copy the source spec for targets that didnt exist in any raw dataset.
"""
features: dict[str, datasets.features.Feature] = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
for k, v in dataset.hf_features.items():
if k not in self.disabled_features:
features[k] = v
# Add mapped target image specs if not present yet
for rid, mapping in self.feature_keys_mapping.items():
ds = None
# find the dataset object to read feature spec for source
for _ds, _rid in zip(self._datasets, self.repo_ids, strict=False):
if _rid == rid:
ds = _ds
break
if ds is None:
continue
for src, tgt in mapping.items():
if tgt not in features and src in ds.hf_features:
features[tgt] = ds.hf_features[src]
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
@@ -1437,12 +1723,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
@@ -1451,21 +1731,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames(self) -> int:
"""Number of samples/frames."""
return sum(d.num_frames for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
@@ -1474,22 +1747,83 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
item = self._datasets[dataset_idx][local_idx]
# Identify which repo this sample came from
repo_id = self.datasets_repo_ids[dataset_idx]
# --- NEW: apply mapping and ensure union of image keys ---
item = self._apply_feature_mapping(item, repo_id)
item = self._ensure_union_image_keys(item)
# annotate dataset index for downstream
item["dataset_index"] = torch.tensor(dataset_idx)
# Pad vector features to max dims using meta (unchanged)
item = create_padded_features(item, self.meta.features)
# Drop any disabled (including original source keys we remapped away)
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]
for k in IGNORED_KEYS:
if k in item:
item.pop(k)
# Convert any datasets.Image still present to tensor
if self.image_transforms is not None:
for cam in [k for k in item.keys() if self._is_image_key_like(k)]:
val = item[cam]
if not torch.is_tensor(val):
item[cam] = self.image_transforms(val)
# 🔑 Pad actions if too short
if "actions" in item and self.max_action_dim is not None:
act = item["actions"]
if act.shape[-1] < self.max_action_dim:
pad_len = self.max_action_dim - act.shape[-1]
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
item["actions_padding_mask"] = torch.cat(
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
# pad obs_state if too short
if "obs_state" in item and self.max_state_dim is not None:
st = item["obs_state"]
if st.shape[-1] < self.max_state_dim:
pad_len = self.max_state_dim - st.shape[-1]
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
item["obs_state_padding_mask"] = torch.cat(
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
# actions
if "actions" in item and self.max_action_dim is not None:
act = item["actions"]
if act.shape[-1] < self.max_action_dim:
pad_len = self.max_action_dim - act.shape[-1]
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
mask = torch.cat(
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
else:
mask = torch.zeros(self.max_action_dim, dtype=torch.bool) # 👈 all False if no padding
item["actions_padding_mask"] = mask
# obs state
if "obs_state" in item and self.max_state_dim is not None:
st = item["obs_state"]
if st.shape[-1] < self.max_state_dim:
pad_len = self.max_state_dim - st.shape[-1]
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
mask = torch.cat(
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
else:
mask = torch.zeros(self.max_state_dim, dtype=torch.bool) # 👈 always add mask
item["obs_state_padding_mask"] = mask
return item
@@ -1506,3 +1840,149 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f" Transformations: {self.image_transforms},\n"
f")"
)
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
"""
Filters datasets to only keep those with consistent feature shapes per robot type.
Args:
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
and `meta.episodes_stats` dictionary.
Returns:
List: Filtered list of datasets with consistent feature shapes.
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
datasets_to_remove = set()
for robot_type in robot_types:
# Collect all stats dicts for this robot type
stats_list = [
ep_stats
for ds in ls_datasets
if ds.meta.info["robot_type"] == robot_type
for ep_stats in episode_stats_values(ds.meta)
]
if not stats_list:
continue
# Determine the most common shape for each key
all_keys = {key for stats in stats_list for key in stats}
for ds in ls_datasets:
if ds.meta.info["robot_type"] != robot_type:
continue
for key in all_keys:
shape_counter = defaultdict(int)
for stats in stats_list:
value = stats.get(key)
if (
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
): # FIXME(mshukor): check all stats; min, mean, max
shape_counter[value["mean"].shape] += 1
if not shape_counter:
continue
# Identify the most frequent shape
main_shape = max(shape_counter, key=shape_counter.get)
# Flag datasets that don't match the main shape
# for ds in ls_datasets:
first_ep_stats = next(iter(episode_stats_values(ds.meta)), None)
if not first_ep_stats:
continue
value = first_ep_stats.get(key)
if (
value
and "mean" in value
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
and value["mean"].shape != main_shape
):
datasets_to_remove.add(ds)
break
# Filter out inconsistent datasets
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
print(
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
)
return filtered_datasets, datasets_maks
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
stats = {robot_type: {} for robot_type in robot_types}
for robot_type in robot_types:
robot_type_datasets = []
for ds in ls_datasets:
if ds.meta.info["robot_type"] == robot_type:
robot_type_datasets.extend(list(episode_stats_values(ds.meta)))
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
stat = aggregate_stats(robot_type_datasets)
stats[robot_type] = stat
return stats
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
"""Reshape features to have a maximum dimension of `max_dim`."""
reshaped_features = {}
for key in features:
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
reshaped_features[key] = features[key]
shape = list(features[key]["shape"])
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
shape[-3] = keys_to_max_dim[key]
shape[-2] = keys_to_max_dim[key]
else:
shape[reshape_dim] = keys_to_max_dim[key]
reshaped_features[key]["shape"] = tuple(shape)
else:
reshaped_features[key] = features[key]
return reshaped_features
def create_padded_features(item: dict, features: dict = {}):
for key, ft in features.items():
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
continue
shape = ft["shape"]
if len(shape) == 3: # images to torch format (C, H, W)
shape = (shape[2], shape[0], shape[1])
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
shape = []
if key not in item:
dtype = str_to_torch_dtype(ft["dtype"])
item[key] = torch.zeros(shape, dtype=dtype)
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
if "image" in key: # FIXME(mshukor): support other observations
item[f"{key}_is_pad"] = torch.BoolTensor([False])
else:
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
return item
def str_to_torch_dtype(dtype_str):
"""Convert a dtype string to a torch dtype."""
mapping = {
"float32": torch.float32,
"int64": torch.int64,
"int16": torch.int16,
"bool": torch.bool,
"video": torch.float32, # Assuming video is stored as uint8 images
}
return mapping.get(dtype_str, torch.float32) # Default to float32
def episode_stats_values(meta):
episodes = meta.episodes.to_pandas().to_dict(orient="records")
return [
{k: v for k, v in ep.items() if isinstance(v, dict) and "mean" in v}
for ep in episodes
]