cleanup/encapsulation

This commit is contained in:
Jade
2025-07-05 13:08:25 -04:00
parent ddb26b7189
commit a9251e612f
2 changed files with 133 additions and 81 deletions
+132 -79
View File
@@ -1139,6 +1139,105 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj return obj
class MultiLeRobotDatasetMeta:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
keys_to_max_dim: dict[str, int],
train_on_all_features: bool = False,
):
self.repo_ids = repo_ids
self.keys_to_max_dim = keys_to_max_dim
self.train_on_all_features = train_on_all_features
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
# assign robot_type if missing
for ds in datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# step 1: compute disabled features
self.disabled_features = set()
if not self.train_on_all_features:
intersection = set(datasets[0].features)
for ds in datasets:
intersection.intersection_update(ds.features)
if not intersection:
raise RuntimeError("No common features across datasets.")
for repo_id, ds in zip(repo_ids, datasets):
extra = set(ds.features) - intersection
logging.warning(f"Disabling {extra} for repo {repo_id}")
self.disabled_features.update(extra)
# step 2: build union_features excluding disabled
self.union_features = {}
for ds in datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
self.union_features[k] = v
# step 3: reshape feature schema
self.features = reshape_features_to_max_dim(
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
# step 4: aggregate stats
self.stats = aggregate_stats_per_robot_type(datasets)
for robot_type_, stats_ in self.stats.items():
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
pad_value = 0 if k in ["min", "mean"] else 1
self.stats[robot_type_][feat_key][k] = pad_tensor(
v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value
)
# step 5: episodes & tasks
self.episodes = {
repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets)
}
self.tasks = {
repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets)
}
self.info = {
repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets)
}
class MultiLeRobotDatasetCleaner:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
sampling_weights: list[float],
datasets_repo_ids: list[str],
min_fps: int = 1,
max_fps: int = 100,
):
self.original_datasets = datasets
self.original_repo_ids = repo_ids
self.original_weights = sampling_weights
self.original_datasets_repo_ids = datasets_repo_ids
# step 1: remove datasets with invalid fps
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
# step 2: keep datasets with same features per robot type
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(valid_fps_datasets)
self.cleaned_datasets = consistent_datasets
self.keep_mask = keep_mask
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_datasets_repo_ids = [
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
]
self.cumulative_sizes = np.array(
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
)
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
class MultiLeRobotDataset(torch.utils.data.Dataset): class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s. """A dataset consisting of multiple underlying `LeRobotDataset`s.
@@ -1225,93 +1324,47 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
# restriction in future iterations of this class. For now, this is necessary at least for being able # restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function. # to use PyTorch's default DataLoader collate function.
# FIXME(mshukor): apply mapping to unify used keys # FIXME(mshukor): apply mapping to unify used keys
self.train_on_all_features = train_on_all_features # FIXME(mshukor): pad based on types in case we have more than one state?
self.disabled_features = set()
if not self.train_on_all_features:
intersection_features = set(_datasets[0].features)
for ds in _datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(repo_ids, _datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
union_features = {}
for ds in _datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
union_features[k] = v
if len(union_features) == 0:
raise RuntimeError("Multiple datasets were provided, but no features were found.")
self.image_transforms = image_transforms self.image_transforms = image_transforms
self.delta_timestamps = ( self.delta_timestamps = (
delta_timestamps.get(repo_id, None) if delta_timestamps else None delta_timestamps.get(repo_id, None) if delta_timestamps else None
) # delta_timestamps # FIXME(mshukor): last repo? ) # delta_timestamps # FIXME(mshukor): last repo?
# self.stats = aggregate_stats(self._datasets) # FIXME(mshukor): stats should be computed per robot type and then the robot type should be passed as input to the model
for ds in _datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# In case datasets with the same robot_type have different features # In case datasets with the same robot_type have different features
_datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps) cleaner = MultiLeRobotDatasetCleaner(
self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets) datasets=_datasets,
self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]] repo_ids=repo_ids,
self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]] sampling_weights=self.sampling_weights,
self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]] datasets_repo_ids=datasets_repo_ids,
# Compute cumulative sizes for fast indexing min_fps=min_fps,
self.cumulative_sizes = np.array( max_fps=max_fps,
[0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0))
) )
self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32) self._datasets = cleaner.cleaned_datasets
self.stats = aggregate_stats_per_robot_type(self._datasets) self.sampling_weights = cleaner.cleaned_weights
self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets self.repo_ids = cleaner.cleaned_repo_ids
self.meta.info = { self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) self.cumulative_sizes = cleaner.cumulative_sizes
} # self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
# self.meta.info = {
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
# }
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features # self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
# FIXME(mshukor): pad based on types in case we have more than one state? self.meta = MultiLeRobotDatasetMeta(
self.keys_to_max_dim = { datasets=self._datasets,
ACTION: max_action_dim, repo_ids=self.repo_ids,
OBS_ENV_STATE: max_state_dim, keys_to_max_dim={
OBS_STATE: max_state_dim, ACTION: max_action_dim,
OBS_IMAGE: max_image_dim, OBS_ENV_STATE: max_state_dim,
OBS_IMAGE_2: max_image_dim, OBS_STATE: max_state_dim,
OBS_IMAGE_3: max_image_dim, OBS_IMAGE: max_image_dim,
} OBS_IMAGE_2: max_image_dim,
# self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim) OBS_IMAGE_3: max_image_dim,
self.meta.info["features"] = reshape_features_to_max_dim( },
union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim train_on_all_features=train_on_all_features,
) )
# reshape stats self.disabled_features = self.meta.disabled_features
for robot_type_, stats_ in self.stats.items(): self.stats = self.meta.stats
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
if k in ["min", "mean"]:
pad_value = 0
elif k in ["max", "std"]:
pad_value = 1
else:
continue
self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value)
self.meta.stats = self.stats
# self.meta.info["features"] = aggregate_features(self._datasets)
self.meta.tasks = {
repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
}
self.meta.episodes = {
repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
}
self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets]
@property @property
def repo_id_to_index(self): def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class. """Return a mapping from dataset repo_id to a dataset index automatically created by this class.
@@ -1402,7 +1455,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item() local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
item = self._datasets[dataset_idx][local_idx] item = self._datasets[dataset_idx][local_idx]
item["dataset_index"] = torch.tensor(dataset_idx) item["dataset_index"] = torch.tensor(dataset_idx)
item = create_padded_features(item, self.meta.info["features"]) item = create_padded_features(item, self.meta.features)
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem? for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
if data_key in item: if data_key in item:
del item[data_key] del item[data_key]
+1 -2
View File
@@ -394,7 +394,7 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. # TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
@pytest.mark.skip("TODO after fix multidataset") # @pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_frames(): def test_multidataset_frames():
"""Check that all dataset frames are incorporated and aligned correctly.""" """Check that all dataset frames are incorporated and aligned correctly."""
repo_ids = [ repo_ids = [
@@ -421,7 +421,6 @@ def test_multidataset_frames():
assert len(dataset) == sum(len(d) for d in sub_datasets) assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets) assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
expected_dataset_indices = [] expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets): for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset)) expected_dataset_indices.extend([i] * len(sub_dataset))