mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 61580a8596 | |||
| 6c8f1f962b | |||
| 7848b15bfb | |||
| 008b592545 | |||
| 55a61259e8 | |||
| e94d78f8a0 | |||
| 67c8d27e9c | |||
| c8b51ef205 | |||
| 9779c58b07 | |||
| 5fbe4f9987 | |||
| f47fd7aabf | |||
| a9251e612f | |||
| 2b27084d63 | |||
| ddb26b7189 | |||
| 96550e4ad1 | |||
| c0146eed7f | |||
| 0447604fce |
File diff suppressed because it is too large
Load Diff
@@ -37,6 +37,21 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
# Multi-dataset support
|
||||
sampling_weights: str | None = None
|
||||
max_action_dim: int | None = None
|
||||
max_state_dim: int | None = None
|
||||
max_num_images: int | None = None
|
||||
max_image_dim: int | None = None
|
||||
train_on_all_features: bool = False
|
||||
features_version: int = 0
|
||||
discard_first_n_frames: int = 0
|
||||
min_fps: int = 1
|
||||
max_fps: int = 100
|
||||
discard_first_idle_frames: bool = False
|
||||
motion_threshold: float = 5e-2
|
||||
motion_window_size: int = 10
|
||||
motion_buffer: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -22,9 +22,13 @@ OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
OBS_IMAGE_2 = "observation.image2"
|
||||
OBS_IMAGE_3 = "observation.image3"
|
||||
OBS_IMAGE_4 = "observation.image4"
|
||||
REWARD = "next.reward"
|
||||
|
||||
ROBOTS = "robots"
|
||||
TASK = "task"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
|
||||
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
|
||||
|
||||
|
||||
def pad_tensor(
|
||||
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
|
||||
) -> torch.Tensor:
|
||||
is_numpy = isinstance(tensor, np.ndarray)
|
||||
if is_numpy:
|
||||
tensor = torch.tensor(tensor)
|
||||
pad = max_size - tensor.shape[pad_dim]
|
||||
if pad > 0:
|
||||
pad_sizes = (0, pad) # pad right
|
||||
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
|
||||
return tensor.numpy() if is_numpy else tensor
|
||||
|
||||
|
||||
def pad_list_of_tensors(
|
||||
tensors: List[torch.Tensor], pad_dim: int = -1, pad_value: float = 0.0
|
||||
) -> List[torch.Tensor]:
|
||||
max_size = max([v.shape[pad_dim] for v in tensors])
|
||||
return [pad_tensor(tensor, max_size, pad_dim=pad_dim, pad_value=pad_value) for tensor in tensors]
|
||||
|
||||
|
||||
def multidataset_collate_fn(
|
||||
batch: List[Dict[str, torch.Tensor]],
|
||||
pad_dim: int = -1,
|
||||
pad_value: float = 0.0,
|
||||
keys_to_max_dim: dict = {},
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Custom collate function to pad tensors with multiple dimensions.
|
||||
|
||||
Args:
|
||||
batch (List[Dict[str, torch.Tensor]]): List of dataset samples (each sample is a dictionary).
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Batch with padded tensors.
|
||||
"""
|
||||
batch_keys = batch[0].keys()
|
||||
collated_batch = [{} for _ in range(len(batch))]
|
||||
# FIXME(mshukor): pad to max shape per feature type
|
||||
for key in batch_keys:
|
||||
values = [sample[key] for sample in batch]
|
||||
if (
|
||||
key in keys_to_max_dim
|
||||
and isinstance(values[0], torch.Tensor)
|
||||
and is_batch_need_padding(values, pad_dim=pad_dim)
|
||||
and keys_to_max_dim[key] is not None
|
||||
):
|
||||
max_size = keys_to_max_dim[key]
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = pad_tensor(
|
||||
batch[i][key], max_size, pad_dim=pad_dim, pad_value=pad_value
|
||||
)
|
||||
else:
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = batch[i][key]
|
||||
collated_batch = default_collate(collated_batch)
|
||||
|
||||
return collated_batch
|
||||
@@ -125,9 +125,30 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||||
# Filter out stats that don't have required keys
|
||||
valid_stats = []
|
||||
for s in stats_ft_list:
|
||||
if all(key in s for key in ["mean", "std", "count", "min", "max"]):
|
||||
valid_stats.append(s)
|
||||
else:
|
||||
# If count is missing, add it with a default value
|
||||
if "count" not in s:
|
||||
s["count"] = np.array([1]) # Default count
|
||||
valid_stats.append(s)
|
||||
|
||||
if not valid_stats:
|
||||
# If no valid stats, return empty stats
|
||||
return {
|
||||
"min": np.array([0]),
|
||||
"max": np.array([0]),
|
||||
"mean": np.array([0]),
|
||||
"std": np.array([0]),
|
||||
"count": np.array([0]),
|
||||
}
|
||||
|
||||
means = np.stack([s["mean"] for s in valid_stats])
|
||||
variances = np.stack([s["std"] ** 2 for s in valid_stats])
|
||||
counts = np.stack([s["count"] for s in valid_stats])
|
||||
total_count = counts.sum(axis=0)
|
||||
|
||||
# Prepare weighted mean by matching number of dimensions
|
||||
@@ -142,12 +163,13 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
delta_means = means - total_mean
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
total_std = np.sqrt(total_variance)
|
||||
|
||||
return {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"min": np.min(np.stack([s["min"] for s in valid_stats]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in valid_stats]), axis=0),
|
||||
"mean": total_mean,
|
||||
"std": np.sqrt(total_variance),
|
||||
"std": total_std,
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ IMAGENET_STATS = {
|
||||
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
|
||||
}
|
||||
|
||||
from lerobot.datasets.utils_must import EPISODES_DATASET_MAPPING, FEATURE_KEYS_MAPPING
|
||||
|
||||
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
@@ -81,37 +83,87 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset.repo_id, str):
|
||||
if "," in cfg.dataset.repo_id:
|
||||
repo_id = cfg.dataset.repo_id.split(",")
|
||||
repo_id = [r for r in repo_id if r]
|
||||
else:
|
||||
repo_id = cfg.dataset.repo_id
|
||||
sampling_weights = cfg.dataset.sampling_weights.split(",") if cfg.dataset.sampling_weights else None
|
||||
feature_keys_mapping = FEATURE_KEYS_MAPPING
|
||||
if isinstance(repo_id, str):
|
||||
revision = getattr(cfg.dataset, "revision", None)
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
cfg.dataset.repo_id,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
revision=revision,
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
root=getattr(cfg.dataset, "root", None),
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
revision=revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
download_videos=True,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
max_action_dim=cfg.dataset.max_action_dim,
|
||||
max_state_dim=cfg.dataset.max_state_dim,
|
||||
max_num_images=cfg.dataset.max_num_images,
|
||||
max_image_dim=cfg.dataset.max_image_dim,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
delta_timestamps = {}
|
||||
episodes = {}
|
||||
for i in range(len(repo_id)):
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
repo_id[i],
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
) # FIXME(mshukor): ?
|
||||
delta_timestamps[repo_id[i]] = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
episodes[repo_id[i]] = EPISODES_DATASET_MAPPING.get(repo_id[i], cfg.dataset.episodes)
|
||||
# training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None)
|
||||
# FIXME: (jadechoghari): check support for training features
|
||||
training_features = None
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
episodes=episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
download_videos=True,
|
||||
sampling_weights=sampling_weights,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
max_action_dim=cfg.policy.max_action_dim,
|
||||
max_state_dim=cfg.policy.max_state_dim,
|
||||
max_num_images=cfg.dataset.max_num_images,
|
||||
max_image_dim=cfg.dataset.max_image_dim,
|
||||
train_on_all_features=cfg.dataset.train_on_all_features,
|
||||
training_features=training_features,
|
||||
discard_first_n_frames=cfg.dataset.discard_first_n_frames,
|
||||
min_fps=cfg.dataset.min_fps,
|
||||
max_fps=cfg.dataset.max_fps,
|
||||
discard_first_idle_frames=cfg.dataset.discard_first_idle_frames,
|
||||
motion_threshold=cfg.dataset.motion_threshold,
|
||||
motion_window_size=cfg.dataset.motion_window_size,
|
||||
motion_buffer=cfg.dataset.motion_buffer,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
# Initialize stats structure if it doesn't exist
|
||||
if dataset.meta.stats is None:
|
||||
dataset.meta.stats = {}
|
||||
|
||||
for key in dataset.meta.camera_keys:
|
||||
# Initialize stats for this camera key if it doesn't exist
|
||||
if key not in dataset.meta.stats or dataset.meta.stats[key] is None:
|
||||
dataset.meta.stats[key] = {}
|
||||
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@@ -30,8 +31,16 @@ from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.constants import (
|
||||
ACTION,
|
||||
HF_LEROBOT_HOME,
|
||||
OBS_ENV_STATE,
|
||||
OBS_STATE,
|
||||
)
|
||||
from lerobot.datasets.compute_stats import ( # aggregate_stats_per_robot_type,
|
||||
aggregate_stats,
|
||||
compute_episode_stats,
|
||||
)
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
@@ -41,7 +50,6 @@ from lerobot.datasets.utils import (
|
||||
_validate_feature_names,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_empty_dataset_info,
|
||||
@@ -58,12 +66,34 @@ from lerobot.datasets.utils import (
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
map_dict_keys,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
write_json,
|
||||
# keep_datasets_with_the_same_features_per_robot_type,
|
||||
# map_dict_pad_keys,
|
||||
# keep_datasets_with_valid_fps,
|
||||
# find_start_of_motion,
|
||||
)
|
||||
|
||||
# mustafa stuff here
|
||||
from lerobot.datasets.utils_must import (
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGE_2,
|
||||
OBS_IMAGE_3,
|
||||
ROBOT_TYPE_KEYS_MAPPING,
|
||||
TASKS_KEYS_MAPPING,
|
||||
aggregate_stats_per_robot_type,
|
||||
create_padded_features,
|
||||
find_start_of_motion,
|
||||
keep_datasets_with_the_same_features_per_robot_type,
|
||||
keep_datasets_with_valid_fps,
|
||||
map_dict_keys,
|
||||
pad_tensor,
|
||||
reshape_features_to_max_dim,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
@@ -74,6 +104,15 @@ from lerobot.datasets.video_utils import (
|
||||
)
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
|
||||
|
||||
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
|
||||
for t in range(len(velocities) - window_size):
|
||||
window_mean = velocities[t : t + window_size].mean()
|
||||
if window_mean > threshold:
|
||||
return max(0, t - motion_buffer) # include slight context before motion
|
||||
return 0
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -81,10 +120,13 @@ class LeRobotDatasetMetadata:
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
feature_keys_mapping: dict[str, str] | None = None,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.local_files_only = local_files_only
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
@@ -99,18 +141,27 @@ class LeRobotDatasetMetadata:
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
# added by mshukor
|
||||
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||
self.inverse_feature_keys_mapping = (
|
||||
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||
)
|
||||
self.info["features"] = map_dict_keys(
|
||||
self.info["features"], feature_keys_mapping=self.feature_keys_mapping
|
||||
)
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
# Force all datasets to use v2.1 format (episodes_stats.jsonl) to avoid missing stats.json issues, because I converted all the datasets to v2.1 format.
|
||||
# if self._version < packaging.version.parse("v2.1"):
|
||||
# self.stats = load_stats(self.root)
|
||||
# self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
# else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -177,7 +228,15 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
# changed
|
||||
keys = []
|
||||
for key, ft in self.features.items():
|
||||
key_ = (
|
||||
self.inverse_feature_keys_mapping.get(key, key) if self.inverse_feature_keys_mapping else key
|
||||
)
|
||||
if ft["dtype"] == "video":
|
||||
keys.append(key_)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
@@ -342,6 +401,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
local_files_only: bool = False,
|
||||
# new thing by M
|
||||
feature_keys_mapping: dict[str, str] | None = None,
|
||||
max_action_dim: int = None,
|
||||
max_state_dim: int = None,
|
||||
max_num_images: int = None,
|
||||
max_image_dim: int = None,
|
||||
training_features: list | None = None,
|
||||
discard_first_n_frames: int = 0,
|
||||
discard_first_idle_frames: bool = False,
|
||||
motion_threshold: float = 5e-2,
|
||||
motion_window_size: int = 10,
|
||||
motion_buffer: int = 3,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -455,15 +527,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.delta_indices = None
|
||||
|
||||
# by mshukor
|
||||
self.training_features = training_features
|
||||
self.discard_first_n_frames = discard_first_n_frames
|
||||
self.discard_first_idle_frames = discard_first_idle_frames
|
||||
self.motion_threshold = motion_threshold
|
||||
self.motion_window_size = motion_window_size
|
||||
self.motion_buffer = motion_buffer
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# more mshukor
|
||||
self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None
|
||||
self.inverse_feature_keys_mapping = (
|
||||
{v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {}
|
||||
)
|
||||
|
||||
# Load metadata
|
||||
# TODO: change
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
self.repo_id,
|
||||
self.root,
|
||||
local_files_only=local_files_only,
|
||||
revision=self.revision,
|
||||
force_cache_sync=force_cache_sync,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
@@ -482,17 +574,74 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# mustafa code
|
||||
if self.discard_first_n_frames > 0:
|
||||
print("Discarding first n frames:", self.discard_first_n_frames)
|
||||
self.subset_frame_ids = []
|
||||
for ep_idx in range(self.num_episodes):
|
||||
from_ = self.episode_data_index["from"][ep_idx]
|
||||
to_ = self.episode_data_index["to"][ep_idx]
|
||||
# TODO implement advanced strategy
|
||||
self.subset_frame_ids += [
|
||||
frame_idx for frame_idx in range(from_ + int(self.fps * self.discard_first_n_frames), to_)
|
||||
]
|
||||
elif self.discard_first_idle_frames:
|
||||
print(
|
||||
f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}"
|
||||
)
|
||||
self.robot_states = torch.stack(self.hf_dataset[OBS_STATE]).numpy() # shape: [T, D]
|
||||
self.subset_frame_ids = []
|
||||
for ep_idx in range(self.num_episodes):
|
||||
from_ = self.episode_data_index["from"][ep_idx]
|
||||
to_ = self.episode_data_index["to"][ep_idx]
|
||||
ep_states = self.robot_states[from_:to_]
|
||||
velocities = np.linalg.norm(np.diff(ep_states, axis=0), axis=1)
|
||||
velocities = np.concatenate([[0.0], velocities])
|
||||
start_idx = find_start_of_motion(
|
||||
velocities, self.motion_window_size, self.motion_threshold, self.motion_buffer
|
||||
)
|
||||
self.subset_frame_ids += list(range(from_ + start_idx, to_))
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
# commented TODO: check why
|
||||
# timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
# episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
# ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
# check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
# TODO: check why commented
|
||||
# check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Mustafa
|
||||
self.meta.info["features"] = map_dict_keys(
|
||||
self.meta.info["features"],
|
||||
feature_keys_mapping=self.feature_keys_mapping,
|
||||
training_features=self.training_features,
|
||||
)
|
||||
self.keys_to_max_dim = {
|
||||
ACTION: max_action_dim,
|
||||
OBS_ENV_STATE: max_state_dim,
|
||||
OBS_STATE: max_state_dim,
|
||||
OBS_IMAGE: max_image_dim,
|
||||
OBS_IMAGE_2: max_image_dim,
|
||||
OBS_IMAGE_3: max_image_dim,
|
||||
}
|
||||
self.meta.info["features"] = reshape_features_to_max_dim(
|
||||
self.meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
self.meta.stats = map_dict_keys(
|
||||
self.meta.stats,
|
||||
feature_keys_mapping=self.feature_keys_mapping,
|
||||
training_features=self.training_features,
|
||||
)
|
||||
self.robot_type = self.meta.info.get("robot_type", "")
|
||||
# Override tasks
|
||||
print(TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks), "previous", self.meta.tasks)
|
||||
self.meta.tasks = TASKS_KEYS_MAPPING.get(self.repo_id, self.meta.tasks)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
@@ -641,12 +790,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
# Bounds check to prevent IndexError when episode_index is out of range
|
||||
if ep_idx >= len(self.episode_data_index["from"]):
|
||||
# Fall back to the last valid episode
|
||||
ep_idx = len(self.episode_data_index["from"]) - 1
|
||||
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
# FIXME(mshukor): what if we train on multiple datasets with different features
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
@@ -670,12 +825,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return query_timestamps
|
||||
|
||||
# TODO: changed by mustafa
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
queries = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if (
|
||||
key not in self.meta.video_keys
|
||||
and self.inverse_feature_keys_mapping.get(key, key) not in self.meta.video_keys
|
||||
):
|
||||
key_ = (
|
||||
self.inverse_feature_keys_mapping.get(key, key)
|
||||
if self.inverse_feature_keys_mapping
|
||||
else key
|
||||
)
|
||||
queries[key] = torch.stack(self.hf_dataset.select(q_idx)[key_])
|
||||
return queries
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
@@ -699,8 +863,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __len__(self):
|
||||
return self.num_frames
|
||||
|
||||
# changed by mshukor
|
||||
def __getitem__(self, idx) -> dict:
|
||||
if self.discard_first_n_frames > 0 or self.discard_first_idle_frames:
|
||||
idx = self.subset_frame_ids[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
item = map_dict_keys(item, feature_keys_mapping=self.feature_keys_mapping)
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
@@ -717,15 +885,27 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
try:
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
except:
|
||||
print(self.meta.tasks, task_idx, self.repo_id)
|
||||
if "robot_type" not in item:
|
||||
item["robot_type"] = self.robot_type
|
||||
item = map_dict_keys(
|
||||
item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features
|
||||
)
|
||||
# Add padded features
|
||||
# item = self._add_padded_features(item, self.training_features)
|
||||
if self.image_transforms is not None:
|
||||
for cam in item:
|
||||
if cam in self.meta.camera_keys or ("image" in cam and "is_pad" not in cam):
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
# Map pad keys
|
||||
# print(item.keys(), "before")
|
||||
# item = map_dict_pad_keys(item, feature_keys_mapping=self.feature_keys_mapping, training_features=self.training_features)
|
||||
# print(item.keys())
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -985,6 +1165,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
obj.revision = None
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
@@ -1005,6 +1186,106 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return obj
|
||||
|
||||
|
||||
class MultiLeRobotDatasetMeta:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
keys_to_max_dim: dict[str, int],
|
||||
train_on_all_features: bool = False,
|
||||
):
|
||||
self.repo_ids = repo_ids
|
||||
self.keys_to_max_dim = keys_to_max_dim
|
||||
self.train_on_all_features = train_on_all_features
|
||||
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
|
||||
|
||||
# assign robot_type if missing
|
||||
for ds in datasets:
|
||||
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||
ds.robot_type = ds.meta.info["robot_type"]
|
||||
|
||||
# step 1: compute disabled features
|
||||
self.disabled_features = set()
|
||||
if not self.train_on_all_features:
|
||||
intersection = set(datasets[0].features)
|
||||
for ds in datasets:
|
||||
intersection.intersection_update(ds.features)
|
||||
if not intersection:
|
||||
raise RuntimeError("No common features across datasets.")
|
||||
for repo_id, ds in zip(repo_ids, datasets, strict=False):
|
||||
extra = set(ds.features) - intersection
|
||||
logging.warning(f"Disabling {extra} for repo {repo_id}")
|
||||
self.disabled_features.update(extra)
|
||||
|
||||
# step 2: build union_features excluding disabled
|
||||
self.union_features = {}
|
||||
for ds in datasets:
|
||||
for k, v in ds.features.items():
|
||||
if k not in self.disabled_features:
|
||||
self.union_features[k] = v
|
||||
|
||||
# step 3: reshape feature schema
|
||||
self.features = reshape_features_to_max_dim(
|
||||
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
|
||||
# step 4: aggregate stats
|
||||
self.stats = aggregate_stats_per_robot_type(datasets)
|
||||
for robot_type_, stats_ in self.stats.items():
|
||||
for feat_key, feat_stats in stats_.items():
|
||||
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||
for k, v in feat_stats.items():
|
||||
pad_value = 0 if k in ["min", "mean"] else 1
|
||||
self.stats[robot_type_][feat_key][k] = pad_tensor(
|
||||
v,
|
||||
max_size=self.keys_to_max_dim.get(feat_key, -1),
|
||||
pad_dim=-1,
|
||||
pad_value=pad_value,
|
||||
)
|
||||
|
||||
# step 5: episodes & tasks
|
||||
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
|
||||
|
||||
class MultiLeRobotDatasetCleaner:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
sampling_weights: list[float],
|
||||
datasets_repo_ids: list[str],
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
):
|
||||
self.original_datasets = datasets
|
||||
self.original_repo_ids = repo_ids
|
||||
self.original_weights = sampling_weights
|
||||
self.original_datasets_repo_ids = datasets_repo_ids
|
||||
|
||||
# step 1: remove datasets with invalid fps
|
||||
valid_fps_datasets = keep_datasets_with_valid_fps(datasets, min_fps=min_fps, max_fps=max_fps)
|
||||
|
||||
# step 2: keep datasets with same features per robot type
|
||||
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
|
||||
valid_fps_datasets
|
||||
)
|
||||
|
||||
self.cleaned_datasets = consistent_datasets
|
||||
self.keep_mask = keep_mask
|
||||
self.cleaned_weights = [sampling_weights[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]]
|
||||
self.cleaned_datasets_repo_ids = [
|
||||
datasets_repo_ids[i] for i in range(len(valid_fps_datasets)) if keep_mask[i]
|
||||
]
|
||||
|
||||
self.cumulative_sizes = np.array(
|
||||
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
|
||||
)
|
||||
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
@@ -1021,7 +1302,24 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
# add
|
||||
sampling_weights: list[float] | None = None,
|
||||
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
|
||||
max_action_dim: int = None,
|
||||
max_state_dim: int = None,
|
||||
max_num_images: int = None,
|
||||
max_image_dim: int = None,
|
||||
train_on_all_features: bool = False,
|
||||
training_features: list | None = None,
|
||||
discard_first_n_frames: int = 0,
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
discard_first_idle_frames: bool = False,
|
||||
motion_threshold: float = 0.05,
|
||||
motion_window_size: int = 10,
|
||||
motion_buffer: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
@@ -1029,46 +1327,89 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
_datasets = []
|
||||
datasets_repo_ids = []
|
||||
self.sampling_weights = []
|
||||
self.training_features = training_features
|
||||
|
||||
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
|
||||
assert len(sampling_weights) == len(repo_ids), (
|
||||
"The number of sampling weights must match the number of datasets. "
|
||||
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
|
||||
)
|
||||
for i, repo_id in enumerate(repo_ids):
|
||||
try:
|
||||
# delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
_datasets.append(
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes.get(repo_id, None) if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
training_features=training_features,
|
||||
discard_first_n_frames=discard_first_n_frames,
|
||||
discard_first_idle_frames=discard_first_idle_frames,
|
||||
motion_threshold=motion_threshold,
|
||||
motion_window_size=motion_window_size,
|
||||
motion_buffer=motion_buffer,
|
||||
)
|
||||
)
|
||||
datasets_repo_ids.append(repo_id)
|
||||
self.sampling_weights.append(float(sampling_weights[i]))
|
||||
except Exception as e:
|
||||
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
|
||||
print(
|
||||
f"Finish loading {len(_datasets)} datasets, with sampling weights: {self.sampling_weights} corresponding to: {datasets_repo_ids}"
|
||||
)
|
||||
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
# FIXME(mshukor): apply mapping to unify used keys
|
||||
# FIXME(mshukor): pad based on types in case we have more than one state?
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
self.delta_timestamps = (
|
||||
delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||
) # delta_timestamps # FIXME(mshukor): last repo?
|
||||
# In case datasets with the same robot_type have different features
|
||||
cleaner = MultiLeRobotDatasetCleaner(
|
||||
datasets=_datasets,
|
||||
repo_ids=repo_ids,
|
||||
sampling_weights=self.sampling_weights,
|
||||
datasets_repo_ids=datasets_repo_ids,
|
||||
min_fps=min_fps,
|
||||
max_fps=max_fps,
|
||||
)
|
||||
self._datasets = cleaner.cleaned_datasets
|
||||
self.sampling_weights = cleaner.cleaned_weights
|
||||
self.repo_ids = cleaner.cleaned_repo_ids
|
||||
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
|
||||
self.cumulative_sizes = cleaner.cumulative_sizes
|
||||
# self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets
|
||||
# self.meta.info = {
|
||||
# repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False)
|
||||
# }
|
||||
# self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features
|
||||
self.meta = MultiLeRobotDatasetMeta(
|
||||
datasets=self._datasets,
|
||||
repo_ids=self.repo_ids,
|
||||
keys_to_max_dim={
|
||||
ACTION: max_action_dim,
|
||||
OBS_ENV_STATE: max_state_dim,
|
||||
OBS_STATE: max_state_dim,
|
||||
OBS_IMAGE: max_image_dim,
|
||||
OBS_IMAGE_2: max_image_dim,
|
||||
OBS_IMAGE_3: max_image_dim,
|
||||
},
|
||||
train_on_all_features=train_on_all_features,
|
||||
)
|
||||
self.disabled_features = self.meta.disabled_features
|
||||
self.stats = self.meta.stats
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
@@ -1156,23 +1497,14 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
|
||||
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||
item = self._datasets[dataset_idx][local_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_features:
|
||||
item = create_padded_features(item, self.meta.features)
|
||||
for data_key in self.disabled_features: # FIXME(mshukor): not in getitem?
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -858,3 +858,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def map_dict_keys(
|
||||
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
|
||||
) -> dict:
|
||||
"""Maps feature keys from the dataset to the keys used in the model."""
|
||||
if feature_keys_mapping is None:
|
||||
return item
|
||||
features = {}
|
||||
for key in item:
|
||||
if key in feature_keys_mapping:
|
||||
if feature_keys_mapping[key] is not None:
|
||||
if training_features is None or feature_keys_mapping[key] in training_features:
|
||||
features[feature_keys_mapping[key]] = item[key]
|
||||
else:
|
||||
if training_features is None or key in training_features or pad_key in key:
|
||||
features[key] = item[key]
|
||||
return features
|
||||
|
||||
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Utils function by Mustafa to refactor
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGE_2 = "observation.image2"
|
||||
OBS_IMAGE_3 = "observation.image3"
|
||||
|
||||
|
||||
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
|
||||
"""Reshape features to have a maximum dimension of `max_dim`."""
|
||||
reshaped_features = {}
|
||||
for key in features:
|
||||
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
|
||||
reshaped_features[key] = features[key]
|
||||
shape = list(features[key]["shape"])
|
||||
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
|
||||
shape[-3] = keys_to_max_dim[key]
|
||||
shape[-2] = keys_to_max_dim[key]
|
||||
else:
|
||||
shape[reshape_dim] = keys_to_max_dim[key]
|
||||
reshaped_features[key]["shape"] = tuple(shape)
|
||||
else:
|
||||
reshaped_features[key] = features[key]
|
||||
return reshaped_features
|
||||
|
||||
|
||||
def keep_datasets_with_valid_fps(ls_datasets: list, min_fps: int = 1, max_fps: int = 100) -> list:
|
||||
print(
|
||||
f"Keeping datasets with fps between {min_fps} and {max_fps}. Considering {len(ls_datasets)} datasets."
|
||||
)
|
||||
for ds in ls_datasets:
|
||||
if ds.fps < min_fps or ds.fps > max_fps:
|
||||
print(f"Dataset {ds} has invalid fps: {ds.fps}. Removing it.")
|
||||
ls_datasets.remove(ds)
|
||||
print(f"Keeping {len(ls_datasets)} datasets with valid fps.")
|
||||
return ls_datasets
|
||||
|
||||
|
||||
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
|
||||
"""
|
||||
Filters datasets to only keep those with consistent feature shapes per robot type.
|
||||
|
||||
Args:
|
||||
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
|
||||
and `meta.episodes_stats` dictionary.
|
||||
|
||||
Returns:
|
||||
List: Filtered list of datasets with consistent feature shapes.
|
||||
"""
|
||||
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||
datasets_to_remove = set()
|
||||
|
||||
for robot_type in robot_types:
|
||||
# Collect all stats dicts for this robot type
|
||||
stats_list = [
|
||||
ep_stats
|
||||
for ds in ls_datasets
|
||||
if ds.meta.info["robot_type"] == robot_type
|
||||
for ep_stats in ds.meta.episodes_stats.values()
|
||||
if ep_stats is not None # Filter out None values
|
||||
]
|
||||
if not stats_list:
|
||||
continue
|
||||
|
||||
# Determine the most common shape for each key
|
||||
all_keys = {key for stats in stats_list for key in stats}
|
||||
for ds in ls_datasets:
|
||||
if ds.meta.info["robot_type"] != robot_type:
|
||||
continue
|
||||
for key in all_keys:
|
||||
shape_counter = defaultdict(int)
|
||||
|
||||
for stats in stats_list:
|
||||
value = stats.get(key)
|
||||
if (
|
||||
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||
): # FIXME(mshukor): check all stats; min, mean, max
|
||||
shape_counter[value["mean"].shape] += 1
|
||||
if not shape_counter:
|
||||
continue
|
||||
|
||||
# Identify the most frequent shape
|
||||
main_shape = max(shape_counter, key=shape_counter.get)
|
||||
# Flag datasets that don't match the main shape
|
||||
# for ds in ls_datasets:
|
||||
first_ep_stats = next(iter(ds.meta.episodes_stats.values()), None)
|
||||
if not first_ep_stats:
|
||||
continue
|
||||
value = first_ep_stats.get(key)
|
||||
if (
|
||||
value
|
||||
and "mean" in value
|
||||
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||
and value["mean"].shape != main_shape
|
||||
):
|
||||
datasets_to_remove.add(ds)
|
||||
break
|
||||
|
||||
# Filter out inconsistent datasets
|
||||
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
|
||||
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
|
||||
print(
|
||||
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
|
||||
)
|
||||
return filtered_datasets, datasets_maks
|
||||
|
||||
|
||||
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
|
||||
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||
stats = {robot_type: {} for robot_type in robot_types}
|
||||
for robot_type in robot_types:
|
||||
robot_type_datasets = []
|
||||
for ds in ls_datasets:
|
||||
if ds.meta.info["robot_type"] == robot_type:
|
||||
# Filter out None values from episodes_stats to handle missing stats
|
||||
valid_episodes_stats = [stats for stats in ds.meta.episodes_stats.values() if stats is not None]
|
||||
robot_type_datasets.extend(valid_episodes_stats)
|
||||
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
|
||||
if robot_type_datasets: # Only aggregate if we have valid stats
|
||||
stat = aggregate_stats(robot_type_datasets)
|
||||
stats[robot_type] = stat
|
||||
else:
|
||||
print(f"Warning: No valid episode stats found for robot type {robot_type}, skipping aggregation")
|
||||
stats[robot_type] = {}
|
||||
return stats
|
||||
|
||||
|
||||
def str_to_torch_dtype(dtype_str):
|
||||
"""Convert a dtype string to a torch dtype."""
|
||||
mapping = {
|
||||
"float32": torch.float32,
|
||||
"int64": torch.int64,
|
||||
"int16": torch.int16,
|
||||
"bool": torch.bool,
|
||||
"video": torch.float32, # Assuming video is stored as uint8 images
|
||||
}
|
||||
return mapping.get(dtype_str, torch.float32) # Default to float32
|
||||
|
||||
|
||||
def create_padded_features(item: dict, features: dict = {}):
|
||||
for key, ft in features.items():
|
||||
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if len(shape) == 3: # images to torch format (C, H, W)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
|
||||
shape = []
|
||||
if key not in item:
|
||||
dtype = str_to_torch_dtype(ft["dtype"])
|
||||
item[key] = torch.zeros(shape, dtype=dtype)
|
||||
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
|
||||
if "image" in key: # FIXME(mshukor): support other observations
|
||||
item[f"{key}_is_pad"] = torch.BoolTensor([False])
|
||||
else:
|
||||
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
|
||||
return item
|
||||
|
||||
|
||||
ROBOT_TYPE_KEYS_MAPPING = {
|
||||
"lerobot/stanford_hydra_dataset": "static_single_arm",
|
||||
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
|
||||
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
|
||||
"lerobot/toto": "static_single_arm",
|
||||
"lerobot/roboturk": "static_single_arm",
|
||||
"lerobot/jaco_play": "static_single_arm",
|
||||
"lerobot/taco_play": "static_single_arm_7statedim",
|
||||
}
|
||||
|
||||
|
||||
def pad_tensor(
|
||||
tensor: torch.Tensor, max_size: int, pad_dim: int = -1, pad_value: float = 0.0
|
||||
) -> torch.Tensor:
|
||||
is_numpy = isinstance(tensor, np.ndarray)
|
||||
if is_numpy:
|
||||
tensor = torch.tensor(tensor)
|
||||
if tensor.ndim == 0:
|
||||
# Scalar — return as-is, no padding needed
|
||||
return tensor
|
||||
pad = max_size - tensor.shape[pad_dim]
|
||||
if pad > 0:
|
||||
pad_sizes = (0, pad) # pad right
|
||||
tensor = torch.nn.functional.pad(tensor, pad_sizes, value=pad_value)
|
||||
return tensor.numpy() if is_numpy else tensor
|
||||
|
||||
|
||||
def map_dict_keys(
|
||||
item: dict, feature_keys_mapping: dict, training_features: list = None, pad_key: str = "is_pad"
|
||||
) -> dict:
|
||||
"""Maps feature keys from the dataset to the keys used in the model."""
|
||||
if feature_keys_mapping is None:
|
||||
return item
|
||||
features = {}
|
||||
for key in item:
|
||||
if key in feature_keys_mapping:
|
||||
if feature_keys_mapping[key] is not None:
|
||||
if training_features is None or feature_keys_mapping[key] in training_features:
|
||||
features[feature_keys_mapping[key]] = item[key]
|
||||
else:
|
||||
if training_features is None or key in training_features or pad_key in key:
|
||||
features[key] = item[key]
|
||||
|
||||
# breakpoint()
|
||||
return features
|
||||
|
||||
|
||||
def find_start_of_motion(velocities, window_size, threshold, motion_buffer):
|
||||
for t in range(len(velocities) - window_size):
|
||||
window_mean = velocities[t : t + window_size].mean()
|
||||
if window_mean > threshold:
|
||||
return max(0, t - motion_buffer) # include slight context before motion
|
||||
return 0
|
||||
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
|
||||
def load_yaml_mapping(name: str) -> dict:
|
||||
"""
|
||||
Loads a YAML mapping from a Hugging Face repo.
|
||||
Example: name='features' → https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/features.yaml
|
||||
"""
|
||||
url = f"https://huggingface.co/jadechoghari/smolvla-keys/resolve/main/{name}.yaml"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status() # raise if the download fails
|
||||
|
||||
return yaml.safe_load(response.text)
|
||||
|
||||
|
||||
# Example usage
|
||||
TASKS_KEYS_MAPPING = load_yaml_mapping("tasks")
|
||||
FEATURE_KEYS_MAPPING = load_yaml_mapping("features")
|
||||
EPISODES_DATASET_MAPPING = {
|
||||
"cadene/droid_1.0.1": list(range(50)),
|
||||
"danaaubakirova/svla_so100_task5_v3": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
],
|
||||
"danaaubakirova/svla_so100_task4_v3": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
],
|
||||
}
|
||||
ACTION = "action"
|
||||
OBS_STATE = "observation.state"
|
||||
TASK = "task"
|
||||
ROBOT = "robot_type"
|
||||
TRAINING_FEATURES = {
|
||||
0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE],
|
||||
1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2],
|
||||
2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
|
||||
}
|
||||
|
||||
|
||||
def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int:
|
||||
return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1
|
||||
|
||||
|
||||
def pad_tensor_to_shape(tensor: torch.Tensor, target_shape: tuple, pad_value: float = 0.0) -> torch.Tensor:
|
||||
"""Pads a tensor to the target shape (right/bottom only)."""
|
||||
pad = []
|
||||
for actual, target in zip(reversed(tensor.shape), reversed(target_shape), strict=False):
|
||||
pad.extend([0, max(target - actual, 0)])
|
||||
return F.pad(tensor, pad, value=pad_value)
|
||||
|
||||
|
||||
def multidataset_collate_fn(
|
||||
batch: List[Dict[str, torch.Tensor]],
|
||||
keys_to_max_dim: Dict[str, tuple] = {},
|
||||
pad_value: float = 0.0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Pads tensors to given target shape (if provided), otherwise uses per-batch max.
|
||||
Supports 1D (e.g. action), 3D (e.g. [C,H,W] images).
|
||||
"""
|
||||
collated_batch = [{} for _ in range(len(batch))]
|
||||
batch_keys = batch[0].keys()
|
||||
|
||||
for key in batch_keys:
|
||||
values = [sample[key] for sample in batch]
|
||||
sample = values[0]
|
||||
|
||||
if not isinstance(sample, torch.Tensor):
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = values[i]
|
||||
continue
|
||||
|
||||
# use user-specified shape if available
|
||||
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
|
||||
target_shape = keys_to_max_dim[key]
|
||||
else:
|
||||
# compute per-batch max shape
|
||||
target_shape = tuple(max(v.shape[i] for v in values) for i in range(sample.ndim))
|
||||
|
||||
for i in range(len(batch)):
|
||||
collated_batch[i][key] = pad_tensor_to_shape(values[i], target_shape, pad_value=pad_value)
|
||||
|
||||
return default_collate(collated_batch)
|
||||
@@ -43,14 +43,32 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
# Handle variable dimensions for image/video data
|
||||
# Expected formats: (frames, channels, height, width) or (channels, height, width)
|
||||
if ep_ft_data.ndim == 4:
|
||||
# Standard case: (frames, channels, height, width)
|
||||
axes_to_reduce = (0, 2, 3) # reduce over frames, height, width
|
||||
elif ep_ft_data.ndim == 3:
|
||||
# Squeezed case: (channels, height, width) - single frame
|
||||
axes_to_reduce = (1, 2) # reduce over height, width
|
||||
else:
|
||||
raise ValueError(f"Unexpected dimensions for {ft['dtype']} data: {ep_ft_data.shape}")
|
||||
keepdims = True
|
||||
else:
|
||||
axes_to_reduce = 0
|
||||
keepdims = ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
if ep_ft_data.ndim == 4:
|
||||
# For 4D data, squeeze the first axis (batch/frames)
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
elif ep_ft_data.ndim == 3:
|
||||
# For 3D data, the stats already have correct shape (channels,)
|
||||
pass
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
@@ -16,5 +16,6 @@ from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
@@ -32,6 +32,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
||||
@@ -74,6 +75,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolvla2":
|
||||
from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy
|
||||
|
||||
return SmolVLA2Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -95,6 +100,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "smolvla2":
|
||||
return SmolVLA2Config(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
@@ -147,7 +154,18 @@ def make_policy(
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
# Handle robot-type grouped stats - flatten to feature-level stats
|
||||
if ds_meta.stats and len(ds_meta.stats) == 1:
|
||||
# Single robot type - use its stats directly
|
||||
robot_type = list(ds_meta.stats.keys())[0]
|
||||
kwargs["dataset_stats"] = ds_meta.stats[robot_type]
|
||||
elif ds_meta.stats and len(ds_meta.stats) > 1:
|
||||
# Multiple robot types - need to aggregate across all robot types
|
||||
# For now, use the first robot type (TODO: proper multi-robot handling)
|
||||
robot_type = list(ds_meta.stats.keys())[0]
|
||||
kwargs["dataset_stats"] = ds_meta.stats[robot_type]
|
||||
else:
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
|
||||
@@ -79,7 +79,7 @@ def create_stats_buffers(
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if stats and key in stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PEFTConfig:
|
||||
r: int = 4
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.1
|
||||
target_modules: str = "q_proj,v_proj"
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla2")
|
||||
@dataclass
|
||||
class SmolVLA2Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512)
|
||||
|
||||
# Add empty images. Used by smolvla_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
proj_width: int = 480
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5 # 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
optimizer_lr_vlm: float = 0
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
checkpoint_path: str = None
|
||||
peft_method: str = ""
|
||||
peft_config: PEFTConfig = field(default_factory=PEFTConfig)
|
||||
peft_target_model: str = ""
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
attention_mode: str = "cross_attn"
|
||||
|
||||
prefix_length: int = -1
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
|
||||
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
|
||||
num_vlm_layers: int = 16
|
||||
past_obs_keys: str = "image"
|
||||
add_local_special_image_tokens: bool = False
|
||||
|
||||
reverse_images_order: bool = False
|
||||
|
||||
state_to_prefix: bool = False
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
causal_action_attention_mask: bool = False
|
||||
|
||||
self_attn_every_n_layers: int = -1 # Number of layers used in the VLM (first num_vlm_layers layers)
|
||||
# self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
|
||||
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
|
||||
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
robot_type: str = ""
|
||||
|
||||
self_attn_only_actions: bool = False
|
||||
|
||||
causal_attention_on_history: bool = False
|
||||
|
||||
predict_relative_actions: bool = False
|
||||
relative_actions_mode: str = "first"
|
||||
|
||||
shuffle_camera_positions: bool = False
|
||||
vlm_img_size: int = -1
|
||||
|
||||
regression_loss: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,605 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
class SmolVLMWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
freeze_vision_encoder: bool = False,
|
||||
attention_mode: str = "self_attn",
|
||||
num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1,
|
||||
self_attn_every_n_layers: int = -1,
|
||||
expert_width_multiplier: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
# Disable device_map when using Accelerate for multi-GPU training
|
||||
import os
|
||||
use_accelerate = os.environ.get('ACCELERATE_USE_FSDP', 'false').lower() != 'true' and \
|
||||
'ACCELERATE_CONFIG_FILE' in os.environ
|
||||
device_map = None if use_accelerate else "auto"
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map=device_map,
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
config = self.vlm.config
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
self.vlm = SmolVLMForConditionalGeneration(config=config)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
if num_vlm_layers > 0:
|
||||
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
|
||||
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
|
||||
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
|
||||
self.config = config
|
||||
# Smaller lm expert
|
||||
lm_expert_config = copy.deepcopy(config.text_config)
|
||||
hidden_size = lm_expert_config.hidden_size
|
||||
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||
if num_expert_layers > 0:
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
)
|
||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||
|
||||
self.num_expert_layers = len(self.lm_expert.layers)
|
||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||
if "cross" in attention_mode:
|
||||
# Reshape qkv projections to have the same input dimension as the vlm
|
||||
for layer_idx in range(len(self.lm_expert.layers)):
|
||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||
continue
|
||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
# Remove unused embed_tokens
|
||||
self.lm_expert.embed_tokens = None
|
||||
|
||||
self.num_attention_heads = self.config.text_config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.text_config.num_key_value_heads
|
||||
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_mode = attention_mode
|
||||
self.expert_hidden_size = lm_expert_config.hidden_size
|
||||
self.set_requires_grad()
|
||||
|
||||
def configure_peft(self, config):
|
||||
# return model
|
||||
self.peft_method = config.peft_method
|
||||
self.peft_target_model = config.peft_target_model
|
||||
if "lora" in self.peft_method:
|
||||
peft_config = config.peft_config
|
||||
target_modules = peft_config.target_modules
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = target_modules.split(",")
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
|
||||
r=peft_config.r, # The rank of the low-rank adaptation
|
||||
lora_alpha=peft_config.lora_alpha, # Scaling factor
|
||||
lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
|
||||
target_modules=target_modules, # The components where LoRA is applied
|
||||
exclude_modules=[
|
||||
"lm_expert",
|
||||
"model.lm_expert.model.layers",
|
||||
], # FIXME(mshukor): this does not work for now
|
||||
)
|
||||
self.lora_config = lora_config
|
||||
# Apply LoRA and ensure only LoRA parameters are trainable
|
||||
if "text" in self.peft_target_model:
|
||||
self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config)
|
||||
else:
|
||||
self.vlm = get_peft_model(self.vlm, lora_config)
|
||||
# assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here?
|
||||
for name, param in self.vlm.named_parameters():
|
||||
if (
|
||||
"lora" in name and "text_model.model.layers.17" not in name
|
||||
): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
def merge_lora_weights(self):
|
||||
"""
|
||||
Merge LoRA weights into the base model.
|
||||
"""
|
||||
if "text" in self.peft_target_model:
|
||||
self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload()
|
||||
else:
|
||||
self.vlm = self.vlm.merge_and_unload()
|
||||
|
||||
def get_vlm_model(
|
||||
self,
|
||||
):
|
||||
if hasattr(self.vlm.model, "model"): # When using peft
|
||||
return self.vlm.model.model
|
||||
else:
|
||||
return self.vlm.model
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
for params in self.get_vlm_model().vision_model.parameters():
|
||||
params.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
for params in self.vlm.parameters():
|
||||
params.requires_grad = False
|
||||
else:
|
||||
# To avoid unused params issue with distributed training
|
||||
last_layers = [self.num_vlm_layers - 1]
|
||||
if (
|
||||
self.num_vlm_layers != self.num_expert_layers
|
||||
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||
):
|
||||
last_layers.append(self.num_vlm_layers - 2)
|
||||
frozen_layers = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
for layer in last_layers:
|
||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||
|
||||
for name, params in self.vlm.named_parameters():
|
||||
if any(k in name for k in frozen_layers):
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
for name, params in self.lm_expert.named_parameters():
|
||||
if "lm_head" in name:
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
patch_attention_mask = None
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = (
|
||||
self.get_vlm_model()
|
||||
.vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
.last_hidden_state
|
||||
)
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
if hidden_states is None or layer is None:
|
||||
continue
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
seq_len = query_states.shape[1]
|
||||
if seq_len < position_ids.shape[1]:
|
||||
_position_ids = position_ids[:, :seq_len]
|
||||
_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
else:
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
|
||||
attention_mask_ = _attention_mask
|
||||
position_ids_ = _position_ids
|
||||
|
||||
query_states = apply_rope(query_states, position_ids_)
|
||||
key_states = apply_rope(key_states, position_ids_)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_output = attention_interface(
|
||||
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
return [att_output], past_key_values
|
||||
|
||||
def forward_cross_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_outputs = []
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
)
|
||||
|
||||
if len(inputs_embeds) == 2 and not past_key_values:
|
||||
# Prefix attention
|
||||
seq_len = inputs_embeds[0].shape[1]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
|
||||
layer = model_layers[0][layer_idx]
|
||||
|
||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
query_states = apply_rope(query_state, position_id)
|
||||
key_states = apply_rope(key_state, position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
expert_position_id = position_ids
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = past_key_values[layer_idx]["key_states"]
|
||||
value_states = past_key_values[layer_idx]["value_states"]
|
||||
|
||||
# Expert
|
||||
expert_layer = model_layers[1][layer_idx]
|
||||
if expert_layer is not None:
|
||||
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
|
||||
|
||||
expert_input_shape = expert_hidden_states.shape[:-1]
|
||||
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
|
||||
|
||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||
*key_states.shape[:2], -1
|
||||
)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
) # k_proj should have same dim as kv
|
||||
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||
*value_states.shape[:2], -1
|
||||
)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
)
|
||||
|
||||
expert_position_id = (
|
||||
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||
) # start from 0
|
||||
expert_attention_mask = attention_mask[
|
||||
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||
] # take into account kv
|
||||
|
||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
expert_attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
expert_query_states,
|
||||
expert_key_states,
|
||||
expert_value_states,
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
att_outputs.append(None)
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
return att_outputs, past_key_values
|
||||
|
||||
def get_model_layers(self, models: list) -> list:
|
||||
vlm_layers = []
|
||||
expert_layers = []
|
||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||
for i in range(self.num_vlm_layers):
|
||||
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
|
||||
expert_layer = None
|
||||
else:
|
||||
expert_layer_index = i // multiple_of if multiple_of > 0 else i
|
||||
expert_layer = models[1].layers[expert_layer_index]
|
||||
vlm_layers.append(models[0].layers[i])
|
||||
expert_layers.append(expert_layer)
|
||||
return [vlm_layers, expert_layers]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||
model_layers = self.get_model_layers(models)
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.num_vlm_layers
|
||||
head_dim = self.vlm.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
if (
|
||||
fill_kv_cache
|
||||
or "cross" not in self.attention_mode
|
||||
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||
):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
att_output = (
|
||||
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||
) # in case of self_attn
|
||||
if hidden_states is not None:
|
||||
if layer is None:
|
||||
outputs_embeds.append(hidden_states)
|
||||
continue
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
att_out = att_output[:, start:end]
|
||||
out_emb = layer.self_attn.o_proj(att_out)
|
||||
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end if len(att_outputs) == 1 else 0
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.num_attention_heads
|
||||
num_key_value_heads = self.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
|
||||
att_weights = att_weights.to(dtype=torch.float32)
|
||||
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -243,7 +243,11 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
if not isinstance(policy, PreTrainedPolicy):
|
||||
# Handle accelerate-wrapped models by unwrapping them
|
||||
if hasattr(policy, 'module') and isinstance(policy.module, PreTrainedPolicy):
|
||||
# This is likely an accelerate-wrapped model (DistributedDataParallel)
|
||||
policy = policy.module
|
||||
elif not isinstance(policy, PreTrainedPolicy):
|
||||
raise ValueError(
|
||||
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
@@ -23,12 +24,16 @@ import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.datasets.utils_must import multidataset_collate_fn
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
@@ -52,6 +57,8 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
from lerobot.utils.wandb_utils import WandBLogger
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
@@ -63,41 +70,55 @@ def update_policy(
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
accelerator=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
grad_norm = 0.0 # Initialize grad_norm to avoid undefined variable
|
||||
|
||||
if accelerator:
|
||||
with accelerator.accumulate(policy):
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
# Standard training loop without accelerate
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
with lock if lock is not None else nullcontext():
|
||||
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
optimizer.zero_grad()
|
||||
grad_scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
if accelerator:
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
else:
|
||||
policy.update()
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
@@ -108,8 +129,34 @@ def update_policy(
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
|
||||
accelerator = None # Initialize accelerator variable
|
||||
|
||||
if is_launched_with_accelerate():
|
||||
import accelerate
|
||||
|
||||
# For example pi0 has unused params (last llm block)
|
||||
from accelerate import DistributedDataParallelKwargs
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||
from accelerate import InitProcessGroupKwargs
|
||||
# Set NCCL timeout (default 30 minutes = 1800 seconds)
|
||||
nccl_timeout = getattr(cfg, 'nccl_timeout', 1800)
|
||||
ddp_init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=nccl_timeout)) # FIXME(mshukor): allow user to set timeout. This should be longer than the evaluation time
|
||||
# Set gradient accumulation steps (default 1)
|
||||
gradient_accumulation_steps = getattr(cfg, 'gradient_accumulation_steps', 1)
|
||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[ddp_init_kwargs, ddp_kwargs])
|
||||
if accelerator is not None and not accelerator.is_main_process:
|
||||
# Disable duplicate logging on non-main processes
|
||||
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if accelerator and not accelerator.is_main_process:
|
||||
# Disable logging on non-main processes.
|
||||
cfg.wandb.enable = False
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
@@ -173,16 +220,29 @@ def train(cfg: TrainPipelineConfig):
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
|
||||
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
|
||||
keys_to_max_dim = {
|
||||
"action": (32,),
|
||||
"observation.state": (32,),
|
||||
"observation.image": (3, 1080, 1920),
|
||||
"observation.image2": (3, 1080, 1920),
|
||||
}
|
||||
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
) # Most important line
|
||||
if accelerator:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -218,6 +278,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.policy.use_amp,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -240,7 +301,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
# Unwrap policy from accelerate if needed
|
||||
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
|
||||
save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
@@ -252,9 +315,11 @@ def train(cfg: TrainPipelineConfig):
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
# Unwrap policy from accelerate if needed for evaluation
|
||||
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
unwrapped_policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
@@ -283,7 +348,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
logging.info("End of training")
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
# Unwrap policy from accelerate if needed
|
||||
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
|
||||
unwrapped_policy.push_model_to_hub(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -60,11 +60,39 @@ def load_training_step(save_dir: Path) -> int:
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
import fcntl
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
last_checkpoint_dir.unlink()
|
||||
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
|
||||
last_checkpoint_dir.symlink_to(relative_target)
|
||||
|
||||
# Use file locking to prevent race conditions in multi-GPU training
|
||||
lock_file = checkpoint_dir.parent / ".symlink_lock"
|
||||
|
||||
try:
|
||||
with open(lock_file, 'w') as f:
|
||||
# Get exclusive lock
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
# Update symlink atomically
|
||||
if last_checkpoint_dir.exists() or last_checkpoint_dir.is_symlink():
|
||||
last_checkpoint_dir.unlink()
|
||||
last_checkpoint_dir.symlink_to(relative_target)
|
||||
|
||||
except (OSError, FileExistsError) as e:
|
||||
# Handle race conditions gracefully - another process may have already updated
|
||||
if not last_checkpoint_dir.exists():
|
||||
try:
|
||||
last_checkpoint_dir.symlink_to(relative_target)
|
||||
except FileExistsError:
|
||||
pass # Another process created it, that's fine
|
||||
finally:
|
||||
# Clean up lock file
|
||||
try:
|
||||
lock_file.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
|
||||
@@ -33,7 +33,16 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
lst.append(f"dataset:{cfg.dataset.repo_id}")
|
||||
# Create shorter dataset tag to avoid wandb 64-char limit
|
||||
repo_id = cfg.dataset.repo_id
|
||||
if "," in repo_id:
|
||||
# Multiple datasets - use count
|
||||
dataset_count = len(repo_id.split(","))
|
||||
lst.append(f"datasets:{dataset_count}")
|
||||
else:
|
||||
# Single dataset - use last part of path
|
||||
dataset_name = repo_id.split("/")[-1][:20] # Truncate to 20 chars
|
||||
lst.append(f"dataset:{dataset_name}")
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
@@ -394,37 +394,56 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
# @pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
"""Check that all dataset frames are incorporated and aligned correctly."""
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
|
||||
# dummy padding dimensions (simulate training setup)
|
||||
MAX_ACTION_DIM = 14
|
||||
MAX_STATE_DIM = 30
|
||||
MAX_NUM_IMAGES = 3
|
||||
MAX_IMAGE_DIM = 224
|
||||
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
dataset = MultiLeRobotDataset(
|
||||
repo_ids,
|
||||
max_action_dim=MAX_ACTION_DIM,
|
||||
max_state_dim=MAX_STATE_DIM,
|
||||
max_num_images=MAX_NUM_IMAGES,
|
||||
max_image_dim=MAX_IMAGE_DIM,
|
||||
)
|
||||
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
# check they match.
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||
for expected_dataset_index, sub_item, multi_item in zip(
|
||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||
):
|
||||
dataset_index = dataset_item.pop("dataset_index")
|
||||
dataset_index = multi_item.pop("dataset_index")
|
||||
assert dataset_index == expected_dataset_index
|
||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||
for k in sub_dataset_item:
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
# we ignore padding_mask and dataset_index keys in multi_item
|
||||
extra_keys = {k for k in multi_item if "padding_mask" in k}
|
||||
filtered_multi_keys = set(multi_item.keys()) - extra_keys
|
||||
assert set(sub_item.keys()) == filtered_multi_keys, "mismatch in keys"
|
||||
|
||||
for k in sub_item:
|
||||
if k not in multi_item:
|
||||
continue
|
||||
v1, v2 = sub_item[k], multi_item[k]
|
||||
if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2), f"tensor mismatch on key: {k}"
|
||||
else:
|
||||
assert v1 == v2, f"value mismatch on key: {k}"
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=smolvla_optimized_8gpu_fresh
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=88
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --mem=0
|
||||
#SBATCH --time=72:00:00
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --output=/fsx/dana_aubakirova/vla/logs/smolvla_optimized_8gpu_fresh_%j.out
|
||||
#SBATCH --error=/fsx/dana_aubakirova/vla/logs/smolvla_optimized_8gpu_fresh_%j.err
|
||||
#SBATCH --exclusive
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p /fsx/dana_aubakirova/vla/logs
|
||||
|
||||
# Activate conda environment
|
||||
source /fsx/dana_aubakirova/miniconda/etc/profile.d/conda.sh
|
||||
conda activate lerobot
|
||||
|
||||
# Add local lerobot source to Python path to use development version
|
||||
export PYTHONPATH="/fsx/dana_aubakirova/vla/lerobot/src:$PYTHONPATH"
|
||||
|
||||
# OPTIMIZED 8-GPU CUDA environment - high performance configuration
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512,expandable_segments:True,garbage_collection_threshold:0.8
|
||||
export TORCH_DISTRIBUTED_DEBUG=OFF
|
||||
export NCCL_DEBUG=WARN
|
||||
export CUDA_LAUNCH_BLOCKING=0
|
||||
export ACCELERATE_USE_FSDP=false
|
||||
export ACCELERATE_USE_DEEPSPEED=false
|
||||
export HF_ACCELERATE_DEVICE_MAP=false
|
||||
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||
# 8-GPU optimizations
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_P2P_DISABLE=1
|
||||
|
||||
# Change to working directory
|
||||
cd /fsx/dana_aubakirova/vla
|
||||
|
||||
# FRESH START 8-GPU training configuration - NEW OUTPUT DIRECTORY
|
||||
export OUTPUT_DIR="/fsx/dana_aubakirova/vla/outputs/test_smolvla_2datasets_$(date +%Y%m%d_%H%M%S)"
|
||||
# Use ALL datasets from relative_datasets_list.txt - full scale training
|
||||
export REPO_IDS="AndrejOrsula/lerobot_double_ball_stacking_random, koenvanwijk/orange50-variation-2"
|
||||
|
||||
# Model configuration - optimized for 8-GPU with global batch size 32
|
||||
export VLM_REPO_ID=HuggingFaceTB/SmolVLM2-500M-Video-Instruct
|
||||
export STEPS=100 # Quick test run
|
||||
export BATCH_SIZE=8 # 4 per GPU = 32 global batch size (prevent hanging)
|
||||
export EVAL_FREQ=-1 # Disable evaluation for faster training
|
||||
export NUM_WORKERS=0 # MEMORY FIX: Disable workers to prevent memory exhaustion
|
||||
export SAVE_FREQ=10000 # Save every 10k steps
|
||||
|
||||
# Model config - optimized settings inspired by SmolPi0
|
||||
export POLICY=smolvla2
|
||||
export USE_AMP=false # DISABLE AMP for stability
|
||||
export OPTIMIZER_LR=5e-4 # Optimized learning rate
|
||||
export PEFT_METHOD=lora
|
||||
export LOAD_VLM_WEIGHTS=true
|
||||
export MAX_ACTION_DIM=32
|
||||
export MAX_STATE_DIM=32
|
||||
|
||||
# Dataset config - optimized from analysis
|
||||
export USE_IMAGENET_STATS=false
|
||||
export ENABLE_IMG_TRANSFORM=true
|
||||
export MAX_NUM_IMAGES=2 # OPTIMIZED: 2 images for better context
|
||||
export MAX_IMAGE_DIM=256 # OPTIMIZED: 256px resolution
|
||||
export TRAIN_ON_ALL_FEATURES=true
|
||||
export FEATURES_VERSION=2
|
||||
|
||||
# Advanced optimizations for 8-GPU setup
|
||||
export FPS_MIN=30
|
||||
export FPS_MAX=30
|
||||
export GRADIENT_ACCUMULATION_STEPS=1 # Global batch size = 4 × 8 × 2 = 64
|
||||
export PRECISION=no
|
||||
export DROP_LAST=true
|
||||
|
||||
# SmolPi0-inspired VLM parameters
|
||||
export VLM_LAYERS=16
|
||||
export EXPERT_WIDTH_MULTIPLIER=0.75
|
||||
export CAUSAL_ACTION_ATTENTION=true
|
||||
export SELF_ATTN_EVERY_N_LAYERS=2
|
||||
export ATTENTION_MODE=cross_attn
|
||||
export LORA_R=32
|
||||
export LORA_TARGET_MODULES=q_proj,v_proj
|
||||
export PREFIX_LENGTH=0
|
||||
|
||||
# Learning rate schedule inspired by SmolPi0
|
||||
export DECAY_LR=1e-6
|
||||
export DECAY_STEPS=50000
|
||||
export LR_VLM=1e-4
|
||||
export WARMUP_STEPS=1000
|
||||
|
||||
# Set environment variables for model cache and offline mode
|
||||
export HF_LEROBOT_HOME="/fsx/dana_aubakirova/vla"
|
||||
export HF_HOME="/fsx/dana_aubakirova/vla/.cache/huggingface"
|
||||
export HF_HUB_CACHE="/fsx/dana_aubakirova/vla/.cache/huggingface"
|
||||
export TRANSFORMERS_CACHE="/fsx/dana_aubakirova/vla/.cache/huggingface"
|
||||
export HF_HUB_OFFLINE=0
|
||||
export TRANSFORMERS_OFFLINE=0
|
||||
|
||||
# Optimized accelerate config
|
||||
export ACCELERATE_CONFIG_FILE="/fsx/dana_aubakirova/vla/accelerate_configs/optimized_fresh_config.yaml"
|
||||
|
||||
# Wandb configuration - FRESH START
|
||||
export WANDB_PROJECT="smolvla2-training"
|
||||
export WANDB_NOTES="8-GPU optimized training FRESH START - same parameters as previous run but from scratch"
|
||||
export WANDB_MODE="disabled"
|
||||
|
||||
# Print comprehensive optimization info
|
||||
echo "🚀 =============================================="
|
||||
echo "🚀 OPTIMIZED 8-GPU FRESH START TRAINING"
|
||||
echo "🚀 =============================================="
|
||||
echo "🆕 FRESH START - No resume, new output directory"
|
||||
echo "📊 Datasets: ALL available datasets (same as previous run)"
|
||||
echo "📁 Output directory: $OUTPUT_DIR"
|
||||
echo "🎯 Policy: $POLICY"
|
||||
echo "🔧 Batch size per GPU: $BATCH_SIZE (GLOBAL BATCH SIZE: $((BATCH_SIZE * 8)))"
|
||||
echo "🔄 Gradient accumulation steps: $GRADIENT_ACCUMULATION_STEPS"
|
||||
echo "📈 Training steps: $STEPS"
|
||||
echo "💾 Save frequency: $SAVE_FREQ"
|
||||
echo "🔬 Evaluation frequency: $EVAL_FREQ"
|
||||
echo "⚡ AMP enabled: $USE_AMP (no mixed precision - stable)"
|
||||
echo "📚 Learning rate: $OPTIMIZER_LR"
|
||||
echo "🎓 VLM Learning rate: $LR_VLM"
|
||||
echo "🔥 Warmup steps: $WARMUP_STEPS"
|
||||
echo "📷 Max images: $MAX_NUM_IMAGES"
|
||||
echo "🖼️ Image dimension: $MAX_IMAGE_DIM"
|
||||
echo "👥 Data workers per GPU: $NUM_WORKERS (memory optimized)"
|
||||
echo "🧠 VLM layers: $VLM_LAYERS"
|
||||
echo "🔄 Expert width multiplier: $EXPERT_WIDTH_MULTIPLIER"
|
||||
echo "🎯 LORA rank: $LORA_R"
|
||||
echo "🖥️ GPUs: 8 (HIGH PERFORMANCE)"
|
||||
echo "📊 Wandb project: $WANDB_PROJECT"
|
||||
echo "🚀 =============================================="
|
||||
|
||||
# Check GPU availability
|
||||
echo "🖥️ GPU Information:"
|
||||
nvidia-smi --list-gpus
|
||||
|
||||
# Create optimized 8-GPU accelerate config
|
||||
mkdir -p /fsx/dana_aubakirova/vla/accelerate_configs
|
||||
cat > /fsx/dana_aubakirova/vla/accelerate_configs/optimized_fresh_config.yaml << EOF
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
gpu_ids: '0,1,2,3,4,5,6,7'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
EOF
|
||||
|
||||
echo "📋 Created optimized accelerate config with no mixed precision for stability"
|
||||
|
||||
# Run distributed training with optimized accelerate configuration - FRESH START
|
||||
accelerate launch --config_file /fsx/dana_aubakirova/vla/accelerate_configs/optimized_fresh_config.yaml \
|
||||
lerobot/src/lerobot/scripts/train.py \
|
||||
--policy.type=$POLICY \
|
||||
--dataset.repo_id="$REPO_IDS" \
|
||||
--dataset.root="/fsx/dana_aubakirova/vla/community_dataset_v1" \
|
||||
--dataset.use_imagenet_stats=$USE_IMAGENET_STATS \
|
||||
--dataset.image_transforms.enable=$ENABLE_IMG_TRANSFORM \
|
||||
--dataset.train_on_all_features=$TRAIN_ON_ALL_FEATURES \
|
||||
--dataset.features_version=$FEATURES_VERSION \
|
||||
--policy.max_action_dim=$MAX_ACTION_DIM \
|
||||
--policy.max_state_dim=$MAX_STATE_DIM \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--steps=$STEPS \
|
||||
--eval_freq=$EVAL_FREQ \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--policy.use_amp=$USE_AMP \
|
||||
--policy.optimizer_lr=$OPTIMIZER_LR \
|
||||
--policy.optimizer_lr_vlm=$LR_VLM \
|
||||
--policy.scheduler_decay_lr=$DECAY_LR \
|
||||
--policy.scheduler_decay_steps=$DECAY_STEPS \
|
||||
--policy.scheduler_warmup_steps=$WARMUP_STEPS \
|
||||
--policy.peft_method=$PEFT_METHOD \
|
||||
--policy.peft_config.r=$LORA_R \
|
||||
--policy.peft_config.target_modules=$LORA_TARGET_MODULES \
|
||||
--policy.load_vlm_weights=$LOAD_VLM_WEIGHTS \
|
||||
--policy.repo_id=$VLM_REPO_ID \
|
||||
--policy.push_to_hub=false \
|
||||
--dataset.max_num_images=$MAX_NUM_IMAGES \
|
||||
--dataset.max_image_dim=$MAX_IMAGE_DIM \
|
||||
--dataset.video_backend=pyav \
|
||||
--num_workers=$NUM_WORKERS \
|
||||
--wandb.enable=false \
|
||||
--wandb.project=$WANDB_PROJECT \
|
||||
--wandb.notes="$WANDB_NOTES" \
|
||||
--dataset.min_fps=$FPS_MIN \
|
||||
--dataset.max_fps=$FPS_MAX \
|
||||
--policy.num_vlm_layers=$VLM_LAYERS \
|
||||
--policy.expert_width_multiplier=$EXPERT_WIDTH_MULTIPLIER \
|
||||
--policy.causal_action_attention_mask=$CAUSAL_ACTION_ATTENTION \
|
||||
--policy.self_attn_every_n_layers=$SELF_ATTN_EVERY_N_LAYERS \
|
||||
--policy.attention_mode=$ATTENTION_MODE \
|
||||
--policy.prefix_length=$PREFIX_LENGTH
|
||||
|
||||
echo "✅ Optimized 8-GPU FRESH START training completed! Check results in: $OUTPUT_DIR"
|
||||
echo "📊 View training progress at: https://wandb.ai"
|
||||
echo "🆕 FRESH START TRAINING SUMMARY:"
|
||||
echo " • Started from scratch with new output directory"
|
||||
echo " • Training from step 0 to step $STEPS"
|
||||
echo " • Same optimized parameters as previous successful run"
|
||||
echo " • New WandB run will be created automatically"
|
||||
echo ""
|
||||
echo "🚀 Key 8-GPU optimizations applied:"
|
||||
echo " • 8 GPUs with global batch size $((BATCH_SIZE * 8))"
|
||||
echo " • Memory-optimized data loading: 0 workers (prevents OOM)"
|
||||
echo " • STABLE: No mixed precision (matches conservative setup)"
|
||||
echo " • Optimized NCCL settings for 8-GPU communication"
|
||||
echo " • Enhanced memory allocation for high-throughput"
|
||||
echo ""
|
||||
echo "🏃♂️ Expected performance gains:"
|
||||
echo " • ~4x faster training throughput vs single GPU"
|
||||
echo " • Clean start without any checkpoint compatibility issues"
|
||||
echo " • Proven parameter configuration from previous run"
|
||||
Reference in New Issue
Block a user