add multi

This commit is contained in:
Jade
2025-06-30 13:11:16 -04:00
parent 483be9aac2
commit ddb26b7189
3 changed files with 915 additions and 79 deletions
+300 -67
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
import logging
import shutil
from pathlib import Path
@@ -73,6 +74,28 @@ from lerobot.common.datasets.video_utils import (
get_video_info,
)
# mustafa stuff here
from lerobot.common.datasets.utils_must import (
reshape_features_to_max_dim,
keep_datasets_with_valid_fps,
keep_datasets_with_the_same_features_per_robot_type,
aggregate_stats_per_robot_type,
create_padded_features,
pad_tensor,
map_dict_keys,
ROBOT_TYPE_KEYS_MAPPING,
OBS_IMAGE,
OBS_IMAGE_2,
OBS_IMAGE_3,
TASKS_KEYS_MAPPING,
)
from lerobot.common.constants import (
ACTION,
OBS_ENV_STATE,
OBS_STATE,
)
CODEBASE_VERSION = "v2.1"
@@ -83,6 +106,7 @@ class LeRobotDatasetMetadata:
root: str | Path | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
feature_keys_mapping: dict[str, str] | None = None,
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
@@ -99,6 +123,14 @@ class LeRobotDatasetMetadata:
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
# added by mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
self.info["features"] = map_dict_keys(
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
)
def load_metadata(self):
self.info = load_info(self.root)
@@ -177,7 +209,15 @@ class LeRobotDatasetMetadata:
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
# changed
keys = []
for key, ft in self.features.items():
key_ = (
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
)
if ft["dtype"] == "video":
keys.append(key_)
return keys
@property
def camera_keys(self) -> list[str]:
@@ -342,6 +382,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
download_videos: bool = True,
video_backend: str | None = None,
# new thing by M
feature_keys_mapping: dict[str, str] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
training_features: list | None = None,
discard_first_n_frames: int = 0,
discard_first_idle_frames: bool = False,
motion_threshold: float = 5e-2,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -455,15 +508,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
# by mshukor
self.training_features = training_features
self.discard_first_n_frames = discard_first_n_frames
self.discard_first_idle_frames = discard_first_idle_frames
self.motion_threshold = motion_threshold
self.motion_window_size = motion_window_size
self.motion_buffer = motion_buffer
# Unused attributes
self.image_writer = None
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
# more mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
# Load metadata
# TODO: change
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync,
feature_keys_mapping=feature_keys_mapping,
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
@@ -482,17 +551,62 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# mustafa code
if self.discard_first_n_frames > 0:
print("Discarding first n frames:", self.discard_first_n_frames)
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
# TODO implement advanced strategy
self.subset_frame_ids += [frame_idx for frame_idx in range(from_ + int(self.fps*self.discard_first_n_frames), to_)]
elif self.discard_first_idle_frames:
print(f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}")
self.robot_states = torch.stack(self.hf_dataset[OBS_ROBOT]).numpy() # shape: [T, D]
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
ep_states = self.robot_states[from_:to_]
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
velocities = np.concatenate([[0.0], velocities])
start_idx = find_start_of_motion(velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer)
self.subset_frame_ids += list(range(from_ + start_idx, to_))
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# commented TODO: check why
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
# TODO: check why commented
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Mustafa
self.meta.info["features"] = map_dict_keys(
self.meta.info["features"], feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
)
self.keys_to_max_dim = {
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
}
self.meta.info["features"] = reshape_features_to_max_dim(
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
self.meta.stats = map_dict_keys(self.meta.stats, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
self.robot_type = self.meta.info.get("robot_type", "")
# Override tasks
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
def push_to_hub(
self,
branch: str | None = None,
@@ -647,6 +761,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
# FIXME(mshukor): what if we train on multiple datasets with different features
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
@@ -670,12 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
# TODO: changed by mustafa
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
queries = {}
for key, q_idx in query_indices.items():
if key not in self.meta.video_keys and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys:
key_ = (
self.inverse_feature_keys_mapping.get(key, key)
if self.inverse_feature_keys_mapping
else key
)
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
return queries
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -699,8 +820,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __len__(self):
return self.num_frames
# changed by mshukor
def __getitem__(self, idx) -> dict:
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
idx = self.subset_frame_ids[idx]
item = self.hf_dataset[idx]
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
ep_idx = item["episode_index"].item()
query_indices = None
@@ -717,15 +842,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
try:
item["task"] = self.meta.tasks[task_idx]
except:
print(self.meta.tasks, task_idx, self.repo_id)
if "robot_type" not in item:
item["robot_type"] = self.robot_type
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
# Add padded features
# item = self._add_padded_features(item, self.training_features)
if self.image_transforms is not None:
for cam in item:
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
item[cam] = self.image_transforms(item[cam])
# Map pad keys
# print(item.keys(), "before")
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
# print(item.keys())
return item
def __repr__(self):
@@ -1022,54 +1157,161 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
# add
sampling_weights: list[float] | None = None,
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
train_on_all_features: bool = False,
training_features: list | None = None,
discard_first_n_frames: int = 0,
min_fps: int = 1,
max_fps: int = 100,
discard_first_idle_frames: bool = False,
motion_threshold: float = 0.05,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
_datasets = []
datasets_repo_ids = []
self.sampling_weights = []
self.training_features = training_features
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
assert len(sampling_weights) == len(repo_ids), (
"The number of sampling weights must match the number of datasets. "
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
)
for i, repo_id in enumerate(repo_ids):
try:
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
_datasets.append(
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes.get(repo_id, None) if episodes else None,
image_transforms=image_transforms,
delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
feature_keys_mapping=feature_keys_mapping,
training_features=training_features,
discard_first_n_frames=discard_first_n_frames,
discard_first_idle_frames=discard_first_idle_frames,
motion_threshold=motion_threshold,
motion_window_size=motion_window_size,
motion_buffer=motion_buffer,
)
)
datasets_repo_ids.append(repo_id)
self.sampling_weights.append(float(sampling_weights[i]))
except Exception as e:
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
print(
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
# FIXME(mshukor): apply mapping to unify used keys
self.train_on_all_features = train_on_all_features
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
if not self.train_on_all_features:
intersection_features = set(_datasets[0].features)
for ds in _datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(repo_ids, _datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
union_features = {}
for ds in _datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
union_features[k] = v
if len(union_features) == 0:
raise RuntimeError("Multiple datasets were provided, but no features were found.")
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
self.delta_timestamps = (
delta_timestamps.get(repo_id, None) if delta_timestamps else None
) # delta_timestamps # FIXME(mshukor): last repo?
# self.stats = aggregate_stats(self._datasets) # FIXME(mshukor): stats should be computed per robot type and then the robot type should be passed as input to the model
for ds in _datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# In case datasets with the same robot_type have different features
_datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps)
self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets)
self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]]
self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
# Compute cumulative sizes for fast indexing
self.cumulative_sizes = np.array(
[0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0))
)
self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32)
self.stats = aggregate_stats_per_robot_type(self._datasets)
self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
self.meta.info = {
repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
}
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
# FIXME(mshukor): pad based on types in case we have more than one state?
self.keys_to_max_dim = {
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
}
# self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim)
self.meta.info["features"] = reshape_features_to_max_dim(
union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
# reshape stats
for robot_type_, stats_ in self.stats.items():
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
if k in ["min", "mean"]:
pad_value = 0
elif k in ["max", "std"]:
pad_value = 1
else:
continue
self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value)
self.meta.stats = self.stats
# self.meta.info["features"] = aggregate_features(self._datasets)
self.meta.tasks = {
repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
}
self.meta.episodes = {
repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
}
self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets]
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
@@ -1156,23 +1398,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
item = self._datasets[dataset_idx][local_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:
item = create_padded_features(item, self.meta.info["features"])
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
if data_key in item:
del item[data_key]
return item
def __repr__(self):
File diff suppressed because one or more lines are too long
+34 -12
View File
@@ -396,35 +396,57 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
@pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
"""Check that all dataset frames are incorporated and aligned correctly."""
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
# dummy padding dimensions (simulate training setup)
MAX_ACTION_DIM = 14
MAX_STATE_DIM = 30
MAX_NUM_IMAGES = 3
MAX_IMAGE_DIM = 224
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
dataset = MultiLeRobotDataset(
repo_ids,
max_action_dim=MAX_ACTION_DIM,
max_state_dim=MAX_STATE_DIM,
max_num_images=MAX_NUM_IMAGES,
max_image_dim=MAX_IMAGE_DIM,
)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
for expected_dataset_index, sub_item, multi_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
dataset_index = multi_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# we ignore padding_mask and dataset_index keys in multi_item
extra_keys = {k for k in multi_item if "padding_mask" in k}
filtered_multi_keys = set(multi_item.keys()) - extra_keys
assert set(sub_item.keys()) == filtered_multi_keys, f"mismatch in keys"
for k in sub_item:
if k not in multi_item:
continue
v1, v2 = sub_item[k], multi_item[k]
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
else:
assert v1 == v2, f"value mismatch on key: {k}"
# TODO(aliberts): Move to more appropriate location