mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
add multi
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -73,6 +74,28 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_video_info,
|
||||
)
|
||||
|
||||
# mustafa stuff here
|
||||
from lerobot.common.datasets.utils_must import (
|
||||
reshape_features_to_max_dim,
|
||||
keep_datasets_with_valid_fps,
|
||||
keep_datasets_with_the_same_features_per_robot_type,
|
||||
aggregate_stats_per_robot_type,
|
||||
create_padded_features,
|
||||
pad_tensor,
|
||||
map_dict_keys,
|
||||
ROBOT_TYPE_KEYS_MAPPING,
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGE_2,
|
||||
OBS_IMAGE_3,
|
||||
TASKS_KEYS_MAPPING,
|
||||
)
|
||||
from lerobot.common.constants import (
|
||||
ACTION,
|
||||
OBS_ENV_STATE,
|
||||
OBS_STATE,
|
||||
|
||||
)
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
|
||||
|
||||
@@ -83,6 +106,7 @@ class LeRobotDatasetMetadata:
|
||||
root: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
feature_keys_mapping: dict[str, str] | None = None,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
@@ -99,6 +123,14 @@ class LeRobotDatasetMetadata:
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
# added by mshukor
|
||||
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||
self.inverse_feature_keys_mapping = (
|
||||
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||
)
|
||||
self.info["features"] = map_dict_keys(
|
||||
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
|
||||
)
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
@@ -177,7 +209,15 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
# changed
|
||||
keys = []
|
||||
for key, ft in self.features.items():
|
||||
key_ = (
|
||||
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
|
||||
)
|
||||
if ft["dtype"] == "video":
|
||||
keys.append(key_)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
@@ -342,6 +382,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
|
||||
# new thing by M
|
||||
feature_keys_mapping: dict[str, str] | None = None,
|
||||
max_action_dim: int = None,
|
||||
max_state_dim: int = None,
|
||||
max_num_images: int = None,
|
||||
max_image_dim: int = None,
|
||||
training_features: list | None = None,
|
||||
discard_first_n_frames: int = 0,
|
||||
discard_first_idle_frames: bool = False,
|
||||
motion_threshold: float = 5e-2,
|
||||
motion_window_size: int = 10,
|
||||
motion_buffer: int = 3,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -455,15 +508,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.delta_indices = None
|
||||
|
||||
# by mshukor
|
||||
self.training_features = training_features
|
||||
self.discard_first_n_frames = discard_first_n_frames
|
||||
self.discard_first_idle_frames = discard_first_idle_frames
|
||||
self.motion_threshold = motion_threshold
|
||||
self.motion_window_size = motion_window_size
|
||||
self.motion_buffer = motion_buffer
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# more mshukor
|
||||
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||
self.inverse_feature_keys_mapping = (
|
||||
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||
)
|
||||
|
||||
# Load metadata
|
||||
# TODO: change
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
@@ -482,17 +551,62 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# mustafa code
|
||||
if self.discard_first_n_frames > 0:
|
||||
print("Discarding first n frames:", self.discard_first_n_frames)
|
||||
self.subset_frame_ids = []
|
||||
for ep_idx in range(self.num_episodes):
|
||||
from_ = self.episode_data_index["from"][ep_idx]
|
||||
to_ = self.episode_data_index["to"][ep_idx]
|
||||
# TODO implement advanced strategy
|
||||
self.subset_frame_ids += [frame_idx for frame_idx in range(from_ + int(self.fps*self.discard_first_n_frames), to_)]
|
||||
elif self.discard_first_idle_frames:
|
||||
print(f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}")
|
||||
self.robot_states = torch.stack(self.hf_dataset[OBS_ROBOT]).numpy() # shape: [T, D]
|
||||
self.subset_frame_ids = []
|
||||
for ep_idx in range(self.num_episodes):
|
||||
from_ = self.episode_data_index["from"][ep_idx]
|
||||
to_ = self.episode_data_index["to"][ep_idx]
|
||||
ep_states = self.robot_states[from_:to_]
|
||||
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
|
||||
velocities = np.concatenate([[0.0], velocities])
|
||||
start_idx = find_start_of_motion(velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer)
|
||||
self.subset_frame_ids += list(range(from_ + start_idx, to_))
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
# commented TODO: check why
|
||||
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
# TODO: check why commented
|
||||
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Mustafa
|
||||
self.meta.info["features"] = map_dict_keys(
|
||||
self.meta.info["features"], feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
|
||||
)
|
||||
self.keys_to_max_dim = {
|
||||
ACTION: max_action_dim,
|
||||
OBS_ENV_STATE: max_state_dim,
|
||||
OBS_STATE: max_state_dim,
|
||||
OBS_IMAGE: max_image_dim,
|
||||
OBS_IMAGE_2: max_image_dim,
|
||||
OBS_IMAGE_3: max_image_dim,
|
||||
}
|
||||
self.meta.info["features"] = reshape_features_to_max_dim(
|
||||
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
self.meta.stats = map_dict_keys(self.meta.stats, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||
self.robot_type = self.meta.info.get("robot_type", "")
|
||||
# Override tasks
|
||||
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
|
||||
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
@@ -647,6 +761,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
# FIXME(mshukor): what if we train on multiple datasets with different features
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
@@ -670,12 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return query_timestamps
|
||||
|
||||
# TODO: changed by mustafa
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
queries = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key not in self.meta.video_keys and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys:
|
||||
key_ = (
|
||||
self.inverse_feature_keys_mapping.get(key, key)
|
||||
if self.inverse_feature_keys_mapping
|
||||
else key
|
||||
)
|
||||
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
|
||||
return queries
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
@@ -699,8 +820,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
# changed by mshukor
|
||||
def __getitem__(self, idx) -> dict:
|
||||
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
|
||||
idx = self.subset_frame_ids[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
@@ -717,15 +842,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
try:
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
except:
|
||||
print(self.meta.tasks, task_idx, self.repo_id)
|
||||
if "robot_type" not in item:
|
||||
item["robot_type"] = self.robot_type
|
||||
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||
# Add padded features
|
||||
# item = self._add_padded_features(item, self.training_features)
|
||||
if self.image_transforms is not None:
|
||||
for cam in item:
|
||||
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
# Map pad keys
|
||||
# print(item.keys(), "before")
|
||||
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||
# print(item.keys())
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1022,54 +1157,161 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
|
||||
# add
|
||||
sampling_weights: list[float] | None = None,
|
||||
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
|
||||
max_action_dim: int = None,
|
||||
max_state_dim: int = None,
|
||||
max_num_images: int = None,
|
||||
max_image_dim: int = None,
|
||||
train_on_all_features: bool = False,
|
||||
training_features: list | None = None,
|
||||
discard_first_n_frames: int = 0,
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
discard_first_idle_frames: bool = False,
|
||||
motion_threshold: float = 0.05,
|
||||
motion_window_size: int = 10,
|
||||
motion_buffer: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
_datasets = []
|
||||
datasets_repo_ids = []
|
||||
self.sampling_weights = []
|
||||
self.training_features = training_features
|
||||
|
||||
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
|
||||
assert len(sampling_weights) == len(repo_ids), (
|
||||
"The number of sampling weights must match the number of datasets. "
|
||||
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
|
||||
)
|
||||
for i, repo_id in enumerate(repo_ids):
|
||||
try:
|
||||
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
_datasets.append(
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes.get(repo_id, None) if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
training_features=training_features,
|
||||
discard_first_n_frames=discard_first_n_frames,
|
||||
discard_first_idle_frames=discard_first_idle_frames,
|
||||
motion_threshold=motion_threshold,
|
||||
motion_window_size=motion_window_size,
|
||||
motion_buffer=motion_buffer,
|
||||
)
|
||||
)
|
||||
datasets_repo_ids.append(repo_id)
|
||||
self.sampling_weights.append(float(sampling_weights[i]))
|
||||
except Exception as e:
|
||||
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
|
||||
print(
|
||||
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
|
||||
)
|
||||
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
# FIXME(mshukor): apply mapping to unify used keys
|
||||
self.train_on_all_features = train_on_all_features
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
if not self.train_on_all_features:
|
||||
intersection_features = set(_datasets[0].features)
|
||||
for ds in _datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(repo_ids, _datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
union_features = {}
|
||||
for ds in _datasets:
|
||||
for k, v in ds.features.items():
|
||||
if k not in self.disabled_features:
|
||||
union_features[k] = v
|
||||
|
||||
if len(union_features) == 0:
|
||||
raise RuntimeError("Multiple datasets were provided, but no features were found.")
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
self.delta_timestamps = (
|
||||
delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||
) # delta_timestamps # FIXME(mshukor): last repo?
|
||||
# self.stats = aggregate_stats(self._datasets) # FIXME(mshukor): stats should be computed per robot type and then the robot type should be passed as input to the model
|
||||
for ds in _datasets:
|
||||
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||
ds.robot_type = ds.meta.info["robot_type"]
|
||||
# In case datasets with the same robot_type have different features
|
||||
_datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps)
|
||||
self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets)
|
||||
self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||
self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||
self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||
# Compute cumulative sizes for fast indexing
|
||||
self.cumulative_sizes = np.array(
|
||||
[0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0))
|
||||
)
|
||||
self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32)
|
||||
self.stats = aggregate_stats_per_robot_type(self._datasets)
|
||||
self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
||||
self.meta.info = {
|
||||
repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
}
|
||||
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
|
||||
# FIXME(mshukor): pad based on types in case we have more than one state?
|
||||
self.keys_to_max_dim = {
|
||||
ACTION: max_action_dim,
|
||||
OBS_ENV_STATE: max_state_dim,
|
||||
OBS_STATE: max_state_dim,
|
||||
OBS_IMAGE: max_image_dim,
|
||||
OBS_IMAGE_2: max_image_dim,
|
||||
OBS_IMAGE_3: max_image_dim,
|
||||
}
|
||||
# self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim)
|
||||
self.meta.info["features"] = reshape_features_to_max_dim(
|
||||
union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
# reshape stats
|
||||
for robot_type_, stats_ in self.stats.items():
|
||||
for feat_key, feat_stats in stats_.items():
|
||||
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||
for k, v in feat_stats.items():
|
||||
if k in ["min", "mean"]:
|
||||
pad_value = 0
|
||||
elif k in ["max", "std"]:
|
||||
pad_value = 1
|
||||
else:
|
||||
continue
|
||||
self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value)
|
||||
|
||||
self.meta.stats = self.stats
|
||||
# self.meta.info["features"] = aggregate_features(self._datasets)
|
||||
self.meta.tasks = {
|
||||
repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
}
|
||||
self.meta.episodes = {
|
||||
repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
}
|
||||
self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets]
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
@@ -1156,23 +1398,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
|
||||
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||
item = self._datasets[dataset_idx][local_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
item = create_padded_features(item, self.meta.info["features"])
|
||||
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -396,35 +396,57 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
"""Check that all dataset frames are incorporated and aligned correctly."""
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
|
||||
# dummy padding dimensions (simulate training setup)
|
||||
MAX_ACTION_DIM = 14
|
||||
MAX_STATE_DIM = 30
|
||||
MAX_NUM_IMAGES = 3
|
||||
MAX_IMAGE_DIM = 224
|
||||
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
dataset = MultiLeRobotDataset(
|
||||
repo_ids,
|
||||
max_action_dim=MAX_ACTION_DIM,
|
||||
max_state_dim=MAX_STATE_DIM,
|
||||
max_num_images=MAX_NUM_IMAGES,
|
||||
max_image_dim=MAX_IMAGE_DIM,
|
||||
)
|
||||
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
# check they match.
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||
for expected_dataset_index, sub_item, multi_item in zip(
|
||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||
):
|
||||
dataset_index = dataset_item.pop("dataset_index")
|
||||
dataset_index = multi_item.pop("dataset_index")
|
||||
assert dataset_index == expected_dataset_index
|
||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||
for k in sub_dataset_item:
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
# we ignore padding_mask and dataset_index keys in multi_item
|
||||
extra_keys = {k for k in multi_item if "padding_mask" in k}
|
||||
filtered_multi_keys = set(multi_item.keys()) - extra_keys
|
||||
assert set(sub_item.keys()) == filtered_multi_keys, f"mismatch in keys"
|
||||
|
||||
for k in sub_item:
|
||||
if k not in multi_item:
|
||||
continue
|
||||
v1, v2 = sub_item[k], multi_item[k]
|
||||
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
|
||||
else:
|
||||
assert v1 == v2, f"value mismatch on key: {k}"
|
||||
|
||||
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
|
||||
Reference in New Issue
Block a user