diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 47794493e..7b777e0d7 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -1139,6 +1139,105 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() return obj +class MultiLeRobotDatasetMeta: + def __init__( + self, + datasets: list[LeRobotDataset], + repo_ids: list[str], + keys_to_max_dim: dict[str, int], + train_on_all_features: bool = False, + ): + self.repo_ids = repo_ids + self.keys_to_max_dim = keys_to_max_dim + self.train_on_all_features = train_on_all_features + self.robot_types = [ds.meta.info["robot_type"] for ds in datasets] + + # assign robot_type if missing + for ds in datasets: + ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"]) + ds.robot_type = ds.meta.info["robot_type"] + + # step 1: compute disabled features + self.disabled_features = set() + if not self.train_on_all_features: + intersection = set(datasets[0].features) + for ds in datasets: + intersection.intersection_update(ds.features) + if not intersection: + raise RuntimeError("No common features across datasets.") + for repo_id, ds in zip(repo_ids, datasets): + extra = set(ds.features) - intersection + logging.warning(f"Disabling {extra} for repo {repo_id}") + self.disabled_features.update(extra) + + # step 2: build union_features excluding disabled + self.union_features = {} + for ds in datasets: + for k, v in ds.features.items(): + if k not in self.disabled_features: + self.union_features[k] = v + + # step 3: reshape feature schema + self.features = reshape_features_to_max_dim( + self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim + ) + + # step 4: aggregate stats + self.stats = aggregate_stats_per_robot_type(datasets) + for robot_type_, stats_ in self.stats.items(): + for feat_key, feat_stats in stats_.items(): + if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]: + for k, v in feat_stats.items(): + pad_value = 0 if k in ["min", "mean"] else 1 + self.stats[robot_type_][feat_key][k] = pad_tensor( + v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value + ) + + # step 5: episodes & tasks + self.episodes = { + repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets) + } + self.tasks = { + repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets) + } + self.info = { + repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets) + } + +class MultiLeRobotDatasetCleaner: + def __init__( + self, + datasets: list[LeRobotDataset], + repo_ids: list[str], + sampling_weights: list[float], + datasets_repo_ids: list[str], + min_fps: int = 1, + max_fps: int = 100, + ): + self.original_datasets = datasets + self.original_repo_ids = repo_ids + self.original_weights = sampling_weights + self.original_datasets_repo_ids = datasets_repo_ids + + # step 1: remove datasets with invalid fps + valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps) + + # step 2: keep datasets with same features per robot type + consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(valid_fps_datasets) + + self.cleaned_datasets = consistent_datasets + self.keep_mask = keep_mask + self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]] + self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]] + self.cleaned_datasets_repo_ids = [ + datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i] + ] + + self.cumulative_sizes = np.array( + [0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0)) + ) + self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32) + class MultiLeRobotDataset(torch.utils.data.Dataset): """A dataset consisting of multiple underlying `LeRobotDataset`s. @@ -1225,93 +1324,47 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): # restriction in future iterations of this class. For now, this is necessary at least for being able # to use PyTorch's default DataLoader collate function. # FIXME(mshukor): apply mapping to unify used keys - self.train_on_all_features = train_on_all_features - self.disabled_features = set() - if not self.train_on_all_features: - intersection_features = set(_datasets[0].features) - for ds in _datasets: - intersection_features.intersection_update(ds.features) - if len(intersection_features) == 0: - raise RuntimeError( - "Multiple datasets were provided but they had no keys common to all of them. " - "The multi-dataset functionality currently only keeps common keys." - ) - for repo_id, ds in zip(repo_ids, _datasets, strict=True): - extra_keys = set(ds.features).difference(intersection_features) - logging.warning( - f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " - "other datasets." - ) - self.disabled_features.update(extra_keys) - union_features = {} - for ds in _datasets: - for k, v in ds.features.items(): - if k not in self.disabled_features: - union_features[k] = v - - if len(union_features) == 0: - raise RuntimeError("Multiple datasets were provided, but no features were found.") - + # FIXME(mshukor): pad based on types in case we have more than one state? self.image_transforms = image_transforms self.delta_timestamps = ( delta_timestamps.get(repo_id, None) if delta_timestamps else None ) # delta_timestamps # FIXME(mshukor): last repo? - # self.stats = aggregate_stats(self._datasets) # FIXME(mshukor): stats should be computed per robot type and then the robot type should be passed as input to the model - for ds in _datasets: - ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"]) - ds.robot_type = ds.meta.info["robot_type"] # In case datasets with the same robot_type have different features - _datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps) - self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets) - self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]] - self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]] - self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]] - # Compute cumulative sizes for fast indexing - self.cumulative_sizes = np.array( - [0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0)) + cleaner = MultiLeRobotDatasetCleaner( + datasets=_datasets, + repo_ids=repo_ids, + sampling_weights=self.sampling_weights, + datasets_repo_ids=datasets_repo_ids, + min_fps=min_fps, + max_fps=max_fps, ) - self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32) - self.stats = aggregate_stats_per_robot_type(self._datasets) - self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets - self.meta.info = { - repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) - } + self._datasets = cleaner.cleaned_datasets + self.sampling_weights = cleaner.cleaned_weights + self.repo_ids = cleaner.cleaned_repo_ids + self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids + self.cumulative_sizes = cleaner.cumulative_sizes + # self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets + # self.meta.info = { + # repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) + # } # self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features - # FIXME(mshukor): pad based on types in case we have more than one state? - self.keys_to_max_dim = { - ACTION: max_action_dim, - OBS_ENV_STATE: max_state_dim, - OBS_STATE: max_state_dim, - OBS_IMAGE: max_image_dim, - OBS_IMAGE_2: max_image_dim, - OBS_IMAGE_3: max_image_dim, - } - # self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim) - self.meta.info["features"] = reshape_features_to_max_dim( - union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim + self.meta = MultiLeRobotDatasetMeta( + datasets=self._datasets, + repo_ids=self.repo_ids, + keys_to_max_dim={ + ACTION: max_action_dim, + OBS_ENV_STATE: max_state_dim, + OBS_STATE: max_state_dim, + OBS_IMAGE: max_image_dim, + OBS_IMAGE_2: max_image_dim, + OBS_IMAGE_3: max_image_dim, + }, + train_on_all_features=train_on_all_features, ) - # reshape stats - for robot_type_, stats_ in self.stats.items(): - for feat_key, feat_stats in stats_.items(): - if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]: - for k, v in feat_stats.items(): - if k in ["min", "mean"]: - pad_value = 0 - elif k in ["max", "std"]: - pad_value = 1 - else: - continue - self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value) + self.disabled_features = self.meta.disabled_features + self.stats = self.meta.stats + - self.meta.stats = self.stats - # self.meta.info["features"] = aggregate_features(self._datasets) - self.meta.tasks = { - repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) - } - self.meta.episodes = { - repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) - } - self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets] @property def repo_id_to_index(self): """Return a mapping from dataset repo_id to a dataset index automatically created by this class. @@ -1402,7 +1455,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): local_idx = (idx - self.cumulative_sizes[dataset_idx]).item() item = self._datasets[dataset_idx][local_idx] item["dataset_index"] = torch.tensor(dataset_idx) - item = create_padded_features(item, self.meta.info["features"]) + item = create_padded_features(item, self.meta.features) for data_key in self.disabled_features: # FIXME(mshukor): not in getitem? if data_key in item: del item[data_key] diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 4bd0baccb..4a8af1cb7 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -394,7 +394,7 @@ def test_factory(env_name, repo_id, policy_name): # TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. -@pytest.mark.skip("TODO after fix multidataset") +# @pytest.mark.skip("TODO after fix multidataset") def test_multidataset_frames(): """Check that all dataset frames are incorporated and aligned correctly.""" repo_ids = [ @@ -421,7 +421,6 @@ def test_multidataset_frames(): assert len(dataset) == sum(len(d) for d in sub_datasets) assert dataset.num_frames == sum(d.num_frames for d in sub_datasets) assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) - expected_dataset_indices = [] for i, sub_dataset in enumerate(sub_datasets): expected_dataset_indices.extend([i] * len(sub_dataset))