mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
add multi
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -73,6 +74,28 @@ from lerobot.common.datasets.video_utils import (
|
|||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# mustafa stuff here
|
||||||
|
from lerobot.common.datasets.utils_must import (
|
||||||
|
reshape_features_to_max_dim,
|
||||||
|
keep_datasets_with_valid_fps,
|
||||||
|
keep_datasets_with_the_same_features_per_robot_type,
|
||||||
|
aggregate_stats_per_robot_type,
|
||||||
|
create_padded_features,
|
||||||
|
pad_tensor,
|
||||||
|
map_dict_keys,
|
||||||
|
ROBOT_TYPE_KEYS_MAPPING,
|
||||||
|
OBS_IMAGE,
|
||||||
|
OBS_IMAGE_2,
|
||||||
|
OBS_IMAGE_3,
|
||||||
|
TASKS_KEYS_MAPPING,
|
||||||
|
)
|
||||||
|
from lerobot.common.constants import (
|
||||||
|
ACTION,
|
||||||
|
OBS_ENV_STATE,
|
||||||
|
OBS_STATE,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
CODEBASE_VERSION = "v2.1"
|
CODEBASE_VERSION = "v2.1"
|
||||||
|
|
||||||
|
|
||||||
@@ -83,6 +106,7 @@ class LeRobotDatasetMetadata:
|
|||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
|
feature_keys_mapping: dict[str, str] | None = None,
|
||||||
):
|
):
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
@@ -99,6 +123,14 @@ class LeRobotDatasetMetadata:
|
|||||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||||
self.pull_from_repo(allow_patterns="meta/")
|
self.pull_from_repo(allow_patterns="meta/")
|
||||||
self.load_metadata()
|
self.load_metadata()
|
||||||
|
# added by mshukor
|
||||||
|
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||||
|
self.inverse_feature_keys_mapping = (
|
||||||
|
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||||
|
)
|
||||||
|
self.info["features"] = map_dict_keys(
|
||||||
|
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
|
||||||
|
)
|
||||||
|
|
||||||
def load_metadata(self):
|
def load_metadata(self):
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
@@ -177,7 +209,15 @@ class LeRobotDatasetMetadata:
|
|||||||
@property
|
@property
|
||||||
def video_keys(self) -> list[str]:
|
def video_keys(self) -> list[str]:
|
||||||
"""Keys to access visual modalities stored as videos."""
|
"""Keys to access visual modalities stored as videos."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
# changed
|
||||||
|
keys = []
|
||||||
|
for key, ft in self.features.items():
|
||||||
|
key_ = (
|
||||||
|
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
|
||||||
|
)
|
||||||
|
if ft["dtype"] == "video":
|
||||||
|
keys.append(key_)
|
||||||
|
return keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
@@ -342,6 +382,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
|
||||||
|
# new thing by M
|
||||||
|
feature_keys_mapping: dict[str, str] | None = None,
|
||||||
|
max_action_dim: int = None,
|
||||||
|
max_state_dim: int = None,
|
||||||
|
max_num_images: int = None,
|
||||||
|
max_image_dim: int = None,
|
||||||
|
training_features: list | None = None,
|
||||||
|
discard_first_n_frames: int = 0,
|
||||||
|
discard_first_idle_frames: bool = False,
|
||||||
|
motion_threshold: float = 5e-2,
|
||||||
|
motion_window_size: int = 10,
|
||||||
|
motion_buffer: int = 3,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||||
@@ -455,15 +508,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
|
|
||||||
|
# by mshukor
|
||||||
|
self.training_features = training_features
|
||||||
|
self.discard_first_n_frames = discard_first_n_frames
|
||||||
|
self.discard_first_idle_frames = discard_first_idle_frames
|
||||||
|
self.motion_threshold = motion_threshold
|
||||||
|
self.motion_window_size = motion_window_size
|
||||||
|
self.motion_buffer = motion_buffer
|
||||||
|
|
||||||
# Unused attributes
|
# Unused attributes
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
self.episode_buffer = None
|
self.episode_buffer = None
|
||||||
|
|
||||||
self.root.mkdir(exist_ok=True, parents=True)
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# more mshukor
|
||||||
|
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||||
|
self.inverse_feature_keys_mapping = (
|
||||||
|
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||||
|
)
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
|
# TODO: change
|
||||||
self.meta = LeRobotDatasetMetadata(
|
self.meta = LeRobotDatasetMetadata(
|
||||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync,
|
||||||
|
feature_keys_mapping=feature_keys_mapping,
|
||||||
)
|
)
|
||||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||||
@@ -482,17 +551,62 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
|
||||||
|
# mustafa code
|
||||||
|
if self.discard_first_n_frames > 0:
|
||||||
|
print("Discarding first n frames:", self.discard_first_n_frames)
|
||||||
|
self.subset_frame_ids = []
|
||||||
|
for ep_idx in range(self.num_episodes):
|
||||||
|
from_ = self.episode_data_index["from"][ep_idx]
|
||||||
|
to_ = self.episode_data_index["to"][ep_idx]
|
||||||
|
# TODO implement advanced strategy
|
||||||
|
self.subset_frame_ids += [frame_idx for frame_idx in range(from_ + int(self.fps*self.discard_first_n_frames), to_)]
|
||||||
|
elif self.discard_first_idle_frames:
|
||||||
|
print(f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}")
|
||||||
|
self.robot_states = torch.stack(self.hf_dataset[OBS_ROBOT]).numpy() # shape: [T, D]
|
||||||
|
self.subset_frame_ids = []
|
||||||
|
for ep_idx in range(self.num_episodes):
|
||||||
|
from_ = self.episode_data_index["from"][ep_idx]
|
||||||
|
to_ = self.episode_data_index["to"][ep_idx]
|
||||||
|
ep_states = self.robot_states[from_:to_]
|
||||||
|
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
|
||||||
|
velocities = np.concatenate([[0.0], velocities])
|
||||||
|
start_idx = find_start_of_motion(velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer)
|
||||||
|
self.subset_frame_ids += list(range(from_ + start_idx, to_))
|
||||||
|
|
||||||
# Check timestamps
|
# Check timestamps
|
||||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
# commented TODO: check why
|
||||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||||
|
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
# Setup delta_indices
|
# Setup delta_indices
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
# TODO: check why commented
|
||||||
|
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||||
|
|
||||||
|
# Mustafa
|
||||||
|
self.meta.info["features"] = map_dict_keys(
|
||||||
|
self.meta.info["features"], feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
|
||||||
|
)
|
||||||
|
self.keys_to_max_dim = {
|
||||||
|
ACTION: max_action_dim,
|
||||||
|
OBS_ENV_STATE: max_state_dim,
|
||||||
|
OBS_STATE: max_state_dim,
|
||||||
|
OBS_IMAGE: max_image_dim,
|
||||||
|
OBS_IMAGE_2: max_image_dim,
|
||||||
|
OBS_IMAGE_3: max_image_dim,
|
||||||
|
}
|
||||||
|
self.meta.info["features"] = reshape_features_to_max_dim(
|
||||||
|
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||||
|
)
|
||||||
|
self.meta.stats = map_dict_keys(self.meta.stats, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||||
|
self.robot_type = self.meta.info.get("robot_type", "")
|
||||||
|
# Override tasks
|
||||||
|
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
|
||||||
|
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
|
||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(
|
||||||
self,
|
self,
|
||||||
branch: str | None = None,
|
branch: str | None = None,
|
||||||
@@ -647,6 +761,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||||
for key, delta_idx in self.delta_indices.items()
|
for key, delta_idx in self.delta_indices.items()
|
||||||
}
|
}
|
||||||
|
# FIXME(mshukor): what if we train on multiple datasets with different features
|
||||||
padding = { # Pad values outside of current episode range
|
padding = { # Pad values outside of current episode range
|
||||||
f"{key}_is_pad": torch.BoolTensor(
|
f"{key}_is_pad": torch.BoolTensor(
|
||||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||||
@@ -670,12 +785,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
|
# TODO: changed by mustafa
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
queries = {}
|
||||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
for key, q_idx in query_indices.items():
|
||||||
for key, q_idx in query_indices.items()
|
if key not in self.meta.video_keys and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys:
|
||||||
if key not in self.meta.video_keys
|
key_ = (
|
||||||
}
|
self.inverse_feature_keys_mapping.get(key, key)
|
||||||
|
if self.inverse_feature_keys_mapping
|
||||||
|
else key
|
||||||
|
)
|
||||||
|
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
|
||||||
|
return queries
|
||||||
|
|
||||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
@@ -699,8 +820,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_frames
|
return self.num_frames
|
||||||
|
|
||||||
|
# changed by mshukor
|
||||||
def __getitem__(self, idx) -> dict:
|
def __getitem__(self, idx) -> dict:
|
||||||
|
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
|
||||||
|
idx = self.subset_frame_ids[idx]
|
||||||
item = self.hf_dataset[idx]
|
item = self.hf_dataset[idx]
|
||||||
|
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
|
||||||
ep_idx = item["episode_index"].item()
|
ep_idx = item["episode_index"].item()
|
||||||
|
|
||||||
query_indices = None
|
query_indices = None
|
||||||
@@ -717,15 +842,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
item = {**video_frames, **item}
|
item = {**video_frames, **item}
|
||||||
|
|
||||||
if self.image_transforms is not None:
|
|
||||||
image_keys = self.meta.camera_keys
|
|
||||||
for cam in image_keys:
|
|
||||||
item[cam] = self.image_transforms(item[cam])
|
|
||||||
|
|
||||||
# Add task as a string
|
# Add task as a string
|
||||||
task_idx = item["task_index"].item()
|
task_idx = item["task_index"].item()
|
||||||
|
try:
|
||||||
item["task"] = self.meta.tasks[task_idx]
|
item["task"] = self.meta.tasks[task_idx]
|
||||||
|
except:
|
||||||
|
print(self.meta.tasks, task_idx, self.repo_id)
|
||||||
|
if "robot_type" not in item:
|
||||||
|
item["robot_type"] = self.robot_type
|
||||||
|
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||||
|
# Add padded features
|
||||||
|
# item = self._add_padded_features(item, self.training_features)
|
||||||
|
if self.image_transforms is not None:
|
||||||
|
for cam in item:
|
||||||
|
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
|
||||||
|
item[cam] = self.image_transforms(item[cam])
|
||||||
|
# Map pad keys
|
||||||
|
# print(item.keys(), "before")
|
||||||
|
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||||
|
# print(item.keys())
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -1022,54 +1157,161 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
tolerances_s: dict | None = None,
|
tolerances_s: dict | None = None,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
|
||||||
|
# add
|
||||||
|
sampling_weights: list[float] | None = None,
|
||||||
|
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
|
||||||
|
max_action_dim: int = None,
|
||||||
|
max_state_dim: int = None,
|
||||||
|
max_num_images: int = None,
|
||||||
|
max_image_dim: int = None,
|
||||||
|
train_on_all_features: bool = False,
|
||||||
|
training_features: list | None = None,
|
||||||
|
discard_first_n_frames: int = 0,
|
||||||
|
min_fps: int = 1,
|
||||||
|
max_fps: int = 100,
|
||||||
|
discard_first_idle_frames: bool = False,
|
||||||
|
motion_threshold: float = 0.05,
|
||||||
|
motion_window_size: int = 10,
|
||||||
|
motion_buffer: int = 3,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_ids = repo_ids
|
self.repo_ids = repo_ids
|
||||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||||
# are handled by this class.
|
# are handled by this class.
|
||||||
self._datasets = [
|
_datasets = []
|
||||||
|
datasets_repo_ids = []
|
||||||
|
self.sampling_weights = []
|
||||||
|
self.training_features = training_features
|
||||||
|
|
||||||
|
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
|
||||||
|
assert len(sampling_weights) == len(repo_ids), (
|
||||||
|
"The number of sampling weights must match the number of datasets. "
|
||||||
|
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
|
||||||
|
)
|
||||||
|
for i, repo_id in enumerate(repo_ids):
|
||||||
|
try:
|
||||||
|
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||||
|
_datasets.append(
|
||||||
LeRobotDataset(
|
LeRobotDataset(
|
||||||
repo_id,
|
repo_id,
|
||||||
root=self.root / repo_id,
|
root=self.root / repo_id,
|
||||||
episodes=episodes[repo_id] if episodes else None,
|
episodes=episodes.get(repo_id, None) if episodes else None,
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None,
|
||||||
tolerance_s=self.tolerances_s[repo_id],
|
tolerance_s=self.tolerances_s[repo_id],
|
||||||
download_videos=download_videos,
|
download_videos=download_videos,
|
||||||
video_backend=video_backend,
|
video_backend=video_backend,
|
||||||
|
feature_keys_mapping=feature_keys_mapping,
|
||||||
|
training_features=training_features,
|
||||||
|
discard_first_n_frames=discard_first_n_frames,
|
||||||
|
discard_first_idle_frames=discard_first_idle_frames,
|
||||||
|
motion_threshold=motion_threshold,
|
||||||
|
motion_window_size=motion_window_size,
|
||||||
|
motion_buffer=motion_buffer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
datasets_repo_ids.append(repo_id)
|
||||||
|
self.sampling_weights.append(float(sampling_weights[i]))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
|
||||||
|
print(
|
||||||
|
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
|
||||||
)
|
)
|
||||||
for repo_id in repo_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||||
# to use PyTorch's default DataLoader collate function.
|
# to use PyTorch's default DataLoader collate function.
|
||||||
|
# FIXME(mshukor): apply mapping to unify used keys
|
||||||
|
self.train_on_all_features = train_on_all_features
|
||||||
self.disabled_features = set()
|
self.disabled_features = set()
|
||||||
intersection_features = set(self._datasets[0].features)
|
if not self.train_on_all_features:
|
||||||
for ds in self._datasets:
|
intersection_features = set(_datasets[0].features)
|
||||||
|
for ds in _datasets:
|
||||||
intersection_features.intersection_update(ds.features)
|
intersection_features.intersection_update(ds.features)
|
||||||
if len(intersection_features) == 0:
|
if len(intersection_features) == 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||||
"The multi-dataset functionality currently only keeps common keys."
|
"The multi-dataset functionality currently only keeps common keys."
|
||||||
)
|
)
|
||||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
for repo_id, ds in zip(repo_ids, _datasets, strict=True):
|
||||||
extra_keys = set(ds.features).difference(intersection_features)
|
extra_keys = set(ds.features).difference(intersection_features)
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||||
"other datasets."
|
"other datasets."
|
||||||
)
|
)
|
||||||
self.disabled_features.update(extra_keys)
|
self.disabled_features.update(extra_keys)
|
||||||
|
union_features = {}
|
||||||
|
for ds in _datasets:
|
||||||
|
for k, v in ds.features.items():
|
||||||
|
if k not in self.disabled_features:
|
||||||
|
union_features[k] = v
|
||||||
|
|
||||||
|
if len(union_features) == 0:
|
||||||
|
raise RuntimeError("Multiple datasets were provided, but no features were found.")
|
||||||
|
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = (
|
||||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||||
# with multiple robots of different ranges. Instead we should have one normalization
|
) # delta_timestamps # FIXME(mshukor): last repo?
|
||||||
# per robot.
|
# self.stats = aggregate_stats(self._datasets) # FIXME(mshukor): stats should be computed per robot type and then the robot type should be passed as input to the model
|
||||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
for ds in _datasets:
|
||||||
|
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||||
|
ds.robot_type = ds.meta.info["robot_type"]
|
||||||
|
# In case datasets with the same robot_type have different features
|
||||||
|
_datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps)
|
||||||
|
self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets)
|
||||||
|
self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||||
|
self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||||
|
self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]]
|
||||||
|
# Compute cumulative sizes for fast indexing
|
||||||
|
self.cumulative_sizes = np.array(
|
||||||
|
[0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0))
|
||||||
|
)
|
||||||
|
self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32)
|
||||||
|
self.stats = aggregate_stats_per_robot_type(self._datasets)
|
||||||
|
self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
||||||
|
self.meta.info = {
|
||||||
|
repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||||
|
}
|
||||||
|
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
|
||||||
|
# FIXME(mshukor): pad based on types in case we have more than one state?
|
||||||
|
self.keys_to_max_dim = {
|
||||||
|
ACTION: max_action_dim,
|
||||||
|
OBS_ENV_STATE: max_state_dim,
|
||||||
|
OBS_STATE: max_state_dim,
|
||||||
|
OBS_IMAGE: max_image_dim,
|
||||||
|
OBS_IMAGE_2: max_image_dim,
|
||||||
|
OBS_IMAGE_3: max_image_dim,
|
||||||
|
}
|
||||||
|
# self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim)
|
||||||
|
self.meta.info["features"] = reshape_features_to_max_dim(
|
||||||
|
union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||||
|
)
|
||||||
|
# reshape stats
|
||||||
|
for robot_type_, stats_ in self.stats.items():
|
||||||
|
for feat_key, feat_stats in stats_.items():
|
||||||
|
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||||
|
for k, v in feat_stats.items():
|
||||||
|
if k in ["min", "mean"]:
|
||||||
|
pad_value = 0
|
||||||
|
elif k in ["max", "std"]:
|
||||||
|
pad_value = 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value)
|
||||||
|
|
||||||
|
self.meta.stats = self.stats
|
||||||
|
# self.meta.info["features"] = aggregate_features(self._datasets)
|
||||||
|
self.meta.tasks = {
|
||||||
|
repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||||
|
}
|
||||||
|
self.meta.episodes = {
|
||||||
|
repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||||
|
}
|
||||||
|
self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets]
|
||||||
@property
|
@property
|
||||||
def repo_id_to_index(self):
|
def repo_id_to_index(self):
|
||||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||||
@@ -1156,23 +1398,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
if idx >= len(self):
|
if idx >= len(self):
|
||||||
raise IndexError(f"Index {idx} out of bounds.")
|
raise IndexError(f"Index {idx} out of bounds.")
|
||||||
# Determine which dataset to get an item from based on the index.
|
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
|
||||||
start_idx = 0
|
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||||
dataset_idx = 0
|
item = self._datasets[dataset_idx][local_idx]
|
||||||
for dataset in self._datasets:
|
|
||||||
if idx >= start_idx + dataset.num_frames:
|
|
||||||
start_idx += dataset.num_frames
|
|
||||||
dataset_idx += 1
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
|
||||||
item = self._datasets[dataset_idx][idx - start_idx]
|
|
||||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||||
for data_key in self.disabled_features:
|
item = create_padded_features(item, self.meta.info["features"])
|
||||||
|
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
|
||||||
if data_key in item:
|
if data_key in item:
|
||||||
del item[data_key]
|
del item[data_key]
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -396,35 +396,57 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||||
@pytest.mark.skip("TODO after fix multidataset")
|
@pytest.mark.skip("TODO after fix multidataset")
|
||||||
def test_multidataset_frames():
|
def test_multidataset_frames():
|
||||||
"""Check that all dataset frames are incorporated."""
|
"""Check that all dataset frames are incorporated and aligned correctly."""
|
||||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
|
||||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
|
||||||
# logic that wouldn't be caught with two repo IDs.
|
|
||||||
repo_ids = [
|
repo_ids = [
|
||||||
"lerobot/aloha_sim_insertion_human_image",
|
"lerobot/aloha_sim_insertion_human_image",
|
||||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||||
"lerobot/aloha_sim_insertion_scripted_image",
|
"lerobot/aloha_sim_insertion_scripted_image",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# dummy padding dimensions (simulate training setup)
|
||||||
|
MAX_ACTION_DIM = 14
|
||||||
|
MAX_STATE_DIM = 30
|
||||||
|
MAX_NUM_IMAGES = 3
|
||||||
|
MAX_IMAGE_DIM = 224
|
||||||
|
|
||||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||||
dataset = MultiLeRobotDataset(repo_ids)
|
dataset = MultiLeRobotDataset(
|
||||||
|
repo_ids,
|
||||||
|
max_action_dim=MAX_ACTION_DIM,
|
||||||
|
max_state_dim=MAX_STATE_DIM,
|
||||||
|
max_num_images=MAX_NUM_IMAGES,
|
||||||
|
max_image_dim=MAX_IMAGE_DIM,
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||||
|
|
||||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
|
||||||
# check they match.
|
|
||||||
expected_dataset_indices = []
|
expected_dataset_indices = []
|
||||||
for i, sub_dataset in enumerate(sub_datasets):
|
for i, sub_dataset in enumerate(sub_datasets):
|
||||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||||
|
|
||||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
for expected_dataset_index, sub_item, multi_item in zip(
|
||||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||||
):
|
):
|
||||||
dataset_index = dataset_item.pop("dataset_index")
|
dataset_index = multi_item.pop("dataset_index")
|
||||||
assert dataset_index == expected_dataset_index
|
assert dataset_index == expected_dataset_index
|
||||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
|
||||||
for k in sub_dataset_item:
|
# we ignore padding_mask and dataset_index keys in multi_item
|
||||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
extra_keys = {k for k in multi_item if "padding_mask" in k}
|
||||||
|
filtered_multi_keys = set(multi_item.keys()) - extra_keys
|
||||||
|
assert set(sub_item.keys()) == filtered_multi_keys, f"mismatch in keys"
|
||||||
|
|
||||||
|
for k in sub_item:
|
||||||
|
if k not in multi_item:
|
||||||
|
continue
|
||||||
|
v1, v2 = sub_item[k], multi_item[k]
|
||||||
|
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
|
||||||
|
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
|
||||||
|
else:
|
||||||
|
assert v1 == v2, f"value mismatch on key: {k}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): Move to more appropriate location
|
# TODO(aliberts): Move to more appropriate location
|
||||||
|
|||||||
Reference in New Issue
Block a user