From 3cb14248a46d5d1b4fb2fe411e032c6fa88b4d0b Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 7 Nov 2025 14:28:36 +0100 Subject: [PATCH] add franka action --- src/lerobot/policies/factory.py | 2 +- src/lerobot/policies/xvla/action_hub.py | 67 +- .../policies/xvla/configuration_florence2.py | 14 +- .../policies/xvla/configuration_xvla.py | 62 +- .../policies/xvla/modeling_florence2.py | 720 +++++++++--------- src/lerobot/policies/xvla/modeling_xvla.py | 16 +- src/lerobot/policies/xvla/processing_xvla.py | 27 +- src/lerobot/policies/xvla/transformer.py | 52 +- test_xvla.py | 5 +- train.sh | 1 + 10 files changed, 508 insertions(+), 458 deletions(-) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 11a73c489..0f8afafe4 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -39,8 +39,8 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import validate_visual_features_consistency -from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.processor.converters import ( batch_to_transition, diff --git a/src/lerobot/policies/xvla/action_hub.py b/src/lerobot/policies/xvla/action_hub.py index 3a6f82170..8e71cd3cf 100644 --- a/src/lerobot/policies/xvla/action_hub.py +++ b/src/lerobot/policies/xvla/action_hub.py @@ -15,18 +15,21 @@ # ------------------------------------------------------------------------------ from __future__ import annotations -from typing import Iterable, Tuple, Dict, Type + +from collections.abc import Iterable + import torch import torch.nn as nn # ============================================================================= # Registry # ============================================================================= -ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {} +ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {} def register_action(name: str): """Decorator for registering a new action space.""" + def _wrap(cls): key = name.lower() if key in ACTION_REGISTRY: @@ -34,10 +37,11 @@ def register_action(name: str): ACTION_REGISTRY[key] = cls cls.name = key return cls + return _wrap -def build_action_space(name: str, **kwargs) -> "BaseActionSpace": +def build_action_space(name: str, **kwargs) -> BaseActionSpace: """Instantiate a registered action space by name.""" key = name.lower() if key not in ACTION_REGISTRY: @@ -62,7 +66,7 @@ class BaseActionSpace(nn.Module): name: str = "base" dim_action: int = 0 - gripper_idx: Tuple[int, ...] = () + gripper_idx: tuple[int, ...] = () def __init__(self): super().__init__() @@ -70,10 +74,10 @@ class BaseActionSpace(nn.Module): # --------------------------------------------------------------------- # Core supervised loss # --------------------------------------------------------------------- - def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: + def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]: raise NotImplementedError - def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: + def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]: """Alias for compute_loss.""" return self.compute_loss(pred, target) @@ -85,7 +89,7 @@ class BaseActionSpace(nn.Module): proprio: torch.Tensor, action: torch.Tensor, mode: str = "train", - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """Default: return unchanged.""" return proprio, action @@ -137,14 +141,14 @@ class EE6DActionSpace(BaseActionSpace): # XYZ position pos_loss = ( - self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + - self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) + self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + + self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) ) * self.XYZ_SCALE # Rotation 6D rot_loss = ( - self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + - self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) + self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + + self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) ) * self.ROT_SCALE return { @@ -236,14 +240,16 @@ class AGIBOTEE6DActionSpace(BaseActionSpace): B, T, D = pred.shape _ensure_indices_valid(D, self.gripper_idx, "gripper_idx") - gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE + gripper_loss = ( + self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE + ) pos_loss = ( - self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + - self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) + self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) + + self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2]) ) * self.XYZ_SCALE rot_loss = ( - self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + - self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) + self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) + + self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2]) ) * self.ROT_SCALE return { @@ -261,6 +267,32 @@ class AGIBOTEE6DActionSpace(BaseActionSpace): return action +@register_action("franka_joint7") +class FrankaJoint7ActionSpace(BaseActionSpace): + """Franka Panda joint-space: 7 joints, no gripper.""" + + dim_action = 7 + JOINTS_SCALE = 1.0 + + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def compute_loss(self, pred, target): + assert pred.shape == target.shape, "pred/target shapes must match" + B, T, D = pred.shape + joints_loss = self.mse(pred, target) * self.JOINTS_SCALE + return {"joints_loss": joints_loss} + + def preprocess(self, proprio, action, mode="train"): + """No preprocessing needed for 7 joint actions.""" + return proprio, action + + def postprocess(self, action: torch.Tensor) -> torch.Tensor: + """Return directly (no sigmoid since no gripper).""" + return action + + # ============================================================================= # Exports # ============================================================================= @@ -271,5 +303,6 @@ __all__ = [ "EE6DActionSpace", "JointActionSpace", "AGIBOTEE6DActionSpace", + "FrankaJoint7ActionSpace", "ACTION_REGISTRY", -] \ No newline at end of file +] diff --git a/src/lerobot/policies/xvla/configuration_florence2.py b/src/lerobot/policies/xvla/configuration_florence2.py index ed895ea68..7dec9d68e 100644 --- a/src/lerobot/policies/xvla/configuration_florence2.py +++ b/src/lerobot/policies/xvla/configuration_florence2.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,20 +11,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings + """ Florence-2 configuration""" -from typing import Optional -from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) + class Florence2VisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel - according to the specified arguments, defining the model architecture. Instantiating a configuration with the + according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Florence2VisionModel architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the @@ -118,7 +117,6 @@ class Florence2VisionConfig(PretrainedConfig): super().__init__(**kwargs) - class Florence2LanguageConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART @@ -269,10 +267,11 @@ class Florence2LanguageConfig(PretrainedConfig): "The config can simply be saved and uploaded again to be fixed." ) + class Florence2Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an - Florence-2 model according to the specified arguments, defining the model architecture. + Florence-2 model according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -281,7 +280,7 @@ class Florence2Config(PretrainedConfig): vision_config (`Florence2VisionConfig`, *optional*): Custom vision config or dict text_config (`Union[AutoConfig, dict]`, *optional*): - The config object of the text backbone. + The config object of the text backbone. ignore_index (`int`, *optional*, defaults to -100): The ignore index for the loss function. vocab_size (`int`, *optional*, defaults to 51289): @@ -335,5 +334,4 @@ class Florence2Config(PretrainedConfig): if text_config is not None: self.text_config = Florence2LanguageConfig(**text_config) - super().__init__(**kwargs) diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py index 3383dde6a..2658b87b0 100644 --- a/src/lerobot/policies/xvla/configuration_xvla.py +++ b/src/lerobot/policies/xvla/configuration_xvla.py @@ -124,43 +124,37 @@ class XVLAConfig(PreTrainedConfig): # TODO: jadechoghari: provide default way, and do not hardcode # Ensure vision_config and text_config are provided with defaults if not specified config_dict = dict(self.florence_config) - if 'vision_config' not in config_dict or config_dict['vision_config'] is None: + if "vision_config" not in config_dict or config_dict["vision_config"] is None: # Provide default vision config - config_dict['vision_config'] = { - 'model_type': 'davit', - 'drop_path_rate': 0.1, - 'patch_size': [14, 7, 7, 7], - 'patch_stride': [4, 2, 2, 2], - 'patch_padding': [3, 1, 1, 1], - 'patch_prenorm': [False, True, True, True], - 'dim_embed': [256, 512, 1024, 2048], - 'num_heads': [8, 16, 32, 64], - 'num_groups': [8, 16, 32, 64], - 'depths': [1, 1, 9, 1], - 'window_size': 12, - 'projection_dim': 1024, - 'visual_temporal_embedding': { - 'type': 'COSINE', - 'max_temporal_embeddings': 100 - }, - 'image_pos_embed': { - 'type': 'learned_abs_2d', - 'max_pos_embeddings': 50 - }, - 'image_feature_source': ['spatial_avg_pool', 'temporal_avg_pool'] + config_dict["vision_config"] = { + "model_type": "davit", + "drop_path_rate": 0.1, + "patch_size": [14, 7, 7, 7], + "patch_stride": [4, 2, 2, 2], + "patch_padding": [3, 1, 1, 1], + "patch_prenorm": [False, True, True, True], + "dim_embed": [256, 512, 1024, 2048], + "num_heads": [8, 16, 32, 64], + "num_groups": [8, 16, 32, 64], + "depths": [1, 1, 9, 1], + "window_size": 12, + "projection_dim": 1024, + "visual_temporal_embedding": {"type": "COSINE", "max_temporal_embeddings": 100}, + "image_pos_embed": {"type": "learned_abs_2d", "max_pos_embeddings": 50}, + "image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"], } - if 'text_config' not in config_dict or config_dict['text_config'] is None: + if "text_config" not in config_dict or config_dict["text_config"] is None: # Provide default text config - config_dict['text_config'] = { - 'model_type': 'florence2_language', - 'vocab_size': 51289, - 'd_model': 1024, - 'encoder_layers': 12, - 'decoder_layers': 12, - 'encoder_attention_heads': 16, - 'decoder_attention_heads': 16, - 'encoder_ffn_dim': 4096, - 'decoder_ffn_dim': 4096, + config_dict["text_config"] = { + "model_type": "florence2_language", + "vocab_size": 51289, + "d_model": 1024, + "encoder_layers": 12, + "decoder_layers": 12, + "encoder_attention_heads": 16, + "decoder_attention_heads": 16, + "encoder_ffn_dim": 4096, + "decoder_ffn_dim": 4096, } self._florence_config_obj = Florence2Config(**config_dict) return self._florence_config_obj diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py index e020c1b11..84f075807 100644 --- a/src/lerobot/policies/xvla/modeling_florence2.py +++ b/src/lerobot/policies/xvla/modeling_florence2.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,39 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Florence-2 model.""" -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +"""PyTorch Florence-2 model.""" import math -import torch -import torch.utils.checkpoint -from torch import nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from torch.nn import CrossEntropyLoss from collections import OrderedDict +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torch.utils.checkpoint as checkpoint from einops import rearrange -from timm.layers import DropPath, trunc_normal_ - -from transformers.modeling_utils import PreTrainedModel -from transformers.generation.utils import GenerationMixin -from transformers.utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, - replace_return_docstrings, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, -) -from .configuration_florence2 import Florence2Config -from .configuration_florence2 import Florence2LanguageConfig -from .configuration_florence2 import Florence2VisionConfig - - +from timm.layers import DropPath +from torch import nn +from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN +from transformers.generation.utils import GenerationMixin from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -58,7 +40,18 @@ from transformers.modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_florence2 import Florence2Config, Florence2LanguageConfig, Florence2VisionConfig if is_flash_attn_2_available(): from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -67,6 +60,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Florence2Config" + class LearnedAbsolutePositionEmbedding2D(nn.Module): """ This module learns positional embeddings up to a fixed maximum size. @@ -79,18 +73,20 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module): def forward(self, pixel_values): """ - pixel_values: (batch_size, height, width, num_channels) + pixel_values: (batch_size, height, width, num_channels) returns: (batch_size, height, width, embedding_dim * 2) """ if len(pixel_values.shape) != 4: - raise ValueError('pixel_values must be a 4D tensor') + raise ValueError("pixel_values must be a 4D tensor") height, width = pixel_values.shape[1:3] width_values = torch.arange(width, device=pixel_values.device) height_values = torch.arange(height, device=pixel_values.device) x_emb = self.column_embeddings(width_values) y_emb = self.row_embeddings(height_values) # (height, width, embedding_dim * 2) - pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) + pos = torch.cat( + [x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1 + ) # (embedding_dim * 2, height, width) pos = pos.permute(2, 0, 1) pos = pos.unsqueeze(0) @@ -100,6 +96,7 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module): pos = pos.permute(0, 2, 3, 1) return pos + class PositionalEmbeddingCosine1D(nn.Module): """ This class implements a very simple positional encoding. It follows closely @@ -111,22 +108,17 @@ class PositionalEmbeddingCosine1D(nn.Module): dropout_prob: The dropout probability. max_seq_len: The maximum length to precompute the positional encodings. """ - def __init__( - self, - embed_dim: int = 512, - max_seq_len: int = 1024) -> None: - super(PositionalEmbeddingCosine1D, self).__init__() + + def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: + super().__init__() self.embed_dim = embed_dim self.max_seq_len = max_seq_len # Generate the sinusoidal arrays. factor = math.log(10000) - denominator = torch.exp( - -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) + denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) # Matrix where rows correspond to a positional embedding as a function # of the position index (i.e., the row index). - frequencies = \ - torch.arange(0, self.max_seq_len) \ - .reshape(self.max_seq_len, 1) * denominator + frequencies = torch.arange(0, self.max_seq_len).reshape(self.max_seq_len, 1) * denominator pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) # Populate uneven entries. pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) @@ -150,11 +142,10 @@ class PositionalEmbeddingCosine1D(nn.Module): assert 2 <= shape_len <= 3 len_seq = seq_embeds.size(-2) assert len_seq <= self.max_seq_len - pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] + pos_embeds = self.pos_idx_to_embed[0 : seq_embeds.size(-2), :] # Adapt pre-computed positional embeddings to the input. if shape_len == 3: - pos_embeds = pos_embeds.view( - (1, pos_embeds.size(0), pos_embeds.size(1))) + pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1))) return pos_embeds @@ -166,11 +157,9 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module): embed_dim: The dimension of the embeddings. max_seq_len: The maximum length to precompute the positional encodings. """ - def __init__( - self, - embedding_dim: int = 512, - num_pos: int = 1024) -> None: - super(LearnedAbsolutePositionEmbedding1D, self).__init__() + + def __init__(self, embedding_dim: int = 512, num_pos: int = 1024) -> None: + super().__init__() self.embeddings = nn.Embedding(num_pos, embedding_dim) self.num_pos = num_pos @@ -194,12 +183,10 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module): pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device)) # Adapt pre-computed positional embeddings to the input. if shape_len == 3: - pos_embeds = pos_embeds.view( - (1, pos_embeds.size(0), pos_embeds.size(1))) + pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1))) return pos_embeds - class MySequential(nn.Sequential): def forward(self, *inputs): for module in self._modules.values(): @@ -243,11 +230,15 @@ class Mlp(nn.Module): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.net = nn.Sequential(OrderedDict([ - ("fc1", nn.Linear(in_features, hidden_features)), - ("act", act_layer()), - ("fc2", nn.Linear(hidden_features, out_features)) - ])) + self.net = nn.Sequential( + OrderedDict( + [ + ("fc1", nn.Linear(in_features, hidden_features)), + ("act", act_layer()), + ("fc2", nn.Linear(hidden_features, out_features)), + ] + ) + ) def forward(self, x, size): return self.net(x), size @@ -264,12 +255,7 @@ class DepthWiseConv2d(nn.Module): ): super().__init__() self.dw = nn.Conv2d( - dim_in, dim_in, - kernel_size=kernel_size, - padding=padding, - groups=dim_in, - stride=stride, - bias=bias + dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias ) def forward(self, x, size): @@ -284,28 +270,15 @@ class DepthWiseConv2d(nn.Module): class ConvEmbed(nn.Module): - """ Image to Patch Embedding - """ + """Image to Patch Embedding""" def __init__( - self, - patch_size=7, - in_chans=3, - embed_dim=64, - stride=4, - padding=2, - norm_layer=None, - pre_norm=True + self, patch_size=7, in_chans=3, embed_dim=64, stride=4, padding=2, norm_layer=None, pre_norm=True ): super().__init__() self.patch_size = patch_size - self.proj = nn.Conv2d( - in_chans, embed_dim, - kernel_size=patch_size, - stride=stride, - padding=padding - ) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) dim_norm = in_chans if pre_norm else embed_dim self.norm = norm_layer(dim_norm) if norm_layer else None @@ -317,15 +290,12 @@ class ConvEmbed(nn.Module): if len(x.size()) == 3: if self.norm and self.pre_norm: x = self.norm(x) - x = rearrange( - x, 'b (h w) c -> b c h w', - h=H, w=W - ) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) x = self.proj(x) _, _, H, W = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') + x = rearrange(x, "b c h w -> b (h w) c") if self.norm and not self.pre_norm: x = self.norm(x) @@ -333,7 +303,6 @@ class ConvEmbed(nn.Module): class ChannelAttention(nn.Module): - def __init__(self, dim, groups=8, qkv_bias=True): super().__init__() @@ -357,25 +326,31 @@ class ChannelAttention(nn.Module): class ChannelBlock(nn.Module): - - def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, - drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - conv_at_attn=True, conv_at_ffn=True): + def __init__( + self, + dim, + groups, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True, + ): super().__init__() - drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None self.channel_attn = PreNorm( - norm_layer(dim), - ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), - drop_path + norm_layer(dim), ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), drop_path ) self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None self.ffn = PreNorm( norm_layer(dim), - Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), - drop_path + Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer), + drop_path, ) def forward(self, x, size): @@ -398,9 +373,9 @@ def window_partition(x, window_size: int): def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): - B = batch_size + B = batch_size # this will cause onnx conversion failed for dynamic axis, because treated as constant - # int(windows.shape[0] / (H * W / window_size / window_size)) + # int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -408,7 +383,6 @@ def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): class WindowAttention(nn.Module): def __init__(self, dim, num_heads, window_size, qkv_bias=True): - super().__init__() self.dim = dim self.window_size = window_size @@ -422,7 +396,6 @@ class WindowAttention(nn.Module): self.softmax = nn.Softmax(dim=-1) def forward(self, x, size): - H, W = size B, L, C = x.shape assert L == H * W, "input feature has wrong size" @@ -446,16 +419,14 @@ class WindowAttention(nn.Module): q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = q @ k.transpose(-2, -1) attn = self.softmax(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) # merge windows - x = x.view( - -1, self.window_size, self.window_size, C - ) + x = x.view(-1, self.window_size, self.window_size, C) x = window_reverse(x, B, self.window_size, Hp, Wp) if pad_r > 0 or pad_b > 0: @@ -467,25 +438,32 @@ class WindowAttention(nn.Module): class SpatialBlock(nn.Module): - - def __init__(self, dim, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, - norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): + def __init__( + self, + dim, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True, + ): super().__init__() - drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None self.window_attn = PreNorm( - norm_layer(dim), - WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), - drop_path + norm_layer(dim), WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), drop_path ) self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None self.ffn = PreNorm( norm_layer(dim), - Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), - drop_path + Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer), + drop_path, ) def forward(self, x, size): @@ -500,7 +478,7 @@ class SpatialBlock(nn.Module): class DaViT(nn.Module): - """ DaViT: Dual-Attention Transformer + """DaViT: Dual-Attention Transformer Args: in_chans (int): Number of input image channels. Default: 3. @@ -535,14 +513,14 @@ class DaViT(nn.Module): num_heads=(3, 6, 12, 24), num_groups=(3, 6, 12, 24), window_size=7, - mlp_ratio=4., + mlp_ratio=4.0, qkv_bias=True, drop_path_rate=0.1, norm_layer=nn.LayerNorm, enable_checkpoint=False, conv_at_attn=True, conv_at_ffn=True, - ): + ): super().__init__() self.num_classes = num_classes @@ -554,7 +532,7 @@ class DaViT(nn.Module): assert self.num_stages == len(self.num_heads) == len(self.num_groups) num_stages = len(embed_dims) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)] depth_offset = 0 convs = [] @@ -567,41 +545,48 @@ class DaViT(nn.Module): in_chans=in_chans if i == 0 else self.embed_dims[i - 1], embed_dim=self.embed_dims[i], norm_layer=norm_layer, - pre_norm=patch_prenorm[i] + pre_norm=patch_prenorm[i], ) convs.append(conv_embed) block = MySequential( *[ - MySequential(OrderedDict([ - ( - 'spatial_block', SpatialBlock( - embed_dims[i], - num_heads[i], - window_size, - drop_path_rate=dpr[depth_offset+j*2], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - ) - ), - ( - 'channel_block', ChannelBlock( - embed_dims[i], - num_groups[i], - drop_path_rate=dpr[depth_offset+j*2+1], - qkv_bias=qkv_bias, - mlp_ratio=mlp_ratio, - conv_at_attn=conv_at_attn, - conv_at_ffn=conv_at_ffn, - ) + MySequential( + OrderedDict( + [ + ( + "spatial_block", + SpatialBlock( + embed_dims[i], + num_heads[i], + window_size, + drop_path_rate=dpr[depth_offset + j * 2], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ), + ), + ( + "channel_block", + ChannelBlock( + embed_dims[i], + num_groups[i], + drop_path_rate=dpr[depth_offset + j * 2 + 1], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ), + ), + ] ) - ])) for j in range(depths[i]) + ) + for j in range(depths[i]) ] ) blocks.append(block) - depth_offset += depths[i]*2 + depth_offset += depths[i] * 2 self.convs = nn.ModuleList(convs) self.blocks = nn.ModuleList(blocks) @@ -616,7 +601,7 @@ class DaViT(nn.Module): def forward_features_unpool(self, x): """ - forward until avg pooling + forward until avg pooling Args: x (_type_): input image tensor """ @@ -644,7 +629,7 @@ class DaViT(nn.Module): x = self.forward_features(x) x = self.head(x) return x - + @classmethod def from_config(cls, config): return cls( @@ -661,12 +646,11 @@ class DaViT(nn.Module): ) - - if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -712,7 +696,10 @@ class Florence2LearnedPositionalEmbedding(nn.Embedding): bsz, seq_len = input_ids.shape[:2] positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + past_key_values_length, + past_key_values_length + seq_len, + dtype=torch.long, + device=self.weight.device, ).expand(bsz, -1) return super().forward(positions + self.offset) @@ -723,7 +710,9 @@ class Florence2ScaledWordEmbedding(nn.Embedding): This module overrides nn.Embeddings' forward by multiplying with embeddings scale. """ - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + def __init__( + self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0 + ): super().__init__(num_embeddings, embedding_dim, padding_idx) self.embed_scale = embed_scale @@ -742,7 +731,7 @@ class Florence2Attention(nn.Module): is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - config: Optional[Florence2LanguageConfig] = None, + config: Florence2LanguageConfig | None = None, ): super().__init__() self.embed_dim = embed_dim @@ -771,12 +760,12 @@ class Florence2Attention(nn.Module): def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, + key_value_states: torch.Tensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -854,7 +843,9 @@ class Florence2Attention(nn.Module): f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: @@ -911,12 +902,12 @@ class Florence2FlashAttention2(Florence2Attention): def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, + key_value_states: torch.Tensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: # Florence2FlashAttention2 attention does not support output_attentions if output_attentions: raise ValueError("Florence2FlashAttention2 attention does not support output_attentions") @@ -1010,7 +1001,14 @@ class Florence2FlashAttention2(Florence2Attention): # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -1096,7 +1094,9 @@ class Florence2FlashAttention2(Florence2Attention): else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) return ( query_layer, @@ -1112,12 +1112,12 @@ class Florence2SdpaAttention(Florence2Attention): def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, + key_value_states: torch.Tensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. @@ -1245,8 +1245,8 @@ class Florence2EncoderLayer(nn.Module): hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + output_attentions: bool | None = False, + ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -1271,7 +1271,9 @@ class Florence2EncoderLayer(nn.Module): residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1324,15 +1326,15 @@ class Florence2DecoderLayer(nn.Module): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = True, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + cross_attn_layer_head_mask: torch.Tensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = True, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -1394,7 +1396,9 @@ class Florence2DecoderLayer(nn.Module): # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -1411,7 +1415,6 @@ class Florence2DecoderLayer(nn.Module): return outputs - class Florence2LanguagePreTrainedModel(PreTrainedModel): config_class = Florence2LanguageConfig base_model_prefix = "model" @@ -1437,10 +1440,7 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel): for name, _ in module.named_parameters(): if name == "bias": nn.init.constant_(module.bias, 0) - elif isinstance(module, nn.LayerNorm): - nn.init.constant_(module.weight, 1.0) - nn.init.constant_(module.bias, 0) - elif isinstance(module, nn.BatchNorm2d): + elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm2d): nn.init.constant_(module.weight, 1.0) nn.init.constant_(module.bias, 0) @@ -1465,7 +1465,7 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: Florence2LanguageConfig, embed_tokens: nn.Embedding | None = None): super().__init__(config) self.dropout = config.dropout @@ -1505,13 +1505,13 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: + attention_mask: torch.Tensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutput: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1548,7 +1548,9 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1652,7 +1654,7 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: Florence2LanguageConfig, embed_tokens: nn.Embedding | None = None): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -1690,18 +1692,18 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.FloatTensor | None = None, + encoder_attention_mask: torch.LongTensor | None = None, + head_mask: torch.Tensor | None = None, + cross_attn_head_mask: torch.Tensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1767,7 +1769,9 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1776,7 +1780,9 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) elif input_ids is not None: input = input_ids input_shape = input.shape @@ -1853,7 +1859,9 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel): next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + for attn_mask, mask_name in zip( + [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] + ): if attn_mask is not None: if attn_mask.size()[0] != (len(self.layers)): raise ValueError( @@ -1967,21 +1975,21 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqModelOutput]: + attention_mask: torch.Tensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + head_mask: torch.Tensor | None = None, + decoder_head_mask: torch.Tensor | None = None, + cross_attn_head_mask: torch.Tensor | None = None, + encoder_outputs: list[torch.FloatTensor] | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + decoder_inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | Seq2SeqModelOutput: # different to other models, Florence2 automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided if decoder_input_ids is None and decoder_inputs_embeds is None: @@ -1996,7 +2004,9 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel): input_ids, self.config.pad_token_id, self.config.decoder_start_token_id ) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -2078,7 +2088,9 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, **kwargs) -> nn.Embedding: + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: int | None = None, **kwargs + ) -> nn.Embedding: new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings @@ -2088,7 +2100,9 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel if new_num_tokens <= old_num_tokens: new_bias = self.final_logits_bias[:, :new_num_tokens] else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + extra_bias = torch.zeros( + (1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device + ) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) @@ -2101,22 +2115,22 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqLMOutput]: + attention_mask: torch.Tensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + head_mask: torch.Tensor | None = None, + decoder_head_mask: torch.Tensor | None = None, + cross_attn_head_mask: torch.Tensor | None = None, + encoder_outputs: list[torch.FloatTensor] | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + decoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | Seq2SeqLMOutput: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -2227,11 +2241,15 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past[:2] + ) + layer_past[2:], ) return reordered_past + @dataclass class Florence2Seq2SeqLMOutput(ModelOutput): """ @@ -2291,17 +2309,18 @@ class Florence2Seq2SeqLMOutput(ModelOutput): image_hidden_states of the model produced by the vision encoder """ - loss: Optional[torch.FloatTensor] = None + + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: tuple[tuple[torch.FloatTensor]] | None = None + decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None + decoder_attentions: tuple[torch.FloatTensor, ...] | None = None + cross_attentions: tuple[torch.FloatTensor, ...] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None + encoder_attentions: tuple[torch.FloatTensor, ...] | None = None + image_hidden_states: tuple[torch.FloatTensor, ...] | None = None FLORENCE2_START_DOCSTRING = r""" @@ -2413,6 +2432,7 @@ FLORENCE2_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ + @add_start_docstrings( """The FLORENCE2 vision model without any head""", FLORENCE2_START_DOCSTRING, @@ -2420,16 +2440,16 @@ FLORENCE2_INPUTS_DOCSTRING = r""" class Florence2VisionModel(Florence2PreTrainedModel): def __init__(self, config: Florence2VisionConfig): super().__init__(config) - assert config.model_type == 'davit', 'only DaViT is supported for now' + assert config.model_type == "davit", "only DaViT is supported for now" self.vision_tower = DaViT.from_config(config=config) self.post_init() - + def forward(self, pixel_values): if len(pixel_values.shape) == 4: x = self.vision_tower.forward_features_unpool(pixel_values) else: - raise ValueError(f'invalid image shape {pixel_values.shape}') + raise ValueError(f"invalid image shape {pixel_values.shape}") return x @@ -2440,40 +2460,37 @@ class Florence2VisionModel(Florence2PreTrainedModel): class Florence2VisionModelWithProjection(Florence2PreTrainedModel): def __init__(self, config: Florence2VisionConfig): super().__init__(config) - assert config.model_type == 'davit', 'only DaViT is supported for now' + assert config.model_type == "davit", "only DaViT is supported for now" self.vision_tower = DaViT.from_config(config=config) self._build_image_projection_layers(config) self.post_init() - + def _build_image_projection_layers(self, config): image_dim_out = config.dim_embed[-1] dim_projection = config.projection_dim - self.image_projection = nn.Parameter( - torch.empty(image_dim_out, dim_projection) - ) + self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection)) self.image_proj_norm = nn.LayerNorm(dim_projection) image_pos_embed_config = config.image_pos_embed - if image_pos_embed_config['type'] == 'learned_abs_2d': + if image_pos_embed_config["type"] == "learned_abs_2d": self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( - embedding_dim=image_dim_out, - num_pos=image_pos_embed_config['max_pos_embeddings'] + embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"] ) else: - raise NotImplementedError('Not implemented yet') + raise NotImplementedError("Not implemented yet") self.image_feature_source = config.image_feature_source # temporal embedding visual_temporal_embedding_config = config.visual_temporal_embedding - if visual_temporal_embedding_config['type'] == 'COSINE': + if visual_temporal_embedding_config["type"] == "COSINE": self.visual_temporal_embed = PositionalEmbeddingCosine1D( embed_dim=image_dim_out, - max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] + max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"], ) else: - raise NotImplementedError('Not implemented yet') + raise NotImplementedError("Not implemented yet") def forward(self, pixel_values): if len(pixel_values.shape) == 4: @@ -2481,37 +2498,39 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel): T = 1 x = self.vision_tower.forward_features_unpool(pixel_values) else: - raise ValueError(f'invalid image shape {pixel_values.shape}') - + raise ValueError(f"invalid image shape {pixel_values.shape}") + if self.image_pos_embed is not None: x = x.view(batch_size * T, -1, x.shape[-1]) num_tokens = x.shape[-2] - h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) - assert h * w == num_tokens, 'only support square feature maps for now' + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, "only support square feature maps for now" x = x.view(batch_size * T, h, w, x.shape[-1]) pos_embed = self.image_pos_embed(x) x = x + pos_embed - x = x.view(batch_size, T * h*w, x.shape[-1]) + x = x.view(batch_size, T * h * w, x.shape[-1]) if self.visual_temporal_embed is not None: - visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] + ) x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) x_feat_dict = {} spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) - x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) - x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] - x_feat_dict['last_frame'] = x + x_feat_dict["last_frame"] = x new_x = [] for _image_feature_source in self.image_feature_source: if _image_feature_source not in x_feat_dict: - raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) + raise ValueError(f"invalid image feature source: {_image_feature_source}") new_x.append(x_feat_dict[_image_feature_source]) x = torch.cat(new_x, dim=1) @@ -2519,23 +2538,25 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel): x = x @ self.image_projection x = self.image_proj_norm(x) - return x - @add_start_docstrings( """The FLORENCE2 model which consists of a vision backbone and a language model.""", FLORENCE2_START_DOCSTRING, ) class Florence2ForConditionalGeneration(Florence2PreTrainedModel): - _tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"] + _tied_weights_keys = [ + "language_model.encoder.embed_tokens.weight", + "language_model.decoder.embed_tokens.weight", + "language_model.lm_head.weight", + ] def __init__(self, config: Florence2Config): super().__init__(config) - assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now' + assert config.vision_config.model_type == "davit", "only DaViT is supported for now" self.vision_tower = DaViT.from_config(config=config.vision_config) - # remove unused layers + # remove unused layers del self.vision_tower.head del self.vision_tower.norms @@ -2549,34 +2570,31 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() - + def _build_image_projection_layers(self, config): image_dim_out = config.vision_config.dim_embed[-1] dim_projection = config.vision_config.projection_dim - self.image_projection = nn.Parameter( - torch.empty(image_dim_out, dim_projection) - ) + self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection)) self.image_proj_norm = nn.LayerNorm(dim_projection) image_pos_embed_config = config.vision_config.image_pos_embed - if image_pos_embed_config['type'] == 'learned_abs_2d': + if image_pos_embed_config["type"] == "learned_abs_2d": self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( - embedding_dim=image_dim_out, - num_pos=image_pos_embed_config['max_pos_embeddings'] + embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"] ) else: - raise NotImplementedError('Not implemented yet') + raise NotImplementedError("Not implemented yet") self.image_feature_source = config.vision_config.image_feature_source # temporal embedding visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding - if visual_temporal_embedding_config['type'] == 'COSINE': + if visual_temporal_embedding_config["type"] == "COSINE": self.visual_temporal_embed = PositionalEmbeddingCosine1D( embed_dim=image_dim_out, - max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] + max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"], ) else: - raise NotImplementedError('Not implemented yet') + raise NotImplementedError("Not implemented yet") def get_encoder(self): return self.language_model.get_encoder() @@ -2587,51 +2605,57 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): def get_input_embeddings(self): return self.language_model.get_input_embeddings() - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, **kwargs) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs) + def resize_token_embeddings( + self, new_num_tokens: int | None = None, pad_to_multiple_of=None, **kwargs + ) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings( + new_num_tokens, pad_to_multiple_of, **kwargs + ) # update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings self.config.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings return model_embeds - + def _encode_image(self, pixel_values): if len(pixel_values.shape) == 4: batch_size, C, H, W = pixel_values.shape T = 1 x = self.vision_tower.forward_features_unpool(pixel_values) else: - raise ValueError(f'invalid image shape {pixel_values.shape}') - + raise ValueError(f"invalid image shape {pixel_values.shape}") + if self.image_pos_embed is not None: x = x.view(batch_size * T, -1, x.shape[-1]) num_tokens = x.shape[-2] - h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) - assert h * w == num_tokens, 'only support square feature maps for now' + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, "only support square feature maps for now" x = x.view(batch_size * T, h, w, x.shape[-1]) pos_embed = self.image_pos_embed(x) x = x + pos_embed - x = x.view(batch_size, T * h*w, x.shape[-1]) + x = x.view(batch_size, T * h * w, x.shape[-1]) if self.visual_temporal_embed is not None: - visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0] + ) x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) x_feat_dict = {} spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) - x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) - x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] - x_feat_dict['last_frame'] = x + x_feat_dict["last_frame"] = x new_x = [] for _image_feature_source in self.image_feature_source: if _image_feature_source not in x_feat_dict: - raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) + raise ValueError(f"invalid image feature source: {_image_feature_source}") new_x.append(x_feat_dict[_image_feature_source]) x = torch.cat(new_x, dim=1) @@ -2639,11 +2663,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): x = x @ self.image_projection x = self.image_proj_norm(x) - return x + return x - def _merge_input_ids_with_image_features( - self, image_features, inputs_embeds - ): + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds): batch_size, image_token_length = image_features.size()[:-1] device = image_features.device image_attention_mask = torch.ones(batch_size, image_token_length, device=device) @@ -2665,29 +2687,28 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): return inputs_embeds, attention_mask - @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Florence2Seq2SeqLMOutput]: + attention_mask: torch.Tensor | None = None, + decoder_input_ids: torch.LongTensor | None = None, + decoder_attention_mask: torch.LongTensor | None = None, + head_mask: torch.Tensor | None = None, + decoder_head_mask: torch.Tensor | None = None, + cross_attn_head_mask: torch.Tensor | None = None, + encoder_outputs: list[torch.FloatTensor] | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + decoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | Florence2Seq2SeqLMOutput: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -2718,7 +2739,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "A green car parked in front of a yellow building." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -2733,7 +2756,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): if pixel_values is not None: # (batch_size, num_image_tokens, hidden_size) image_features = self._encode_image(pixel_values) - inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds) + inputs_embeds, attention_mask = self._merge_input_ids_with_image_features( + image_features, inputs_embeds + ) if inputs_embeds is not None: attention_mask = attention_mask.to(inputs_embeds.dtype) @@ -2772,17 +2797,10 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, - image_hidden_states=image_features + image_hidden_states=image_features, ) - def generate( - self, - input_ids, - inputs_embeds=None, - pixel_values=None, - **kwargs - ): - + def generate(self, input_ids, inputs_embeds=None, pixel_values=None, **kwargs): if inputs_embeds is None: # 1. Extra the input embeddings if input_ids is not None: @@ -2790,13 +2808,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): # 2. Merge text and images if pixel_values is not None: image_features = self._encode_image(pixel_values) - inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds) - - return self.language_model.generate( - input_ids=None, - inputs_embeds=inputs_embeds, - **kwargs - ) + inputs_embeds, attention_mask = self._merge_input_ids_with_image_features( + image_features, inputs_embeds + ) + + return self.language_model.generate(input_ids=None, inputs_embeds=inputs_embeds, **kwargs) def prepare_inputs_for_generation( self, @@ -2824,7 +2840,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): remove_prefix_length = decoder_input_ids.shape[1] - 1 decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - + return { "input_ids": None, # encoder_outputs is defined. input_ids not needed "encoder_outputs": encoder_outputs, @@ -2838,9 +2854,9 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel): "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self.language_model.shift_tokens_right(labels) def _reorder_cache(self, *args, **kwargs): - return self.language_model._reorder_cache(*args, **kwargs) \ No newline at end of file + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index b2a9f1096..b973e713c 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -19,7 +19,6 @@ from __future__ import annotations from collections import deque -from typing import Dict import torch import torch.nn.functional as F # noqa: N812 @@ -87,7 +86,7 @@ class XVLAModel(nn.Module): input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_mask: torch.Tensor, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: """ Encode text and multi-view images via Florence2 encoder. """ @@ -129,13 +128,14 @@ class XVLAModel(nn.Module): domain_id: torch.LongTensor, proprio: torch.Tensor, action: torch.Tensor, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: enc = self.forward_vlm(input_ids, image_input, image_mask) batch_size = input_ids.shape[0] - t = (torch.rand(1, device=input_ids.device) + torch.arange(batch_size, device=input_ids.device) / batch_size) % ( - 1 - 1e-5 - ) + t = ( + torch.rand(1, device=input_ids.device) + + torch.arange(batch_size, device=input_ids.device) / batch_size + ) % (1 - 1e-5) action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1) proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy) @@ -350,7 +350,9 @@ def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float ratio = max(current_width / width, current_height / height) resized_height = int(current_height / ratio) resized_width = int(current_width / ratio) - resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) pad_height = max(0, height - resized_height) pad_width = max(0, width - resized_width) diff --git a/src/lerobot/policies/xvla/processing_xvla.py b/src/lerobot/policies/xvla/processing_xvla.py index 198b20d71..b68d1a38c 100644 --- a/src/lerobot/policies/xvla/processing_xvla.py +++ b/src/lerobot/policies/xvla/processing_xvla.py @@ -14,7 +14,7 @@ # limitations under the License. # ------------------------------------------------------------------------------ -from typing import Any, Dict, List, Optional, Union +from typing import Any import torch from transformers import ProcessorMixin @@ -88,7 +88,7 @@ class XVLAProcessor(ProcessorMixin): super().__init__(image_processor, tokenizer) # ================== LANGUAGE ENCODING ================== - def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]: + def encode_language(self, language_instruction: str | list[str]) -> dict[str, torch.Tensor]: """ Tokenize one or more language instructions. @@ -117,11 +117,7 @@ class XVLAProcessor(ProcessorMixin): return {"input_ids": inputs["input_ids"]} # ================== IMAGE ENCODING ================== - def encode_image( - self, - images: Union[List, List[List]], - **kwargs - ) -> Dict[str, torch.Tensor]: + def encode_image(self, images: list | list[list], **kwargs) -> dict[str, torch.Tensor]: """ Preprocess one or more sets of multi-view images. @@ -157,8 +153,7 @@ class XVLAProcessor(ProcessorMixin): # Pad to self.num_views if V_exist < self.num_views: processed = torch.cat( - [processed, - processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])], + [processed, processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])], dim=0, ) @@ -177,10 +172,10 @@ class XVLAProcessor(ProcessorMixin): # ================== COMBINED CALL ================== def __call__( self, - images: Optional[Union[List, List[List]]] = None, - language_instruction: Optional[Union[str, List[str]]] = None, - **kwargs - ) -> Dict[str, torch.Tensor]: + images: list | list[list] | None = None, + language_instruction: str | list[str] | None = None, + **kwargs, + ) -> dict[str, torch.Tensor]: """ Combine image and text encoding into a unified multimodal input. @@ -202,7 +197,7 @@ class XVLAProcessor(ProcessorMixin): "image_mask": [B, num_views], optional } """ - outputs: Dict[str, Any] = {} + outputs: dict[str, Any] = {} # Encode language if provided if language_instruction is not None: @@ -243,7 +238,9 @@ def make_xvla_pre_post_processors( padding_side=config.tokenizer_padding_side, ), DeviceProcessorStep(device=config.device), - NormalizerProcessorStep(features=features, norm_map=config.normalization_mapping, stats=dataset_stats), + NormalizerProcessorStep( + features=features, norm_map=config.normalization_mapping, stats=dataset_stats + ), ] output_steps = [ UnnormalizerProcessorStep( diff --git a/src/lerobot/policies/xvla/transformer.py b/src/lerobot/policies/xvla/transformer.py index a6bf36518..3e43b446e 100644 --- a/src/lerobot/policies/xvla/transformer.py +++ b/src/lerobot/policies/xvla/transformer.py @@ -17,17 +17,18 @@ from __future__ import annotations import math +from collections.abc import Iterable from functools import partial -from typing import Final, Iterable, Tuple +from typing import Final import torch import torch.nn as nn import torch.nn.functional as F - # ------------------------------- Small utils ---------------------------------- -def _to_2tuple(x) -> Tuple: + +def _to_2tuple(x) -> tuple: """Minimal replacement for timm.layers.to_2tuple.""" if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): t = tuple(x) @@ -42,6 +43,7 @@ def _has_sdp_attention() -> bool: # ---------------------------------- MLP -------------------------------------- + class Mlp(nn.Module): """ MLP used in ViT-style blocks. @@ -55,8 +57,8 @@ class Mlp(nn.Module): hidden_features: int | None = None, out_features: int | None = None, norm_layer: type[nn.Module] | None = None, - bias: bool | Tuple[bool, bool] = True, - drop: float | Tuple[float, float] = 0.0, + bias: bool | tuple[bool, bool] = True, + drop: float | tuple[float, float] = 0.0, use_conv: bool = False, ) -> None: super().__init__() @@ -86,6 +88,7 @@ class Mlp(nn.Module): # -------------------------------- Attention ---------------------------------- + class Attention(nn.Module): """ Multi-Head Self-Attention with optional fused SDPA fallback. @@ -110,7 +113,7 @@ class Attention(nn.Module): assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 self.fused_attn = _has_sdp_attention() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -143,17 +146,19 @@ class Attention(nn.Module): if self.fused_attn: x = F.scaled_dot_product_attention( - q, k, v, + q, + k, + v, dropout_p=self.attn_drop.p if self.training else 0.0, ) # [B, H, T, Dh] else: q = q * self.scale - attn = q @ k.transpose(-2, -1) # [B, H, T, T] + attn = q @ k.transpose(-2, -1) # [B, H, T, T] attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = attn @ v # [B, H, T, Dh] + x = attn @ v # [B, H, T, Dh] - x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C] + x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C] x = self.proj(x) x = self.proj_drop(x) return x @@ -161,6 +166,7 @@ class Attention(nn.Module): # ------------------------------- Utilities ----------------------------------- + def basic_init(module: nn.Module) -> None: """ Apply a basic initialization scheme to Linear layers. @@ -194,9 +200,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc """ half = dim // 2 freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) - / half + -math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half ) args = t[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -207,6 +211,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc # ------------------------------- Core Layers ---------------------------------- + class DomainAwareLinear(nn.Module): """ Linear layer with domain-conditioned parameters (per-sample). @@ -283,6 +288,7 @@ class TransformerBlock(nn.Module): # --------------------------- Main Model --------------------------------------- + class SoftPromptedTransformer(nn.Module): """ Multi-modal, domain-aware Transformer with optional soft prompts. @@ -318,7 +324,9 @@ class SoftPromptedTransformer(nn.Module): if use_hetero_proj: self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains) - self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains) + self.aux_visual_proj = DomainAwareLinear( + multi_modal_input_size, hidden_size, num_domains=num_domains + ) else: self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size) self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size) @@ -367,16 +375,20 @@ class SoftPromptedTransformer(nn.Module): B, num_actions = action_with_noise.shape[:2] # Encode (action + proprio + time) → tokens - time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time] + time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time] time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time) proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1]) action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1) - x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H] + x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H] # Project visual streams and concatenate if self.use_hetero_proj: x = torch.cat( - [x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)], + [ + x, + self.vlm_proj(vlm_features, domain_id), + self.aux_visual_proj(aux_visual_inputs, domain_id), + ], dim=1, ) else: @@ -385,9 +397,7 @@ class SoftPromptedTransformer(nn.Module): # Add positional embeddings (truncate if needed) seq_len = x.shape[1] if seq_len > self.pos_emb.shape[1]: - raise ValueError( - f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}." - ) + raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.") x = x + self.pos_emb[:, :seq_len, :] # Append soft prompts @@ -400,4 +410,4 @@ class SoftPromptedTransformer(nn.Module): x = block(x) # Decode only the action segment - return self.action_decoder(self.norm(x[:, :num_actions]), domain_id) \ No newline at end of file + return self.action_decoder(self.norm(x[:, :num_actions]), domain_id) diff --git a/test_xvla.py b/test_xvla.py index fe0888d1a..5cf8817f3 100644 --- a/test_xvla.py +++ b/test_xvla.py @@ -1,6 +1,5 @@ -from lerobot.policies.factory import make_policy -from lerobot.policies.factory import make_policy_config from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.policies.factory import make_policy, make_policy_config cfg = make_policy_config("xvla") @@ -8,4 +7,4 @@ dataset_id = "lerobot/svla_so101_pickplace" # This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets dataset_metadata = LeRobotDatasetMetadata(dataset_id) policy = make_policy(cfg=cfg, ds_meta=dataset_metadata) -print(policy) \ No newline at end of file +print(policy) diff --git a/train.sh b/train.sh index 1a9d8371a..4683936ae 100644 --- a/train.sh +++ b/train.sh @@ -4,6 +4,7 @@ lerobot-train \ --output_dir=outputs/train/act_your_dataset \ --job_name=xvla_so101_pickplace \ --policy.device=cuda \ + --policy.action_mode=franka_joint7 \ --wandb.enable=true \ --policy.repo_id=jadechoghari/xvla_policy \ --steps=10000