@@ -31,6 +31,7 @@ import torch.utils
from huggingface_hub import HfApi , snapshot_download
from huggingface_hub . errors import RevisionNotFoundError
from collections import defaultdict
from lerobot . constants import HF_LEROBOT_HOME
from lerobot . datasets . compute_stats import aggregate_stats , compute_episode_stats
from lerobot . datasets . image_writer import AsyncImageWriter , write_image
@@ -81,7 +82,12 @@ from lerobot.datasets.video_utils import (
)
CODEBASE_VERSION = " v3.0 "
OBS_IMAGE = " observation.image "
OBS_IMAGE_2 = " observation.image_2 "
OBS_IMAGE_3 = " observation.image_3 "
OBS_STATE = " observation.state "
OBS_ENV_STATE = " observation.env_state "
ACTION = " action "
class LeRobotDatasetMetadata :
def __init__ (
@@ -1322,13 +1328,139 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj . video_backend = video_backend if video_backend is not None else get_safe_default_codec ( )
return obj
ROBOT_TYPE_KEYS_MAPPING = {
" lerobot/stanford_hydra_dataset " : " static_single_arm " ,
" lerobot/iamlab_cmu_pickup_insert " : " static_single_arm " ,
" lerobot/berkeley_fanuc_manipulation " : " static_single_arm " ,
" lerobot/toto " : " static_single_arm " ,
" lerobot/roboturk " : " static_single_arm " ,
" lerobot/jaco_play " : " static_single_arm " ,
" lerobot/taco_play " : " static_single_arm_7statedim " ,
}
class MultiLeRobotDatasetMeta :
def __init__ (
self ,
datasets : list [ LeRobotDataset ] ,
repo_ids : list [ str ] ,
keys_to_max_dim : dict [ str , int ] ,
train_on_all_features : bool = False ,
) :
self . repo_ids = repo_ids
self . keys_to_max_dim = keys_to_max_dim
self . train_on_all_features = train_on_all_features
self . robot_types = [ ds . meta . info [ " robot_type " ] for ds in datasets ]
# assign robot_type if missing
for ds in datasets :
ds . meta . info [ " robot_type " ] = ROBOT_TYPE_KEYS_MAPPING . get ( ds . repo_id , ds . meta . info [ " robot_type " ] )
ds . robot_type = ds . meta . info [ " robot_type " ]
# step 1: compute disabled features
self . disabled_features = set ( )
if not self . train_on_all_features :
intersection = set ( datasets [ 0 ] . features )
for ds in datasets :
intersection . intersection_update ( ds . features )
if not intersection :
raise RuntimeError ( " No common features across datasets. " )
for repo_id , ds in zip ( repo_ids , datasets , strict = False ) :
extra = set ( ds . features ) - intersection
logging . warning ( f " Disabling { extra } for repo { repo_id } " )
self . disabled_features . update ( extra )
# step 2: build union_features excluding disabled
self . union_features = { }
for ds in datasets :
for k , v in ds . features . items ( ) :
if k not in self . disabled_features :
self . union_features [ k ] = v
# step 3: reshape feature schema
self . features = reshape_features_to_max_dim (
self . union_features , reshape_dim = - 1 , keys_to_max_dim = self . keys_to_max_dim
)
# step 4: aggregate stats
self . stats = aggregate_stats_per_robot_type ( datasets )
for robot_type_ , stats_ in self . stats . items ( ) :
for feat_key , feat_stats in stats_ . items ( ) :
if feat_key in [ ACTION , OBS_ENV_STATE , OBS_STATE ] :
for k , v in feat_stats . items ( ) :
pad_value = 0 if k in [ " min " , " mean " ] else 1
self . stats [ robot_type_ ] [ feat_key ] [ k ] = pad_tensor (
v ,
max_size = self . keys_to_max_dim . get ( feat_key , - 1 ) ,
pad_dim = - 1 ,
pad_value = pad_value ,
)
# step 5: episodes & tasks
self . episodes = { repo_id : ds . meta . episodes for repo_id , ds in zip ( repo_ids , datasets , strict = False ) }
self . tasks = { repo_id : ds . meta . tasks for repo_id , ds in zip ( repo_ids , datasets , strict = False ) }
self . info = { repo_id : ds . meta . info for repo_id , ds in zip ( repo_ids , datasets , strict = False ) }
class MultiLeRobotDatasetCleaner :
def __init__ (
self ,
datasets : list [ LeRobotDataset ] ,
repo_ids : list [ str ] ,
sampling_weights : list [ float ] ,
datasets_repo_ids : list [ str ] ,
min_fps : int = 1 ,
max_fps : int = 100 ,
) :
self . original_datasets = datasets
self . original_repo_ids = repo_ids
self . original_weights = sampling_weights
self . original_datasets_repo_ids = datasets_repo_ids
# step 1: remove datasets with invalid fps
# step 2: keep datasets with same features per robot type
consistent_datasets , keep_mask = keep_datasets_with_the_same_features_per_robot_type (
datasets
)
self . cleaned_datasets = consistent_datasets
self . keep_mask = keep_mask
self . cleaned_weights = [ sampling_weights [ i ] for i in range ( len ( datasets ) ) if keep_mask [ i ] ]
self . cleaned_repo_ids = [ repo_ids [ i ] for i in range ( len ( datasets ) ) if keep_mask [ i ] ]
self . cleaned_datasets_repo_ids = [
datasets_repo_ids [ i ] for i in range ( len ( datasets ) ) if keep_mask [ i ]
]
self . cumulative_sizes = np . array (
[ 0 ] + list ( torch . cumsum ( torch . tensor ( [ len ( d ) for d in consistent_datasets ] ) , dim = 0 ) )
)
self . cleaned_weights = np . array ( self . cleaned_weights , dtype = np . float32 )
# --- at the top of the file (same imports as before) ---
from collections import defaultdict
from typing import Callable
import copy
import numpy as np
import torch
import datasets
from pathlib import Path
# If you already have these in your codebase, reuse them
try :
from lerobot . common . constants import (
ACTION , OBS_ENV_STATE , OBS_STATE , OBS_IMAGE , OBS_IMAGE_2 , OBS_IMAGE_3
)
except Exception :
# Fallbacks if constants are already strings elsewhere
ACTION = " action "
OBS_ENV_STATE = " observation.env_state "
OBS_STATE = " observation.state "
OBS_IMAGE = " observation.image "
OBS_IMAGE_2 = " observation.image_2 "
OBS_IMAGE_3 = " observation.image_3 "
IGNORED_KEYS = [ " observation.effort " ]
class MultiLeRobotDataset ( torch . utils . data . Dataset ) :
""" A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying ` LeRobotDataset ` s are effectively concatenated , and this class adopts much of the API
structure of ` LeRobotDataset ` .
"""
# ... keep your existing docstring ...
def __init__ (
self ,
@@ -1336,99 +1468,253 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root : str | Path | None = None ,
episodes : dict | None = None ,
image_transforms : Callable | None = None ,
delta_timestamps : dict [ str , list [ float ] ] | None = None ,
delta_timestamps : dict [ list [ float ] ] | None = None ,
tolerances_s : dict | None = None ,
download_videos : bool = True ,
video_backend : str | None = None ,
# --- NEW: simple add-ons ---
sampling_weights : list [ float ] | None = None ,
feature_keys_mapping : dict [ str , dict [ str , str ] ] | None = None ,
max_action_dim : int | None = None ,
max_state_dim : int | None = None ,
max_num_images : int | None = None ,
max_image_dim : int | None = None ,
train_on_all_features : bool = False ,
min_fps : int = 1 ,
max_fps : int = 100 ,
ignore_keys : list [ str ] | None = None , # exact or glob patterns
) :
super ( ) . __init__ ( )
self . repo_ids = repo_ids
self . root = Path ( root ) if root else HF_LEROBOT_HOME
self . tolerances_s = tolerances_s if tolerances_s else dict . fromkeys ( repo_ids , 0.0001 )
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self . _datasets = [
LeRobotDataset (
repo_id ,
root = self . root / repo_id ,
episodes = episodes [ repo_id ] if episodes else None ,
image_transforms = image_transforms ,
delta_timestamps = delta_timestamps ,
tolerance_s = self . tolerances_s [ repo_id ] ,
download_videos = download_videos ,
video_backend = video_backend ,
)
for repo_id in repo_ids
]
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self . disabled_features = set ( )
intersection_features = set ( self . _datasets [ 0 ] . features )
for ds in self . _datasets :
intersection_features . intersection_update ( ds . features )
if len ( intersection_features ) == 0 :
raise RuntimeError (
" Multiple datasets were provided but they had no keys common to all of them. "
" The multi-dataset functionality currently only keeps common keys. "
)
for repo_id , ds in zip ( self . repo_ids , self . _datasets , strict = True ) :
extra_keys = set ( ds . features ) . difference ( intersection_features )
logging . warning (
f " keys { extra_keys } of { repo_id } were disabled as they are not contained in all the "
" other datasets. "
)
self . disabled_features . update ( extra_keys )
# --- NEW: store mapping and simple knobs ---
self . feature_keys_mapping : dict [ str , dict [ str , str ] ] = feature_keys_mapping or { }
self . train_on_all_features = train_on_all_features
self . max_action_dim = max_action_dim
self . max_state_dim = max_state_dim
self . max_image_dim = max_image_dim
self . max_num_images = max_num_images # (optional, we don’ t enforce count, we enforce names)
self . _ignore_patterns = list ( ignore_keys or [ ] )
# Build underlying single datasets
_datasets = [ ]
datasets_repo_ids = [ ]
self . sampling_weights = [ ]
sampling_weights = sampling_weights if sampling_weights is not None else [ 1 ] * len ( repo_ids )
assert len ( sampling_weights ) == len ( repo_ids ) , (
" The number of sampling weights must match the number of datasets. "
f " Got { len ( sampling_weights ) } weights for { len ( repo_ids ) } datasets. "
)
for i , repo_id in enumerate ( repo_ids ) :
try :
_datasets . append (
LeRobotDataset (
repo_id ,
root = self . root / repo_id ,
episodes = episodes . get ( repo_id , None ) if episodes else None ,
image_transforms = image_transforms , # transforms applied inside single ds
delta_timestamps = delta_timestamps . get ( repo_id , None ) if delta_timestamps else None ,
tolerance_s = self . tolerances_s [ repo_id ] ,
download_videos = download_videos ,
video_backend = video_backend ,
)
)
datasets_repo_ids . append ( repo_id )
self . sampling_weights . append ( float ( sampling_weights [ i ] ) )
except Exception as e :
print ( f " Failed to load dataset: { repo_id } due to Exception: { e } " )
print (
f " Finish loading { len ( _datasets ) } datasets, with sampling weights: "
f " { self . sampling_weights } corresponding to: { datasets_repo_ids } "
)
# Bookkeeping for mapping & canonical image inventory
self . image_transforms = image_transforms
self . delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self . stats = aggregate_stats ( [ dataset . meta . stats for dataset in self . _datasets ] )
self . delta_timestamps = delta_timestamps . get ( repo_id , None ) if delta_timestamps else None
self . _datasets = _datasets
self . datasets_repo_ids = datasets_repo_ids
# --- NEW: compute “canonical image keys” (targets across all mappings) ---
self . _canonical_image_keys : set [ str ] = set ( )
self . _source_keys_per_repo : dict [ str , set [ str ] ] = { }
self . _target_keys_per_repo : dict [ str , set [ str ] ] = { }
for rid , mapping in self . feature_keys_mapping . items ( ) :
src_keys = set ( mapping . keys ( ) )
tgt_keys = set ( mapping . values ( ) )
self . _source_keys_per_repo [ rid ] = src_keys
self . _target_keys_per_repo [ rid ] = tgt_keys
# union of target names (we will ensure these exist at __getitem__)
self . _canonical_image_keys | = {
k for k in tgt_keys if self . _is_image_key_like ( k )
}
# If user didn’ t give any mapping, fall back to native keys (no-ops)
if not self . _canonical_image_keys and self . train_on_all_features :
# discover all image-like keys from raw features
for ds in self . _datasets :
for k , v in ds . hf_features . items ( ) :
if isinstance ( v , ( datasets . Image , VideoFrame ) ) :
self . _canonical_image_keys . add ( k )
# Cleaner: keep fps & consistent feature sets per robot type (unchanged)
cleaner = MultiLeRobotDatasetCleaner (
datasets = self . _datasets ,
repo_ids = repo_ids ,
sampling_weights = self . sampling_weights ,
datasets_repo_ids = self . datasets_repo_ids ,
min_fps = min_fps ,
max_fps = max_fps ,
)
self . _datasets = cleaner . cleaned_datasets
self . sampling_weights = cleaner . cleaned_weights
self . repo_ids = cleaner . cleaned_repo_ids
self . datasets_repo_ids = cleaner . cleaned_datasets_repo_ids
self . cumulative_sizes = cleaner . cumulative_sizes
# Meta (unchanged): we give it dim maxima; it will reshape/pad vectors
self . meta = MultiLeRobotDatasetMeta (
datasets = self . _datasets ,
repo_ids = self . repo_ids ,
keys_to_max_dim = {
ACTION : self . max_action_dim if self . max_action_dim is not None else - 1 ,
OBS_ENV_STATE : self . max_state_dim if self . max_state_dim is not None else - 1 ,
OBS_STATE : self . max_state_dim if self . max_state_dim is not None else - 1 ,
OBS_IMAGE : self . max_image_dim if self . max_image_dim is not None else - 1 ,
OBS_IMAGE_2 : self . max_image_dim if self . max_image_dim is not None else - 1 ,
OBS_IMAGE_3 : self . max_image_dim if self . max_image_dim is not None else - 1 ,
} ,
train_on_all_features = train_on_all_features ,
)
# --- NEW: track dropped (source) keys so collate won’ t expect them
# Anything that we *rename away* should be considered disabled,
# otherwise downstream may expect them to exist.
self . _dropped_keys = set ( )
for rid , mapping in self . feature_keys_mapping . items ( ) :
self . _dropped_keys | = set ( mapping . keys ( ) )
# Merge with meta’ s disabled features
self . disabled_features = set ( self . meta . disabled_features ) | self . _dropped_keys
self . stats = self . meta . stats
# --- NEW: cache an example image shape per canonical key (lazy, filled on first use)
self . _cached_img_shape : dict [ str , torch . Size ] = { }
# ---------------------- NEW small helpers ----------------------
def _is_image_key_like ( self , key : str ) - > bool :
# A loose heuristic: rely on name OR on features later
return ( " image " in key ) or ( " cam_ " in key ) or ( " images. " in key )
def _should_ignore ( self , key : str ) - > bool :
# exact or glob-style match
for pat in self . _ignore_patterns :
if key == pat or fnmatch . fnmatch ( key , pat ) :
return True
return False
def _apply_feature_mapping ( self , item : dict , repo_id : str ) - > dict :
"""
Rename features according to feature_keys_mapping [ repo_id ] .
- Moves tensor / image under target key .
- Drops source key if moved .
- Adds * _is_pad = False for image targets we fill / keep .
"""
mapping = self . feature_keys_mapping . get ( repo_id , { } ) or { }
if not mapping :
return item
for src , tgt in mapping . items ( ) :
if src in item :
# Move value
item [ tgt ] = item [ src ]
# Drop the source to avoid duplication
del item [ src ]
return item
def _ensure_union_image_keys ( self , item : dict ) - > dict :
"""
Ensure that every canonical image key exists .
When missing , create a zero tensor matching ( B , C , H , W ) or ( C , H , W ) of an available image .
Also add boolean mask at f " { key } _is_pad " .
"""
if not self . train_on_all_features or not self . _canonical_image_keys :
return item
# find any existing image tensor in item to copy shape/dtype
exemplar = None
for k in list ( item . keys ( ) ) :
v = item [ k ]
if torch . is_tensor ( v ) and v . ndim in ( 3 , 4 , 5 ) : # (C,H,W) or (B,C,H,W) or (B,T,C,H,W)
exemplar = v
break
# fallback to a safe 3x224x224 if nothing found
def _fallback_image ( ) :
return torch . zeros ( 3 , 224 , 224 , dtype = torch . uint8 )
for key in self . _canonical_image_keys :
if key not in item :
img = torch . zeros_like ( exemplar ) if exemplar is not None else _fallback_image ( )
item [ key ] = img
item [ f " { key } _is_pad " ] = torch . tensor ( True , dtype = torch . bool )
else :
# Add a mask saying it’ s *not* padded
if f " { key } _is_pad " not in item :
item [ f " { key } _is_pad " ] = torch . tensor ( False , dtype = torch . bool )
return item
# ---------------------- existing API below (mostly unchanged) ----------------------
@property
def repo_id_to_index ( self ) :
""" Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by ` __getitem__ ` .
"""
return { repo_id : i for i , repo_id in enumerate ( self . repo_ids ) }
@property
def repo_index_to_id ( self ) :
""" Return the inverse mapping if repo_id_to_index. """
return { v : k for k , v in self . repo_id_to_index }
@property
def fps ( self ) - > int :
""" Frames per second used during data collection.
NOTE : Fow now , this relies on a check in __init__ to make sure all sub - datasets have the same info .
"""
return self . _datasets [ 0 ] . meta . info [ " fps " ]
@property
def video ( self ) - > bool :
""" Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files .
NOTE : Fow now , this relies on a check in __init__ to make sure all sub - datasets have the same info .
"""
return self . _datasets [ 0 ] . meta . info . get ( " video " , False )
@property
def features ( self ) - > datasets . Features :
features = { }
"""
Extend native HF features with any * target * keys introduced by mapping .
We copy the source spec for targets that didn ’ t exist in any raw dataset .
"""
features : dict [ str , datasets . features . Feature ] = { }
for dataset in self . _datasets :
features . update ( { k : v for k , v in dataset . hf_features . items ( ) if k not in self . disabled_features } )
for k , v in dataset . hf_features . items ( ) :
if k not in self . disabled_features :
features [ k ] = v
# Add mapped target image specs if not present yet
for rid , mapping in self . feature_keys_mapping . items ( ) :
ds = None
# find the dataset object to read feature spec for source
for _ds , _rid in zip ( self . _datasets , self . repo_ids , strict = False ) :
if _rid == rid :
ds = _ds
break
if ds is None :
continue
for src , tgt in mapping . items ( ) :
if tgt not in features and src in ds . hf_features :
features [ tgt ] = ds . hf_features [ src ]
return features
@property
def camera_keys ( self ) - > list [ str ] :
""" Keys to access image and video stream from cameras. """
keys = [ ]
for key , feats in self . features . items ( ) :
if isinstance ( feats , ( datasets . Image , VideoFrame ) ) :
@@ -1437,12 +1723,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
@property
def video_frame_keys ( self ) - > list [ str ] :
""" Keys to access video frames that requires to be decoded into images.
Note : It is empty if the dataset contains images only ,
or equal to ` self . cameras ` if the dataset contains videos only ,
or can even be a subset of ` self . cameras ` in a case of a mixed image / video dataset .
"""
video_frame_keys = [ ]
for key , feats in self . features . items ( ) :
if isinstance ( feats , VideoFrame ) :
@@ -1451,21 +1731,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames ( self ) - > int :
""" Number of samples/frames. """
return sum ( d . num_frames for d in self . _datasets )
@property
def num_episodes ( self ) - > int :
""" Number of episodes. """
return sum ( d . num_episodes for d in self . _datasets )
@property
def tolerance_s ( self ) - > float :
""" Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames . It is only used when ` delta_timestamps `
is provided or when loading video frames from mp4 files .
"""
# 1e-4 to account for possible numerical error
return 1 / self . fps - 1e-4
def __len__ ( self ) :
@@ -1474,22 +1747,83 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __getitem__ ( self , idx : int ) - > dict [ str , torch . Tensor ] :
if idx > = len ( self ) :
raise IndexError ( f " Index { idx } out of bounds. " )
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self . _datasets :
if idx > = start_idx + dataset . num_frames :
start_idx + = dataset . num_frames
dataset_idx + = 1
continue
break
else :
raise AssertionError ( " We expect the loop to break out as long as the index is within bounds. " )
item = self . _datasets [ dataset_idx ] [ idx - start_idx ]
dataset_idx = np . searchsorted ( self . cumulative_sizes , idx , side = " right " ) . item ( ) - 1
local_idx = ( idx - self . cumulative_sizes [ dataset_idx ] ) . item ( )
item = self . _datasets [ dataset_idx ] [ local_idx ]
# Identify which repo this sample came from
repo_id = self . datasets_repo_ids [ dataset_idx ]
# --- NEW: apply mapping and ensure union of image keys ---
item = self . _apply_feature_mapping ( item , repo_id )
item = self . _ensure_union_image_keys ( item )
# annotate dataset index for downstream
item [ " dataset_index " ] = torch . tensor ( dataset_idx )
# Pad vector features to max dims using meta (unchanged)
item = create_padded_features ( item , self . meta . features )
# Drop any disabled (including original source keys we remapped away)
for data_key in self . disabled_features :
if data_key in item :
del item [ data_key ]
for k in IGNORED_KEYS :
if k in item :
item . pop ( k )
# Convert any datasets.Image still present to tensor
if self . image_transforms is not None :
for cam in [ k for k in item . keys ( ) if self . _is_image_key_like ( k ) ] :
val = item [ cam ]
if not torch . is_tensor ( val ) :
item [ cam ] = self . image_transforms ( val )
# 🔑 Pad actions if too short
if " actions " in item and self . max_action_dim is not None :
act = item [ " actions " ]
if act . shape [ - 1 ] < self . max_action_dim :
pad_len = self . max_action_dim - act . shape [ - 1 ]
item [ " actions " ] = torch . cat ( [ act , torch . zeros ( pad_len , dtype = act . dtype ) ] , dim = - 1 )
item [ " actions_padding_mask " ] = torch . cat (
[ torch . zeros_like ( act , dtype = torch . bool ) , torch . ones ( pad_len , dtype = torch . bool ) ] ,
dim = - 1 ,
)
# pad obs_state if too short
if " obs_state " in item and self . max_state_dim is not None :
st = item [ " obs_state " ]
if st . shape [ - 1 ] < self . max_state_dim :
pad_len = self . max_state_dim - st . shape [ - 1 ]
item [ " obs_state " ] = torch . cat ( [ st , torch . zeros ( pad_len , dtype = st . dtype ) ] , dim = - 1 )
item [ " obs_state_padding_mask " ] = torch . cat (
[ torch . zeros_like ( st , dtype = torch . bool ) , torch . ones ( pad_len , dtype = torch . bool ) ] ,
dim = - 1 ,
)
# actions
if " actions " in item and self . max_action_dim is not None :
act = item [ " actions " ]
if act . shape [ - 1 ] < self . max_action_dim :
pad_len = self . max_action_dim - act . shape [ - 1 ]
item [ " actions " ] = torch . cat ( [ act , torch . zeros ( pad_len , dtype = act . dtype ) ] , dim = - 1 )
mask = torch . cat (
[ torch . zeros_like ( act , dtype = torch . bool ) , torch . ones ( pad_len , dtype = torch . bool ) ] ,
dim = - 1 ,
)
else :
mask = torch . zeros ( self . max_action_dim , dtype = torch . bool ) # 👈 all False if no padding
item [ " actions_padding_mask " ] = mask
# obs state
if " obs_state " in item and self . max_state_dim is not None :
st = item [ " obs_state " ]
if st . shape [ - 1 ] < self . max_state_dim :
pad_len = self . max_state_dim - st . shape [ - 1 ]
item [ " obs_state " ] = torch . cat ( [ st , torch . zeros ( pad_len , dtype = st . dtype ) ] , dim = - 1 )
mask = torch . cat (
[ torch . zeros_like ( st , dtype = torch . bool ) , torch . ones ( pad_len , dtype = torch . bool ) ] ,
dim = - 1 ,
)
else :
mask = torch . zeros ( self . max_state_dim , dtype = torch . bool ) # 👈 always add mask
item [ " obs_state_padding_mask " ] = mask
return item
@@ -1506,3 +1840,149 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f " Transformations: { self . image_transforms } , \n "
f " ) "
)
def keep_datasets_with_the_same_features_per_robot_type ( ls_datasets : list ) - > list :
"""
Filters datasets to only keep those with consistent feature shapes per robot type .
Args :
ls_datasets ( List ) : List of datasets , each with a ` meta . info [ ' robot_type ' ] `
and ` meta . episodes_stats ` dictionary .
Returns :
List : Filtered list of datasets with consistent feature shapes .
"""
robot_types = { ds . meta . info [ " robot_type " ] for ds in ls_datasets }
datasets_to_remove = set ( )
for robot_type in robot_types :
# Collect all stats dicts for this robot type
stats_list = [
ep_stats
for ds in ls_datasets
if ds . meta . info [ " robot_type " ] == robot_type
for ep_stats in episode_stats_values ( ds . meta )
]
if not stats_list :
continue
# Determine the most common shape for each key
all_keys = { key for stats in stats_list for key in stats }
for ds in ls_datasets :
if ds . meta . info [ " robot_type " ] != robot_type :
continue
for key in all_keys :
shape_counter = defaultdict ( int )
for stats in stats_list :
value = stats . get ( key )
if (
value and " mean " in value and isinstance ( value [ " mean " ] , ( torch . Tensor , np . ndarray ) )
) : # FIXME(mshukor): check all stats; min, mean, max
shape_counter [ value [ " mean " ] . shape ] + = 1
if not shape_counter :
continue
# Identify the most frequent shape
main_shape = max ( shape_counter , key = shape_counter . get )
# Flag datasets that don't match the main shape
# for ds in ls_datasets:
first_ep_stats = next ( iter ( episode_stats_values ( ds . meta ) ) , None )
if not first_ep_stats :
continue
value = first_ep_stats . get ( key )
if (
value
and " mean " in value
and isinstance ( value [ " mean " ] , ( torch . Tensor , np . ndarray ) )
and value [ " mean " ] . shape != main_shape
) :
datasets_to_remove . add ( ds )
break
# Filter out inconsistent datasets
datasets_maks = [ ds not in datasets_to_remove for ds in ls_datasets ]
filtered_datasets = [ ds for ds in ls_datasets if ds not in datasets_to_remove ]
print (
f " Keeping { len ( filtered_datasets ) } datasets. Removed { len ( datasets_to_remove ) } inconsistent ones. Inconsistent datasets: \n { datasets_to_remove } "
)
return filtered_datasets , datasets_maks
def aggregate_stats_per_robot_type ( ls_datasets ) - > dict [ str , dict [ str , torch . Tensor ] ] :
""" Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
The final stats will have the union of all data keys from each of the datasets .
The final stats will have the union of all data keys from each of the datasets . For instance :
- new_max = max ( max_dataset_0 , max_dataset_1 , . . . )
- new_min = min ( min_dataset_0 , min_dataset_1 , . . . )
- new_mean = ( mean of all data )
- new_std = ( std of all data )
"""
robot_types = { ds . meta . info [ " robot_type " ] for ds in ls_datasets }
stats = { robot_type : { } for robot_type in robot_types }
for robot_type in robot_types :
robot_type_datasets = [ ]
for ds in ls_datasets :
if ds . meta . info [ " robot_type " ] == robot_type :
robot_type_datasets . extend ( list ( episode_stats_values ( ds . meta ) ) )
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
stat = aggregate_stats ( robot_type_datasets )
stats [ robot_type ] = stat
return stats
def reshape_features_to_max_dim ( features : dict , reshape_dim : int = - 1 , keys_to_max_dim : dict = { } ) - > dict :
""" Reshape features to have a maximum dimension of `max_dim`. """
reshaped_features = { }
for key in features :
if key in keys_to_max_dim and keys_to_max_dim [ key ] is not None :
reshaped_features [ key ] = features [ key ]
shape = list ( features [ key ] [ " shape " ] )
if any ( [ k in key for k in [ OBS_IMAGE , OBS_IMAGE_2 , OBS_IMAGE_3 ] ] ) : # Assume square images
shape [ - 3 ] = keys_to_max_dim [ key ]
shape [ - 2 ] = keys_to_max_dim [ key ]
else :
shape [ reshape_dim ] = keys_to_max_dim [ key ]
reshaped_features [ key ] [ " shape " ] = tuple ( shape )
else :
reshaped_features [ key ] = features [ key ]
return reshaped_features
def create_padded_features ( item : dict , features : dict = { } ) :
for key , ft in features . items ( ) :
if any ( [ k in key for k in [ " cam " , " effort " , " absolute " ] ] ) : # FIXME(mshukor): temporary hack
continue
shape = ft [ " shape " ]
if len ( shape ) == 3 : # images to torch format (C, H, W)
shape = ( shape [ 2 ] , shape [ 0 ] , shape [ 1 ] )
if len ( shape ) == 1 and shape [ 0 ] == 1 : # ft with shape are actually tensor(ele)
shape = [ ]
if key not in item :
dtype = str_to_torch_dtype ( ft [ " dtype " ] )
item [ key ] = torch . zeros ( shape , dtype = dtype )
item [ f " { key } _padding_mask " ] = torch . tensor ( 0 , dtype = torch . int64 )
if " image " in key : # FIXME(mshukor): support other observations
item [ f " { key } _is_pad " ] = torch . BoolTensor ( [ False ] )
else :
item [ f " { key } _padding_mask " ] = torch . tensor ( 1 , dtype = torch . int64 )
return item
def str_to_torch_dtype ( dtype_str ) :
""" Convert a dtype string to a torch dtype. """
mapping = {
" float32 " : torch . float32 ,
" int64 " : torch . int64 ,
" int16 " : torch . int16 ,
" bool " : torch . bool ,
" video " : torch . float32 , # Assuming video is stored as uint8 images
}
return mapping . get ( dtype_str , torch . float32 ) # Default to float32
def episode_stats_values ( meta ) :
episodes = meta . episodes . to_pandas ( ) . to_dict ( orient = " records " )
return [
{ k : v for k , v in ep . items ( ) if isinstance ( v , dict ) and " mean " in v }
for ep in episodes
]