Compare commits

...

15 Commits

Author SHA1 Message Date
danaaubakirova 7848b15bfb fix: changes to compute stats and modeling 2025-07-11 15:50:22 +02:00
pre-commit-ci[bot] 008b592545 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-07-11 03:55:05 +00:00
Jade 55a61259e8 make training work 2025-07-10 23:51:47 -04:00
danaaubakirova e94d78f8a0 Merge branch 'pr-1451' into danaaubakirova/25_06_2025 2025-07-10 10:26:31 +02:00
danaaubakirova 67c8d27e9c add 2025-07-09 14:22:34 +02:00
danaaubakirova c8b51ef205 Merge remote-tracking branch 'origin/main' into danaaubakirova/25_06_2025 2025-07-09 14:22:16 +02:00
Jade 9779c58b07 add training support 2025-07-08 23:57:46 -04:00
Jade 5fbe4f9987 remove yamls 2025-07-07 09:23:22 -04:00
Jade f47fd7aabf add hf hub integration 2025-07-07 09:22:48 -04:00
Jade a9251e612f cleanup/encapsulation 2025-07-05 13:08:25 -04:00
danaaubakirova 2b27084d63 Refactor embed prefix in modeling_smolvla2.py
Add old collator functions and constants for dataset handling
2025-07-01 14:35:02 +02:00
Jade ddb26b7189 add multi 2025-06-30 13:11:16 -04:00
Dana 96550e4ad1 nit 2025-06-30 12:01:32 +02:00
danaaubakirova c0146eed7f updates to lerobot_dataset.py 2025-06-27 14:43:33 +02:00
Dana 0447604fce some changes 2025-06-26 15:07:38 +02:00
16 changed files with 6372 additions and 100 deletions
File diff suppressed because it is too large Load Diff
+15
View File
@@ -37,6 +37,21 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
# Multi-dataset support
sampling_weights: str | None = None
max_action_dim: int | None = None
max_state_dim: int | None = None
max_num_images: int | None = None
max_image_dim: int | None = None
train_on_all_features: bool = False
features_version: int = 0
discard_first_n_frames: int = 0
min_fps: int = 1
max_fps: int = 100
discard_first_idle_frames: bool = False
motion_threshold: float = 5e-2
motion_window_size: int = 10
motion_buffer: int = 3
@dataclass
+4
View File
@@ -22,9 +22,13 @@ OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGE_3 = "observation.image3"
OBS_IMAGE_4 = "observation.image4"
REWARD = "next.reward"
ROBOTS = "robots"
TASK = "task"
ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators"
+68
View File
@@ -0,0 +1,68 @@
from typing import Dict, List
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
def pad_tensor(
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
) -> torch.Tensor:
is_numpy = isinstance(tensor, np.ndarray)
if is_numpy:
tensor = torch.tensor(tensor)
pad = max_size - tensor.shape[pad_dim]
if pad > 0:
pad_sizes = (0, pad) # pad right
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
return tensor.numpy() if is_numpy else tensor
def pad_list_of_tensors(
tensors: List[torch.Tensor], pad_dim: int = -1, pad_value: float = 0.0
) -> List[torch.Tensor]:
max_size = max([v.shape[pad_dim] for v in tensors])
return [pad_tensor(tensor, max_size, pad_dim=pad_dim, pad_value=pad_value) for tensor in tensors]
def multidataset_collate_fn(
batch: List[Dict[str, torch.Tensor]],
pad_dim: int = -1,
pad_value: float = 0.0,
keys_to_max_dim: dict = {},
) -> Dict[str, torch.Tensor]:
"""
Custom collate function to pad tensors with multiple dimensions.
Args:
batch (List[Dict[str, torch.Tensor]]): List of dataset samples (each sample is a dictionary).
Returns:
Dict[str, torch.Tensor]: Batch with padded tensors.
"""
batch_keys = batch[0].keys()
collated_batch = [{} for _ in range(len(batch))]
# FIXME(mshukor): pad to max shape per feature type
for key in batch_keys:
values = [sample[key] for sample in batch]
if (
key in keys_to_max_dim
and isinstance(values[0], torch.Tensor)
and is_batch_need_padding(values, pad_dim=pad_dim)
and keys_to_max_dim[key] is not None
):
max_size = keys_to_max_dim[key]
for i in range(len(batch)):
collated_batch[i][key] = pad_tensor(
batch[i][key], max_size, pad_dim=pad_dim, pad_value=pad_value
)
else:
for i in range(len(batch)):
collated_batch[i][key] = batch[i][key]
collated_batch = default_collate(collated_batch)
return collated_batch
+26 -5
View File
@@ -125,9 +125,30 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
# Filter out stats that don't have required keys
valid_stats = []
for s in stats_ft_list:
if all(key in s for key in ["mean", "std", "count", "min", "max"]):
valid_stats.append(s)
else:
# If count is missing, add it with a default value
if "count" not in s:
s["count"] = np.array([1]) # Default count
valid_stats.append(s)
if not valid_stats:
# If no valid stats, return empty stats
return {
"min": np.array([0]),
"max": np.array([0]),
"mean": np.array([0]),
"std": np.array([0]),
"count": np.array([0]),
}
means = np.stack([s["mean"] for s in valid_stats])
variances = np.stack([s["std"] ** 2 for s in valid_stats])
counts = np.stack([s["count"] for s in valid_stats])
total_count = counts.sum(axis=0)
# Prepare weighted mean by matching number of dimensions
@@ -144,8 +165,8 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
total_variance = weighted_variances.sum(axis=0) / total_count
return {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"min": np.min(np.stack([s["min"] for s in valid_stats]), axis=0),
"max": np.max(np.stack([s["max"] for s in valid_stats]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"count": total_count,
+53 -9
View File
@@ -32,6 +32,8 @@ IMAGENET_STATS = {
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
from lerobot.datasets.utils_must import EPISODES_DATASET_MAPPING, FEATURE_KEYS_MAPPING
def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
@@ -81,35 +83,77 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
if isinstance(cfg.dataset.repo_id, str):
if "," in cfg.dataset.repo_id:
repo_id = cfg.dataset.repo_id.split(",")
repo_id = [r for r in repo_id if r]
else:
repo_id = cfg.dataset.repo_id
sampling_weights = cfg.dataset.sampling_weights.split(",") if cfg.dataset.sampling_weights else None
feature_keys_mapping = FEATURE_KEYS_MAPPING
if isinstance(repo_id, str):
revision = getattr(cfg.dataset, "revision", None)
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
cfg.dataset.repo_id,
feature_keys_mapping=feature_keys_mapping,
revision=revision,
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
root=getattr(cfg.dataset, "root", None),
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
revision=revision,
video_backend=cfg.dataset.video_backend,
download_videos=True,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.dataset.max_action_dim,
max_state_dim=cfg.dataset.max_state_dim,
max_num_images=cfg.dataset.max_num_images,
max_image_dim=cfg.dataset.max_image_dim,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
delta_timestamps = {}
episodes = {}
for i in range(len(repo_id)):
ds_meta = LeRobotDatasetMetadata(
repo_id[i],
feature_keys_mapping=feature_keys_mapping,
) # FIXME(mshukor): ?
delta_timestamps[repo_id[i]] = resolve_delta_timestamps(cfg.policy, ds_meta)
episodes[repo_id[i]] = EPISODES_DATASET_MAPPING.get(repo_id[i], cfg.dataset.episodes)
# training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None)
# FIXME: (jadechoghari): check support for training features
training_features = None
dataset = MultiLeRobotDataset(
cfg.dataset.repo_id,
repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
episodes=episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
download_videos=True,
sampling_weights=sampling_weights,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.policy.max_action_dim,
max_state_dim=cfg.policy.max_state_dim,
max_num_images=cfg.dataset.max_num_images,
max_image_dim=cfg.dataset.max_image_dim,
train_on_all_features=cfg.dataset.train_on_all_features,
training_features=training_features,
discard_first_n_frames=cfg.dataset.discard_first_n_frames,
min_fps=cfg.dataset.min_fps,
max_fps=cfg.dataset.max_fps,
discard_first_idle_frames=cfg.dataset.discard_first_idle_frames,
motion_threshold=cfg.dataset.motion_threshold,
motion_window_size=cfg.dataset.motion_window_size,
motion_buffer=cfg.dataset.motion_buffer,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
)
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
+395 -71
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import contextlib
import logging
import os
import shutil
from pathlib import Path
from typing import Callable
@@ -30,8 +31,16 @@ from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.constants import (
ACTION,
HF_LEROBOT_HOME,
OBS_ENV_STATE,
OBS_STATE,
)
from lerobot.datasets.compute_stats import ( # aggregate_stats_per_robot_type,
aggregate_stats,
compute_episode_stats,
)
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
DEFAULT_FEATURES,
@@ -41,7 +50,6 @@ from lerobot.datasets.utils import (
_validate_feature_names,
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
create_empty_dataset_info,
@@ -58,12 +66,34 @@ from lerobot.datasets.utils import (
load_info,
load_stats,
load_tasks,
map_dict_keys,
validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
write_json,
# keep_datasets_with_the_same_features_per_robot_type,
# map_dict_pad_keys,
# keep_datasets_with_valid_fps,
# find_start_of_motion,
)
# mustafa stuff here
from lerobot.datasets.utils_must import (
OBS_IMAGE,
OBS_IMAGE_2,
OBS_IMAGE_3,
ROBOT_TYPE_KEYS_MAPPING,
TASKS_KEYS_MAPPING,
aggregate_stats_per_robot_type,
create_padded_features,
find_start_of_motion,
keep_datasets_with_the_same_features_per_robot_type,
keep_datasets_with_valid_fps,
map_dict_keys,
pad_tensor,
reshape_features_to_max_dim,
)
from lerobot.datasets.video_utils import (
VideoFrame,
@@ -74,6 +104,15 @@ from lerobot.datasets.video_utils import (
)
CODEBASE_VERSION = "v2.1"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
for t in range(len(velocities) - window_size):
window_mean = velocities[t : t + window_size].mean()
if window_mean > threshold:
return max(0, t - motion_buffer) # include slight context before motion
return 0
class LeRobotDatasetMetadata:
@@ -81,10 +120,13 @@ class LeRobotDatasetMetadata:
self,
repo_id: str,
root: str | Path | None = None,
local_files_only: bool = False,
feature_keys_mapping: dict[str, str] | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
):
self.repo_id = repo_id
self.local_files_only = local_files_only
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
@@ -99,6 +141,14 @@ class LeRobotDatasetMetadata:
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
# added by mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
self.info["features"] = map_dict_keys(
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
)
def load_metadata(self):
self.info = load_info(self.root)
@@ -177,7 +227,15 @@ class LeRobotDatasetMetadata:
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
# changed
keys = []
for key, ft in self.features.items():
key_ = (
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
)
if ft["dtype"] == "video":
keys.append(key_)
return keys
@property
def camera_keys(self) -> list[str]:
@@ -342,6 +400,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
download_videos: bool = True,
video_backend: str | None = None,
# new thing by M
feature_keys_mapping: dict[str, str] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
training_features: list | None = None,
discard_first_n_frames: int = 0,
discard_first_idle_frames: bool = False,
motion_threshold: float = 5e-2,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -455,15 +525,34 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
# by mshukor
self.training_features = training_features
self.discard_first_n_frames = discard_first_n_frames
self.discard_first_idle_frames = discard_first_idle_frames
self.motion_threshold = motion_threshold
self.motion_window_size = motion_window_size
self.motion_buffer = motion_buffer
# Unused attributes
self.image_writer = None
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
# more mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
# Load metadata
# TODO: change
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id,
self.root,
self.revision,
force_cache_sync=force_cache_sync,
feature_keys_mapping=feature_keys_mapping,
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
@@ -482,17 +571,74 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# mustafa code
if self.discard_first_n_frames > 0:
print("Discarding first n frames:", self.discard_first_n_frames)
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
# TODO implement advanced strategy
self.subset_frame_ids += [
frame_idx for frame_idx in range(from_ + int(self.fps * self.discard_first_n_frames), to_)
]
elif self.discard_first_idle_frames:
print(
f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}"
)
self.robot_states = torch.stack(self.hf_dataset[OBS_STATE]).numpy() # shape: [T, D]
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
ep_states = self.robot_states[from_:to_]
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
velocities = np.concatenate([[0.0], velocities])
start_idx = find_start_of_motion(
velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer
)
self.subset_frame_ids += list(range(from_ + start_idx, to_))
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# commented TODO: check why
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
# TODO: check why commented
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Mustafa
self.meta.info["features"] = map_dict_keys(
self.meta.info["features"],
feature_keys_mapping=self.feature_keys_mapping,
training_features=self.training_features,
)
self.keys_to_max_dim = {
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
}
self.meta.info["features"] = reshape_features_to_max_dim(
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
self.meta.stats = map_dict_keys(
self.meta.stats,
feature_keys_mapping=self.feature_keys_mapping,
training_features=self.training_features,
)
self.robot_type = self.meta.info.get("robot_type", "")
# Override tasks
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
def push_to_hub(
self,
branch: str | None = None,
@@ -647,6 +793,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
# FIXME(mshukor): what if we train on multiple datasets with different features
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
@@ -670,12 +817,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
# TODO: changed by mustafa
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
queries = {}
for key, q_idx in query_indices.items():
if (
key not in self.meta.video_keys
and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys
):
key_ = (
self.inverse_feature_keys_mapping.get(key, key)
if self.inverse_feature_keys_mapping
else key
)
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
return queries
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -699,8 +855,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __len__(self):
return self.num_frames
# changed by mshukor
def __getitem__(self, idx) -> dict:
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
idx = self.subset_frame_ids[idx]
item = self.hf_dataset[idx]
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
ep_idx = item["episode_index"].item()
query_indices = None
@@ -717,15 +877,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
try:
item["task"] = self.meta.tasks[task_idx]
except:
print(self.meta.tasks, task_idx, self.repo_id)
if "robot_type" not in item:
item["robot_type"] = self.robot_type
item = map_dict_keys(
item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
)
# Add padded features
# item = self._add_padded_features(item, self.training_features)
if self.image_transforms is not None:
for cam in item:
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
item[cam] = self.image_transforms(item[cam])
# Map pad keys
# print(item.keys(), "before")
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
# print(item.keys())
return item
def __repr__(self):
@@ -985,6 +1157,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
@@ -1005,6 +1178,106 @@ class LeRobotDataset(torch.utils.data.Dataset):
return obj
class MultiLeRobotDatasetMeta:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
keys_to_max_dim: dict[str, int],
train_on_all_features: bool = False,
):
self.repo_ids = repo_ids
self.keys_to_max_dim = keys_to_max_dim
self.train_on_all_features = train_on_all_features
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
# assign robot_type if missing
for ds in datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# step 1: compute disabled features
self.disabled_features = set()
if not self.train_on_all_features:
intersection = set(datasets[0].features)
for ds in datasets:
intersection.intersection_update(ds.features)
if not intersection:
raise RuntimeError("No common features across datasets.")
for repo_id, ds in zip(repo_ids, datasets, strict=False):
extra = set(ds.features) - intersection
logging.warning(f"Disabling {extra} for repo {repo_id}")
self.disabled_features.update(extra)
# step 2: build union_features excluding disabled
self.union_features = {}
for ds in datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
self.union_features[k] = v
# step 3: reshape feature schema
self.features = reshape_features_to_max_dim(
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
# step 4: aggregate stats
self.stats = aggregate_stats_per_robot_type(datasets)
for robot_type_, stats_ in self.stats.items():
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
pad_value = 0 if k in ["min", "mean"] else 1
self.stats[robot_type_][feat_key][k] = pad_tensor(
v,
max_size=self.keys_to_max_dim.get(feat_key, -1),
pad_dim=-1,
pad_value=pad_value,
)
# step 5: episodes & tasks
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
class MultiLeRobotDatasetCleaner:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
sampling_weights: list[float],
datasets_repo_ids: list[str],
min_fps: int = 1,
max_fps: int = 100,
):
self.original_datasets = datasets
self.original_repo_ids = repo_ids
self.original_weights = sampling_weights
self.original_datasets_repo_ids = datasets_repo_ids
# step 1: remove datasets with invalid fps
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
# step 2: keep datasets with same features per robot type
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
valid_fps_datasets
)
self.cleaned_datasets = consistent_datasets
self.keep_mask = keep_mask
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_datasets_repo_ids = [
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
]
self.cumulative_sizes = np.array(
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
)
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
@@ -1021,7 +1294,24 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
# add
sampling_weights: list[float] | None = None,
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
train_on_all_features: bool = False,
training_features: list | None = None,
discard_first_n_frames: int = 0,
min_fps: int = 1,
max_fps: int = 100,
discard_first_idle_frames: bool = False,
motion_threshold: float = 0.05,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
super().__init__()
self.repo_ids = repo_ids
@@ -1029,46 +1319,89 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
_datasets = []
datasets_repo_ids = []
self.sampling_weights = []
self.training_features = training_features
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
assert len(sampling_weights) == len(repo_ids), (
"The number of sampling weights must match the number of datasets. "
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
)
for i, repo_id in enumerate(repo_ids):
try:
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
_datasets.append(
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes.get(repo_id, None) if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
feature_keys_mapping=feature_keys_mapping,
training_features=training_features,
discard_first_n_frames=discard_first_n_frames,
discard_first_idle_frames=discard_first_idle_frames,
motion_threshold=motion_threshold,
motion_window_size=motion_window_size,
motion_buffer=motion_buffer,
)
)
datasets_repo_ids.append(repo_id)
self.sampling_weights.append(float(sampling_weights[i]))
except Exception as e:
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
print(
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
# FIXME(mshukor): apply mapping to unify used keys
# FIXME(mshukor): pad based on types in case we have more than one state?
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
self.delta_timestamps = (
delta_timestamps.get(repo_id, None) if delta_timestamps else None
) # delta_timestamps # FIXME(mshukor): last repo?
# In case datasets with the same robot_type have different features
cleaner = MultiLeRobotDatasetCleaner(
datasets=_datasets,
repo_ids=repo_ids,
sampling_weights=self.sampling_weights,
datasets_repo_ids=datasets_repo_ids,
min_fps=min_fps,
max_fps=max_fps,
)
self._datasets = cleaner.cleaned_datasets
self.sampling_weights = cleaner.cleaned_weights
self.repo_ids = cleaner.cleaned_repo_ids
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
self.cumulative_sizes = cleaner.cumulative_sizes
# self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
# self.meta.info = {
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
# }
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
self.meta = MultiLeRobotDatasetMeta(
datasets=self._datasets,
repo_ids=self.repo_ids,
keys_to_max_dim={
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
},
train_on_all_features=train_on_all_features,
)
self.disabled_features = self.meta.disabled_features
self.stats = self.meta.stats
@property
def repo_id_to_index(self):
@@ -1156,23 +1489,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
item = self._datasets[dataset_idx][local_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:
item = create_padded_features(item, self.meta.features)
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
if data_key in item:
del item[data_key]
return item
def __repr__(self):
+18
View File
@@ -858,3 +858,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def map_dict_keys(
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
) -> dict:
"""Maps feature keys from the dataset to the keys used in the model."""
if feature_keys_mapping is None:
return item
features = {}
for key in item:
if key in feature_keys_mapping:
if feature_keys_mapping[key] is not None:
if training_features is None or feature_keys_mapping[key] in training_features:
features[feature_keys_mapping[key]] = item[key]
else:
if training_features is None or key in training_features or pad_key in key:
features[key] = item[key]
return features
+409
View File
@@ -0,0 +1,409 @@
"""
Utils function by Mustafa to refactor
"""
from collections import defaultdict
from typing import Dict, List
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate
from lerobot.datasets.compute_stats import aggregate_stats
OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGE_3 = "observation.image3"
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
"""Reshape features to have a maximum dimension of `max_dim`."""
reshaped_features = {}
for key in features:
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
reshaped_features[key] = features[key]
shape = list(features[key]["shape"])
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
shape[-3] = keys_to_max_dim[key]
shape[-2] = keys_to_max_dim[key]
else:
shape[reshape_dim] = keys_to_max_dim[key]
reshaped_features[key]["shape"] = tuple(shape)
else:
reshaped_features[key] = features[key]
return reshaped_features
def keep_datasets_with_valid_fps(ls_datasets: list, min_fps: int = 1, max_fps: int = 100) -> list:
print(
f"Keeping datasets with fps between {min_fps} and {max_fps}. Considering {len(ls_datasets)} datasets."
)
for ds in ls_datasets:
if ds.fps < min_fps or ds.fps > max_fps:
print(f"Dataset {ds} has invalid fps: {ds.fps}. Removing it.")
ls_datasets.remove(ds)
print(f"Keeping {len(ls_datasets)} datasets with valid fps.")
return ls_datasets
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
"""
Filters datasets to only keep those with consistent feature shapes per robot type.
Args:
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
and `meta.episodes_stats` dictionary.
Returns:
List: Filtered list of datasets with consistent feature shapes.
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
datasets_to_remove = set()
for robot_type in robot_types:
# Collect all stats dicts for this robot type
stats_list = [
ep_stats
for ds in ls_datasets
if ds.meta.info["robot_type"] == robot_type
for ep_stats in ds.meta.episodes_stats.values()
]
if not stats_list:
continue
# Determine the most common shape for each key
all_keys = {key for stats in stats_list for key in stats}
for ds in ls_datasets:
if ds.meta.info["robot_type"] != robot_type:
continue
for key in all_keys:
shape_counter = defaultdict(int)
for stats in stats_list:
value = stats.get(key)
if (
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
): # FIXME(mshukor): check all stats; min, mean, max
shape_counter[value["mean"].shape] += 1
if not shape_counter:
continue
# Identify the most frequent shape
main_shape = max(shape_counter, key=shape_counter.get)
# Flag datasets that don't match the main shape
# for ds in ls_datasets:
first_ep_stats = next(iter(ds.meta.episodes_stats.values()), None)
if not first_ep_stats:
continue
value = first_ep_stats.get(key)
if (
value
and "mean" in value
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
and value["mean"].shape != main_shape
):
datasets_to_remove.add(ds)
break
# Filter out inconsistent datasets
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
print(
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
)
return filtered_datasets, datasets_maks
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
stats = {robot_type: {} for robot_type in robot_types}
for robot_type in robot_types:
robot_type_datasets = []
for ds in ls_datasets:
if ds.meta.info["robot_type"] == robot_type:
robot_type_datasets.extend(list(ds.meta.episodes_stats.values()))
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
stat = aggregate_stats(robot_type_datasets)
stats[robot_type] = stat
return stats
def str_to_torch_dtype(dtype_str):
"""Convert a dtype string to a torch dtype."""
mapping = {
"float32": torch.float32,
"int64": torch.int64,
"int16": torch.int16,
"bool": torch.bool,
"video": torch.float32, # Assuming video is stored as uint8 images
}
return mapping.get(dtype_str, torch.float32) # Default to float32
def create_padded_features(item: dict, features: dict = {}):
for key, ft in features.items():
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
continue
shape = ft["shape"]
if len(shape) == 3: # images to torch format (C, H, W)
shape = (shape[2], shape[0], shape[1])
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
shape = []
if key not in item:
dtype = str_to_torch_dtype(ft["dtype"])
item[key] = torch.zeros(shape, dtype=dtype)
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
if "image" in key: # FIXME(mshukor): support other observations
item[f"{key}_is_pad"] = torch.BoolTensor([False])
else:
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
return item
ROBOT_TYPE_KEYS_MAPPING = {
"lerobot/stanford_hydra_dataset": "static_single_arm",
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
"lerobot/toto": "static_single_arm",
"lerobot/roboturk": "static_single_arm",
"lerobot/jaco_play": "static_single_arm",
"lerobot/taco_play": "static_single_arm_7statedim",
}
def pad_tensor(
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
) -> torch.Tensor:
is_numpy = isinstance(tensor, np.ndarray)
if is_numpy:
tensor = torch.tensor(tensor)
if tensor.ndim == 0:
# Scalar — return as-is, no padding needed
return tensor
pad = max_size - tensor.shape[pad_dim]
if pad > 0:
pad_sizes = (0, pad) # pad right
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
return tensor.numpy() if is_numpy else tensor
def map_dict_keys(
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
) -> dict:
"""Maps feature keys from the dataset to the keys used in the model."""
if feature_keys_mapping is None:
return item
features = {}
for key in item:
if key in feature_keys_mapping:
if feature_keys_mapping[key] is not None:
if training_features is None or feature_keys_mapping[key] in training_features:
features[feature_keys_mapping[key]] = item[key]
else:
if training_features is None or key in training_features or pad_key in key:
features[key] = item[key]
# breakpoint()
return features
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
for t in range(len(velocities) - window_size):
window_mean = velocities[t : t + window_size].mean()
if window_mean > threshold:
return max(0, t - motion_buffer) # include slight context before motion
return 0
import requests
import yaml
def load_yaml_mapping(name: str) -> dict:
"""
Loads a YAML mapping from a Hugging Face repo.
Example: name='features' → https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/features.yaml
"""
url = f"https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/{name}.yaml"
response = requests.get(url)
response.raise_for_status() # raise if the download fails
return yaml.safe_load(response.text)
# Example usage
TASKS_KEYS_MAPPING = load_yaml_mapping("tasks")
FEATURE_KEYS_MAPPING = load_yaml_mapping("features")
EPISODES_DATASET_MAPPING = {
"cadene/droid_1.0.1": list(range(50)),
"danaaubakirova/svla_so100_task5_v3": [
0,
1,
2,
3,
4,
5,
6,
7,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
],
"danaaubakirova/svla_so100_task4_v3": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
],
}
ACTION = "action"
OBS_STATE = "observation.state"
TASK = "task"
ROBOT = "robot_type"
TRAINING_FEATURES = {
0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE],
1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2],
2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
}
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
def pad_tensor_to_shape(tensor: torch.Tensor, target_shape: tuple, pad_value: float = 0.0) -> torch.Tensor:
"""Pads a tensor to the target shape (right/bottom only)."""
pad = []
for actual, target in zip(reversed(tensor.shape), reversed(target_shape), strict=False):
pad.extend([0, max(target - actual, 0)])
return F.pad(tensor, pad, value=pad_value)
def multidataset_collate_fn(
batch: List[Dict[str, torch.Tensor]],
keys_to_max_dim: Dict[str, tuple] = {},
pad_value: float = 0.0,
) -> Dict[str, torch.Tensor]:
"""
Pads tensors to given target shape (if provided), otherwise uses per-batch max.
Supports 1D (e.g. action), 3D (e.g. [C,H,W] images).
"""
collated_batch = [{} for _ in range(len(batch))]
batch_keys = batch[0].keys()
for key in batch_keys:
values = [sample[key] for sample in batch]
sample = values[0]
if not isinstance(sample, torch.Tensor):
for i in range(len(batch)):
collated_batch[i][key] = values[i]
continue
# use user-specified shape if available
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
target_shape = keys_to_max_dim[key]
else:
# compute per-batch max shape
target_shape = tuple(max(v.shape[i] for v in values) for i in range(sample.ndim))
for i in range(len(batch)):
collated_batch[i][key] = pad_tensor_to_shape(values[i], target_shape, pad_value=pad_value)
return default_collate(collated_batch)
+1
View File
@@ -16,5 +16,6 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
+7
View File
@@ -32,6 +32,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -74,6 +75,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "smolvla2":
from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy
return SmolVLA2Policy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -95,6 +100,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "smolvla2":
return SmolVLA2Config(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
else:
@@ -0,0 +1,191 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@dataclass
class PEFTConfig:
r: int = 4
lora_alpha: int = 16
lora_dropout: float = 0.1
target_modules: str = "q_proj,v_proj"
@PreTrainedConfig.register_subclass("smolvla2")
@dataclass
class SmolVLA2Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512)
# Add empty images. Used by smolvla_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
proj_width: int = 480
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5 # 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
optimizer_lr_vlm: float = 0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
checkpoint_path: str = None
peft_method: str = ""
peft_config: PEFTConfig = field(default_factory=PEFTConfig)
peft_target_model: str = ""
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
attention_mode: str = "cross_attn"
prefix_length: int = -1
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
num_vlm_layers: int = 16
past_obs_keys: str = "image"
add_local_special_image_tokens: bool = False
reverse_images_order: bool = False
state_to_prefix: bool = False
pad_language_to: str = "longest" # "max_length"
causal_action_attention_mask: bool = False
self_attn_every_n_layers: int = -1 # Number of layers used in the VLM (first num_vlm_layers layers)
# self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
robot_type: str = ""
self_attn_only_actions: bool = False
causal_attention_on_history: bool = False
predict_relative_actions: bool = False
relative_actions_mode: str = "first"
shuffle_camera_positions: bool = False
vlm_img_size: int = -1
regression_loss: bool = False
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return [0]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,600 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Optional
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
from peft import LoraConfig, TaskType, get_peft_model
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
config = self.vlm.config
else:
config = AutoConfig.from_pretrained(model_id)
self.vlm = SmolVLMForConditionalGeneration(config=config)
self.processor = AutoProcessor.from_pretrained(model_id)
if num_vlm_layers > 0:
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
self.config = config
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
if "cross" in attention_mode:
# Reshape qkv projections to have the same input dimension as the vlm
for layer_idx in range(len(self.lm_expert.layers)):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
self.num_attention_heads = self.config.text_config.num_attention_heads
self.num_key_value_heads = self.config.text_config.num_key_value_heads
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_mode = attention_mode
self.expert_hidden_size = lm_expert_config.hidden_size
self.set_requires_grad()
def configure_peft(self, config):
# return model
self.peft_method = config.peft_method
self.peft_target_model = config.peft_target_model
if "lora" in self.peft_method:
peft_config = config.peft_config
target_modules = peft_config.target_modules
if not isinstance(target_modules, list):
target_modules = target_modules.split(",")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
r=peft_config.r, # The rank of the low-rank adaptation
lora_alpha=peft_config.lora_alpha, # Scaling factor
lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
target_modules=target_modules, # The components where LoRA is applied
exclude_modules=[
"lm_expert",
"model.lm_expert.model.layers",
], # FIXME(mshukor): this does not work for now
)
self.lora_config = lora_config
# Apply LoRA and ensure only LoRA parameters are trainable
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config)
else:
self.vlm = get_peft_model(self.vlm, lora_config)
# assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here?
for name, param in self.vlm.named_parameters():
if (
"lora" in name and "text_model.model.layers.17" not in name
): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
param.requires_grad = True
else:
param.requires_grad = False
def merge_lora_weights(self):
"""
Merge LoRA weights into the base model.
"""
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload()
else:
self.vlm = self.vlm.merge_and_unload()
def get_vlm_model(
self,
):
if hasattr(self.vlm.model, "model"): # When using peft
return self.vlm.model.model
else:
return self.vlm.model
def set_requires_grad(self):
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
for params in self.get_vlm_model().vision_model.parameters():
params.requires_grad = False
if self.train_expert_only:
self.vlm.eval()
for params in self.vlm.parameters():
params.requires_grad = False
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any(k in name for k in frozen_layers):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
if "lm_head" in name:
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
if self.train_expert_only:
self.vlm.eval()
def embed_image(self, image: torch.Tensor):
patch_attention_mask = None
# Get sequence from the vision encoder
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
seq_len = query_states.shape[1]
if seq_len < position_ids.shape[1]:
_position_ids = position_ids[:, :seq_len]
_attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
_position_ids = position_ids
_attention_mask = attention_mask
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_)
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
)
return [att_output], past_key_values
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# B,L,H,D with L sequence length, H number of heads, D head dim
query_states = apply_rope(query_state, position_id)
key_states = apply_rope(key_state, position_id)
att_output = attention_interface(
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_outputs.append(att_output)
else:
expert_position_id = position_ids
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
expert_input_shape = expert_hidden_states.shape[:-1]
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
att_outputs.append(None)
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list:
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
for i in range(self.num_vlm_layers):
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
expert_layer = None
else:
expert_layer_index = i // multiple_of if multiple_of > 0 else i
expert_layer = models[1].layers[expert_layer_index]
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
continue
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
att_out = att_output[:, start:end]
out_emb = layer.self_attn.o_proj(att_out)
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end if len(att_outputs) == 1 else 0
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
attention_interface = self.eager_attention_forward
return attention_interface
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output
+12 -1
View File
@@ -16,6 +16,7 @@
import logging
import time
from contextlib import nullcontext
from functools import partial
from pprint import pformat
from typing import Any
@@ -29,6 +30,7 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.datasets.utils_must import multidataset_collate_fn
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
@@ -173,9 +175,18 @@ def train(cfg: TrainPipelineConfig):
else:
shuffle = True
sampler = None
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
keys_to_max_dim = {
"action": (32,),
"observation.state": (32,),
"observation.image": (3, 1080, 1920),
"observation.image2": (3, 1080, 1920),
}
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=collate_fn,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
+33 -14
View File
@@ -394,37 +394,56 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
@pytest.mark.skip("TODO after fix multidataset")
# @pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
"""Check that all dataset frames are incorporated and aligned correctly."""
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
# dummy padding dimensions (simulate training setup)
MAX_ACTION_DIM = 14
MAX_STATE_DIM = 30
MAX_NUM_IMAGES = 3
MAX_IMAGE_DIM = 224
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
dataset = MultiLeRobotDataset(
repo_ids,
max_action_dim=MAX_ACTION_DIM,
max_state_dim=MAX_STATE_DIM,
max_num_images=MAX_NUM_IMAGES,
max_image_dim=MAX_IMAGE_DIM,
)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
for expected_dataset_index, sub_item, multi_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
dataset_index = multi_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# we ignore padding_mask and dataset_index keys in multi_item
extra_keys = {k for k in multi_item if "padding_mask" in k}
filtered_multi_keys = set(multi_item.keys()) - extra_keys
assert set(sub_item.keys()) == filtered_multi_keys, "mismatch in keys"
for k in sub_item:
if k not in multi_item:
continue
v1, v2 = sub_item[k], multi_item[k]
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
else:
assert v1 == v2, f"value mismatch on key: {k}"
# TODO(aliberts): Move to more appropriate location