mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
cleanup/encapsulation
This commit is contained in:
@@ -1139,6 +1139,105 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
class MultiLeRobotDatasetMeta:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datasets: list[LeRobotDataset],
|
||||||
|
repo_ids: list[str],
|
||||||
|
keys_to_max_dim: dict[str, int],
|
||||||
|
train_on_all_features: bool = False,
|
||||||
|
):
|
||||||
|
self.repo_ids = repo_ids
|
||||||
|
self.keys_to_max_dim = keys_to_max_dim
|
||||||
|
self.train_on_all_features = train_on_all_features
|
||||||
|
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
|
||||||
|
|
||||||
|
# assign robot_type if missing
|
||||||
|
for ds in datasets:
|
||||||
|
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||||
|
ds.robot_type = ds.meta.info["robot_type"]
|
||||||
|
|
||||||
|
# step 1: compute disabled features
|
||||||
|
self.disabled_features = set()
|
||||||
|
if not self.train_on_all_features:
|
||||||
|
intersection = set(datasets[0].features)
|
||||||
|
for ds in datasets:
|
||||||
|
intersection.intersection_update(ds.features)
|
||||||
|
if not intersection:
|
||||||
|
raise RuntimeError("No common features across datasets.")
|
||||||
|
for repo_id, ds in zip(repo_ids, datasets):
|
||||||
|
extra = set(ds.features) - intersection
|
||||||
|
logging.warning(f"Disabling {extra} for repo {repo_id}")
|
||||||
|
self.disabled_features.update(extra)
|
||||||
|
|
||||||
|
# step 2: build union_features excluding disabled
|
||||||
|
self.union_features = {}
|
||||||
|
for ds in datasets:
|
||||||
|
for k, v in ds.features.items():
|
||||||
|
if k not in self.disabled_features:
|
||||||
|
self.union_features[k] = v
|
||||||
|
|
||||||
|
# step 3: reshape feature schema
|
||||||
|
self.features = reshape_features_to_max_dim(
|
||||||
|
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# step 4: aggregate stats
|
||||||
|
self.stats = aggregate_stats_per_robot_type(datasets)
|
||||||
|
for robot_type_, stats_ in self.stats.items():
|
||||||
|
for feat_key, feat_stats in stats_.items():
|
||||||
|
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||||
|
for k, v in feat_stats.items():
|
||||||
|
pad_value = 0 if k in ["min", "mean"] else 1
|
||||||
|
self.stats[robot_type_][feat_key][k] = pad_tensor(
|
||||||
|
v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# step 5: episodes & tasks
|
||||||
|
self.episodes = {
|
||||||
|
repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets)
|
||||||
|
}
|
||||||
|
self.tasks = {
|
||||||
|
repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets)
|
||||||
|
}
|
||||||
|
self.info = {
|
||||||
|
repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets)
|
||||||
|
}
|
||||||
|
|
||||||
|
class MultiLeRobotDatasetCleaner:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datasets: list[LeRobotDataset],
|
||||||
|
repo_ids: list[str],
|
||||||
|
sampling_weights: list[float],
|
||||||
|
datasets_repo_ids: list[str],
|
||||||
|
min_fps: int = 1,
|
||||||
|
max_fps: int = 100,
|
||||||
|
):
|
||||||
|
self.original_datasets = datasets
|
||||||
|
self.original_repo_ids = repo_ids
|
||||||
|
self.original_weights = sampling_weights
|
||||||
|
self.original_datasets_repo_ids = datasets_repo_ids
|
||||||
|
|
||||||
|
# step 1: remove datasets with invalid fps
|
||||||
|
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
|
||||||
|
|
||||||
|
# step 2: keep datasets with same features per robot type
|
||||||
|
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(valid_fps_datasets)
|
||||||
|
|
||||||
|
self.cleaned_datasets = consistent_datasets
|
||||||
|
self.keep_mask = keep_mask
|
||||||
|
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||||
|
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||||
|
self.cleaned_datasets_repo_ids = [
|
||||||
|
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
|
||||||
|
]
|
||||||
|
|
||||||
|
self.cumulative_sizes = np.array(
|
||||||
|
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
|
||||||
|
)
|
||||||
|
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||||
@@ -1225,93 +1324,47 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||||
# to use PyTorch's default DataLoader collate function.
|
# to use PyTorch's default DataLoader collate function.
|
||||||
# FIXME(mshukor): apply mapping to unify used keys
|
# FIXME(mshukor): apply mapping to unify used keys
|
||||||
self.train_on_all_features = train_on_all_features
|
# FIXME(mshukor): pad based on types in case we have more than one state?
|
||||||
self.disabled_features = set()
|
|
||||||
if not self.train_on_all_features:
|
|
||||||
intersection_features = set(_datasets[0].features)
|
|
||||||
for ds in _datasets:
|
|
||||||
intersection_features.intersection_update(ds.features)
|
|
||||||
if len(intersection_features) == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
|
||||||
"The multi-dataset functionality currently only keeps common keys."
|
|
||||||
)
|
|
||||||
for repo_id, ds in zip(repo_ids, _datasets, strict=True):
|
|
||||||
extra_keys = set(ds.features).difference(intersection_features)
|
|
||||||
logging.warning(
|
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
|
||||||
"other datasets."
|
|
||||||
)
|
|
||||||
self.disabled_features.update(extra_keys)
|
|
||||||
union_features = {}
|
|
||||||
for ds in _datasets:
|
|
||||||
for k, v in ds.features.items():
|
|
||||||
if k not in self.disabled_features:
|
|
||||||
union_features[k] = v
|
|
||||||
|
|
||||||
if len(union_features) == 0:
|
|
||||||
raise RuntimeError("Multiple datasets were provided, but no features were found.")
|
|
||||||
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = (
|
self.delta_timestamps = (
|
||||||
delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||||
) # delta_timestamps # FIXME(mshukor): last repo?
|
) # delta_timestamps # FIXME(mshukor): last repo?
|
||||||
# self.stats = aggregate_stats(self._datasets) # FIXME(mshukor): stats should be computed per robot type and then the robot type should be passed as input to the model
|
|
||||||
for ds in _datasets:
|
|
||||||
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
|
||||||
ds.robot_type = ds.meta.info["robot_type"]
|
|
||||||
# In case datasets with the same robot_type have different features
|
# In case datasets with the same robot_type have different features
|
||||||
_datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps)
|
cleaner = MultiLeRobotDatasetCleaner(
|
||||||
self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets)
|
datasets=_datasets,
|
||||||
self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
repo_ids=repo_ids,
|
||||||
self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
sampling_weights=self.sampling_weights,
|
||||||
self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
datasets_repo_ids=datasets_repo_ids,
|
||||||
# Compute cumulative sizes for fast indexing
|
min_fps=min_fps,
|
||||||
self.cumulative_sizes = np.array(
|
max_fps=max_fps,
|
||||||
[0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0))
|
|
||||||
)
|
)
|
||||||
self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32)
|
self._datasets = cleaner.cleaned_datasets
|
||||||
self.stats = aggregate_stats_per_robot_type(self._datasets)
|
self.sampling_weights = cleaner.cleaned_weights
|
||||||
self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
self.repo_ids = cleaner.cleaned_repo_ids
|
||||||
self.meta.info = {
|
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
|
||||||
repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
self.cumulative_sizes = cleaner.cumulative_sizes
|
||||||
}
|
# self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
||||||
|
# self.meta.info = {
|
||||||
|
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||||
|
# }
|
||||||
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
|
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
|
||||||
# FIXME(mshukor): pad based on types in case we have more than one state?
|
self.meta = MultiLeRobotDatasetMeta(
|
||||||
self.keys_to_max_dim = {
|
datasets=self._datasets,
|
||||||
ACTION: max_action_dim,
|
repo_ids=self.repo_ids,
|
||||||
OBS_ENV_STATE: max_state_dim,
|
keys_to_max_dim={
|
||||||
OBS_STATE: max_state_dim,
|
ACTION: max_action_dim,
|
||||||
OBS_IMAGE: max_image_dim,
|
OBS_ENV_STATE: max_state_dim,
|
||||||
OBS_IMAGE_2: max_image_dim,
|
OBS_STATE: max_state_dim,
|
||||||
OBS_IMAGE_3: max_image_dim,
|
OBS_IMAGE: max_image_dim,
|
||||||
}
|
OBS_IMAGE_2: max_image_dim,
|
||||||
# self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim)
|
OBS_IMAGE_3: max_image_dim,
|
||||||
self.meta.info["features"] = reshape_features_to_max_dim(
|
},
|
||||||
union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
train_on_all_features=train_on_all_features,
|
||||||
)
|
)
|
||||||
# reshape stats
|
self.disabled_features = self.meta.disabled_features
|
||||||
for robot_type_, stats_ in self.stats.items():
|
self.stats = self.meta.stats
|
||||||
for feat_key, feat_stats in stats_.items():
|
|
||||||
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
|
||||||
for k, v in feat_stats.items():
|
|
||||||
if k in ["min", "mean"]:
|
|
||||||
pad_value = 0
|
|
||||||
elif k in ["max", "std"]:
|
|
||||||
pad_value = 1
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value)
|
|
||||||
|
|
||||||
self.meta.stats = self.stats
|
|
||||||
# self.meta.info["features"] = aggregate_features(self._datasets)
|
|
||||||
self.meta.tasks = {
|
|
||||||
repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
|
||||||
}
|
|
||||||
self.meta.episodes = {
|
|
||||||
repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
|
||||||
}
|
|
||||||
self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets]
|
|
||||||
@property
|
@property
|
||||||
def repo_id_to_index(self):
|
def repo_id_to_index(self):
|
||||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||||
@@ -1402,7 +1455,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||||
item = self._datasets[dataset_idx][local_idx]
|
item = self._datasets[dataset_idx][local_idx]
|
||||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||||
item = create_padded_features(item, self.meta.info["features"])
|
item = create_padded_features(item, self.meta.features)
|
||||||
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
|
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
|
||||||
if data_key in item:
|
if data_key in item:
|
||||||
del item[data_key]
|
del item[data_key]
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
|
|
||||||
|
|
||||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||||
@pytest.mark.skip("TODO after fix multidataset")
|
# @pytest.mark.skip("TODO after fix multidataset")
|
||||||
def test_multidataset_frames():
|
def test_multidataset_frames():
|
||||||
"""Check that all dataset frames are incorporated and aligned correctly."""
|
"""Check that all dataset frames are incorporated and aligned correctly."""
|
||||||
repo_ids = [
|
repo_ids = [
|
||||||
@@ -421,7 +421,6 @@ def test_multidataset_frames():
|
|||||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||||
|
|
||||||
expected_dataset_indices = []
|
expected_dataset_indices = []
|
||||||
for i, sub_dataset in enumerate(sub_datasets):
|
for i, sub_dataset in enumerate(sub_datasets):
|
||||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||||
|
|||||||
Reference in New Issue
Block a user