diff --git a/lerobot/common/policies/smolvla2/configuration_smolvla2.py b/lerobot/common/policies/smolvla2/configuration_smolvla2.py deleted file mode 100644 index 9b839c62a..000000000 --- a/lerobot/common/policies/smolvla2/configuration_smolvla2.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.common.optim.optimizers import AdamWConfig -from lerobot.common.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - - -@dataclass -class PEFTConfig: - r: int = 4 - lora_alpha: int = 16 - lora_dropout: float = 0.1 - target_modules: str = "q_proj,v_proj" - - -@PreTrainedConfig.register_subclass("smolvla2") -@dataclass -class SmolVLA2Config(PreTrainedConfig): - # Input / output structure. - n_obs_steps: int = 1 - chunk_size: int = 50 - n_action_steps: int = 50 - - normalization_mapping: dict[str, NormalizationMode] = field( - default_factory=lambda: { - "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, - } - ) - - # Shorter state and action vectors will be padded - max_state_dim: int = 32 - max_action_dim: int = 32 - - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (512, 512) - - # Add empty images. Used by smolvla_aloha_sim which adds the empty - # left and right wrist cameras in addition to the top camera. - empty_cameras: int = 0 - - # Converts the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi_aloha: bool = False - - # Converts joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions_aloha: bool = False - - # Tokenizer - tokenizer_max_length: int = 48 - proj_width: int = 480 - # Decoding - num_steps: int = 10 - - # Attention utils - use_cache: bool = True - - # Finetuning settings - freeze_vision_encoder: bool = True - train_expert_only: bool = False - train_state_proj: bool = True - - # Training presets - optimizer_lr: float = 2.5e-5 # 1e-4 - optimizer_betas: tuple[float, float] = (0.9, 0.95) - optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-10 - optimizer_grad_clip_norm: float = 10 - optimizer_lr_vlm: float = 0 - - scheduler_warmup_steps: int = 1_000 - scheduler_decay_steps: int = 30_000 - scheduler_decay_lr: float = 2.5e-6 - - vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone. - load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights - checkpoint_path: str = None - peft_method: str = "" - peft_config: PEFTConfig = field(default_factory=PEFTConfig) - peft_target_model: str = "" - add_image_special_tokens: bool = False # Whether to use special image tokens around image features. - - attention_mode: str = "cross_attn" - - prefix_length: int = -1 - - pad_language_to: str = "longest" # "max_length" - - num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers. - num_vlm_layers: int = 16 - past_obs_keys: str = "image" - add_local_special_image_tokens: bool = False - - reverse_images_order: bool = False - - state_to_prefix: bool = False - - pad_language_to: str = "longest" # "max_length" - causal_action_attention_mask: bool = False - - self_attn_every_n_layers: int = -1 # Number of layers used in the VLM (first num_vlm_layers layers) - # self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers - expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM) - - min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding - max_period: float = 4.0 - - robot_type: str = "" - - self_attn_only_actions: bool = False - - causal_attention_on_history: bool = False - - predict_relative_actions: bool = False - relative_actions_mode: str = "first" - - shuffle_camera_positions: bool = False - vlm_img_size: int = -1 - - regression_loss: bool = False - - def __post_init__(self): - super().__post_init__() - - """Input validation (not exhaustive).""" - if self.n_action_steps > self.chunk_size: - raise ValueError( - f"The chunk size is the upper bound for the number of action steps per model invocation. Got " - f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." - ) - if self.use_delta_joint_actions_aloha: - raise NotImplementedError( - "`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot." - ) - - def validate_features(self) -> None: - for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" - empty_camera = PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 480, 640), - ) - self.input_features[key] = empty_camera - - def get_optimizer_preset(self) -> AdamWConfig: - return AdamWConfig( - lr=self.optimizer_lr, - betas=self.optimizer_betas, - eps=self.optimizer_eps, - weight_decay=self.optimizer_weight_decay, - grad_clip_norm=self.optimizer_grad_clip_norm, - ) - - def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) - - @property - def observation_delta_indices(self) -> list: - return [0] - - @property - def action_delta_indices(self) -> list: - return list(range(self.chunk_size)) - - @property - def reward_delta_indices(self) -> None: - return None diff --git a/lerobot/common/policies/smolvla2/modeling_smolvla2.py b/lerobot/common/policies/smolvla2/modeling_smolvla2.py deleted file mode 100644 index 3e85ba0ff..000000000 --- a/lerobot/common/policies/smolvla2/modeling_smolvla2.py +++ /dev/null @@ -1,1085 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -SmolVLA: - -[Paper](https://huggingface.co/papers/2506.01844) - -Designed by Hugging Face. - -Install smolvla extra dependencies: -```bash -pip install -e ".[smolvla]" -``` - -Example of finetuning the smolvla pretrained model (`smolvla_base`): -```bash -python lerobot/scripts/train.py \ ---policy.path=lerobot/smolvla_base \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ ---batch_size=64 \ ---steps=200000 -``` - -Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM, -and an action expert. -```bash -python lerobot/scripts/train.py \ ---policy.type=smolvla \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ ---batch_size=64 \ ---steps=200000 -``` - -Example of using the smolvla pretrained model outside LeRobot training framework: -```python -policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") -``` - -""" - -import math -import os -import random -import re -from collections import deque - -import safetensors -import torch -import torch.nn.functional as F # noqa: N812 -from torch import Tensor, nn -from transformers import AutoProcessor - -from lerobot.common.constants import ACTION, OBS_STATE -from lerobot.common.policies.normalize import ( - Normalize, - Unnormalize, -) -from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel -from lerobot.common.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config -from lerobot.common.policies.utils import ( - populate_queues, -) -from lerobot.common.utils.utils import get_safe_dtype -from lerobot.datasets import IMAGES_ORDER - -# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker -_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") - - -def canonicalise(k: str) -> str: - """ - Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a - normalisation-buffer key. - """ - return _VARIANT_RE.sub(".buffer_", k) - - -def standardise_state_dict( - checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True -) -> tuple[dict[str, torch.Tensor], list[str]]: - """ - • Re-keys `checkpoint ` so that every entry matches the *reference* key set. - • If several variant keys collapse to the same canonical name we keep the - first one and log the collision. - • Returns the new dict + a list of entries that could not be matched. - """ - out, collisions, unmatched = {}, {}, [] - - for k, v in checkpoint.items(): - canon = canonicalise(k) - if canon in ref_keys: - if canon in out: # duplicate after collapsing - collisions.setdefault(canon, []).append(k) - else: - out[canon] = v - else: - unmatched.append(k) - - if verbose: - for canon, variants in collisions.items(): - print(f"[standardise_state_dict] '{canon}' ← {variants}") - if unmatched: - print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") - - out.update({k: checkpoint[k] for k in unmatched}) - return out, unmatched - - -def rename_checkpoint_keys(checkpoint: dict, rename_str: str): - """ - Renames keys in a checkpoint dictionary based on the given rename string. - - Args: - checkpoint (dict): The checkpoint dictionary. - rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". - - Returns: - dict: The modified checkpoint with renamed keys. - """ - - rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) - - new_checkpoint = {} - for k, v in checkpoint.items(): - for old_key, new_key in rename_dict.items(): - if old_key in k: - k = k.replace(old_key, new_key) - new_checkpoint[k] = v - return new_checkpoint - - -def load_smolvla( - model: torch.nn.Module, - filename: str | os.PathLike, - *, - device: str = "cpu", - checkpoint_keys_mapping: str = "", -) -> torch.nn.Module: - state_dict = safetensors.torch.load_file(filename, device=device) - - # Optional user-supplied renames (e.g. "model._orig_mod.//model.") - if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: - state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) - - state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - - # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset - norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") - state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - - missing, unexpected = model.load_state_dict(state_dict, strict=False) - - if not all(key.startswith(norm_keys) for key in missing) or unexpected: - raise RuntimeError( - "SmolVLA %d missing / %d unexpected keys", - len(missing), - len(unexpected), - ) - - return model - - -def create_sinusoidal_pos_embedding( - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" -) -> Tensor: - """Computes sine-cosine positional embedding vectors for scalar positions.""" - if dimension % 2 != 0: - raise ValueError(f"dimension ({dimension}) must be divisible by 2") - - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") - - dtype = get_safe_dtype(torch.float64, device.type) - fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) - period = min_period * (max_period / min_period) ** fraction - - # Compute the outer product - scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - return pos_emb - - -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - -def make_att_2d_masks(pad_masks, att_masks): - """Copied from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on - it and 0 where it shares the same attention mask as the previous token. - """ - if att_masks.ndim != 2: - raise ValueError(att_masks.ndim) - if pad_masks.ndim != 2: - raise ValueError(pad_masks.ndim) - - cumsum = torch.cumsum(att_masks, dim=1) - att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] - pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - att_2d_masks = att_2d_masks & pad_2d_masks - return att_2d_masks - - -def resize_with_pad(img, width, height, pad_value=-1): - # assume no-op when width height fits already - if img.ndim != 4: - raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - - cur_height, cur_width = img.shape[2:] - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - resized_img = F.interpolate( - img, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - pad_height = max(0, int(height - resized_height)) - pad_width = max(0, int(width - resized_width)) - - # pad on left and top of image - padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img - - -def pad_vector(vector, new_dim): - """Can be (batch_size x sequence_length x features_dimension) - or (batch_size x features_dimension) - """ - if vector.shape[-1] == new_dim: - return vector - shape = list(vector.shape) - current_dim = shape[-1] - shape[-1] = new_dim - new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) - new_vector[..., :current_dim] = vector - return new_vector - - -def normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) - - -def unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val - - -def safe_arcsin(value): - # This ensures that the input stays within - # [−1,1] to avoid invalid values for arcsin - return torch.arcsin(torch.clamp(value, -1.0, 1.0)) - - -def aloha_gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with smolvla which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return safe_arcsin(value) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # Normalize to [0, 1]. - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - return normalize(value, min_val=0.4, max_val=1.5) - - -def aloha_gripper_from_angular(value): - # Convert from the gripper position used by smolvla to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - value = unnormalize(value, min_val=0.4, max_val=1.5) - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return normalize(value, min_val=-0.6213, max_val=1.4910) - - -def aloha_gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = unnormalize(value, min_val=-0.6213, max_val=1.4910) - return normalize(value, min_val=0.4, max_val=1.5) - - -class SmolVLA2Policy(PreTrainedPolicy): - """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" - - config_class = SmolVLA2Config - name = "smolvla2" - - def __init__( - self, - config: SmolVLA2Config, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. - """ - - super().__init__(config) - config.validate_features() - self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - - self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer - self.model = VLAFlowMatching(config) - self.reset() - - def reset(self): - """This should be called whenever the environment is reset.""" - self._queues = { - ACTION: deque(maxlen=self.config.n_action_steps), - } - if self.config.n_obs_steps > 1: - for k in self.config.input_features: - if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]): - self._queues[k] = deque(maxlen=self.config.n_obs_steps) - - # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues - @classmethod - def _load_as_safetensor( - cls, - model: "SmolVLA2Policy", - model_file: str, - map_location: str, - strict: bool, - ): - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) - return load_smolvla( - model, - model_file, - device=map_location, - checkpoint_keys_mapping="model._orig_mod.//model.", - ) - - def get_optim_params(self) -> dict: - return self.parameters() - - def merge_peft_model_weights(self) -> None: - if "lora" in self.config.peft_method: - self.model.vlm_with_expert.merge_lora_weights() - - @torch.no_grad - def select_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - batch = self.normalize_inputs(batch) - - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) - - actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) - # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - actions = self.unnormalize_outputs({"action": actions, "robot_type": batch["robot_type"]})["action"] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - return actions - - @torch.no_grad - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - batch = self.normalize_inputs(batch) - - self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._queues[ACTION]) == 0: - for k in batch: - if k in self._queues: - batch[k] = torch.stack(list(self._queues[k]), dim=1) - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) - - actions = self.model.sample_actions( - images, img_masks, lang_tokens, lang_masks, state, noise=noise - ) - if self.config.predict_relative_actions and actions.ndim == 3: - # If the model predicts relative actions, we need to unpad the actions - # and then convert them to absolute actions. - if self.config.relative_actions_mode == "first": - actions = torch.cat((actions[:, :1], actions[:, 1:] + actions[:, :1]), dim=1) - elif self.config.relative_actions_mode == "state": - actions = actions + state.unsqueeze(1) - else: - actions = torch.cat((actions[:, :1], actions[:, 1:] + actions[:, :-1]), dim=1) - # Unpad actions - - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - actions = self.unnormalize_outputs({"action": actions})["action"] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) - return self._queues[ACTION].popleft() - - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: - """Do a full training forward pass to compute the loss""" - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) - actions = self.prepare_action(batch, state=state) - actions_is_pad = batch.get("actions_id_pad") - loss_dict = {} - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.clone() - - if actions_is_pad is not None: - in_episode_bound = ~actions_is_pad - losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.clone() - - # Remove padding - losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.clone() - - # For backward pass - loss = losses.mean() - # For backward pass - loss_dict["loss"] = loss.item() - return loss, loss_dict - - def prepare_images(self, batch): - """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and - convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. - """ - images = [] - img_masks = [] - present_img_keys = [key for key in self.config.image_features if key in batch] - missing_img_keys = [key for key in self.config.image_features if key not in batch] - - present_img_keys = sorted( - present_img_keys, - key=lambda k: IMAGES_ORDER.get(k, float("inf")), - reverse=self.config.reverse_images_order, - ) - if self.config.shuffle_camera_positions and ACTION in batch: # only during training - present_img_keys = random.sample(present_img_keys, len(present_img_keys)) - if len(present_img_keys) == 0: - raise ValueError( - f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" - ) - for i in range(self.num_past_images): - # Preprocess image features present in the batch - for key in present_img_keys: - img = batch[key][:, i, :, :, :] if batch[key].ndim == 5 else batch[key] - if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) - - # Normalize from range [0,1] to [-1,1] as expacted by siglip - img = img * 2.0 - 1.0 - - bsize = img.shape[0] - device = img.device - if f"{key}_padding_mask" in batch: - mask = batch[f"{key}_padding_mask"].bool() - else: - mask = torch.ones(bsize, dtype=torch.bool, device=device) - images.append(img) - img_masks.append(mask) - - # Create image features not present in the batch - # as fully 0 padded images. - for num_empty_cameras in range(len(missing_img_keys)): - if num_empty_cameras >= self.config.empty_cameras: - break - img = torch.ones_like(img) * -1 - mask = torch.zeros_like(mask) - images.append(img) - img_masks.append(mask) - return images, img_masks - - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - if isinstance(tasks, str): - tasks = [tasks] - - if len(tasks) == 1: - tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] - - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding=self.config.pad_language_to, - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - truncation=True, # FIXME(mshukor) - ) - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - - def _pi_aloha_decode_state(self, state): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - state[:, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) - return state - - def _pi_aloha_encode_actions(self, actions): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) - return actions - - def _pi_aloha_encode_actions_inv(self, actions): - # Flip the joints again. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) - return actions - - def prepare_state(self, batch): - """Pad state""" - state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE] - state = pad_vector(state, self.config.max_state_dim) - return state - - def prepare_action(self, batch, state=None): - """Pad action""" - actions = pad_vector(batch[ACTION], self.config.max_action_dim) - if self.config.predict_relative_actions and actions.ndim == 3: - if self.config.relative_actions_mode == "first": - actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1) - elif self.config.relative_actions_mode == "state": - assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], ( - "Relative action mode 'state' requires the action and state to have the same dimension." - ) - if state.ndim == 2: - state = state.unsqueeze(1) - actions = actions - state - else: - actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1) - return actions - - -def pad_tensor(tensor, max_len, pad_value=0): - """ - Efficiently pads a tensor along sequence dimension to match max_len. - - Args: - tensor (torch.Tensor): Shape (B, L, ...) or (B, L). - max_len (int): Fixed sequence length. - pad_value (int/float): Value for padding. - - Returns: - torch.Tensor: Shape (B, max_len, ...) or (B, max_len). - """ - b, d = tensor.shape[:2] - - # Create a padded tensor of max_len and copy the existing values - padded_tensor = torch.full( - (b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device - ) - padded_tensor[:, :d] = tensor # Efficient in-place copy - - return padded_tensor - - -class VLAFlowMatching(nn.Module): - """ - SmolVLA - - [Paper]() - - Designed by Hugging Face. - ┌──────────────────────────────┐ - │ actions │ - │ ▲ │ - │ ┌─────────┐ ┌─|────┐ │ - │ | │────► │ │ │ - │ | │ kv │ │ │ - │ | │────► │Action│ │ - │ | VLM │cache │Expert│ | - │ │ │────► | │ │ - │ │ │ │ │ │ - │ └▲──▲───▲─┘ └───▲──┘ | - │ │ | | │ | - │ | | | noise │ - │ │ │ state │ - │ │ language tokens │ - │ image(s) │ - └──────────────────────────────┘ - """ - - def __init__(self, config): - super().__init__() - self.config = config - - self.vlm_with_expert = SmolVLMWithExpertModel( - model_id=self.config.vlm_model_name, - freeze_vision_encoder=self.config.freeze_vision_encoder, - train_expert_only=self.config.train_expert_only, - attention_implementation=self.config.attention_implementation, - load_vlm_weights=self.config.load_vlm_weights, - attention_mode=self.config.attention_mode, - num_expert_layers=self.config.num_expert_layers, - num_vlm_layers=self.config.num_vlm_layers, - self_attn_every_n_layers=self.config.self_attn_every_n_layers, - expert_width_multiplier=self.config.expert_width_multiplier, - ) - self.vlm_with_expert.configure_peft(config=self.config) - # Projections are float32 - self.state_to_prefix = self.config.state_to_prefix - if self.state_to_prefix: - self.state_proj = nn.Linear( - self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size - ) - else: - self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size) - self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) - self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) - - self.action_time_mlp_in = nn.Linear( - self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size - ) - self.action_time_mlp_out = nn.Linear( - self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size - ) - - self.set_requires_grad() - # SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ... - self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id - self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id - self.global_image_start_token = torch.tensor( - [self.fake_image_token, self.global_image_token], dtype=torch.long - ) - - self.add_image_special_tokens = self.config.add_image_special_tokens - self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) - self.prefix_length = self.config.prefix_length - self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split( - "," - ) - self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 - self.causal_attention_on_history = self.config.causal_attention_on_history - - def set_requires_grad(self): - for params in self.state_proj.parameters(): - params.requires_grad = self.config.train_state_proj - - def sample_noise(self, shape, device): - noise = torch.normal( - mean=0.0, - std=1.0, - size=shape, - dtype=torch.float32, - device=device, - ) - return noise - - def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) - time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) - - def embed_prefix( - self, - images, - img_masks, - lang_tokens, - lang_masks, - state: torch.Tensor = None, - pointtrackers=None, - pt_masks=None, - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed multiple modalities for vlm processing. - - Simple, extensible approach using list + torch.cat. - Easy to add new information/modalities like point trackers, audio, etc. - - Args: - images: List of image tensors - img_masks: List of image masks - lang_tokens: Language token tensor - lang_masks: Language mask tensor - state: Optional state tensor - pointtrackers: Optional point tracker tensors (future extension) - pt_masks: Optional point tracker masks (future extension) - **kwargs: Additional modalities for future extensions - """ - embs = [] - pad_masks = [] - att_masks = [] - - # Process each modality type - self._add_image_embeddings(images, img_masks, embs, pad_masks, att_masks) - self._add_language_embeddings(lang_tokens, lang_masks, embs, pad_masks, att_masks) - - if state is not None and self.state_to_prefix: - self._add_state_embeddings(state, embs, pad_masks, att_masks) - - # Future extensions - easy to add new modalities - if pointtrackers is not None: - self._add_pointtracker_embeddings(pointtrackers, pt_masks, embs, pad_masks, att_masks) - - # Add more modalities here as needed: - # if audio is not None: - # self._add_audio_embeddings(audio, audio_masks, embs, pad_masks, att_masks) - - # Concatenate all embeddings - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) - - # Handle prefix length padding - seq_len = pad_masks.shape[1] - if seq_len < self.prefix_length: - embs = pad_tensor(embs, self.prefix_length, pad_value=0) - pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0) - att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0) - - # Expand attention masks to batch size - bsize = pad_masks.shape[0] - att_masks = att_masks[None, :].expand(bsize, -1) - - return embs, pad_masks, att_masks - - def _add_image_embeddings(self, images, img_masks, embs, pad_masks, att_masks): - """Add image embeddings with special tokens to the lists.""" - for img, img_mask in zip(images, img_masks, strict=False): - # Add image start tokens if enabled - if self.add_image_special_tokens: - start_emb = ( - self.vlm_with_expert.embed_language_tokens( - self.global_image_start_token.to(device=img.device) - ) - .unsqueeze(0) - .expand(img.shape[0], -1, -1) - ) - - start_mask = torch.ones_like(start_emb[:, :, 0], dtype=torch.bool) - embs.append(start_emb) - pad_masks.append(start_mask) - att_masks += [0] * start_emb.shape[1] - - # Process image embedding - img_emb = self.vlm_with_expert.embed_image(img) - - # Normalize image embeddings - img_emb_dim = img_emb.shape[-1] - img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) - - # Expand mask to match image embedding sequence length - bsize, num_img_embs = img_emb.shape[:2] - expanded_mask = img_mask[:, None].expand(bsize, num_img_embs) - - embs.append(img_emb) - pad_masks.append(expanded_mask) - att_masks += [0] * num_img_embs - - # Add image end tokens if enabled - if self.add_image_special_tokens: - end_emb = ( - self.vlm_with_expert.embed_language_tokens(self.image_end_token.to(device=img.device)) - .unsqueeze(0) - .expand(img.shape[0], -1, -1) - ) - - end_mask = torch.ones_like(end_emb[:, :, 0], dtype=torch.bool) - embs.append(end_emb) - pad_masks.append(end_mask) - att_masks += [0] * end_emb.shape[1] - - def _add_language_embeddings(self, lang_tokens, lang_masks, embs, pad_masks, att_masks): - """Add language embeddings to the lists.""" - lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) - - # Normalize language embeddings - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) - - embs.append(lang_emb) - pad_masks.append(lang_masks) - att_masks += [0] * lang_emb.shape[1] - - def _add_state_embeddings(self, state, embs, pad_masks, att_masks): - """Add state embeddings to the lists.""" - state_emb = self.state_proj(state) - state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb - - bsize, states_seq_len = state_emb.shape[:2] - state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=state_emb.device) - - embs.append(state_emb) - pad_masks.append(state_mask) - att_masks += [1] * states_seq_len # State tokens get causal attention - - def _add_pointtracker_embeddings(self, pointtrackers, pt_masks, embs, pad_masks, att_masks): - """Add point tracker embeddings to the lists (future extension).""" - # TODO: Implement point tracker processing - # Example implementation: - # for pt, pt_mask in zip(pointtrackers, pt_masks): - # pt_emb = self.pointtracker_encoder(pt) # Need to add this - # embs.append(pt_emb) - # pad_masks.append(pt_mask) - # att_masks += [0] * pt_emb.shape[1] - pass - - def embed_suffix(self, state, noisy_actions, timestep): - """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" - embs = [] - pad_masks = [] - att_masks = [] - # Embed state - if not self.state_to_prefix: - state_emb = self.state_proj(state) - state_emb = ( - state_emb[:, None, :] if state_emb.ndim == 2 else state_emb - ) # .to(dtype=self.vlm_with_expert.type) - embs.append(state_emb) - bsize = state_emb.shape[0] - dtype = state_emb.dtype - device = state_emb.device - - states_seq_len = state_emb.shape[1] - state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) - pad_masks.append(state_mask) - - # Set attention masks so that image and language inputs do not attend to state or actions - att_masks += [1] + [0] * (states_seq_len - 1) - # Fuse timestep + action information using an MLP - action_emb = self.action_in_proj(noisy_actions) - device = action_emb.device - bsize = action_emb.shape[0] - dtype = action_emb.dtype - # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] - time_emb = create_sinusoidal_pos_embedding( - timestep, - self.vlm_with_expert.expert_hidden_size, - self.config.min_period, - self.config.max_period, - device=device, - ) - time_emb = time_emb.type(dtype=dtype) - - time_emb = time_emb[:, None, :].expand_as(action_emb) - action_time_emb = torch.cat([action_emb, time_emb], dim=2) - - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) - - # Add to input tokens - embs.append(action_time_emb) - - bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) - pad_masks.append(action_time_mask) - - # Set attention masks so that image, language and state inputs do not attend to action tokens - att_masks += [1] * self.config.chunk_size - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - return embs, pad_masks, att_masks - - def forward( - self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None - ) -> Tensor: - """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - if noise is None: - noise = self.sample_noise(actions.shape, actions.device) - - if time is None: - time = self.sample_time(actions.shape[0], actions.device) - - time_expanded = time[:, None, None] - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state - ) - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time) - - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - (_, suffix_out), _ = self.vlm_with_expert.forward( - attention_mask=att_2d_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - fill_kv_cache=False, - ) - suffix_out = suffix_out[:, -self.config.chunk_size :] - # Original openpi code, upcast attention output - suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - losses = F.mse_loss(u_t, v_t, reduction="none") - return losses - - def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" - bsize = state.shape[0] - device = state.device - - if noise is None: - actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) - noise = self.sample_noise(actions_shape, device) - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state - ) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - # Compute image and language key value cache - _, past_key_values = self.vlm_with_expert.forward( - attention_mask=prefix_att_2d_masks, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, None], - use_cache=self.config.use_cache, - fill_kv_cache=True, - ) - if self.config.regression_loss: - x_t = torch.zeros_like(noise, dtype=torch.float32, device=device) - expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device) - x_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - else: - dt = -1.0 / self.config.num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) - - x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) - v_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - # Euler step - x_t += dt * v_t - time += dt - return x_t - - def denoise_step( - self, - state, - prefix_pad_masks, - past_key_values, - x_t, - timestep, - ): - """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) - - suffix_len = suffix_pad_masks.shape[1] - batch_size = prefix_pad_masks.shape[0] - prefix_len = prefix_pad_masks.shape[1] - prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) - - suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) - - full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) - prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - - outputs_embeds, _ = self.vlm_with_expert.forward( - attention_mask=full_att_2d_masks, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=[None, suffix_embs], - use_cache=self.config.use_cache, - fill_kv_cache=False, - ) - suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.chunk_size :] - suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - return v_t diff --git a/lerobot/common/policies/smolvla2/smolvlm_with_expert2.py b/lerobot/common/policies/smolvla2/smolvlm_with_expert2.py deleted file mode 100644 index f1030cc4b..000000000 --- a/lerobot/common/policies/smolvla2/smolvlm_with_expert2.py +++ /dev/null @@ -1,599 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from typing import List, Optional - -import torch -from torch import nn -from transformers import ( - AutoConfig, - AutoModel, - AutoModelForImageTextToText, - AutoProcessor, - SmolVLMForConditionalGeneration, -) - - -def apply_rope(x, positions, max_wavelength=10_000): - """ - Applies RoPE positions [B, L] to x [B, L, H, D]. - """ - d_half = x.shape[-1] // 2 - device = x.device - dtype = x.dtype - x = x.to(torch.float32) - - freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) - timescale = max_wavelength**freq_exponents - radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) - - radians = radians[..., None, :] - - sin = torch.sin(radians) # .to(dtype=dtype) - cos = torch.cos(radians) # .to(dtype=dtype) - - x1, x2 = x.split(d_half, dim=-1) - res = torch.empty_like(x) - res[..., :d_half] = x1 * cos - x2 * sin - res[..., d_half:] = x2 * cos + x1 * sin - - return res.to(dtype) - - -def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256): - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - return hidden_dim - - -class SmolVLMWithExpertModel(nn.Module): - def __init__( - self, - model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", - load_vlm_weights: bool = True, - train_expert_only: bool = True, - freeze_vision_encoder: bool = False, - attention_mode: str = "self_attn", - num_expert_layers: int = -1, - num_vlm_layers: int = -1, - self_attn_every_n_layers: int = -1, - expert_width_multiplier: float = 0.5, - ): - super().__init__() - if load_vlm_weights: - print(f"Loading {model_id} weights ...") - self.vlm = AutoModelForImageTextToText.from_pretrained( - model_id, - device_map="auto", - torch_dtype="bfloat16", - low_cpu_mem_usage=True, - ) - config = self.vlm.config - else: - config = AutoConfig.from_pretrained(model_id) - self.vlm = SmolVLMForConditionalGeneration(config=config) - self.processor = AutoProcessor.from_pretrained(model_id) - if num_vlm_layers > 0: - print(f"Reducing the number of VLM layers to {num_vlm_layers} ...") - self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers] - self.num_vlm_layers = len(self.get_vlm_model().text_model.layers) - self.config = config - # Smaller lm expert - lm_expert_config = copy.deepcopy(config.text_config) - hidden_size = lm_expert_config.hidden_size - lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2 - lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier)) - lm_expert_config.num_hidden_layers = self.num_vlm_layers - if num_expert_layers > 0: - assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, ( - f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" - ) - lm_expert_config.num_hidden_layers = num_expert_layers - self.lm_expert = AutoModel.from_config(lm_expert_config) - - self.num_expert_layers = len(self.lm_expert.layers) - self.self_attn_every_n_layers = self_attn_every_n_layers - if "cross" in attention_mode: - # Reshape qkv projections to have the same input dimension as the vlm - for layer_idx in range(len(self.lm_expert.layers)): - if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0: - continue - self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear( - config.text_config.num_key_value_heads * config.text_config.head_dim, - lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, - bias=lm_expert_config.attention_bias, - ) - self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear( - config.text_config.num_key_value_heads * config.text_config.head_dim, - lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, - bias=lm_expert_config.attention_bias, - ) - # Remove unused embed_tokens - self.lm_expert.embed_tokens = None - - self.num_attention_heads = self.config.text_config.num_attention_heads - self.num_key_value_heads = self.config.text_config.num_key_value_heads - - self.freeze_vision_encoder = freeze_vision_encoder - self.train_expert_only = train_expert_only - self.attention_mode = attention_mode - self.expert_hidden_size = lm_expert_config.hidden_size - self.set_requires_grad() - - def configure_peft(self, config): - # return model - self.peft_method = config.peft_method - self.peft_target_model = config.peft_target_model - if "lora" in self.peft_method: - peft_config = config.peft_config - target_modules = peft_config.target_modules - if not isinstance(target_modules, list): - target_modules = target_modules.split(",") - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) - r=peft_config.r, # The rank of the low-rank adaptation - lora_alpha=peft_config.lora_alpha, # Scaling factor - lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers - target_modules=target_modules, # The components where LoRA is applied - exclude_modules=[ - "lm_expert", - "model.lm_expert.model.layers", - ], # FIXME(mshukor): this does not work for now - ) - self.lora_config = lora_config - # Apply LoRA and ensure only LoRA parameters are trainable - if "text" in self.peft_target_model: - self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config) - else: - self.vlm = get_peft_model(self.vlm, lora_config) - # assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here? - for name, param in self.vlm.named_parameters(): - if ( - "lora" in name and "text_model.model.layers.17" not in name - ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer - param.requires_grad = True - else: - param.requires_grad = False - - def merge_lora_weights(self): - """ - Merge LoRA weights into the base model. - """ - if "text" in self.peft_target_model: - self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload() - else: - self.vlm = self.vlm.merge_and_unload() - - def get_vlm_model( - self, - ): - if hasattr(self.vlm.model, "model"): # When using peft - return self.vlm.model.model - else: - return self.vlm.model - - def set_requires_grad(self): - if self.freeze_vision_encoder: - self.get_vlm_model().vision_model.eval() - for params in self.get_vlm_model().vision_model.parameters(): - params.requires_grad = False - if self.train_expert_only: - self.vlm.eval() - for params in self.vlm.parameters(): - params.requires_grad = False - else: - # To avoid unused params issue with distributed training - last_layers = [self.num_vlm_layers - 1] - if ( - self.num_vlm_layers != self.num_expert_layers - and self.num_vlm_layers % self.num_expert_layers == 0 - ): - last_layers.append(self.num_vlm_layers - 2) - frozen_layers = [ - "lm_head", - "text_model.model.norm.weight", - ] - for layer in last_layers: - frozen_layers.append(f"text_model.model.layers.{layer}.") - - for name, params in self.vlm.named_parameters(): - if any(k in name for k in frozen_layers): - params.requires_grad = False - # To avoid unused params issue with distributed training - for name, params in self.lm_expert.named_parameters(): - if "lm_head" in name: - params.requires_grad = False - - def train(self, mode: bool = True): - super().train(mode) - - if self.freeze_vision_encoder: - self.get_vlm_model().vision_model.eval() - - if self.train_expert_only: - self.vlm.eval() - - def embed_image(self, image: torch.Tensor): - patch_attention_mask = None - # Get sequence from the vision encoder - image_hidden_states = ( - self.get_vlm_model() - .vision_model( - pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), - patch_attention_mask=patch_attention_mask, - ) - .last_hidden_state - ) - # Modality projection & resampling - image_hidden_states = self.get_vlm_model().connector(image_hidden_states) - return image_hidden_states - - def embed_language_tokens(self, tokens: torch.Tensor): - return self.get_vlm_model().text_model.get_input_embeddings()(tokens) - - def forward_attn_layer( - self, - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache: bool = True, - fill_kv_cache: bool = True, - past_key_values=None, - ) -> list[torch.Tensor]: - query_states = [] - key_states = [] - value_states = [] - for i, hidden_states in enumerate(inputs_embeds): - layer = model_layers[i][layer_idx] - if hidden_states is None or layer is None: - continue - hidden_states = layer.input_layernorm(hidden_states) - - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - - hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) - - query_states.append(query_state) - key_states.append(key_state) - value_states.append(value_state) - - # B,L,H,D with L sequence length, H number of heads, D head dim - # concatenate on the number of embeddings/tokens - query_states = torch.cat(query_states, dim=1) - key_states = torch.cat(key_states, dim=1) - value_states = torch.cat(value_states, dim=1) - seq_len = query_states.shape[1] - if seq_len < position_ids.shape[1]: - _position_ids = position_ids[:, :seq_len] - _attention_mask = attention_mask[:, :seq_len, :seq_len] - else: - _position_ids = position_ids - _attention_mask = attention_mask - - attention_mask_ = _attention_mask - position_ids_ = _position_ids - - query_states = apply_rope(query_states, position_ids_) - key_states = apply_rope(key_states, position_ids_) - - if use_cache and past_key_values is None: - past_key_values = {} - - if use_cache: - if fill_kv_cache: - past_key_values[layer_idx] = { - "key_states": key_states, - "value_states": value_states, - } - else: - # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. - # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach - # the max len, then we (for instance) double the cache size. This implementation already exists - # in `transformers`. (molbap) - key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) - value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1) - - attention_interface = self.get_attention_interface() - - att_output = attention_interface( - attention_mask_, batch_size, head_dim, query_states, key_states, value_states - ) - return [att_output], past_key_values - - def forward_cross_attn_layer( - self, - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache: bool = True, - fill_kv_cache: bool = True, - past_key_values=None, - ) -> list[torch.Tensor]: - attention_interface = self.get_attention_interface() - - att_outputs = [] - assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), ( - f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" - ) - - if len(inputs_embeds) == 2 and not past_key_values: - # Prefix attention - seq_len = inputs_embeds[0].shape[1] - position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] - prefix_attention_mask = attention_mask[:, :seq_len, :seq_len] - - layer = model_layers[0][layer_idx] - - hidden_states = layer.input_layernorm(inputs_embeds[0]) - - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - - hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape) - - # B,L,H,D with L sequence length, H number of heads, D head dim - query_states = apply_rope(query_state, position_id) - key_states = apply_rope(key_state, position_id) - - att_output = attention_interface( - prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states - ) - att_outputs.append(att_output) - else: - expert_position_id = position_ids - - if use_cache and past_key_values is None: - past_key_values = {} - - if use_cache: - if fill_kv_cache: - past_key_values[layer_idx] = { - "key_states": key_states, - "value_states": value_states, - } - else: - # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. - # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach - # the max len, then we (for instance) double the cache size. This implementation already exists - # in `transformers`. (molbap) - key_states = past_key_values[layer_idx]["key_states"] - value_states = past_key_values[layer_idx]["value_states"] - - # Expert - expert_layer = model_layers[1][layer_idx] - if expert_layer is not None: - expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1]) - - expert_input_shape = expert_hidden_states.shape[:-1] - expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim) - - expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype) - expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape) - - _key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view( - *key_states.shape[:2], -1 - ) - expert_key_states = expert_layer.self_attn.k_proj(_key_states).view( - *_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim - ) # k_proj should have same dim as kv - - _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view( - *value_states.shape[:2], -1 - ) - expert_value_states = expert_layer.self_attn.v_proj(_value_states).view( - *_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim - ) - - expert_position_id = ( - expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values - ) # start from 0 - expert_attention_mask = attention_mask[ - :, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] : - ] # take into account kv - - expert_query_states = apply_rope(expert_query_state, expert_position_id) - - att_output = attention_interface( - expert_attention_mask, - batch_size, - head_dim, - expert_query_states, - expert_key_states, - expert_value_states, - ) - att_outputs.append(att_output) - else: - att_outputs.append(None) - - # att_output = att_output.to(dtype=models[i].dtype) - return att_outputs, past_key_values - - def get_model_layers(self, models: list) -> list: - vlm_layers = [] - expert_layers = [] - multiple_of = self.num_vlm_layers // self.num_expert_layers - for i in range(self.num_vlm_layers): - if multiple_of > 0 and i > 0 and i % multiple_of != 0: - expert_layer = None - else: - expert_layer_index = i // multiple_of if multiple_of > 0 else i - expert_layer = models[1].layers[expert_layer_index] - vlm_layers.append(models[0].layers[i]) - expert_layers.append(expert_layer) - return [vlm_layers, expert_layers] - - def forward( - self, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: List[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - fill_kv_cache: Optional[bool] = None, - ): - models = [self.get_vlm_model().text_model, self.lm_expert] - model_layers = self.get_model_layers(models) - for hidden_states in inputs_embeds: - # TODO this is very inefficient - # dtype is always the same, batch size too (if > 1 len) - # device could be trickier in multi gpu edge cases but that's it - if hidden_states is None: - continue - batch_size = hidden_states.shape[0] - - # RMSNorm - num_layers = self.num_vlm_layers - head_dim = self.vlm.config.text_config.head_dim - for layer_idx in range(num_layers): - if ( - fill_kv_cache - or "cross" not in self.attention_mode - or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0) - ): - att_outputs, past_key_values = self.forward_attn_layer( - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache=use_cache, - fill_kv_cache=fill_kv_cache, - past_key_values=past_key_values, - ) - else: - att_outputs, past_key_values = self.forward_cross_attn_layer( - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache=use_cache, - fill_kv_cache=fill_kv_cache, - past_key_values=past_key_values, - ) - outputs_embeds = [] - start = 0 - for i, hidden_states in enumerate(inputs_embeds): - layer = model_layers[i][layer_idx] - att_output = ( - att_outputs[i] if i < len(att_outputs) else att_outputs[0] - ) # in case of self_attn - if hidden_states is not None: - if layer is None: - outputs_embeds.append(hidden_states) - continue - end = start + hidden_states.shape[1] - - if att_output.dtype != layer.self_attn.o_proj.weight.dtype: - att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) - att_out = att_output[:, start:end] - out_emb = layer.self_attn.o_proj(att_out) - - out_emb += hidden_states - after_first_residual = out_emb.clone() - - out_emb = layer.post_attention_layernorm(out_emb) - out_emb = layer.mlp(out_emb) - - out_emb += after_first_residual - - outputs_embeds.append(out_emb) - - start = end if len(att_outputs) == 1 else 0 - else: - outputs_embeds.append(None) - - inputs_embeds = outputs_embeds - - # final norm - outputs_embeds = [] - for i, hidden_states in enumerate(inputs_embeds): - if hidden_states is not None: - out_emb = models[i].norm(hidden_states) - outputs_embeds.append(out_emb) - else: - outputs_embeds.append(None) - return outputs_embeds, past_key_values - - def get_attention_interface(self): - attention_interface = self.eager_attention_forward - return attention_interface - - def eager_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states - ): - num_att_heads = self.num_attention_heads - num_key_value_heads = self.num_key_value_heads - num_key_value_groups = num_att_heads // num_key_value_heads - - sequence_length = key_states.shape[1] - - key_states = key_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - key_states = key_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - value_states = value_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - value_states = value_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - # Attention here is upcasted to float32 to match the original eager implementation. - query_states = query_states.to(dtype=torch.float32) - key_states = key_states.to(dtype=torch.float32) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - - att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - att_weights *= head_dim**-0.5 - - att_weights = att_weights.to(dtype=torch.float32) - big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py - masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) - probs = nn.functional.softmax(masked_att_weights, dim=-1) - probs = probs.to(dtype=value_states.dtype) - - att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) - - att_output = att_output.permute(0, 2, 1, 3) - # we use -1 because sequence length can change - att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) - - return att_output diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index a1e0dac84..046c60848 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -66,58 +66,6 @@ def resolve_delta_timestamps( return delta_timestamps - -def make_dataset1(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset: - """Handles the logic of setting up delta timestamps and image transforms before creating a dataset. - - Args: - cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig. - - Raises: - NotImplementedError: The MultiLeRobotDataset is currently deactivated. - - Returns: - LeRobotDataset | MultiLeRobotDataset - """ - image_transforms = ( - ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None - ) - - if isinstance(cfg.dataset.repo_id, str): - ds_meta = LeRobotDatasetMetadata( - cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision - ) - delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - episodes=cfg.dataset.episodes, - delta_timestamps=delta_timestamps, - image_transforms=image_transforms, - revision=cfg.dataset.revision, - video_backend=cfg.dataset.video_backend, - ) - else: - raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") - dataset = MultiLeRobotDataset( - cfg.dataset.repo_id, - # TODO(aliberts): add proper support for multi dataset - # delta_timestamps=delta_timestamps, - image_transforms=image_transforms, - video_backend=cfg.dataset.video_backend, - ) - logging.info( - "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " - f"{pformat(dataset.repo_id_to_index, indent=2)}" - ) - - if cfg.dataset.use_imagenet_stats: - for key in dataset.meta.camera_keys: - for stats_type, stats in IMAGENET_STATS.items(): - dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) - - return dataset - def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset: """Handles the logic of setting up delta timestamps and image transforms before creating a dataset. @@ -144,7 +92,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas revision = getattr(cfg.dataset, "revision", None) ds_meta = LeRobotDatasetMetadata( cfg.dataset.repo_id, - local_files_only=cfg.dataset.local_files_only, feature_keys_mapping=feature_keys_mapping, revision=revision, ) @@ -157,7 +104,6 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms=image_transforms, revision=revision, video_backend=cfg.dataset.video_backend, - local_files_only=cfg.dataset.local_files_only, feature_keys_mapping=feature_keys_mapping, max_action_dim=cfg.dataset.max_action_dim, max_state_dim=cfg.dataset.max_state_dim, @@ -170,12 +116,13 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas for i in range(len(repo_id)): ds_meta = LeRobotDatasetMetadata( repo_id[i], - local_files_only=cfg.dataset.local_files_only, feature_keys_mapping=feature_keys_mapping, ) # FIXME(mshukor): ? delta_timestamps[repo_id[i]] = resolve_delta_timestamps(cfg.policy, ds_meta) episodes[repo_id[i]] = EPISODES_DATASET_MAPPING.get(repo_id[i], cfg.dataset.episodes) - training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None) + # training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None) + #FIXME: (jadechoghari): check support for training features + training_features = None dataset = MultiLeRobotDataset( repo_id, # TODO(aliberts): add proper support for multi dataset @@ -183,11 +130,10 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas delta_timestamps=delta_timestamps, image_transforms=image_transforms, video_backend=cfg.dataset.video_backend, - local_files_only=cfg.dataset.local_files_only, sampling_weights=sampling_weights, feature_keys_mapping=feature_keys_mapping, - max_action_dim=cfg.dataset.max_action_dim, - max_state_dim=cfg.dataset.max_state_dim, + max_action_dim=cfg.policy.max_action_dim, + max_state_dim=cfg.policy.max_state_dim, max_num_images=cfg.dataset.max_num_images, max_image_dim=cfg.dataset.max_image_dim, train_on_all_features=cfg.dataset.train_on_all_features, diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 98cb32387..bb3477b6c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -82,7 +82,7 @@ from lerobot.datasets.video_utils import ( ) # mustafa stuff here -from lerobot.common.datasets.utils_must import ( +from lerobot.datasets.utils_must import ( reshape_features_to_max_dim, keep_datasets_with_valid_fps, keep_datasets_with_the_same_features_per_robot_type, @@ -97,7 +97,7 @@ from lerobot.common.datasets.utils_must import ( OBS_IMAGE_3, TASKS_KEYS_MAPPING, ) -from lerobot.common.constants import ( +from lerobot.constants import ( ACTION, OBS_ENV_STATE, OBS_STATE, @@ -124,7 +124,6 @@ class LeRobotDatasetMetadata: feature_keys_mapping: dict[str, str] | None = None, revision: str | None = None, force_cache_sync: bool = False, - feature_keys_mapping: dict[str, str] | None = None, ): self.repo_id = repo_id self.local_files_only = local_files_only diff --git a/src/lerobot/datasets/utils_must.py b/src/lerobot/datasets/utils_must.py index 040627dfa..1ae3e2ae2 100644 --- a/src/lerobot/datasets/utils_must.py +++ b/src/lerobot/datasets/utils_must.py @@ -3,10 +3,19 @@ Utils function by Mustafa to refactor """ import torch import numpy as np -from lerobot.common.datasets.compute_stats import ( +from lerobot.datasets.compute_stats import ( aggregate_stats ) from collections import defaultdict +from torch.utils.data.dataloader import default_collate +import torch.nn.functional as F + +import torch +from typing import Dict, List + + + +from typing import Dict, List OBS_IMAGE = "observation.image" OBS_IMAGE_2 = "observation.image2" OBS_IMAGE_3 = "observation.image3" @@ -170,6 +179,9 @@ def pad_tensor( is_numpy = isinstance(tensor, np.ndarray) if is_numpy: tensor = torch.tensor(tensor) + if tensor.ndim == 0: + # Scalar — return as-is, no padding needed + return tensor pad = max_size - tensor.shape[pad_dim] if pad > 0: pad_sizes = (0, pad) # pad right @@ -189,6 +201,8 @@ def map_dict_keys(item: dict, feature_keys_mapping: dict, training_features: lis else: if training_features is None or key in training_features or pad_key in key: features[key] = item[key] + + # breakpoint() return features def find_start_of_motion(velocities, window_size, threshold, motion_buffer): @@ -228,3 +242,48 @@ TRAINING_FEATURES = { 1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2], 2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3], } + +def is_batch_need_padding(values: list[torch.Tensor], pad_dim: int = -1) -> int: + return len(values[0].shape) > 0 # and len(set([v.shape[pad_dim] for v in values])) > 1 + + +def pad_tensor_to_shape(tensor: torch.Tensor, target_shape: tuple, pad_value: float = 0.0) -> torch.Tensor: + """Pads a tensor to the target shape (right/bottom only).""" + pad = [] + for actual, target in zip(reversed(tensor.shape), reversed(target_shape)): + pad.extend([0, max(target - actual, 0)]) + return F.pad(tensor, pad, value=pad_value) + + +def multidataset_collate_fn( + batch: List[Dict[str, torch.Tensor]], + keys_to_max_dim: Dict[str, tuple] = {}, + pad_value: float = 0.0, +) -> Dict[str, torch.Tensor]: + """ + Pads tensors to given target shape (if provided), otherwise uses per-batch max. + Supports 1D (e.g. action), 3D (e.g. [C,H,W] images). + """ + collated_batch = [{} for _ in range(len(batch))] + batch_keys = batch[0].keys() + + for key in batch_keys: + values = [sample[key] for sample in batch] + sample = values[0] + + if not isinstance(sample, torch.Tensor): + for i in range(len(batch)): + collated_batch[i][key] = values[i] + continue + + # use user-specified shape if available + if key in keys_to_max_dim and keys_to_max_dim[key] is not None: + target_shape = keys_to_max_dim[key] + else: + # compute per-batch max shape + target_shape = tuple(max(v.shape[i] for v in values) for i in range(sample.ndim)) + + for i in range(len(batch)): + collated_batch[i][key] = pad_tensor_to_shape(values[i], target_shape, pad_value=pad_value) + + return default_collate(collated_batch) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 2f2e88de6..d049c0203 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -51,7 +51,8 @@ from lerobot.utils.utils import ( init_logging, ) from lerobot.utils.wandb_utils import WandBLogger - +from lerobot.datasets.utils_must import multidataset_collate_fn +from functools import partial def update_policy( train_metrics: MetricsTracker, @@ -173,7 +174,9 @@ def train(cfg: TrainPipelineConfig): else: shuffle = True sampler = None - + + keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {}) + collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim) dataloader = torch.utils.data.DataLoader( dataset, num_workers=cfg.num_workers,