mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e9b6d4b88 | |||
| 0a3851e2a3 |
@@ -0,0 +1,66 @@
|
|||||||
|
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||||
|
|
||||||
|
REPO_A = "lerobot/pusht"
|
||||||
|
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id
|
||||||
|
|
||||||
|
feature_keys_mapping = {
|
||||||
|
REPO_A: { # pusht (1 camera, 2-dim)
|
||||||
|
"action": "actions",
|
||||||
|
"observation.state": "obs_state",
|
||||||
|
"observation.image": "obs_image.cam_high",
|
||||||
|
},
|
||||||
|
REPO_B: { # dual arm (3 cameras, 14-dim)
|
||||||
|
"action": "actions",
|
||||||
|
"observation.state": "obs_state",
|
||||||
|
"observation.images.cam_high": "obs_image.cam_high",
|
||||||
|
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
|
||||||
|
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
from torchvision.transforms.v2 import Compose, ToImage, Resize
|
||||||
|
image_tf = Compose([
|
||||||
|
ToImage(), # converts to tensor if needed
|
||||||
|
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
|
||||||
|
])
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
dataset = MultiLeRobotDataset(
|
||||||
|
repo_ids=[REPO_A, REPO_B],
|
||||||
|
image_transforms=image_tf, # ensures same HxW
|
||||||
|
feature_keys_mapping=feature_keys_mapping,
|
||||||
|
train_on_all_features=True, # keep union of cameras; zero-fill missing
|
||||||
|
# optional: override if you want fixed maxima; else inferred:
|
||||||
|
# max_action_dim=14,
|
||||||
|
# max_state_dim=14,
|
||||||
|
max_action_dim=14,
|
||||||
|
max_state_dim=14,
|
||||||
|
max_image_dim=224,
|
||||||
|
ignore_keys=[
|
||||||
|
"next.*", # drop reward/done/success
|
||||||
|
"index",
|
||||||
|
"timestamp",
|
||||||
|
"videos/*", # drop all video metadata
|
||||||
|
"observation.effort", # 👈 drop effort everywhere
|
||||||
|
],
|
||||||
|
)
|
||||||
|
breakpoint()
|
||||||
|
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
|
||||||
|
for _ in range(100):
|
||||||
|
batch = next(iter(loader))
|
||||||
|
|
||||||
|
breakpoint()
|
||||||
|
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
|
||||||
|
assert batch["actions"].shape[-1] == 14
|
||||||
|
assert batch["obs_state"].shape[-1] == 14
|
||||||
|
assert batch["actions_padding_mask"].shape[-1] == 14
|
||||||
|
assert batch["obs_state_padding_mask"].shape[-1] == 14
|
||||||
|
|
||||||
|
# cameras: all canonical keys exist; pusht will have wrists zero-filled
|
||||||
|
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
|
||||||
|
assert cam in batch
|
||||||
|
assert f"{cam}_is_pad" in batch
|
||||||
|
# images should all be 3x224x224 (or your transform’s size)
|
||||||
|
img = batch[cam]
|
||||||
|
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
# storage / caches
|
||||||
|
RAID=/raid/jade
|
||||||
|
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
|
||||||
|
export HF_HOME=$RAID/.cache/huggingface
|
||||||
|
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
|
||||||
|
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
|
||||||
|
export WANDB_CACHE_DIR=$RAID/.cache/wandb
|
||||||
|
export TMPDIR=$RAID/.cache/tmp
|
||||||
|
mkdir -p $TMPDIR
|
||||||
|
export WANDB_MODE=offline
|
||||||
|
# export HF_DATASETS_OFFLINE=1
|
||||||
|
# export HF_HUB_OFFLINE=1
|
||||||
|
export TOKENIZERS_PARALLELISM=false
|
||||||
|
export MUJOCO_GL=egl
|
||||||
|
|
||||||
|
python examples/tester.py
|
||||||
@@ -174,3 +174,79 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
|
|||||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||||
|
|
||||||
return aggregated_stats
|
return aggregated_stats
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def aggregate_stats_multi(
|
||||||
|
stats_list: list[dict[str, dict]],
|
||||||
|
max_action_dim: int | None = None,
|
||||||
|
max_state_dim: int | None = None,
|
||||||
|
) -> dict[str, dict[str, np.ndarray]]:
|
||||||
|
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||||
|
|
||||||
|
Supports heterogeneous robots by padding action/state stats to the max dim.
|
||||||
|
The final stats will have the union of all data keys from each of the stats dicts.
|
||||||
|
|
||||||
|
- new_min = elementwise min across datasets
|
||||||
|
- new_max = elementwise max across datasets
|
||||||
|
- new_mean = weighted mean (by count)
|
||||||
|
- new_std = recomputed from total variance
|
||||||
|
"""
|
||||||
|
|
||||||
|
data_keys = {key for stats in stats_list for key in stats}
|
||||||
|
aggregated_stats = {key: {} for key in data_keys}
|
||||||
|
|
||||||
|
def _pad(arr: np.ndarray, target: int) -> np.ndarray:
|
||||||
|
if arr.ndim == 0: # scalar
|
||||||
|
return arr
|
||||||
|
if target is None or target <= 0 or arr.shape[-1] == target:
|
||||||
|
return arr
|
||||||
|
pad_width = [(0, 0)] * arr.ndim
|
||||||
|
pad_width[-1] = (0, target - arr.shape[-1])
|
||||||
|
return np.pad(arr, pad_width, mode="constant")
|
||||||
|
|
||||||
|
for key in data_keys:
|
||||||
|
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||||
|
|
||||||
|
# decide if this key should be padded
|
||||||
|
target_dim = None
|
||||||
|
if "action" in key and max_action_dim:
|
||||||
|
target_dim = max_action_dim
|
||||||
|
elif "state" in key and max_state_dim:
|
||||||
|
target_dim = max_state_dim
|
||||||
|
|
||||||
|
padded = []
|
||||||
|
counts = []
|
||||||
|
for s in stats_with_key:
|
||||||
|
mean = _pad(np.array(s["mean"]), target_dim)
|
||||||
|
std = _pad(np.array(s["std"]), target_dim)
|
||||||
|
min_ = _pad(np.array(s["min"]), target_dim)
|
||||||
|
max_ = _pad(np.array(s["max"]), target_dim)
|
||||||
|
count = s.get("count", 1)
|
||||||
|
|
||||||
|
padded.append(dict(mean=mean, std=std, min=min_, max=max_, count=count))
|
||||||
|
counts.append(count)
|
||||||
|
|
||||||
|
counts = np.array(counts, dtype=np.float64)
|
||||||
|
total_count = counts.sum()
|
||||||
|
|
||||||
|
means = np.stack([p["mean"] for p in padded])
|
||||||
|
stds = np.stack([p["std"] for p in padded])
|
||||||
|
mins = np.stack([p["min"] for p in padded])
|
||||||
|
maxs = np.stack([p["max"] for p in padded])
|
||||||
|
|
||||||
|
# weighted mean (broadcast weights properly)
|
||||||
|
new_mean = np.average(means, axis=0, weights=counts)
|
||||||
|
new_var = np.average(stds**2 + (means - new_mean)**2, axis=0, weights=counts)
|
||||||
|
|
||||||
|
new_std = np.sqrt(new_var)
|
||||||
|
|
||||||
|
aggregated_stats[key] = {
|
||||||
|
"min": mins.min(axis=0),
|
||||||
|
"max": maxs.max(axis=0),
|
||||||
|
"mean": new_mean,
|
||||||
|
"std": new_std,
|
||||||
|
"count": int(total_count),
|
||||||
|
}
|
||||||
|
|
||||||
|
return aggregated_stats
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import torch.utils
|
|||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
from lerobot.constants import HF_LEROBOT_HOME
|
from lerobot.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
@@ -81,7 +82,12 @@ from lerobot.datasets.video_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
CODEBASE_VERSION = "v3.0"
|
CODEBASE_VERSION = "v3.0"
|
||||||
|
OBS_IMAGE = "observation.image"
|
||||||
|
OBS_IMAGE_2 = "observation.image_2"
|
||||||
|
OBS_IMAGE_3 = "observation.image_3"
|
||||||
|
OBS_STATE = "observation.state"
|
||||||
|
OBS_ENV_STATE = "observation.env_state"
|
||||||
|
ACTION = "action"
|
||||||
|
|
||||||
class LeRobotDatasetMetadata:
|
class LeRobotDatasetMetadata:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1322,13 +1328,139 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
ROBOT_TYPE_KEYS_MAPPING = {
|
||||||
|
"lerobot/stanford_hydra_dataset": "static_single_arm",
|
||||||
|
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
|
||||||
|
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
|
||||||
|
"lerobot/toto": "static_single_arm",
|
||||||
|
"lerobot/roboturk": "static_single_arm",
|
||||||
|
"lerobot/jaco_play": "static_single_arm",
|
||||||
|
"lerobot/taco_play": "static_single_arm_7statedim",
|
||||||
|
}
|
||||||
|
class MultiLeRobotDatasetMeta:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datasets: list[LeRobotDataset],
|
||||||
|
repo_ids: list[str],
|
||||||
|
keys_to_max_dim: dict[str, int],
|
||||||
|
train_on_all_features: bool = False,
|
||||||
|
):
|
||||||
|
self.repo_ids = repo_ids
|
||||||
|
self.keys_to_max_dim = keys_to_max_dim
|
||||||
|
self.train_on_all_features = train_on_all_features
|
||||||
|
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
|
||||||
|
|
||||||
|
# assign robot_type if missing
|
||||||
|
for ds in datasets:
|
||||||
|
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||||
|
ds.robot_type = ds.meta.info["robot_type"]
|
||||||
|
|
||||||
|
# step 1: compute disabled features
|
||||||
|
self.disabled_features = set()
|
||||||
|
if not self.train_on_all_features:
|
||||||
|
intersection = set(datasets[0].features)
|
||||||
|
for ds in datasets:
|
||||||
|
intersection.intersection_update(ds.features)
|
||||||
|
if not intersection:
|
||||||
|
raise RuntimeError("No common features across datasets.")
|
||||||
|
for repo_id, ds in zip(repo_ids, datasets, strict=False):
|
||||||
|
extra = set(ds.features) - intersection
|
||||||
|
logging.warning(f"Disabling {extra} for repo {repo_id}")
|
||||||
|
self.disabled_features.update(extra)
|
||||||
|
|
||||||
|
# step 2: build union_features excluding disabled
|
||||||
|
self.union_features = {}
|
||||||
|
for ds in datasets:
|
||||||
|
for k, v in ds.features.items():
|
||||||
|
if k not in self.disabled_features:
|
||||||
|
self.union_features[k] = v
|
||||||
|
|
||||||
|
# step 3: reshape feature schema
|
||||||
|
self.features = reshape_features_to_max_dim(
|
||||||
|
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# step 4: aggregate stats
|
||||||
|
self.stats = aggregate_stats_per_robot_type(datasets)
|
||||||
|
for robot_type_, stats_ in self.stats.items():
|
||||||
|
for feat_key, feat_stats in stats_.items():
|
||||||
|
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||||
|
for k, v in feat_stats.items():
|
||||||
|
pad_value = 0 if k in ["min", "mean"] else 1
|
||||||
|
self.stats[robot_type_][feat_key][k] = pad_tensor(
|
||||||
|
v,
|
||||||
|
max_size=self.keys_to_max_dim.get(feat_key, -1),
|
||||||
|
pad_dim=-1,
|
||||||
|
pad_value=pad_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# step 5: episodes & tasks
|
||||||
|
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||||
|
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||||
|
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLeRobotDatasetCleaner:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datasets: list[LeRobotDataset],
|
||||||
|
repo_ids: list[str],
|
||||||
|
sampling_weights: list[float],
|
||||||
|
datasets_repo_ids: list[str],
|
||||||
|
min_fps: int = 1,
|
||||||
|
max_fps: int = 100,
|
||||||
|
):
|
||||||
|
self.original_datasets = datasets
|
||||||
|
self.original_repo_ids = repo_ids
|
||||||
|
self.original_weights = sampling_weights
|
||||||
|
self.original_datasets_repo_ids = datasets_repo_ids
|
||||||
|
|
||||||
|
# step 1: remove datasets with invalid fps
|
||||||
|
|
||||||
|
# step 2: keep datasets with same features per robot type
|
||||||
|
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
|
||||||
|
datasets
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cleaned_datasets = consistent_datasets
|
||||||
|
self.keep_mask = keep_mask
|
||||||
|
self.cleaned_weights = [sampling_weights[i] for i in range(len(datasets)) if keep_mask[i]]
|
||||||
|
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(datasets)) if keep_mask[i]]
|
||||||
|
self.cleaned_datasets_repo_ids = [
|
||||||
|
datasets_repo_ids[i] for i in range(len(datasets)) if keep_mask[i]
|
||||||
|
]
|
||||||
|
|
||||||
|
self.cumulative_sizes = np.array(
|
||||||
|
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
|
||||||
|
)
|
||||||
|
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
|
||||||
|
|
||||||
|
# --- at the top of the file (same imports as before) ---
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Callable
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# If you already have these in your codebase, reuse them
|
||||||
|
try:
|
||||||
|
from lerobot.common.constants import (
|
||||||
|
ACTION, OBS_ENV_STATE, OBS_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fallbacks if constants are already strings elsewhere
|
||||||
|
ACTION = "action"
|
||||||
|
OBS_ENV_STATE = "observation.env_state"
|
||||||
|
OBS_STATE = "observation.state"
|
||||||
|
OBS_IMAGE = "observation.image"
|
||||||
|
OBS_IMAGE_2 = "observation.image_2"
|
||||||
|
OBS_IMAGE_3 = "observation.image_3"
|
||||||
|
|
||||||
|
IGNORED_KEYS = ["observation.effort"]
|
||||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
# ... keep your existing docstring ...
|
||||||
|
|
||||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
|
||||||
structure of `LeRobotDataset`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1336,99 +1468,253 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
episodes: dict | None = None,
|
episodes: dict | None = None,
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[str, list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
tolerances_s: dict | None = None,
|
tolerances_s: dict | None = None,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
# --- NEW: simple add-ons ---
|
||||||
|
sampling_weights: list[float] | None = None,
|
||||||
|
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
|
||||||
|
max_action_dim: int | None = None,
|
||||||
|
max_state_dim: int | None = None,
|
||||||
|
max_num_images: int | None = None,
|
||||||
|
max_image_dim: int | None = None,
|
||||||
|
train_on_all_features: bool = False,
|
||||||
|
min_fps: int = 1,
|
||||||
|
max_fps: int = 100,
|
||||||
|
ignore_keys: list[str] | None = None, # exact or glob patterns
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
|
||||||
# are handled by this class.
|
|
||||||
self._datasets = [
|
|
||||||
LeRobotDataset(
|
|
||||||
repo_id,
|
|
||||||
root=self.root / repo_id,
|
|
||||||
episodes=episodes[repo_id] if episodes else None,
|
|
||||||
image_transforms=image_transforms,
|
|
||||||
delta_timestamps=delta_timestamps,
|
|
||||||
tolerance_s=self.tolerances_s[repo_id],
|
|
||||||
download_videos=download_videos,
|
|
||||||
video_backend=video_backend,
|
|
||||||
)
|
|
||||||
for repo_id in repo_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
# --- NEW: store mapping and simple knobs ---
|
||||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
self.feature_keys_mapping: dict[str, dict[str, str]] = feature_keys_mapping or {}
|
||||||
# to use PyTorch's default DataLoader collate function.
|
self.train_on_all_features = train_on_all_features
|
||||||
self.disabled_features = set()
|
self.max_action_dim = max_action_dim
|
||||||
intersection_features = set(self._datasets[0].features)
|
self.max_state_dim = max_state_dim
|
||||||
for ds in self._datasets:
|
self.max_image_dim = max_image_dim
|
||||||
intersection_features.intersection_update(ds.features)
|
self.max_num_images = max_num_images # (optional, we don’t enforce count, we enforce names)
|
||||||
if len(intersection_features) == 0:
|
self._ignore_patterns = list(ignore_keys or [])
|
||||||
raise RuntimeError(
|
# Build underlying single datasets
|
||||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
_datasets = []
|
||||||
"The multi-dataset functionality currently only keeps common keys."
|
datasets_repo_ids = []
|
||||||
)
|
self.sampling_weights = []
|
||||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
|
||||||
extra_keys = set(ds.features).difference(intersection_features)
|
|
||||||
logging.warning(
|
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
|
||||||
"other datasets."
|
|
||||||
)
|
|
||||||
self.disabled_features.update(extra_keys)
|
|
||||||
|
|
||||||
|
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
|
||||||
|
assert len(sampling_weights) == len(repo_ids), (
|
||||||
|
"The number of sampling weights must match the number of datasets. "
|
||||||
|
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
|
||||||
|
)
|
||||||
|
for i, repo_id in enumerate(repo_ids):
|
||||||
|
try:
|
||||||
|
_datasets.append(
|
||||||
|
LeRobotDataset(
|
||||||
|
repo_id,
|
||||||
|
root=self.root / repo_id,
|
||||||
|
episodes=episodes.get(repo_id, None) if episodes else None,
|
||||||
|
image_transforms=image_transforms, # transforms applied inside single ds
|
||||||
|
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
|
||||||
|
tolerance_s=self.tolerances_s[repo_id],
|
||||||
|
download_videos=download_videos,
|
||||||
|
video_backend=video_backend,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
datasets_repo_ids.append(repo_id)
|
||||||
|
self.sampling_weights.append(float(sampling_weights[i]))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Finish loading {len(_datasets)} datasets, with sampling weights: "
|
||||||
|
f"{self.sampling_weights} corresponding to: {datasets_repo_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bookkeeping for mapping & canonical image inventory
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
self._datasets = _datasets
|
||||||
# with multiple robots of different ranges. Instead we should have one normalization
|
self.datasets_repo_ids = datasets_repo_ids
|
||||||
# per robot.
|
|
||||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
# --- NEW: compute “canonical image keys” (targets across all mappings) ---
|
||||||
|
self._canonical_image_keys: set[str] = set()
|
||||||
|
self._source_keys_per_repo: dict[str, set[str]] = {}
|
||||||
|
self._target_keys_per_repo: dict[str, set[str]] = {}
|
||||||
|
for rid, mapping in self.feature_keys_mapping.items():
|
||||||
|
src_keys = set(mapping.keys())
|
||||||
|
tgt_keys = set(mapping.values())
|
||||||
|
self._source_keys_per_repo[rid] = src_keys
|
||||||
|
self._target_keys_per_repo[rid] = tgt_keys
|
||||||
|
# union of target names (we will ensure these exist at __getitem__)
|
||||||
|
self._canonical_image_keys |= {
|
||||||
|
k for k in tgt_keys if self._is_image_key_like(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
# If user didn’t give any mapping, fall back to native keys (no-ops)
|
||||||
|
if not self._canonical_image_keys and self.train_on_all_features:
|
||||||
|
# discover all image-like keys from raw features
|
||||||
|
for ds in self._datasets:
|
||||||
|
for k, v in ds.hf_features.items():
|
||||||
|
if isinstance(v, (datasets.Image, VideoFrame)):
|
||||||
|
self._canonical_image_keys.add(k)
|
||||||
|
|
||||||
|
# Cleaner: keep fps & consistent feature sets per robot type (unchanged)
|
||||||
|
cleaner = MultiLeRobotDatasetCleaner(
|
||||||
|
datasets=self._datasets,
|
||||||
|
repo_ids=repo_ids,
|
||||||
|
sampling_weights=self.sampling_weights,
|
||||||
|
datasets_repo_ids=self.datasets_repo_ids,
|
||||||
|
min_fps=min_fps,
|
||||||
|
max_fps=max_fps,
|
||||||
|
)
|
||||||
|
self._datasets = cleaner.cleaned_datasets
|
||||||
|
self.sampling_weights = cleaner.cleaned_weights
|
||||||
|
self.repo_ids = cleaner.cleaned_repo_ids
|
||||||
|
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
|
||||||
|
self.cumulative_sizes = cleaner.cumulative_sizes
|
||||||
|
|
||||||
|
# Meta (unchanged): we give it dim maxima; it will reshape/pad vectors
|
||||||
|
self.meta = MultiLeRobotDatasetMeta(
|
||||||
|
datasets=self._datasets,
|
||||||
|
repo_ids=self.repo_ids,
|
||||||
|
keys_to_max_dim={
|
||||||
|
ACTION: self.max_action_dim if self.max_action_dim is not None else -1,
|
||||||
|
OBS_ENV_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
|
||||||
|
OBS_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
|
||||||
|
OBS_IMAGE: self.max_image_dim if self.max_image_dim is not None else -1,
|
||||||
|
OBS_IMAGE_2: self.max_image_dim if self.max_image_dim is not None else -1,
|
||||||
|
OBS_IMAGE_3: self.max_image_dim if self.max_image_dim is not None else -1,
|
||||||
|
},
|
||||||
|
train_on_all_features=train_on_all_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- NEW: track dropped (source) keys so collate won’t expect them
|
||||||
|
# Anything that we *rename away* should be considered disabled,
|
||||||
|
# otherwise downstream may expect them to exist.
|
||||||
|
self._dropped_keys = set()
|
||||||
|
for rid, mapping in self.feature_keys_mapping.items():
|
||||||
|
self._dropped_keys |= set(mapping.keys())
|
||||||
|
|
||||||
|
# Merge with meta’s disabled features
|
||||||
|
self.disabled_features = set(self.meta.disabled_features) | self._dropped_keys
|
||||||
|
|
||||||
|
self.stats = self.meta.stats
|
||||||
|
|
||||||
|
# --- NEW: cache an example image shape per canonical key (lazy, filled on first use)
|
||||||
|
self._cached_img_shape: dict[str, torch.Size] = {}
|
||||||
|
|
||||||
|
# ---------------------- NEW small helpers ----------------------
|
||||||
|
|
||||||
|
def _is_image_key_like(self, key: str) -> bool:
|
||||||
|
# A loose heuristic: rely on name OR on features later
|
||||||
|
return ("image" in key) or ("cam_" in key) or ("images." in key)
|
||||||
|
|
||||||
|
def _should_ignore(self, key: str) -> bool:
|
||||||
|
# exact or glob-style match
|
||||||
|
for pat in self._ignore_patterns:
|
||||||
|
if key == pat or fnmatch.fnmatch(key, pat):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
def _apply_feature_mapping(self, item: dict, repo_id: str) -> dict:
|
||||||
|
"""
|
||||||
|
Rename features according to feature_keys_mapping[repo_id].
|
||||||
|
- Moves tensor/image under target key.
|
||||||
|
- Drops source key if moved.
|
||||||
|
- Adds *_is_pad=False for image targets we fill/keep.
|
||||||
|
"""
|
||||||
|
mapping = self.feature_keys_mapping.get(repo_id, {}) or {}
|
||||||
|
if not mapping:
|
||||||
|
return item
|
||||||
|
|
||||||
|
for src, tgt in mapping.items():
|
||||||
|
if src in item:
|
||||||
|
# Move value
|
||||||
|
item[tgt] = item[src]
|
||||||
|
# Drop the source to avoid duplication
|
||||||
|
del item[src]
|
||||||
|
return item
|
||||||
|
|
||||||
|
def _ensure_union_image_keys(self, item: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Ensure that every canonical image key exists.
|
||||||
|
When missing, create a zero tensor matching (B,C,H,W) or (C,H,W) of an available image.
|
||||||
|
Also add boolean mask at f"{key}_is_pad".
|
||||||
|
"""
|
||||||
|
if not self.train_on_all_features or not self._canonical_image_keys:
|
||||||
|
return item
|
||||||
|
|
||||||
|
# find any existing image tensor in item to copy shape/dtype
|
||||||
|
exemplar = None
|
||||||
|
for k in list(item.keys()):
|
||||||
|
v = item[k]
|
||||||
|
if torch.is_tensor(v) and v.ndim in (3, 4, 5): # (C,H,W) or (B,C,H,W) or (B,T,C,H,W)
|
||||||
|
exemplar = v
|
||||||
|
break
|
||||||
|
|
||||||
|
# fallback to a safe 3x224x224 if nothing found
|
||||||
|
def _fallback_image():
|
||||||
|
return torch.zeros(3, 224, 224, dtype=torch.uint8)
|
||||||
|
|
||||||
|
for key in self._canonical_image_keys:
|
||||||
|
if key not in item:
|
||||||
|
img = torch.zeros_like(exemplar) if exemplar is not None else _fallback_image()
|
||||||
|
item[key] = img
|
||||||
|
item[f"{key}_is_pad"] = torch.tensor(True, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
# Add a mask saying it’s *not* padded
|
||||||
|
if f"{key}_is_pad" not in item:
|
||||||
|
item[f"{key}_is_pad"] = torch.tensor(False, dtype=torch.bool)
|
||||||
|
return item
|
||||||
|
|
||||||
|
# ---------------------- existing API below (mostly unchanged) ----------------------
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_id_to_index(self):
|
def repo_id_to_index(self):
|
||||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
|
||||||
|
|
||||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
|
||||||
"""
|
|
||||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_index_to_id(self):
|
def repo_index_to_id(self):
|
||||||
"""Return the inverse mapping if repo_id_to_index."""
|
|
||||||
return {v: k for k, v in self.repo_id_to_index}
|
return {v: k for k, v in self.repo_id_to_index}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
"""Frames per second used during data collection.
|
|
||||||
|
|
||||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
||||||
"""
|
|
||||||
return self._datasets[0].meta.info["fps"]
|
return self._datasets[0].meta.info["fps"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video(self) -> bool:
|
def video(self) -> bool:
|
||||||
"""Returns True if this dataset loads video frames from mp4 files.
|
|
||||||
|
|
||||||
Returns False if it only loads images from png files.
|
|
||||||
|
|
||||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
||||||
"""
|
|
||||||
return self._datasets[0].meta.info.get("video", False)
|
return self._datasets[0].meta.info.get("video", False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> datasets.Features:
|
def features(self) -> datasets.Features:
|
||||||
features = {}
|
"""
|
||||||
|
Extend native HF features with any *target* keys introduced by mapping.
|
||||||
|
We copy the source spec for targets that didn’t exist in any raw dataset.
|
||||||
|
"""
|
||||||
|
features: dict[str, datasets.features.Feature] = {}
|
||||||
for dataset in self._datasets:
|
for dataset in self._datasets:
|
||||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
for k, v in dataset.hf_features.items():
|
||||||
|
if k not in self.disabled_features:
|
||||||
|
features[k] = v
|
||||||
|
|
||||||
|
# Add mapped target image specs if not present yet
|
||||||
|
for rid, mapping in self.feature_keys_mapping.items():
|
||||||
|
ds = None
|
||||||
|
# find the dataset object to read feature spec for source
|
||||||
|
for _ds, _rid in zip(self._datasets, self.repo_ids, strict=False):
|
||||||
|
if _rid == rid:
|
||||||
|
ds = _ds
|
||||||
|
break
|
||||||
|
if ds is None:
|
||||||
|
continue
|
||||||
|
for src, tgt in mapping.items():
|
||||||
|
if tgt not in features and src in ds.hf_features:
|
||||||
|
features[tgt] = ds.hf_features[src]
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
"""Keys to access image and video stream from cameras."""
|
|
||||||
keys = []
|
keys = []
|
||||||
for key, feats in self.features.items():
|
for key, feats in self.features.items():
|
||||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||||
@@ -1437,12 +1723,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def video_frame_keys(self) -> list[str]:
|
def video_frame_keys(self) -> list[str]:
|
||||||
"""Keys to access video frames that requires to be decoded into images.
|
|
||||||
|
|
||||||
Note: It is empty if the dataset contains images only,
|
|
||||||
or equal to `self.cameras` if the dataset contains videos only,
|
|
||||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
|
||||||
"""
|
|
||||||
video_frame_keys = []
|
video_frame_keys = []
|
||||||
for key, feats in self.features.items():
|
for key, feats in self.features.items():
|
||||||
if isinstance(feats, VideoFrame):
|
if isinstance(feats, VideoFrame):
|
||||||
@@ -1451,21 +1731,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_frames(self) -> int:
|
def num_frames(self) -> int:
|
||||||
"""Number of samples/frames."""
|
|
||||||
return sum(d.num_frames for d in self._datasets)
|
return sum(d.num_frames for d in self._datasets)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
"""Number of episodes."""
|
|
||||||
return sum(d.num_episodes for d in self._datasets)
|
return sum(d.num_episodes for d in self._datasets)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tolerance_s(self) -> float:
|
def tolerance_s(self) -> float:
|
||||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
|
||||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
|
||||||
is provided or when loading video frames from mp4 files.
|
|
||||||
"""
|
|
||||||
# 1e-4 to account for possible numerical error
|
|
||||||
return 1 / self.fps - 1e-4
|
return 1 / self.fps - 1e-4
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -1474,22 +1747,83 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
if idx >= len(self):
|
if idx >= len(self):
|
||||||
raise IndexError(f"Index {idx} out of bounds.")
|
raise IndexError(f"Index {idx} out of bounds.")
|
||||||
# Determine which dataset to get an item from based on the index.
|
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
|
||||||
start_idx = 0
|
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||||
dataset_idx = 0
|
item = self._datasets[dataset_idx][local_idx]
|
||||||
for dataset in self._datasets:
|
|
||||||
if idx >= start_idx + dataset.num_frames:
|
# Identify which repo this sample came from
|
||||||
start_idx += dataset.num_frames
|
repo_id = self.datasets_repo_ids[dataset_idx]
|
||||||
dataset_idx += 1
|
|
||||||
continue
|
# --- NEW: apply mapping and ensure union of image keys ---
|
||||||
break
|
item = self._apply_feature_mapping(item, repo_id)
|
||||||
else:
|
item = self._ensure_union_image_keys(item)
|
||||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
|
||||||
item = self._datasets[dataset_idx][idx - start_idx]
|
# annotate dataset index for downstream
|
||||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||||
|
|
||||||
|
# Pad vector features to max dims using meta (unchanged)
|
||||||
|
item = create_padded_features(item, self.meta.features)
|
||||||
|
|
||||||
|
# Drop any disabled (including original source keys we remapped away)
|
||||||
for data_key in self.disabled_features:
|
for data_key in self.disabled_features:
|
||||||
if data_key in item:
|
if data_key in item:
|
||||||
del item[data_key]
|
del item[data_key]
|
||||||
|
for k in IGNORED_KEYS:
|
||||||
|
if k in item:
|
||||||
|
item.pop(k)
|
||||||
|
# Convert any datasets.Image still present to tensor
|
||||||
|
if self.image_transforms is not None:
|
||||||
|
for cam in [k for k in item.keys() if self._is_image_key_like(k)]:
|
||||||
|
val = item[cam]
|
||||||
|
if not torch.is_tensor(val):
|
||||||
|
item[cam] = self.image_transforms(val)
|
||||||
|
# 🔑 Pad actions if too short
|
||||||
|
if "actions" in item and self.max_action_dim is not None:
|
||||||
|
act = item["actions"]
|
||||||
|
if act.shape[-1] < self.max_action_dim:
|
||||||
|
pad_len = self.max_action_dim - act.shape[-1]
|
||||||
|
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
|
||||||
|
item["actions_padding_mask"] = torch.cat(
|
||||||
|
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pad obs_state if too short
|
||||||
|
if "obs_state" in item and self.max_state_dim is not None:
|
||||||
|
st = item["obs_state"]
|
||||||
|
if st.shape[-1] < self.max_state_dim:
|
||||||
|
pad_len = self.max_state_dim - st.shape[-1]
|
||||||
|
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
|
||||||
|
item["obs_state_padding_mask"] = torch.cat(
|
||||||
|
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
# actions
|
||||||
|
if "actions" in item and self.max_action_dim is not None:
|
||||||
|
act = item["actions"]
|
||||||
|
if act.shape[-1] < self.max_action_dim:
|
||||||
|
pad_len = self.max_action_dim - act.shape[-1]
|
||||||
|
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
|
||||||
|
mask = torch.cat(
|
||||||
|
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros(self.max_action_dim, dtype=torch.bool) # 👈 all False if no padding
|
||||||
|
item["actions_padding_mask"] = mask
|
||||||
|
# obs state
|
||||||
|
if "obs_state" in item and self.max_state_dim is not None:
|
||||||
|
st = item["obs_state"]
|
||||||
|
if st.shape[-1] < self.max_state_dim:
|
||||||
|
pad_len = self.max_state_dim - st.shape[-1]
|
||||||
|
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
|
||||||
|
mask = torch.cat(
|
||||||
|
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros(self.max_state_dim, dtype=torch.bool) # 👈 always add mask
|
||||||
|
item["obs_state_padding_mask"] = mask
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
@@ -1506,3 +1840,149 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
f" Transformations: {self.image_transforms},\n"
|
f" Transformations: {self.image_transforms},\n"
|
||||||
f")"
|
f")"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
|
||||||
|
"""
|
||||||
|
Filters datasets to only keep those with consistent feature shapes per robot type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
|
||||||
|
and `meta.episodes_stats` dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: Filtered list of datasets with consistent feature shapes.
|
||||||
|
"""
|
||||||
|
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||||
|
datasets_to_remove = set()
|
||||||
|
|
||||||
|
for robot_type in robot_types:
|
||||||
|
# Collect all stats dicts for this robot type
|
||||||
|
stats_list = [
|
||||||
|
ep_stats
|
||||||
|
for ds in ls_datasets
|
||||||
|
if ds.meta.info["robot_type"] == robot_type
|
||||||
|
for ep_stats in episode_stats_values(ds.meta)
|
||||||
|
]
|
||||||
|
if not stats_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine the most common shape for each key
|
||||||
|
all_keys = {key for stats in stats_list for key in stats}
|
||||||
|
for ds in ls_datasets:
|
||||||
|
if ds.meta.info["robot_type"] != robot_type:
|
||||||
|
continue
|
||||||
|
for key in all_keys:
|
||||||
|
shape_counter = defaultdict(int)
|
||||||
|
|
||||||
|
for stats in stats_list:
|
||||||
|
value = stats.get(key)
|
||||||
|
if (
|
||||||
|
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||||
|
): # FIXME(mshukor): check all stats; min, mean, max
|
||||||
|
shape_counter[value["mean"].shape] += 1
|
||||||
|
if not shape_counter:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Identify the most frequent shape
|
||||||
|
main_shape = max(shape_counter, key=shape_counter.get)
|
||||||
|
# Flag datasets that don't match the main shape
|
||||||
|
# for ds in ls_datasets:
|
||||||
|
first_ep_stats = next(iter(episode_stats_values(ds.meta)), None)
|
||||||
|
if not first_ep_stats:
|
||||||
|
continue
|
||||||
|
value = first_ep_stats.get(key)
|
||||||
|
if (
|
||||||
|
value
|
||||||
|
and "mean" in value
|
||||||
|
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||||
|
and value["mean"].shape != main_shape
|
||||||
|
):
|
||||||
|
datasets_to_remove.add(ds)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Filter out inconsistent datasets
|
||||||
|
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
|
||||||
|
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
|
||||||
|
print(
|
||||||
|
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
|
||||||
|
)
|
||||||
|
return filtered_datasets, datasets_maks
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
|
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
|
||||||
|
|
||||||
|
The final stats will have the union of all data keys from each of the datasets.
|
||||||
|
|
||||||
|
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||||
|
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||||
|
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||||
|
- new_mean = (mean of all data)
|
||||||
|
- new_std = (std of all data)
|
||||||
|
"""
|
||||||
|
|
||||||
|
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||||
|
stats = {robot_type: {} for robot_type in robot_types}
|
||||||
|
for robot_type in robot_types:
|
||||||
|
robot_type_datasets = []
|
||||||
|
for ds in ls_datasets:
|
||||||
|
if ds.meta.info["robot_type"] == robot_type:
|
||||||
|
robot_type_datasets.extend(list(episode_stats_values(ds.meta)))
|
||||||
|
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
|
||||||
|
stat = aggregate_stats(robot_type_datasets)
|
||||||
|
stats[robot_type] = stat
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
|
||||||
|
"""Reshape features to have a maximum dimension of `max_dim`."""
|
||||||
|
reshaped_features = {}
|
||||||
|
for key in features:
|
||||||
|
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
|
||||||
|
reshaped_features[key] = features[key]
|
||||||
|
shape = list(features[key]["shape"])
|
||||||
|
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
|
||||||
|
shape[-3] = keys_to_max_dim[key]
|
||||||
|
shape[-2] = keys_to_max_dim[key]
|
||||||
|
else:
|
||||||
|
shape[reshape_dim] = keys_to_max_dim[key]
|
||||||
|
reshaped_features[key]["shape"] = tuple(shape)
|
||||||
|
else:
|
||||||
|
reshaped_features[key] = features[key]
|
||||||
|
return reshaped_features
|
||||||
|
|
||||||
|
def create_padded_features(item: dict, features: dict = {}):
|
||||||
|
for key, ft in features.items():
|
||||||
|
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
|
||||||
|
continue
|
||||||
|
shape = ft["shape"]
|
||||||
|
if len(shape) == 3: # images to torch format (C, H, W)
|
||||||
|
shape = (shape[2], shape[0], shape[1])
|
||||||
|
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
|
||||||
|
shape = []
|
||||||
|
if key not in item:
|
||||||
|
dtype = str_to_torch_dtype(ft["dtype"])
|
||||||
|
item[key] = torch.zeros(shape, dtype=dtype)
|
||||||
|
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
|
||||||
|
if "image" in key: # FIXME(mshukor): support other observations
|
||||||
|
item[f"{key}_is_pad"] = torch.BoolTensor([False])
|
||||||
|
else:
|
||||||
|
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
|
||||||
|
return item
|
||||||
|
|
||||||
|
def str_to_torch_dtype(dtype_str):
|
||||||
|
"""Convert a dtype string to a torch dtype."""
|
||||||
|
mapping = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"int16": torch.int16,
|
||||||
|
"bool": torch.bool,
|
||||||
|
"video": torch.float32, # Assuming video is stored as uint8 images
|
||||||
|
}
|
||||||
|
return mapping.get(dtype_str, torch.float32) # Default to float32
|
||||||
|
|
||||||
|
def episode_stats_values(meta):
|
||||||
|
episodes = meta.episodes.to_pandas().to_dict(orient="records")
|
||||||
|
return [
|
||||||
|
{k: v for k, v in ep.items() if isinstance(v, dict) and "mean" in v}
|
||||||
|
for ep in episodes
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user