mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
cleanup/encapsulation
This commit is contained in:
@@ -1139,6 +1139,105 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
class MultiLeRobotDatasetMeta:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
keys_to_max_dim: dict[str, int],
|
||||
train_on_all_features: bool = False,
|
||||
):
|
||||
self.repo_ids = repo_ids
|
||||
self.keys_to_max_dim = keys_to_max_dim
|
||||
self.train_on_all_features = train_on_all_features
|
||||
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
|
||||
|
||||
# assign robot_type if missing
|
||||
for ds in datasets:
|
||||
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||
ds.robot_type = ds.meta.info["robot_type"]
|
||||
|
||||
# step 1: compute disabled features
|
||||
self.disabled_features = set()
|
||||
if not self.train_on_all_features:
|
||||
intersection = set(datasets[0].features)
|
||||
for ds in datasets:
|
||||
intersection.intersection_update(ds.features)
|
||||
if not intersection:
|
||||
raise RuntimeError("No common features across datasets.")
|
||||
for repo_id, ds in zip(repo_ids, datasets):
|
||||
extra = set(ds.features) - intersection
|
||||
logging.warning(f"Disabling {extra} for repo {repo_id}")
|
||||
self.disabled_features.update(extra)
|
||||
|
||||
# step 2: build union_features excluding disabled
|
||||
self.union_features = {}
|
||||
for ds in datasets:
|
||||
for k, v in ds.features.items():
|
||||
if k not in self.disabled_features:
|
||||
self.union_features[k] = v
|
||||
|
||||
# step 3: reshape feature schema
|
||||
self.features = reshape_features_to_max_dim(
|
||||
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
|
||||
# step 4: aggregate stats
|
||||
self.stats = aggregate_stats_per_robot_type(datasets)
|
||||
for robot_type_, stats_ in self.stats.items():
|
||||
for feat_key, feat_stats in stats_.items():
|
||||
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||
for k, v in feat_stats.items():
|
||||
pad_value = 0 if k in ["min", "mean"] else 1
|
||||
self.stats[robot_type_][feat_key][k] = pad_tensor(
|
||||
v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value
|
||||
)
|
||||
|
||||
# step 5: episodes & tasks
|
||||
self.episodes = {
|
||||
repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets)
|
||||
}
|
||||
self.tasks = {
|
||||
repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets)
|
||||
}
|
||||
self.info = {
|
||||
repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets)
|
||||
}
|
||||
|
||||
class MultiLeRobotDatasetCleaner:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
sampling_weights: list[float],
|
||||
datasets_repo_ids: list[str],
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
):
|
||||
self.original_datasets = datasets
|
||||
self.original_repo_ids = repo_ids
|
||||
self.original_weights = sampling_weights
|
||||
self.original_datasets_repo_ids = datasets_repo_ids
|
||||
|
||||
# step 1: remove datasets with invalid fps
|
||||
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
|
||||
|
||||
# step 2: keep datasets with same features per robot type
|
||||
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(valid_fps_datasets)
|
||||
|
||||
self.cleaned_datasets = consistent_datasets
|
||||
self.keep_mask = keep_mask
|
||||
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||
self.cleaned_datasets_repo_ids = [
|
||||
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
|
||||
]
|
||||
|
||||
self.cumulative_sizes = np.array(
|
||||
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
|
||||
)
|
||||
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
@@ -1225,93 +1324,47 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
# FIXME(mshukor): apply mapping to unify used keys
|
||||
self.train_on_all_features = train_on_all_features
|
||||
self.disabled_features = set()
|
||||
if not self.train_on_all_features:
|
||||
intersection_features = set(_datasets[0].features)
|
||||
for ds in _datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(repo_ids, _datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
union_features = {}
|
||||
for ds in _datasets:
|
||||
for k, v in ds.features.items():
|
||||
if k not in self.disabled_features:
|
||||
union_features[k] = v
|
||||
|
||||
if len(union_features) == 0:
|
||||
raise RuntimeError("Multiple datasets were provided, but no features were found.")
|
||||
|
||||
# FIXME(mshukor): pad based on types in case we have more than one state?
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = (
|
||||
delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||
) # delta_timestamps # FIXME(mshukor): last repo?
|
||||
# self.stats = aggregate_stats(self._datasets) # FIXME(mshukor): stats should be computed per robot type and then the robot type should be passed as input to the model
|
||||
for ds in _datasets:
|
||||
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||
ds.robot_type = ds.meta.info["robot_type"]
|
||||
# In case datasets with the same robot_type have different features
|
||||
_datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps)
|
||||
self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets)
|
||||
self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||
self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||
self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||
# Compute cumulative sizes for fast indexing
|
||||
self.cumulative_sizes = np.array(
|
||||
[0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0))
|
||||
cleaner = MultiLeRobotDatasetCleaner(
|
||||
datasets=_datasets,
|
||||
repo_ids=repo_ids,
|
||||
sampling_weights=self.sampling_weights,
|
||||
datasets_repo_ids=datasets_repo_ids,
|
||||
min_fps=min_fps,
|
||||
max_fps=max_fps,
|
||||
)
|
||||
self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32)
|
||||
self.stats = aggregate_stats_per_robot_type(self._datasets)
|
||||
self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
||||
self.meta.info = {
|
||||
repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
}
|
||||
self._datasets = cleaner.cleaned_datasets
|
||||
self.sampling_weights = cleaner.cleaned_weights
|
||||
self.repo_ids = cleaner.cleaned_repo_ids
|
||||
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
|
||||
self.cumulative_sizes = cleaner.cumulative_sizes
|
||||
# self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
||||
# self.meta.info = {
|
||||
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
# }
|
||||
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
|
||||
# FIXME(mshukor): pad based on types in case we have more than one state?
|
||||
self.keys_to_max_dim = {
|
||||
self.meta = MultiLeRobotDatasetMeta(
|
||||
datasets=self._datasets,
|
||||
repo_ids=self.repo_ids,
|
||||
keys_to_max_dim={
|
||||
ACTION: max_action_dim,
|
||||
OBS_ENV_STATE: max_state_dim,
|
||||
OBS_STATE: max_state_dim,
|
||||
OBS_IMAGE: max_image_dim,
|
||||
OBS_IMAGE_2: max_image_dim,
|
||||
OBS_IMAGE_3: max_image_dim,
|
||||
}
|
||||
# self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim)
|
||||
self.meta.info["features"] = reshape_features_to_max_dim(
|
||||
union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
},
|
||||
train_on_all_features=train_on_all_features,
|
||||
)
|
||||
# reshape stats
|
||||
for robot_type_, stats_ in self.stats.items():
|
||||
for feat_key, feat_stats in stats_.items():
|
||||
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||
for k, v in feat_stats.items():
|
||||
if k in ["min", "mean"]:
|
||||
pad_value = 0
|
||||
elif k in ["max", "std"]:
|
||||
pad_value = 1
|
||||
else:
|
||||
continue
|
||||
self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value)
|
||||
self.disabled_features = self.meta.disabled_features
|
||||
self.stats = self.meta.stats
|
||||
|
||||
|
||||
self.meta.stats = self.stats
|
||||
# self.meta.info["features"] = aggregate_features(self._datasets)
|
||||
self.meta.tasks = {
|
||||
repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
}
|
||||
self.meta.episodes = {
|
||||
repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
}
|
||||
self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets]
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
@@ -1402,7 +1455,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||
item = self._datasets[dataset_idx][local_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
item = create_padded_features(item, self.meta.info["features"])
|
||||
item = create_padded_features(item, self.meta.features)
|
||||
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
@@ -394,7 +394,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
# @pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_frames():
|
||||
"""Check that all dataset frames are incorporated and aligned correctly."""
|
||||
repo_ids = [
|
||||
@@ -421,7 +421,6 @@ def test_multidataset_frames():
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
Reference in New Issue
Block a user