Compare commits

...

17 Commits

Author SHA1 Message Date
danaaubakirova 61580a8596 Fix multi-GPU training script for local datasets 2025-09-16 16:37:10 +00:00
danaaubakirova 6c8f1f962b Update lerobot Python modules and add test training script
- Enhanced dataset processing and statistics computation
- Updated policy factory and normalization
- Improved SmolVLA2 modeling and expert integration
- Enhanced training and evaluation scripts
- Added utility improvements for training and wandb integration
- Added test training script with 2 datasets for validation
2025-09-16 16:11:26 +00:00
danaaubakirova 7848b15bfb fix: changes to compute stats and modeling 2025-07-11 15:50:22 +02:00
pre-commit-ci[bot] 008b592545 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-07-11 03:55:05 +00:00
Jade 55a61259e8 make training work 2025-07-10 23:51:47 -04:00
danaaubakirova e94d78f8a0 Merge branch 'pr-1451' into danaaubakirova/25_06_2025 2025-07-10 10:26:31 +02:00
danaaubakirova 67c8d27e9c add 2025-07-09 14:22:34 +02:00
danaaubakirova c8b51ef205 Merge remote-tracking branch 'origin/main' into danaaubakirova/25_06_2025 2025-07-09 14:22:16 +02:00
Jade 9779c58b07 add training support 2025-07-08 23:57:46 -04:00
Jade 5fbe4f9987 remove yamls 2025-07-07 09:23:22 -04:00
Jade f47fd7aabf add hf hub integration 2025-07-07 09:22:48 -04:00
Jade a9251e612f cleanup/encapsulation 2025-07-05 13:08:25 -04:00
danaaubakirova 2b27084d63 Refactor embed prefix in modeling_smolvla2.py
Add old collator functions and constants for dataset handling
2025-07-01 14:35:02 +02:00
Jade ddb26b7189 add multi 2025-06-30 13:11:16 -04:00
Dana 96550e4ad1 nit 2025-06-30 12:01:32 +02:00
danaaubakirova c0146eed7f updates to lerobot_dataset.py 2025-06-27 14:43:33 +02:00
Dana 0447604fce some changes 2025-06-26 15:07:38 +02:00
22 changed files with 6806 additions and 146 deletions
File diff suppressed because it is too large Load Diff
+15
View File
@@ -37,6 +37,21 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
# Multi-dataset support
sampling_weights: str | None = None
max_action_dim: int | None = None
max_state_dim: int | None = None
max_num_images: int | None = None
max_image_dim: int | None = None
train_on_all_features: bool = False
features_version: int = 0
discard_first_n_frames: int = 0
min_fps: int = 1
max_fps: int = 100
discard_first_idle_frames: bool = False
motion_threshold: float = 5e-2
motion_window_size: int = 10
motion_buffer: int = 3
@dataclass
+4
View File
@@ -22,9 +22,13 @@ OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGE_3 = "observation.image3"
OBS_IMAGE_4 = "observation.image4"
REWARD = "next.reward"
ROBOTS = "robots"
TASK = "task"
ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators"
+68
View File
@@ -0,0 +1,68 @@
from typing import Dict, List
import numpy as np
import torch
from torch.utils.data.dataloader import default_collate
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
def pad_tensor(
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
) -> torch.Tensor:
is_numpy = isinstance(tensor, np.ndarray)
if is_numpy:
tensor = torch.tensor(tensor)
pad = max_size - tensor.shape[pad_dim]
if pad > 0:
pad_sizes = (0, pad) # pad right
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
return tensor.numpy() if is_numpy else tensor
def pad_list_of_tensors(
tensors: List[torch.Tensor], pad_dim: int = -1, pad_value: float = 0.0
) -> List[torch.Tensor]:
max_size = max([v.shape[pad_dim] for v in tensors])
return [pad_tensor(tensor, max_size, pad_dim=pad_dim, pad_value=pad_value) for tensor in tensors]
def multidataset_collate_fn(
batch: List[Dict[str, torch.Tensor]],
pad_dim: int = -1,
pad_value: float = 0.0,
keys_to_max_dim: dict = {},
) -> Dict[str, torch.Tensor]:
"""
Custom collate function to pad tensors with multiple dimensions.
Args:
batch (List[Dict[str, torch.Tensor]]): List of dataset samples (each sample is a dictionary).
Returns:
Dict[str, torch.Tensor]: Batch with padded tensors.
"""
batch_keys = batch[0].keys()
collated_batch = [{} for _ in range(len(batch))]
# FIXME(mshukor): pad to max shape per feature type
for key in batch_keys:
values = [sample[key] for sample in batch]
if (
key in keys_to_max_dim
and isinstance(values[0], torch.Tensor)
and is_batch_need_padding(values, pad_dim=pad_dim)
and keys_to_max_dim[key] is not None
):
max_size = keys_to_max_dim[key]
for i in range(len(batch)):
collated_batch[i][key] = pad_tensor(
batch[i][key], max_size, pad_dim=pad_dim, pad_value=pad_value
)
else:
for i in range(len(batch)):
collated_batch[i][key] = batch[i][key]
collated_batch = default_collate(collated_batch)
return collated_batch
+28 -6
View File
@@ -125,9 +125,30 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
# Filter out stats that don't have required keys
valid_stats = []
for s in stats_ft_list:
if all(key in s for key in ["mean", "std", "count", "min", "max"]):
valid_stats.append(s)
else:
# If count is missing, add it with a default value
if "count" not in s:
s["count"] = np.array([1]) # Default count
valid_stats.append(s)
if not valid_stats:
# If no valid stats, return empty stats
return {
"min": np.array([0]),
"max": np.array([0]),
"mean": np.array([0]),
"std": np.array([0]),
"count": np.array([0]),
}
means = np.stack([s["mean"] for s in valid_stats])
variances = np.stack([s["std"] ** 2 for s in valid_stats])
counts = np.stack([s["count"] for s in valid_stats])
total_count = counts.sum(axis=0)
# Prepare weighted mean by matching number of dimensions
@@ -142,12 +163,13 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
delta_means = means - total_mean
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
total_std = np.sqrt(total_variance)
return {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"min": np.min(np.stack([s["min"] for s in valid_stats]), axis=0),
"max": np.max(np.stack([s["max"] for s in valid_stats]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"std": total_std,
"count": total_count,
}
+61 -9
View File
@@ -32,6 +32,8 @@ IMAGENET_STATS = {
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
from lerobot.datasets.utils_must import EPISODES_DATASET_MAPPING, FEATURE_KEYS_MAPPING
def resolve_delta_timestamps(
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
@@ -81,37 +83,87 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
if isinstance(cfg.dataset.repo_id, str):
if "," in cfg.dataset.repo_id:
repo_id = cfg.dataset.repo_id.split(",")
repo_id = [r for r in repo_id if r]
else:
repo_id = cfg.dataset.repo_id
sampling_weights = cfg.dataset.sampling_weights.split(",") if cfg.dataset.sampling_weights else None
feature_keys_mapping = FEATURE_KEYS_MAPPING
if isinstance(repo_id, str):
revision = getattr(cfg.dataset, "revision", None)
ds_meta = LeRobotDatasetMetadata(
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
cfg.dataset.repo_id,
feature_keys_mapping=feature_keys_mapping,
revision=revision,
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
root=getattr(cfg.dataset, "root", None),
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
revision=revision,
video_backend=cfg.dataset.video_backend,
download_videos=True,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.dataset.max_action_dim,
max_state_dim=cfg.dataset.max_state_dim,
max_num_images=cfg.dataset.max_num_images,
max_image_dim=cfg.dataset.max_image_dim,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
delta_timestamps = {}
episodes = {}
for i in range(len(repo_id)):
ds_meta = LeRobotDatasetMetadata(
repo_id[i],
feature_keys_mapping=feature_keys_mapping,
) # FIXME(mshukor): ?
delta_timestamps[repo_id[i]] = resolve_delta_timestamps(cfg.policy, ds_meta)
episodes[repo_id[i]] = EPISODES_DATASET_MAPPING.get(repo_id[i], cfg.dataset.episodes)
# training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None)
# FIXME: (jadechoghari): check support for training features
training_features = None
dataset = MultiLeRobotDataset(
cfg.dataset.repo_id,
repo_id,
# TODO(aliberts): add proper support for multi dataset
# delta_timestamps=delta_timestamps,
episodes=episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
download_videos=True,
sampling_weights=sampling_weights,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.policy.max_action_dim,
max_state_dim=cfg.policy.max_state_dim,
max_num_images=cfg.dataset.max_num_images,
max_image_dim=cfg.dataset.max_image_dim,
train_on_all_features=cfg.dataset.train_on_all_features,
training_features=training_features,
discard_first_n_frames=cfg.dataset.discard_first_n_frames,
min_fps=cfg.dataset.min_fps,
max_fps=cfg.dataset.max_fps,
discard_first_idle_frames=cfg.dataset.discard_first_idle_frames,
motion_threshold=cfg.dataset.motion_threshold,
motion_window_size=cfg.dataset.motion_window_size,
motion_buffer=cfg.dataset.motion_buffer,
)
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(dataset.repo_id_to_index, indent=2)}"
)
if cfg.dataset.use_imagenet_stats:
# Initialize stats structure if it doesn't exist
if dataset.meta.stats is None:
dataset.meta.stats = {}
for key in dataset.meta.camera_keys:
# Initialize stats for this camera key if it doesn't exist
if key not in dataset.meta.stats or dataset.meta.stats[key] is None:
dataset.meta.stats[key] = {}
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
+409 -77
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import contextlib
import logging
import os
import shutil
from pathlib import Path
from typing import Callable
@@ -30,8 +31,16 @@ from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.constants import (
ACTION,
HF_LEROBOT_HOME,
OBS_ENV_STATE,
OBS_STATE,
)
from lerobot.datasets.compute_stats import ( # aggregate_stats_per_robot_type,
aggregate_stats,
compute_episode_stats,
)
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
DEFAULT_FEATURES,
@@ -41,7 +50,6 @@ from lerobot.datasets.utils import (
_validate_feature_names,
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
create_empty_dataset_info,
@@ -58,12 +66,34 @@ from lerobot.datasets.utils import (
load_info,
load_stats,
load_tasks,
map_dict_keys,
validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
write_json,
# keep_datasets_with_the_same_features_per_robot_type,
# map_dict_pad_keys,
# keep_datasets_with_valid_fps,
# find_start_of_motion,
)
# mustafa stuff here
from lerobot.datasets.utils_must import (
OBS_IMAGE,
OBS_IMAGE_2,
OBS_IMAGE_3,
ROBOT_TYPE_KEYS_MAPPING,
TASKS_KEYS_MAPPING,
aggregate_stats_per_robot_type,
create_padded_features,
find_start_of_motion,
keep_datasets_with_the_same_features_per_robot_type,
keep_datasets_with_valid_fps,
map_dict_keys,
pad_tensor,
reshape_features_to_max_dim,
)
from lerobot.datasets.video_utils import (
VideoFrame,
@@ -74,6 +104,15 @@ from lerobot.datasets.video_utils import (
)
CODEBASE_VERSION = "v2.1"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
for t in range(len(velocities) - window_size):
window_mean = velocities[t : t + window_size].mean()
if window_mean > threshold:
return max(0, t - motion_buffer) # include slight context before motion
return 0
class LeRobotDatasetMetadata:
@@ -81,10 +120,13 @@ class LeRobotDatasetMetadata:
self,
repo_id: str,
root: str | Path | None = None,
local_files_only: bool = False,
feature_keys_mapping: dict[str, str] | None = None,
revision: str | None = None,
force_cache_sync: bool = False,
):
self.repo_id = repo_id
self.local_files_only = local_files_only
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
@@ -99,18 +141,27 @@ class LeRobotDatasetMetadata:
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
# added by mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
self.info["features"] = map_dict_keys(
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
)
def load_metadata(self):
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
# Force all datasets to use v2.1 format (episodes_stats.jsonl) to avoid missing stats.json issues, because I converted all the datasets to v2.1 format.
# if self._version < packaging.version.parse("v2.1"):
# self.stats = load_stats(self.root)
# self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
# else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo(
self,
@@ -177,7 +228,15 @@ class LeRobotDatasetMetadata:
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
# changed
keys = []
for key, ft in self.features.items():
key_ = (
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
)
if ft["dtype"] == "video":
keys.append(key_)
return keys
@property
def camera_keys(self) -> list[str]:
@@ -342,6 +401,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
download_videos: bool = True,
video_backend: str | None = None,
local_files_only: bool = False,
# new thing by M
feature_keys_mapping: dict[str, str] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
training_features: list | None = None,
discard_first_n_frames: int = 0,
discard_first_idle_frames: bool = False,
motion_threshold: float = 5e-2,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -455,15 +527,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
# by mshukor
self.training_features = training_features
self.discard_first_n_frames = discard_first_n_frames
self.discard_first_idle_frames = discard_first_idle_frames
self.motion_threshold = motion_threshold
self.motion_window_size = motion_window_size
self.motion_buffer = motion_buffer
# Unused attributes
self.image_writer = None
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
# more mshukor
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
self.inverse_feature_keys_mapping = (
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
)
# Load metadata
# TODO: change
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
self.repo_id,
self.root,
local_files_only=local_files_only,
revision=self.revision,
force_cache_sync=force_cache_sync,
feature_keys_mapping=feature_keys_mapping,
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
@@ -482,17 +574,74 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# mustafa code
if self.discard_first_n_frames > 0:
print("Discarding first n frames:", self.discard_first_n_frames)
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
# TODO implement advanced strategy
self.subset_frame_ids += [
frame_idx for frame_idx in range(from_ + int(self.fps * self.discard_first_n_frames), to_)
]
elif self.discard_first_idle_frames:
print(
f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}"
)
self.robot_states = torch.stack(self.hf_dataset[OBS_STATE]).numpy() # shape: [T, D]
self.subset_frame_ids = []
for ep_idx in range(self.num_episodes):
from_ = self.episode_data_index["from"][ep_idx]
to_ = self.episode_data_index["to"][ep_idx]
ep_states = self.robot_states[from_:to_]
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
velocities = np.concatenate([[0.0], velocities])
start_idx = find_start_of_motion(
velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer
)
self.subset_frame_ids += list(range(from_ + start_idx, to_))
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# commented TODO: check why
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
# TODO: check why commented
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Mustafa
self.meta.info["features"] = map_dict_keys(
self.meta.info["features"],
feature_keys_mapping=self.feature_keys_mapping,
training_features=self.training_features,
)
self.keys_to_max_dim = {
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
}
self.meta.info["features"] = reshape_features_to_max_dim(
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
self.meta.stats = map_dict_keys(
self.meta.stats,
feature_keys_mapping=self.feature_keys_mapping,
training_features=self.training_features,
)
self.robot_type = self.meta.info.get("robot_type", "")
# Override tasks
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
def push_to_hub(
self,
branch: str | None = None,
@@ -641,12 +790,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
# Bounds check to prevent IndexError when episode_index is out of range
if ep_idx >= len(self.episode_data_index["from"]):
# Fall back to the last valid episode
ep_idx = len(self.episode_data_index["from"]) - 1
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
# FIXME(mshukor): what if we train on multiple datasets with different features
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
@@ -670,12 +825,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
# TODO: changed by mustafa
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
queries = {}
for key, q_idx in query_indices.items():
if (
key not in self.meta.video_keys
and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys
):
key_ = (
self.inverse_feature_keys_mapping.get(key, key)
if self.inverse_feature_keys_mapping
else key
)
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
return queries
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -699,8 +863,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __len__(self):
return self.num_frames
# changed by mshukor
def __getitem__(self, idx) -> dict:
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
idx = self.subset_frame_ids[idx]
item = self.hf_dataset[idx]
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
ep_idx = item["episode_index"].item()
query_indices = None
@@ -717,15 +885,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
try:
item["task"] = self.meta.tasks[task_idx]
except:
print(self.meta.tasks, task_idx, self.repo_id)
if "robot_type" not in item:
item["robot_type"] = self.robot_type
item = map_dict_keys(
item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
)
# Add padded features
# item = self._add_padded_features(item, self.training_features)
if self.image_transforms is not None:
for cam in item:
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
item[cam] = self.image_transforms(item[cam])
# Map pad keys
# print(item.keys(), "before")
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
# print(item.keys())
return item
def __repr__(self):
@@ -985,6 +1165,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
@@ -1005,6 +1186,106 @@ class LeRobotDataset(torch.utils.data.Dataset):
return obj
class MultiLeRobotDatasetMeta:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
keys_to_max_dim: dict[str, int],
train_on_all_features: bool = False,
):
self.repo_ids = repo_ids
self.keys_to_max_dim = keys_to_max_dim
self.train_on_all_features = train_on_all_features
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
# assign robot_type if missing
for ds in datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# step 1: compute disabled features
self.disabled_features = set()
if not self.train_on_all_features:
intersection = set(datasets[0].features)
for ds in datasets:
intersection.intersection_update(ds.features)
if not intersection:
raise RuntimeError("No common features across datasets.")
for repo_id, ds in zip(repo_ids, datasets, strict=False):
extra = set(ds.features) - intersection
logging.warning(f"Disabling {extra} for repo {repo_id}")
self.disabled_features.update(extra)
# step 2: build union_features excluding disabled
self.union_features = {}
for ds in datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
self.union_features[k] = v
# step 3: reshape feature schema
self.features = reshape_features_to_max_dim(
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
# step 4: aggregate stats
self.stats = aggregate_stats_per_robot_type(datasets)
for robot_type_, stats_ in self.stats.items():
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
pad_value = 0 if k in ["min", "mean"] else 1
self.stats[robot_type_][feat_key][k] = pad_tensor(
v,
max_size=self.keys_to_max_dim.get(feat_key, -1),
pad_dim=-1,
pad_value=pad_value,
)
# step 5: episodes & tasks
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
class MultiLeRobotDatasetCleaner:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
sampling_weights: list[float],
datasets_repo_ids: list[str],
min_fps: int = 1,
max_fps: int = 100,
):
self.original_datasets = datasets
self.original_repo_ids = repo_ids
self.original_weights = sampling_weights
self.original_datasets_repo_ids = datasets_repo_ids
# step 1: remove datasets with invalid fps
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
# step 2: keep datasets with same features per robot type
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
valid_fps_datasets
)
self.cleaned_datasets = consistent_datasets
self.keep_mask = keep_mask
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
self.cleaned_datasets_repo_ids = [
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
]
self.cumulative_sizes = np.array(
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
)
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
@@ -1021,7 +1302,24 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
# add
sampling_weights: list[float] | None = None,
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
max_action_dim: int = None,
max_state_dim: int = None,
max_num_images: int = None,
max_image_dim: int = None,
train_on_all_features: bool = False,
training_features: list | None = None,
discard_first_n_frames: int = 0,
min_fps: int = 1,
max_fps: int = 100,
discard_first_idle_frames: bool = False,
motion_threshold: float = 0.05,
motion_window_size: int = 10,
motion_buffer: int = 3,
):
super().__init__()
self.repo_ids = repo_ids
@@ -1029,46 +1327,89 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
_datasets = []
datasets_repo_ids = []
self.sampling_weights = []
self.training_features = training_features
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
assert len(sampling_weights) == len(repo_ids), (
"The number of sampling weights must match the number of datasets. "
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
)
for i, repo_id in enumerate(repo_ids):
try:
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
_datasets.append(
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes.get(repo_id, None) if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
feature_keys_mapping=feature_keys_mapping,
training_features=training_features,
discard_first_n_frames=discard_first_n_frames,
discard_first_idle_frames=discard_first_idle_frames,
motion_threshold=motion_threshold,
motion_window_size=motion_window_size,
motion_buffer=motion_buffer,
)
)
datasets_repo_ids.append(repo_id)
self.sampling_weights.append(float(sampling_weights[i]))
except Exception as e:
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
print(
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
# FIXME(mshukor): apply mapping to unify used keys
# FIXME(mshukor): pad based on types in case we have more than one state?
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
self.delta_timestamps = (
delta_timestamps.get(repo_id, None) if delta_timestamps else None
) # delta_timestamps # FIXME(mshukor): last repo?
# In case datasets with the same robot_type have different features
cleaner = MultiLeRobotDatasetCleaner(
datasets=_datasets,
repo_ids=repo_ids,
sampling_weights=self.sampling_weights,
datasets_repo_ids=datasets_repo_ids,
min_fps=min_fps,
max_fps=max_fps,
)
self._datasets = cleaner.cleaned_datasets
self.sampling_weights = cleaner.cleaned_weights
self.repo_ids = cleaner.cleaned_repo_ids
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
self.cumulative_sizes = cleaner.cumulative_sizes
# self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
# self.meta.info = {
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
# }
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
self.meta = MultiLeRobotDatasetMeta(
datasets=self._datasets,
repo_ids=self.repo_ids,
keys_to_max_dim={
ACTION: max_action_dim,
OBS_ENV_STATE: max_state_dim,
OBS_STATE: max_state_dim,
OBS_IMAGE: max_image_dim,
OBS_IMAGE_2: max_image_dim,
OBS_IMAGE_3: max_image_dim,
},
train_on_all_features=train_on_all_features,
)
self.disabled_features = self.meta.disabled_features
self.stats = self.meta.stats
@property
def repo_id_to_index(self):
@@ -1156,23 +1497,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
item = self._datasets[dataset_idx][local_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:
item = create_padded_features(item, self.meta.features)
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
if data_key in item:
del item[data_key]
return item
def __repr__(self):
+18
View File
@@ -858,3 +858,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def map_dict_keys(
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
) -> dict:
"""Maps feature keys from the dataset to the keys used in the model."""
if feature_keys_mapping is None:
return item
features = {}
for key in item:
if key in feature_keys_mapping:
if feature_keys_mapping[key] is not None:
if training_features is None or feature_keys_mapping[key] in training_features:
features[feature_keys_mapping[key]] = item[key]
else:
if training_features is None or key in training_features or pad_key in key:
features[key] = item[key]
return features
+416
View File
@@ -0,0 +1,416 @@
"""
Utils function by Mustafa to refactor
"""
from collections import defaultdict
from typing import Dict, List
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate
from lerobot.datasets.compute_stats import aggregate_stats
OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image2"
OBS_IMAGE_3 = "observation.image3"
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
"""Reshape features to have a maximum dimension of `max_dim`."""
reshaped_features = {}
for key in features:
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
reshaped_features[key] = features[key]
shape = list(features[key]["shape"])
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
shape[-3] = keys_to_max_dim[key]
shape[-2] = keys_to_max_dim[key]
else:
shape[reshape_dim] = keys_to_max_dim[key]
reshaped_features[key]["shape"] = tuple(shape)
else:
reshaped_features[key] = features[key]
return reshaped_features
def keep_datasets_with_valid_fps(ls_datasets: list, min_fps: int = 1, max_fps: int = 100) -> list:
print(
f"Keeping datasets with fps between {min_fps} and {max_fps}. Considering {len(ls_datasets)} datasets."
)
for ds in ls_datasets:
if ds.fps < min_fps or ds.fps > max_fps:
print(f"Dataset {ds} has invalid fps: {ds.fps}. Removing it.")
ls_datasets.remove(ds)
print(f"Keeping {len(ls_datasets)} datasets with valid fps.")
return ls_datasets
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
"""
Filters datasets to only keep those with consistent feature shapes per robot type.
Args:
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
and `meta.episodes_stats` dictionary.
Returns:
List: Filtered list of datasets with consistent feature shapes.
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
datasets_to_remove = set()
for robot_type in robot_types:
# Collect all stats dicts for this robot type
stats_list = [
ep_stats
for ds in ls_datasets
if ds.meta.info["robot_type"] == robot_type
for ep_stats in ds.meta.episodes_stats.values()
if ep_stats is not None # Filter out None values
]
if not stats_list:
continue
# Determine the most common shape for each key
all_keys = {key for stats in stats_list for key in stats}
for ds in ls_datasets:
if ds.meta.info["robot_type"] != robot_type:
continue
for key in all_keys:
shape_counter = defaultdict(int)
for stats in stats_list:
value = stats.get(key)
if (
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
): # FIXME(mshukor): check all stats; min, mean, max
shape_counter[value["mean"].shape] += 1
if not shape_counter:
continue
# Identify the most frequent shape
main_shape = max(shape_counter, key=shape_counter.get)
# Flag datasets that don't match the main shape
# for ds in ls_datasets:
first_ep_stats = next(iter(ds.meta.episodes_stats.values()), None)
if not first_ep_stats:
continue
value = first_ep_stats.get(key)
if (
value
and "mean" in value
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
and value["mean"].shape != main_shape
):
datasets_to_remove.add(ds)
break
# Filter out inconsistent datasets
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
print(
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
)
return filtered_datasets, datasets_maks
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
stats = {robot_type: {} for robot_type in robot_types}
for robot_type in robot_types:
robot_type_datasets = []
for ds in ls_datasets:
if ds.meta.info["robot_type"] == robot_type:
# Filter out None values from episodes_stats to handle missing stats
valid_episodes_stats = [stats for stats in ds.meta.episodes_stats.values() if stats is not None]
robot_type_datasets.extend(valid_episodes_stats)
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
if robot_type_datasets: # Only aggregate if we have valid stats
stat = aggregate_stats(robot_type_datasets)
stats[robot_type] = stat
else:
print(f"Warning: No valid episode stats found for robot type {robot_type}, skipping aggregation")
stats[robot_type] = {}
return stats
def str_to_torch_dtype(dtype_str):
"""Convert a dtype string to a torch dtype."""
mapping = {
"float32": torch.float32,
"int64": torch.int64,
"int16": torch.int16,
"bool": torch.bool,
"video": torch.float32, # Assuming video is stored as uint8 images
}
return mapping.get(dtype_str, torch.float32) # Default to float32
def create_padded_features(item: dict, features: dict = {}):
for key, ft in features.items():
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
continue
shape = ft["shape"]
if len(shape) == 3: # images to torch format (C, H, W)
shape = (shape[2], shape[0], shape[1])
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
shape = []
if key not in item:
dtype = str_to_torch_dtype(ft["dtype"])
item[key] = torch.zeros(shape, dtype=dtype)
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
if "image" in key: # FIXME(mshukor): support other observations
item[f"{key}_is_pad"] = torch.BoolTensor([False])
else:
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
return item
ROBOT_TYPE_KEYS_MAPPING = {
"lerobot/stanford_hydra_dataset": "static_single_arm",
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
"lerobot/toto": "static_single_arm",
"lerobot/roboturk": "static_single_arm",
"lerobot/jaco_play": "static_single_arm",
"lerobot/taco_play": "static_single_arm_7statedim",
}
def pad_tensor(
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
) -> torch.Tensor:
is_numpy = isinstance(tensor, np.ndarray)
if is_numpy:
tensor = torch.tensor(tensor)
if tensor.ndim == 0:
# Scalar — return as-is, no padding needed
return tensor
pad = max_size - tensor.shape[pad_dim]
if pad > 0:
pad_sizes = (0, pad) # pad right
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
return tensor.numpy() if is_numpy else tensor
def map_dict_keys(
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
) -> dict:
"""Maps feature keys from the dataset to the keys used in the model."""
if feature_keys_mapping is None:
return item
features = {}
for key in item:
if key in feature_keys_mapping:
if feature_keys_mapping[key] is not None:
if training_features is None or feature_keys_mapping[key] in training_features:
features[feature_keys_mapping[key]] = item[key]
else:
if training_features is None or key in training_features or pad_key in key:
features[key] = item[key]
# breakpoint()
return features
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
for t in range(len(velocities) - window_size):
window_mean = velocities[t : t + window_size].mean()
if window_mean > threshold:
return max(0, t - motion_buffer) # include slight context before motion
return 0
import requests
import yaml
def load_yaml_mapping(name: str) -> dict:
"""
Loads a YAML mapping from a Hugging Face repo.
Example: name='features' https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/features.yaml
"""
url = f"https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/{name}.yaml"
response = requests.get(url)
response.raise_for_status() # raise if the download fails
return yaml.safe_load(response.text)
# Example usage
TASKS_KEYS_MAPPING = load_yaml_mapping("tasks")
FEATURE_KEYS_MAPPING = load_yaml_mapping("features")
EPISODES_DATASET_MAPPING = {
"cadene/droid_1.0.1": list(range(50)),
"danaaubakirova/svla_so100_task5_v3": [
0,
1,
2,
3,
4,
5,
6,
7,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
],
"danaaubakirova/svla_so100_task4_v3": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
],
}
ACTION = "action"
OBS_STATE = "observation.state"
TASK = "task"
ROBOT = "robot_type"
TRAINING_FEATURES = {
0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE],
1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2],
2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
}
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
def pad_tensor_to_shape(tensor: torch.Tensor, target_shape: tuple, pad_value: float = 0.0) -> torch.Tensor:
"""Pads a tensor to the target shape (right/bottom only)."""
pad = []
for actual, target in zip(reversed(tensor.shape), reversed(target_shape), strict=False):
pad.extend([0, max(target - actual, 0)])
return F.pad(tensor, pad, value=pad_value)
def multidataset_collate_fn(
batch: List[Dict[str, torch.Tensor]],
keys_to_max_dim: Dict[str, tuple] = {},
pad_value: float = 0.0,
) -> Dict[str, torch.Tensor]:
"""
Pads tensors to given target shape (if provided), otherwise uses per-batch max.
Supports 1D (e.g. action), 3D (e.g. [C,H,W] images).
"""
collated_batch = [{} for _ in range(len(batch))]
batch_keys = batch[0].keys()
for key in batch_keys:
values = [sample[key] for sample in batch]
sample = values[0]
if not isinstance(sample, torch.Tensor):
for i in range(len(batch)):
collated_batch[i][key] = values[i]
continue
# use user-specified shape if available
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
target_shape = keys_to_max_dim[key]
else:
# compute per-batch max shape
target_shape = tuple(max(v.shape[i] for v in values) for i in range(sample.ndim))
for i in range(len(batch)):
collated_batch[i][key] = pad_tensor_to_shape(values[i], target_shape, pad_value=pad_value)
return default_collate(collated_batch)
+23 -5
View File
@@ -43,14 +43,32 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
else:
ep_ft_data = np.array(ep_data[key])
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
if ft["dtype"] in ["image", "video"]:
# Handle variable dimensions for image/video data
# Expected formats: (frames, channels, height, width) or (channels, height, width)
if ep_ft_data.ndim == 4:
# Standard case: (frames, channels, height, width)
axes_to_reduce = (0, 2, 3) # reduce over frames, height, width
elif ep_ft_data.ndim == 3:
# Squeezed case: (channels, height, width) - single frame
axes_to_reduce = (1, 2) # reduce over height, width
else:
raise ValueError(f"Unexpected dimensions for {ft['dtype']} data: {ep_ft_data.shape}")
keepdims = True
else:
axes_to_reduce = 0
keepdims = ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
if ep_ft_data.ndim == 4:
# For 4D data, squeeze the first axis (batch/frames)
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
elif ep_ft_data.ndim == 3:
# For 3D data, the stats already have correct shape (channels,)
pass
dataset.meta.episodes_stats[ep_idx] = ep_stats
+1
View File
@@ -16,5 +16,6 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
+19 -1
View File
@@ -32,6 +32,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -74,6 +75,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "smolvla2":
from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy
return SmolVLA2Policy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -95,6 +100,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "smolvla2":
return SmolVLA2Config(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
else:
@@ -147,7 +154,18 @@ def make_policy(
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
# Handle robot-type grouped stats - flatten to feature-level stats
if ds_meta.stats and len(ds_meta.stats) == 1:
# Single robot type - use its stats directly
robot_type = list(ds_meta.stats.keys())[0]
kwargs["dataset_stats"] = ds_meta.stats[robot_type]
elif ds_meta.stats and len(ds_meta.stats) > 1:
# Multiple robot types - need to aggregate across all robot types
# For now, use the first robot type (TODO: proper multi-robot handling)
robot_type = list(ds_meta.stats.keys())[0]
kwargs["dataset_stats"] = ds_meta.stats[robot_type]
else:
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
+1 -1
View File
@@ -79,7 +79,7 @@ def create_stats_buffers(
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if stats and key in stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
@@ -0,0 +1,191 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@dataclass
class PEFTConfig:
r: int = 4
lora_alpha: int = 16
lora_dropout: float = 0.1
target_modules: str = "q_proj,v_proj"
@PreTrainedConfig.register_subclass("smolvla2")
@dataclass
class SmolVLA2Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512)
# Add empty images. Used by smolvla_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
proj_width: int = 480
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5 # 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
optimizer_lr_vlm: float = 0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
checkpoint_path: str = None
peft_method: str = ""
peft_config: PEFTConfig = field(default_factory=PEFTConfig)
peft_target_model: str = ""
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
attention_mode: str = "cross_attn"
prefix_length: int = -1
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
num_vlm_layers: int = 16
past_obs_keys: str = "image"
add_local_special_image_tokens: bool = False
reverse_images_order: bool = False
state_to_prefix: bool = False
pad_language_to: str = "longest" # "max_length"
causal_action_attention_mask: bool = False
self_attn_every_n_layers: int = -1 # Number of layers used in the VLM (first num_vlm_layers layers)
# self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
robot_type: str = ""
self_attn_only_actions: bool = False
causal_attention_on_history: bool = False
predict_relative_actions: bool = False
relative_actions_mode: str = "first"
shuffle_camera_positions: bool = False
vlm_img_size: int = -1
regression_loss: bool = False
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return [0]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,605 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List, Optional
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
from peft import LoraConfig, TaskType, get_peft_model
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
# Disable device_map when using Accelerate for multi-GPU training
import os
use_accelerate = os.environ.get('ACCELERATE_USE_FSDP', 'false').lower() != 'true' and \
'ACCELERATE_CONFIG_FILE' in os.environ
device_map = None if use_accelerate else "auto"
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map=device_map,
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
config = self.vlm.config
else:
config = AutoConfig.from_pretrained(model_id)
self.vlm = SmolVLMForConditionalGeneration(config=config)
self.processor = AutoProcessor.from_pretrained(model_id)
if num_vlm_layers > 0:
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
self.config = config
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
if "cross" in attention_mode:
# Reshape qkv projections to have the same input dimension as the vlm
for layer_idx in range(len(self.lm_expert.layers)):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
self.num_attention_heads = self.config.text_config.num_attention_heads
self.num_key_value_heads = self.config.text_config.num_key_value_heads
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_mode = attention_mode
self.expert_hidden_size = lm_expert_config.hidden_size
self.set_requires_grad()
def configure_peft(self, config):
# return model
self.peft_method = config.peft_method
self.peft_target_model = config.peft_target_model
if "lora" in self.peft_method:
peft_config = config.peft_config
target_modules = peft_config.target_modules
if not isinstance(target_modules, list):
target_modules = target_modules.split(",")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
r=peft_config.r, # The rank of the low-rank adaptation
lora_alpha=peft_config.lora_alpha, # Scaling factor
lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
target_modules=target_modules, # The components where LoRA is applied
exclude_modules=[
"lm_expert",
"model.lm_expert.model.layers",
], # FIXME(mshukor): this does not work for now
)
self.lora_config = lora_config
# Apply LoRA and ensure only LoRA parameters are trainable
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config)
else:
self.vlm = get_peft_model(self.vlm, lora_config)
# assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here?
for name, param in self.vlm.named_parameters():
if (
"lora" in name and "text_model.model.layers.17" not in name
): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
param.requires_grad = True
else:
param.requires_grad = False
def merge_lora_weights(self):
"""
Merge LoRA weights into the base model.
"""
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload()
else:
self.vlm = self.vlm.merge_and_unload()
def get_vlm_model(
self,
):
if hasattr(self.vlm.model, "model"): # When using peft
return self.vlm.model.model
else:
return self.vlm.model
def set_requires_grad(self):
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
for params in self.get_vlm_model().vision_model.parameters():
params.requires_grad = False
if self.train_expert_only:
self.vlm.eval()
for params in self.vlm.parameters():
params.requires_grad = False
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any(k in name for k in frozen_layers):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
if "lm_head" in name:
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
if self.train_expert_only:
self.vlm.eval()
def embed_image(self, image: torch.Tensor):
patch_attention_mask = None
# Get sequence from the vision encoder
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
seq_len = query_states.shape[1]
if seq_len < position_ids.shape[1]:
_position_ids = position_ids[:, :seq_len]
_attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
_position_ids = position_ids
_attention_mask = attention_mask
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(query_states, position_ids_)
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
)
return [att_output], past_key_values
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# B,L,H,D with L sequence length, H number of heads, D head dim
query_states = apply_rope(query_state, position_id)
key_states = apply_rope(key_state, position_id)
att_output = attention_interface(
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_outputs.append(att_output)
else:
expert_position_id = position_ids
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
expert_input_shape = expert_hidden_states.shape[:-1]
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
att_outputs.append(None)
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list:
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
for i in range(self.num_vlm_layers):
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
expert_layer = None
else:
expert_layer_index = i // multiple_of if multiple_of > 0 else i
expert_layer = models[1].layers[expert_layer_index]
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: List[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
continue
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
att_out = att_output[:, start:end]
out_emb = layer.self_attn.o_proj(att_out)
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end if len(att_outputs) == 1 else 0
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
attention_interface = self.eager_attention_forward
return attention_interface
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output
+5 -1
View File
@@ -243,7 +243,11 @@ def eval_policy(
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
if not isinstance(policy, PreTrainedPolicy):
# Handle accelerate-wrapped models by unwrapping them
if hasattr(policy, 'module') and isinstance(policy.module, PreTrainedPolicy):
# This is likely an accelerate-wrapped model (DistributedDataParallel)
policy = policy.module
elif not isinstance(policy, PreTrainedPolicy):
raise ValueError(
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
)
+95 -28
View File
@@ -16,6 +16,7 @@
import logging
import time
from contextlib import nullcontext
from functools import partial
from pprint import pformat
from typing import Any
@@ -23,12 +24,16 @@ import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
import os
from datetime import timedelta
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.datasets.utils_must import multidataset_collate_fn
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
@@ -52,6 +57,8 @@ from lerobot.utils.utils import (
)
from lerobot.utils.wandb_utils import WandBLogger
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def update_policy(
train_metrics: MetricsTracker,
@@ -63,41 +70,55 @@ def update_policy(
lr_scheduler=None,
use_amp: bool = False,
lock=None,
accelerator=None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
grad_norm = 0.0 # Initialize grad_norm to avoid undefined variable
if accelerator:
with accelerator.accumulate(policy):
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
optimizer.zero_grad()
else:
# Standard training loop without accelerate
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.scale(loss).backward()
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
grad_scaler.update()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
if accelerator:
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
policy.update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
@@ -108,8 +129,34 @@ def update_policy(
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
accelerator = None # Initialize accelerator variable
if is_launched_with_accelerate():
import accelerate
# For example pi0 has unused params (last llm block)
from accelerate import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
from accelerate import InitProcessGroupKwargs
# Set NCCL timeout (default 30 minutes = 1800 seconds)
nccl_timeout = getattr(cfg, 'nccl_timeout', 1800)
ddp_init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=nccl_timeout)) # FIXME(mshukor): allow user to set timeout. This should be longer than the evaluation time
# Set gradient accumulation steps (default 1)
gradient_accumulation_steps = getattr(cfg, 'gradient_accumulation_steps', 1)
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[ddp_init_kwargs, ddp_kwargs])
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
logging.info(pformat(cfg.to_dict()))
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
@@ -173,16 +220,29 @@ def train(cfg: TrainPipelineConfig):
else:
shuffle = True
sampler = None
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
keys_to_max_dim = {
"action": (32,),
"observation.state": (32,),
"observation.image": (3, 1080, 1920),
"observation.image2": (3, 1080, 1920),
}
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=collate_fn,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
) # Most important line
if accelerator:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
@@ -218,6 +278,7 @@ def train(cfg: TrainPipelineConfig):
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp,
accelerator=accelerator,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -240,7 +301,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
# Unwrap policy from accelerate if needed
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
@@ -252,9 +315,11 @@ def train(cfg: TrainPipelineConfig):
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
# Unwrap policy from accelerate if needed for evaluation
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
eval_info = eval_policy(
eval_env,
policy,
unwrapped_policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
@@ -283,7 +348,9 @@ def train(cfg: TrainPipelineConfig):
logging.info("End of training")
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
# Unwrap policy from accelerate if needed
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
unwrapped_policy.push_model_to_hub(cfg)
if __name__ == "__main__":
+31 -3
View File
@@ -60,11 +60,39 @@ def load_training_step(save_dir: Path) -> int:
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
import fcntl
import tempfile
import os
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
last_checkpoint_dir.unlink()
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
last_checkpoint_dir.symlink_to(relative_target)
# Use file locking to prevent race conditions in multi-GPU training
lock_file = checkpoint_dir.parent / ".symlink_lock"
try:
with open(lock_file, 'w') as f:
# Get exclusive lock
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
# Update symlink atomically
if last_checkpoint_dir.exists() or last_checkpoint_dir.is_symlink():
last_checkpoint_dir.unlink()
last_checkpoint_dir.symlink_to(relative_target)
except (OSError, FileExistsError) as e:
# Handle race conditions gracefully - another process may have already updated
if not last_checkpoint_dir.exists():
try:
last_checkpoint_dir.symlink_to(relative_target)
except FileExistsError:
pass # Another process created it, that's fine
finally:
# Clean up lock file
try:
lock_file.unlink()
except FileNotFoundError:
pass
def save_checkpoint(
+10 -1
View File
@@ -33,7 +33,16 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
f"seed:{cfg.seed}",
]
if cfg.dataset is not None:
lst.append(f"dataset:{cfg.dataset.repo_id}")
# Create shorter dataset tag to avoid wandb 64-char limit
repo_id = cfg.dataset.repo_id
if "," in repo_id:
# Multiple datasets - use count
dataset_count = len(repo_id.split(","))
lst.append(f"datasets:{dataset_count}")
else:
# Single dataset - use last part of path
dataset_name = repo_id.split("/")[-1][:20] # Truncate to 20 chars
lst.append(f"dataset:{dataset_name}")
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
+33 -14
View File
@@ -394,37 +394,56 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
@pytest.mark.skip("TODO after fix multidataset")
# @pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
"""Check that all dataset frames are incorporated and aligned correctly."""
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
# dummy padding dimensions (simulate training setup)
MAX_ACTION_DIM = 14
MAX_STATE_DIM = 30
MAX_NUM_IMAGES = 3
MAX_IMAGE_DIM = 224
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
dataset = MultiLeRobotDataset(
repo_ids,
max_action_dim=MAX_ACTION_DIM,
max_state_dim=MAX_STATE_DIM,
max_num_images=MAX_NUM_IMAGES,
max_image_dim=MAX_IMAGE_DIM,
)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
for expected_dataset_index, sub_item, multi_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
dataset_index = multi_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# we ignore padding_mask and dataset_index keys in multi_item
extra_keys = {k for k in multi_item if "padding_mask" in k}
filtered_multi_keys = set(multi_item.keys()) - extra_keys
assert set(sub_item.keys()) == filtered_multi_keys, "mismatch in keys"
for k in sub_item:
if k not in multi_item:
continue
v1, v2 = sub_item[k], multi_item[k]
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
else:
assert v1 == v2, f"value mismatch on key: {k}"
# TODO(aliberts): Move to more appropriate location
+228
View File
@@ -0,0 +1,228 @@
#!/bin/bash
#SBATCH --job-name=smolvla_optimized_8gpu_fresh
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=88
#SBATCH --gres=gpu:8
#SBATCH --mem=0
#SBATCH --time=72:00:00
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/dana_aubakirova/vla/logs/smolvla_optimized_8gpu_fresh_%j.out
#SBATCH --error=/fsx/dana_aubakirova/vla/logs/smolvla_optimized_8gpu_fresh_%j.err
#SBATCH --exclusive
# Create logs directory if it doesn't exist
mkdir -p /fsx/dana_aubakirova/vla/logs
# Activate conda environment
source /fsx/dana_aubakirova/miniconda/etc/profile.d/conda.sh
conda activate lerobot
# Add local lerobot source to Python path to use development version
export PYTHONPATH="/fsx/dana_aubakirova/vla/lerobot/src:$PYTHONPATH"
# OPTIMIZED 8-GPU CUDA environment - high performance configuration
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512,expandable_segments:True,garbage_collection_threshold:0.8
export TORCH_DISTRIBUTED_DEBUG=OFF
export NCCL_DEBUG=WARN
export CUDA_LAUNCH_BLOCKING=0
export ACCELERATE_USE_FSDP=false
export ACCELERATE_USE_DEEPSPEED=false
export HF_ACCELERATE_DEVICE_MAP=false
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
# 8-GPU optimizations
export NCCL_IB_DISABLE=1
export NCCL_P2P_DISABLE=1
# Change to working directory
cd /fsx/dana_aubakirova/vla
# FRESH START 8-GPU training configuration - NEW OUTPUT DIRECTORY
export OUTPUT_DIR="/fsx/dana_aubakirova/vla/outputs/test_smolvla_2datasets_$(date +%Y%m%d_%H%M%S)"
# Use ALL datasets from relative_datasets_list.txt - full scale training
export REPO_IDS="AndrejOrsula/lerobot_double_ball_stacking_random, koenvanwijk/orange50-variation-2"
# Model configuration - optimized for 8-GPU with global batch size 32
export VLM_REPO_ID=HuggingFaceTB/SmolVLM2-500M-Video-Instruct
export STEPS=100 # Quick test run
export BATCH_SIZE=8 # 4 per GPU = 32 global batch size (prevent hanging)
export EVAL_FREQ=-1 # Disable evaluation for faster training
export NUM_WORKERS=0 # MEMORY FIX: Disable workers to prevent memory exhaustion
export SAVE_FREQ=10000 # Save every 10k steps
# Model config - optimized settings inspired by SmolPi0
export POLICY=smolvla2
export USE_AMP=false # DISABLE AMP for stability
export OPTIMIZER_LR=5e-4 # Optimized learning rate
export PEFT_METHOD=lora
export LOAD_VLM_WEIGHTS=true
export MAX_ACTION_DIM=32
export MAX_STATE_DIM=32
# Dataset config - optimized from analysis
export USE_IMAGENET_STATS=false
export ENABLE_IMG_TRANSFORM=true
export MAX_NUM_IMAGES=2 # OPTIMIZED: 2 images for better context
export MAX_IMAGE_DIM=256 # OPTIMIZED: 256px resolution
export TRAIN_ON_ALL_FEATURES=true
export FEATURES_VERSION=2
# Advanced optimizations for 8-GPU setup
export FPS_MIN=30
export FPS_MAX=30
export GRADIENT_ACCUMULATION_STEPS=1 # Global batch size = 4 × 8 × 2 = 64
export PRECISION=no
export DROP_LAST=true
# SmolPi0-inspired VLM parameters
export VLM_LAYERS=16
export EXPERT_WIDTH_MULTIPLIER=0.75
export CAUSAL_ACTION_ATTENTION=true
export SELF_ATTN_EVERY_N_LAYERS=2
export ATTENTION_MODE=cross_attn
export LORA_R=32
export LORA_TARGET_MODULES=q_proj,v_proj
export PREFIX_LENGTH=0
# Learning rate schedule inspired by SmolPi0
export DECAY_LR=1e-6
export DECAY_STEPS=50000
export LR_VLM=1e-4
export WARMUP_STEPS=1000
# Set environment variables for model cache and offline mode
export HF_LEROBOT_HOME="/fsx/dana_aubakirova/vla"
export HF_HOME="/fsx/dana_aubakirova/vla/.cache/huggingface"
export HF_HUB_CACHE="/fsx/dana_aubakirova/vla/.cache/huggingface"
export TRANSFORMERS_CACHE="/fsx/dana_aubakirova/vla/.cache/huggingface"
export HF_HUB_OFFLINE=0
export TRANSFORMERS_OFFLINE=0
# Optimized accelerate config
export ACCELERATE_CONFIG_FILE="/fsx/dana_aubakirova/vla/accelerate_configs/optimized_fresh_config.yaml"
# Wandb configuration - FRESH START
export WANDB_PROJECT="smolvla2-training"
export WANDB_NOTES="8-GPU optimized training FRESH START - same parameters as previous run but from scratch"
export WANDB_MODE="disabled"
# Print comprehensive optimization info
echo "🚀 =============================================="
echo "🚀 OPTIMIZED 8-GPU FRESH START TRAINING"
echo "🚀 =============================================="
echo "🆕 FRESH START - No resume, new output directory"
echo "📊 Datasets: ALL available datasets (same as previous run)"
echo "📁 Output directory: $OUTPUT_DIR"
echo "🎯 Policy: $POLICY"
echo "🔧 Batch size per GPU: $BATCH_SIZE (GLOBAL BATCH SIZE: $((BATCH_SIZE * 8)))"
echo "🔄 Gradient accumulation steps: $GRADIENT_ACCUMULATION_STEPS"
echo "📈 Training steps: $STEPS"
echo "💾 Save frequency: $SAVE_FREQ"
echo "🔬 Evaluation frequency: $EVAL_FREQ"
echo "⚡ AMP enabled: $USE_AMP (no mixed precision - stable)"
echo "📚 Learning rate: $OPTIMIZER_LR"
echo "🎓 VLM Learning rate: $LR_VLM"
echo "🔥 Warmup steps: $WARMUP_STEPS"
echo "📷 Max images: $MAX_NUM_IMAGES"
echo "🖼️ Image dimension: $MAX_IMAGE_DIM"
echo "👥 Data workers per GPU: $NUM_WORKERS (memory optimized)"
echo "🧠 VLM layers: $VLM_LAYERS"
echo "🔄 Expert width multiplier: $EXPERT_WIDTH_MULTIPLIER"
echo "🎯 LORA rank: $LORA_R"
echo "🖥️ GPUs: 8 (HIGH PERFORMANCE)"
echo "📊 Wandb project: $WANDB_PROJECT"
echo "🚀 =============================================="
# Check GPU availability
echo "🖥️ GPU Information:"
nvidia-smi --list-gpus
# Create optimized 8-GPU accelerate config
mkdir -p /fsx/dana_aubakirova/vla/accelerate_configs
cat > /fsx/dana_aubakirova/vla/accelerate_configs/optimized_fresh_config.yaml << EOF
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: '0,1,2,3,4,5,6,7'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
EOF
echo "📋 Created optimized accelerate config with no mixed precision for stability"
# Run distributed training with optimized accelerate configuration - FRESH START
accelerate launch --config_file /fsx/dana_aubakirova/vla/accelerate_configs/optimized_fresh_config.yaml \
lerobot/src/lerobot/scripts/train.py \
--policy.type=$POLICY \
--dataset.repo_id="$REPO_IDS" \
--dataset.root="/fsx/dana_aubakirova/vla/community_dataset_v1" \
--dataset.use_imagenet_stats=$USE_IMAGENET_STATS \
--dataset.image_transforms.enable=$ENABLE_IMG_TRANSFORM \
--dataset.train_on_all_features=$TRAIN_ON_ALL_FEATURES \
--dataset.features_version=$FEATURES_VERSION \
--policy.max_action_dim=$MAX_ACTION_DIM \
--policy.max_state_dim=$MAX_STATE_DIM \
--output_dir=$OUTPUT_DIR \
--batch_size=$BATCH_SIZE \
--steps=$STEPS \
--eval_freq=$EVAL_FREQ \
--save_freq=$SAVE_FREQ \
--policy.use_amp=$USE_AMP \
--policy.optimizer_lr=$OPTIMIZER_LR \
--policy.optimizer_lr_vlm=$LR_VLM \
--policy.scheduler_decay_lr=$DECAY_LR \
--policy.scheduler_decay_steps=$DECAY_STEPS \
--policy.scheduler_warmup_steps=$WARMUP_STEPS \
--policy.peft_method=$PEFT_METHOD \
--policy.peft_config.r=$LORA_R \
--policy.peft_config.target_modules=$LORA_TARGET_MODULES \
--policy.load_vlm_weights=$LOAD_VLM_WEIGHTS \
--policy.repo_id=$VLM_REPO_ID \
--policy.push_to_hub=false \
--dataset.max_num_images=$MAX_NUM_IMAGES \
--dataset.max_image_dim=$MAX_IMAGE_DIM \
--dataset.video_backend=pyav \
--num_workers=$NUM_WORKERS \
--wandb.enable=false \
--wandb.project=$WANDB_PROJECT \
--wandb.notes="$WANDB_NOTES" \
--dataset.min_fps=$FPS_MIN \
--dataset.max_fps=$FPS_MAX \
--policy.num_vlm_layers=$VLM_LAYERS \
--policy.expert_width_multiplier=$EXPERT_WIDTH_MULTIPLIER \
--policy.causal_action_attention_mask=$CAUSAL_ACTION_ATTENTION \
--policy.self_attn_every_n_layers=$SELF_ATTN_EVERY_N_LAYERS \
--policy.attention_mode=$ATTENTION_MODE \
--policy.prefix_length=$PREFIX_LENGTH
echo "✅ Optimized 8-GPU FRESH START training completed! Check results in: $OUTPUT_DIR"
echo "📊 View training progress at: https://wandb.ai"
echo "🆕 FRESH START TRAINING SUMMARY:"
echo " • Started from scratch with new output directory"
echo " • Training from step 0 to step $STEPS"
echo " • Same optimized parameters as previous successful run"
echo " • New WandB run will be created automatically"
echo ""
echo "🚀 Key 8-GPU optimizations applied:"
echo " • 8 GPUs with global batch size $((BATCH_SIZE * 8))"
echo " • Memory-optimized data loading: 0 workers (prevents OOM)"
echo " • STABLE: No mixed precision (matches conservative setup)"
echo " • Optimized NCCL settings for 8-GPU communication"
echo " • Enhanced memory allocation for high-throughput"
echo ""
echo "🏃‍♂️ Expected performance gains:"
echo " • ~4x faster training throughput vs single GPU"
echo " • Clean start without any checkpoint compatibility issues"
echo " • Proven parameter configuration from previous run"