diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 661ecc566..3ff795be0 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -127,4 +127,44 @@ LeRobot uses MuJoCo for simulation. You need to set the rendering backend before ## Reproducing π₀ and π₀.₅ results -We can also reproduce the results of π₀ and π₀.₅ on the Libero benchmark by using the finetuned libero models. +We reproduce the results of π₀ and π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base models (`pi0_libero` and `pi05_libero`) and finetune them for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero). + +The finetuned models can be found here: + +- **π₀ LIBERO**: [pepijn223/pi0_libero_fp32](https://huggingface.co/pepijn223/pi0_libero_fp32) +- **π₀.₅ LIBERO**: [pepijn223/pi05_libero_fp32](https://huggingface.co/pepijn223/pi05_libero_fp32) + +We then evaluate the finetuned models using the LeRobot LIBERO implementation, by running the following command: + +```bash +python src/lerobot/scripts/eval.py \ + --output_dir=/logs/ \ + --env.type=libero \ + --env.task=libero_spatial,libero_object,libero_goal,libero_10 \ + --eval.batch_size=1 \ + --eval.n_episodes=5 \ + --policy.path=pepijn223/pi0_libero_fp32 \ + --env.multitask_eval=true \ + --output_dir=./eval_logs/ \ + --policy.compile_model=false \ + --policy.gradient_checkpointing=false \ + --env.max_parallel_tasks=1 +``` + +**Note:** We set `n_action_steps=10` which normally has a default value of 50. + +### Results + +We obtain the following results on the LIBERO benchmark: + +| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average | +| -------- | -------------- | ------------- | ----------- | --------- | ------- | +| **π₀** | x | x | x | x | **x** | +| **π₀.₅** | 98.0 | 99.0 | 97.0 | 93.0 | **x** | + +These results are consistent with the original results reported by Physical Intelligence: + +| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average | +| -------- | -------------- | ------------- | ----------- | --------- | --------- | +| **π₀** | 96.8 | 98.8 | 95.8 | 85.2 | **94.15** | +| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** | diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 2d8887dc7..b737bfd91 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -21,44 +21,13 @@ As described by Physical Intelligence, while AI has achieved remarkable success ## Installation Requirements -⚠️ **Warning**: This policy requires patching the Hugging Face `transformers` library. - -### Prerequisites - -1. Ensure you have the exact version installed: +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install Pi0 dependencies by running: ```bash - pip show transformers + pip install -e ".[pi]" ``` - It must be version **4.53.2**. - -2. Apply the custom patches: - ```bash - cp -r ./src/lerobot/policies/pi0/transformers_replace/* \ - $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))") - ``` - -### What the patches do: - -- Support the **AdaRMS optimizer** -- Correctly control the precision of activations -- Allow the KV cache to be used without updates - -**Important Notes:** - -- This permanently modifies your `transformers` installation -- The changes survive reinstalls unless you explicitly remove the patched files or recreate the environment - -### Restoring Clean State - -To undo the patches and restore a clean state: - -```bash -pip uninstall transformers -pip install transformers==4.53.2 -``` - ## Training Data and Capabilities π₀ is trained on the largest robot interaction dataset to date, combining three key data sources: diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index 01320dc88..ec718d25d 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -29,22 +29,11 @@ This diverse training mixture creates a "curriculum" that enables generalization ## Installation Requirements -⚠️ **Warning**: This policy requires patching the Hugging Face `transformers` library. - -### Prerequisites - -1. Ensure you have the exact version installed: +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install Pi0.5 dependencies by running: ```bash - pip show transformers - ``` - - It must be version **4.53.2**. - -2. Apply the custom patches: - ```bash - cp -r ./src/lerobot/policies/pi05/transformers_replace/* \ - $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))") + pip install -e ".[pi]" ``` ### What the patches do: diff --git a/pyproject.toml b/pyproject.toml index c5b3d0185..54b23ea22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,7 @@ phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # ] # TODO: Currently not supported # Policies -pi = ["lerobot[transformers-dep]"] +pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_pi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 7be238889..58f1bf893 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -33,7 +33,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi from lerobot.configs.policies import PreTrainedConfig from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig +from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -492,7 +492,7 @@ class PaliGemmaWithExpertModel( class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI0 PyTorch model.""" - def __init__(self, config: PI0OpenPIConfig): + def __init__(self, config: PI0Config): super().__init__() self.config = config @@ -523,10 +523,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` # Also compile the main forward pass used during training self.forward = torch.compile(self.forward, mode=config.compile_mode) - msg = """transformers_replace is not installed correctly. -Please install it with `pip install transformers==4.53.2` -and `cp -r ./src/lerobot/policies/pi0/transformers_replace/* \ -$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" try: from transformers.models.siglip import check @@ -842,15 +839,15 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ return self.action_out_proj(suffix_out) -class PI0OpenPIPolicy(PreTrainedPolicy): +class PI0Policy(PreTrainedPolicy): """PI0 OpenPI Policy for LeRobot.""" - config_class = PI0OpenPIConfig + config_class = PI0Config name = "pi0" def __init__( # see lerobot pi0 `__init__` self, - config: PI0OpenPIConfig, + config: PI0Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ diff --git a/src/lerobot/policies/pi0/transformers_replace/models/gemma/configuration_gemma.py b/src/lerobot/policies/pi0/transformers_replace/models/gemma/configuration_gemma.py deleted file mode 100644 index 72eb2a36c..000000000 --- a/src/lerobot/policies/pi0/transformers_replace/models/gemma/configuration_gemma.py +++ /dev/null @@ -1,173 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...configuration_utils import PretrainedConfig - - -class GemmaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma-7B. - e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 256000): - Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`GemmaModel`] - hidden_size (`int`, *optional*, defaults to 3072): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 24576): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 28): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 16): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details, check out [this - paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The legacy activation function. It is overwritten by the `hidden_activation`. - hidden_activation (`str` or `function`, *optional*): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. - max_position_embeddings (`int`, *optional*, defaults to 8192): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - use_adarms (`bool`, *optional*, defaults to `False`): - Whether to use ADARMS. - adarms_cond_dim (`int`, *optional*, defaults to `None`): - The dimension of the ADARMS condition. - ```python - >>> from transformers import GemmaModel, GemmaConfig - >>> # Initializing a Gemma gemma-7b style configuration - >>> configuration = GemmaConfig() - >>> # Initializing a model from the gemma-7b style configuration - >>> model = GemmaModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "gemma" - keys_to_ignore_at_inference = ["past_key_values"] - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=256000, - hidden_size=3072, - intermediate_size=24576, - num_hidden_layers=28, - num_attention_heads=16, - num_key_value_heads=16, - head_dim=256, - hidden_act="gelu_pytorch_tanh", - hidden_activation=None, - max_position_embeddings=8192, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - eos_token_id=1, - bos_token_id=2, - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - use_adarms: bool = False, - adarms_cond_dim: int | None = None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.head_dim = head_dim - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.hidden_activation = hidden_activation - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.use_adarms = use_adarms - self.adarms_cond_dim = adarms_cond_dim - - # Set default for adarms_cond_dim if use_adarms is True - if self.use_adarms and self.adarms_cond_dim is None: - self.adarms_cond_dim = self.hidden_size - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -__all__ = ["GemmaConfig"] diff --git a/src/lerobot/policies/pi0/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0/transformers_replace/models/gemma/modeling_gemma.py deleted file mode 100644 index 05066afc5..000000000 --- a/src/lerobot/policies/pi0/transformers_replace/models/gemma/modeling_gemma.py +++ /dev/null @@ -1,895 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Callable - -import torch -from torch import nn - -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...generation import GenerationMixin -from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging -from .configuration_gemma import GemmaConfig - -logger = logging.get_logger(__name__) - - -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - -class GemmaRMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): - super().__init__() - self.eps = eps - self.dim = dim - self.cond_dim = cond_dim - - # Dense layer for adaptive normalization (if cond_dim is provided) - if cond_dim is not None: - # self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16) - self.dense = nn.Linear(cond_dim, dim * 3, bias=True) - # Initialize with zeros (matches source implementation) - nn.init.zeros_(self.dense.weight) - else: - self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16)) - self.dense = None - - def _norm(self, x): - # Compute variance in float32 (like the source implementation) - var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) - # Compute normalization in float32 - normed_inputs = x * torch.rsqrt(var + self.eps) - return normed_inputs - - def forward(self, x, cond=None): - dtype = x.dtype # original dtype, could be half-precision - normed_inputs = self._norm(x) - - if cond is None or self.dense is None: - # regular RMSNorm - # scale by learned parameter in float32 (matches source implementation) - normed_inputs = normed_inputs * (1.0 + self.weight.float()) - return normed_inputs.to(dtype), None # return in original dtype with None gate - - # adaptive RMSNorm (if cond is provided and dense layer exists) - if cond.shape[-1] != self.cond_dim: - raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}") - - # self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32) - modulation = self.dense(cond) - # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features] - if len(x.shape) == 3: # [batch, seq, features] - modulation = modulation.unsqueeze(1) - - scale, shift, gate = torch.chunk(modulation, 3, dim=-1) - - # Apply adaptive normalization: use model weight dtype to ensure compatibility - # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16) - # scale = scale.to(model_dtype) - # shift = shift.to(model_dtype) - # gate = gate.to(model_dtype) - # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype - - normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32) - - return normed_inputs.to(dtype), gate.to(dtype) - - def extra_repr(self): - repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}" - if self.dense is not None: - repr_str += f", adaptive=True, cond_dim={self.cond_dim}" - return repr_str - - -class GemmaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class GemmaRotaryEmbedding(nn.Module): - def __init__(self, config: GemmaConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - ) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def _gated_residual(x, y, gate): - """ - Applies gated residual connection with optional gate parameter. - - Args: - x: Input tensor (residual) - y: Output tensor to be added - gate: Optional gate tensor to modulate the addition - - Returns: - x + y if gate is None, otherwise x + y * gate - """ - if x is None and y is None: - return None - if x is None or y is None: - return x if x is not None else y - if gate is None: - return x + y - return x + y * gate - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class GemmaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_value: Cache | None = None, - cache_position: torch.LongTensor | None = None, - use_cache: bool = False, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - # Use cache if provided - if past_key_value is not None: - if use_cache: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - else: - key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2) - value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class GemmaDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) - - self.mlp = GemmaMLP(config) - cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) - self.post_attention_layernorm = GemmaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: None - | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC - adarms_cond: torch.Tensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - residual = hidden_states - hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = _gated_residual(residual, hidden_states, gate) - - # Fully Connected - residual = hidden_states - hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond) - hidden_states = self.mlp(hidden_states) - hidden_states = _gated_residual(residual, hidden_states, gate) - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -@safe_auto_docstring -class GemmaPreTrainedModel(PreTrainedModel): - config_class = GemmaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["GemmaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_3 = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GemmaRMSNorm): - if hasattr(module, "weight"): - module.weight.data.fill_(1.0) - - -@safe_auto_docstring -class GemmaModel(GemmaPreTrainedModel): - def __init__(self, config: GemmaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - - cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) - self.rotary_emb = GemmaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - cache_position: torch.LongTensor | None = None, - adarms_cond: torch.Tensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: - """ - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - # embed positions - hidden_states = inputs_embeds - # Convert to bfloat16 if the first layer uses bfloat16 - if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.bfloat16) - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # normalized - # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - _normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - # hidden_states = hidden_states * normalizer - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - adarms_cond=adarms_cond, - **kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states, _ = self.norm(hidden_states, adarms_cond) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -@safe_auto_docstring -class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - cache_position: torch.LongTensor | None = None, - logits_to_keep: int | torch.Tensor = 0, - adarms_cond: torch.Tensor | None = None, - **kwargs: Unpack[KwargsForCausalLM], - ) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - - Example: - - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - adarms_cond=adarms_cond, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@safe_auto_docstring( - custom_intro=""" - The Gemma Model transformer with a sequence classification head on top (linear layer). - - [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """ -) -class GemmaForSequenceClassification(GemmaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = GemmaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - adarms_cond: torch.Tensor | None = None, - ) -> SequenceClassifierOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - - transformer_outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - adarms_cond=adarms_cond, - ) - hidden_states = transformer_outputs.last_hidden_state - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config - ) - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@safe_auto_docstring -class GemmaForTokenClassification(GemmaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = GemmaModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - adarms_cond: torch.Tensor | None = None, - ) -> TokenClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - - outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - adarms_cond=adarms_cond, - ) - sequence_output = outputs.last_hidden_state - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "GemmaModel", - "GemmaForCausalLM", - "GemmaForSequenceClassification", - "GemmaForTokenClassification", - "GemmaPreTrainedModel", -] diff --git a/src/lerobot/policies/pi0/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi0/transformers_replace/models/paligemma/modeling_paligemma.py deleted file mode 100644 index b2a36b5ca..000000000 --- a/src/lerobot/policies/pi0/transformers_replace/models/paligemma/modeling_paligemma.py +++ /dev/null @@ -1,666 +0,0 @@ -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch PaliGemmamodel.""" - -from dataclasses import dataclass - -import torch -import torch.utils.checkpoint -from torch import nn - -from ...cache_utils import Cache, HybridCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...processing_utils import Unpack -from ...utils import ( - LossKwargs, - ModelOutput, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - logging, -) -from ..auto import AutoModel -from .configuration_paligemma import PaliGemmaConfig - -logger = logging.get_logger(__name__) - - -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for Paligemma outputs, with hidden states and attentions. - """ -) -class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. - """ - - image_hidden_states: torch.FloatTensor | None = None - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for PaliGemma causal language model (or autoregressive) outputs. - """ -) -class PaliGemmaCausalLMOutputWithPast(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - past_key_values: list[torch.FloatTensor] | Cache | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - image_hidden_states: torch.FloatTensor | None = None - - -class PaliGemmaMultiModalProjector(nn.Module): - def __init__(self, config: PaliGemmaConfig): - super().__init__() - self.linear = nn.Linear( - config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True - ) - - def forward(self, image_features): - hidden_states = self.linear(image_features) - - return hidden_states - - -@safe_auto_docstring -class PaliGemmaPreTrainedModel(PreTrainedModel): - config_class = PaliGemmaConfig - base_model_prefix = "" - supports_gradient_checkpointing = True - _no_split_modules = ["PaliGemmaMultiModalProjector"] - _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_attention_backend = True - - def _init_weights(self, module): - # important: this ported version of PaliGemmaisn't meant for training from scratch - only - # inference and fine-tuning - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - - -@safe_auto_docstring( - custom_intro=""" - The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., - """ -) -class PaliGemmaModel(PaliGemmaPreTrainedModel): - _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch - accepts_loss_kwargs = False - - def __init__(self, config: PaliGemmaConfig): - super().__init__(config) - self.vision_tower = AutoModel.from_config(config=config.vision_config) - self.multi_modal_projector = PaliGemmaMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - - language_model = AutoModel.from_config(config=config.text_config) - self.language_model = language_model - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.post_init() - - # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - - def _update_causal_mask( - self, - attention_mask, - token_type_ids=None, - past_key_values=None, - cache_position=None, - input_tensor=None, - is_training: bool | None = None, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - is_training = is_training if is_training is not None else self.training - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - if input_tensor is None: - input_tensor = attention_mask - - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=self.dtype, - device=cache_position.device, - ) - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - if is_training: - causal_mask = torch.triu(causal_mask, diagonal=1) - else: - causal_mask[:, :sequence_length] = 0.0 - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # First unmask prefix tokens during training - if is_training: - if token_type_ids is None: - raise ValueError("Token type ids must be provided during training") - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - def get_image_features(self, pixel_values: torch.FloatTensor): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) - The tensors corresponding to the input images. - Returns: - image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). - """ - image_outputs = self.vision_tower(pixel_values) - selected_image_feature = image_outputs.last_hidden_state - image_features = self.multi_modal_projector(selected_image_feature) - return image_features - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | Cache | None = None, - token_type_ids: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple | PaligemmaModelOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration - - >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") - >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") - - >>> prompt = "Where is the cat standing?" - >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs,) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Where is the cat standing?\nsnow" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - is_training = token_type_ids is not None and labels is not None - - # Replace image id with PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_id >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - cache_position=cache_position, - **kwargs, - ) - - return PaligemmaModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -@safe_auto_docstring( - custom_intro=""" - The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., - """ -) -class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "^language_model.model": "model.language_model", - "^vision_tower": "model.vision_tower", - "^multi_modal_projector": "model.multi_modal_projector", - "^language_model.lm_head": "lm_head", - } - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: PaliGemmaConfig): - super().__init__(config) - self.model = PaliGemmaModel(config) - self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - - def get_image_features(self, pixel_values): - return self.model.get_image_features(pixel_values) - - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | Cache | None = None, - token_type_ids: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> tuple | PaliGemmaCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration - - >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") - >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") - - >>> prompt = "Where is the cat standing?" - >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs,) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Where is the cat standing?\nsnow" - ```""" - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids=input_ids, - pixel_values=pixel_values, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs - ) - - return PaliGemmaCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - logits_to_keep=None, - labels=None, - **kwargs, - ): - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # position_ids in Paligemma are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self.model._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - @staticmethod - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=cache_position.device, - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] diff --git a/src/lerobot/policies/pi0/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi0/transformers_replace/models/siglip/check.py deleted file mode 100644 index d899dc1b9..000000000 --- a/src/lerobot/policies/pi0/transformers_replace/models/siglip/check.py +++ /dev/null @@ -1,5 +0,0 @@ -import transformers - - -def check_whether_transformers_replace_is_installed_correctly(): - return transformers.__version__ == "4.53.2" diff --git a/src/lerobot/policies/pi0/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi0/transformers_replace/models/siglip/modeling_siglip.py deleted file mode 100644 index 0fc0bba0f..000000000 --- a/src/lerobot/policies/pi0/transformers_replace/models/siglip/modeling_siglip.py +++ /dev/null @@ -1,1283 +0,0 @@ -# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Siglip model.""" - -import math -import warnings -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -import numpy as np -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.nn.init import _calculate_fan_in_and_fan_out - -from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int -from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig - -logger = logging.get_logger(__name__) - - -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) # noqa: E741 - u = norm_cdf((b - mean) / std) # noqa: E741 - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741 - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - """ -) -# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip -class SiglipVisionModelOutput(ModelOutput): - r""" - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - """ - - image_embeds: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for text model's outputs that also contains a pooling of the last hidden states. - """ -) -# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip -class SiglipTextModelOutput(ModelOutput): - r""" - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The text embeddings obtained by applying the projection layer to the pooler_output. - """ - - text_embeds: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - - -@dataclass -@safe_auto_docstring -# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip -class SiglipOutput(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. - text_model_output (`BaseModelOutputWithPooling`): - The output of the [`SiglipTextModel`]. - vision_model_output (`BaseModelOutputWithPooling`): - The output of the [`SiglipVisionModel`]. - """ - - loss: torch.FloatTensor | None = None - logits_per_image: torch.FloatTensor | None = None - logits_per_text: torch.FloatTensor | None = None - text_embeds: torch.FloatTensor | None = None - image_embeds: torch.FloatTensor | None = None - text_model_output: BaseModelOutputWithPooling = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class SiglipVisionEmbeddings(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - self.register_buffer( - "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False - ) - - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution - images. This method is also adapted to support torch.jit tracing and no class embeddings. - - Adapted from: - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 - """ - - num_patches = embeddings.shape[1] - num_positions = self.position_embedding.weight.shape[0] - - # always interpolate when tracing to ensure the exported model works for dynamic input shapes - if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embedding(self.position_ids) - - patch_pos_embed = self.position_embedding.weight.unsqueeze(0) - - dim = embeddings.shape[-1] - - new_height = height // self.patch_size - new_width = width // self.patch_size - - sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - size=(new_height, new_width), - mode="bicubic", - align_corners=False, - ) - - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return patch_pos_embed - - def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: - _, _, height, width = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype) - ) # shape = [*, width, grid, grid] - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip -class SiglipTextEmbeddings(nn.Module): - def __init__(self, config: SiglipTextConfig): - super().__init__() - embed_dim = config.hidden_size - - self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) - self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) - - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer( - "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False - ) - - def forward( - self, - input_ids: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - max_position_embedding = self.position_embedding.weight.shape[0] - - if seq_length > max_position_embedding: - raise ValueError( - f"Sequence length must be less than max_position_embeddings (got `sequence length`: " - f"{seq_length} and max_position_embeddings: {max_position_embedding}" - ) - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs, -): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class SiglipAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - self.is_causal = False - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - output_attentions: bool | None = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - batch_size, seq_length, embed_dim = hidden_states.shape - - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) - - queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - queries, - keys, - values, - attention_mask, - is_causal=self.is_causal, - scaling=self.scale, - dropout=0.0 if not self.training else self.dropout, - ) - - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip -class SiglipMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class SiglipEncoderLayer(GradientCheckpointingLayer): - def __init__(self, config: SiglipVisionConfig | SiglipTextConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.self_attn = SiglipAttention(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_attentions: bool | None = False, - ) -> tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -@safe_auto_docstring -class SiglipPreTrainedModel(PreTrainedModel): - config_class = SiglipConfig - base_model_prefix = "siglip" - supports_gradient_checkpointing = True - - _no_split_modules = [ - "SiglipTextEmbeddings", - "SiglipEncoderLayer", - "SiglipVisionEmbeddings", - "SiglipEncoderLayer", - "SiglipMultiheadAttentionPoolingHead", - ] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_attention_backend = True - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, SiglipVisionEmbeddings): - width = ( - self.config.vision_config.hidden_size - if isinstance(self.config, SiglipConfig) - else self.config.hidden_size - ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, SiglipAttention): - nn.init.xavier_uniform_(module.q_proj.weight) - nn.init.xavier_uniform_(module.k_proj.weight) - nn.init.xavier_uniform_(module.v_proj.weight) - nn.init.xavier_uniform_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, SiglipMLP): - nn.init.xavier_uniform_(module.fc1.weight) - nn.init.xavier_uniform_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, SiglipMultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe.data) - nn.init.xavier_uniform_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) - elif isinstance(module, SiglipModel): - logit_scale_init = torch.log(torch.tensor(1.0)) - module.logit_scale.data.fill_(logit_scale_init) - module.logit_bias.data.zero_() - elif isinstance(module, SiglipForImageClassification): - nn.init.normal_( - module.classifier.weight, - std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, - ) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip -class SiglipEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`SiglipEncoderLayer`]. - - Args: - config: SiglipConfig - """ - - def __init__(self, config: SiglipConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - # Ignore copy - @can_return_tuple - def forward( - self, - inputs_embeds, - attention_mask: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> BaseModelOutput: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, - ) - - -class SiglipTextTransformer(nn.Module): - def __init__(self, config: SiglipTextConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - self.embeddings = SiglipTextEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - - self.head = nn.Linear(embed_dim, config.projection_size) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> BaseModelOutputWithPooling: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if input_ids is None: - raise ValueError("You have to specify input_ids") - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) - - # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. - # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: - # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - - encoder_outputs: BaseModelOutput = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - last_hidden_state = encoder_outputs.last_hidden_state - last_hidden_state = self.final_layer_norm(last_hidden_state) - - # Assuming "sticky" EOS tokenization, last token is always EOS. - pooled_output = last_hidden_state[:, -1, :] - pooled_output = self.head(pooled_output) - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -@safe_auto_docstring( - custom_intro=""" - The text model from SigLIP without any head or projection on top. - """ -) -class SiglipTextModel(SiglipPreTrainedModel): - config_class = SiglipTextConfig - - def __init__(self, config: SiglipTextConfig): - super().__init__(config) - self.text_model = SiglipTextTransformer(config) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, value): - self.text_model.embeddings.token_embedding = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> BaseModelOutputWithPooling: - r""" - Examples: - - ```python - >>> from transformers import AutoTokenizer, SiglipTextModel - - >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled (EOS token) states - ```""" - - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - -class SiglipVisionTransformer(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - - self.embeddings = SiglipVisionEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head - if self.use_head: - self.head = SiglipMultiheadAttentionPoolingHead(config) - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - pixel_values, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = False, - ) -> BaseModelOutputWithPooling: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - # Convert to bfloat16 if the encoder uses bfloat16 - if ( - len(self.encoder.layers) > 0 - and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 - ): - hidden_states = hidden_states.to(torch.bfloat16) - - encoder_outputs: BaseModelOutput = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - last_hidden_state = encoder_outputs.last_hidden_state - last_hidden_state = self.post_layernorm(last_hidden_state) - - pooler_output = self.head(last_hidden_state) if self.use_head else None - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooler_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class SiglipMultiheadAttentionPoolingHead(nn.Module): - """Multihead Attention Pooling.""" - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - - self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.attention = torch.nn.MultiheadAttention( - config.hidden_size, config.num_attention_heads, batch_first=True - ) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - - def forward(self, hidden_state): - batch_size = hidden_state.shape[0] - probe = self.probe.repeat(batch_size, 1, 1) - - hidden_state = self.attention(probe, hidden_state, hidden_state)[0] - - residual = hidden_state - hidden_state = self.layernorm(hidden_state) - hidden_state = residual + self.mlp(hidden_state) - - return hidden_state[:, 0] - - -@safe_auto_docstring( - custom_intro=""" - The vision model from SigLIP without any head or projection on top. - """ -) -class SiglipVisionModel(SiglipPreTrainedModel): - config_class = SiglipVisionConfig - main_input_name = "pixel_values" - - def __init__(self, config: SiglipVisionConfig): - super().__init__(config) - - self.vision_model = SiglipVisionTransformer(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - pixel_values, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> BaseModelOutputWithPooling: - r""" - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, SiglipVisionModel - - >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled features - ```""" - - return self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - -@safe_auto_docstring -class SiglipModel(SiglipPreTrainedModel): - config_class = SiglipConfig - - def __init__(self, config: SiglipConfig): - super().__init__(config) - - if not isinstance(config.text_config, SiglipTextConfig): - raise TypeError( - "config.text_config is expected to be of type SiglipTextConfig but is of type" - f" {type(config.text_config)}." - ) - - if not isinstance(config.vision_config, SiglipVisionConfig): - raise TypeError( - "config.vision_config is expected to be of type SiglipVisionConfig but is of type" - f" {type(config.vision_config)}." - ) - - text_config = config.text_config - vision_config = config.vision_config - - # First, initialize the text and vision models with proper attention implementation - text_model = SiglipTextModel._from_config(text_config) - vision_model = SiglipVisionModel._from_config(vision_config) - - # Second, get the text and vision submodules (for backward compatibility) - self.text_model = text_model.text_model - self.vision_model = vision_model.vision_model - - self.logit_scale = nn.Parameter(torch.randn(1)) - self.logit_bias = nn.Parameter(torch.randn(1)) - - # Initialize weights and apply final processing - self.post_init() - - @safe_auto_docstring - def get_text_features( - self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> torch.FloatTensor: - r""" - Returns: - text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by - applying the projection layer to the pooled output of [`SiglipTextModel`]. - - Examples: - - ```python - >>> from transformers import AutoTokenizer, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") - >>> with torch.no_grad(): - ... text_features = model.get_text_features(**inputs) - ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - pooled_output = text_outputs.pooler_output - - return pooled_output - - @safe_auto_docstring - def get_image_features( - self, - pixel_values: torch.FloatTensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> torch.FloatTensor: - r""" - Returns: - image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by - applying the projection layer to the pooled output of [`SiglipVisionModel`]. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="pt") - - >>> with torch.no_grad(): - ... image_features = model.get_image_features(**inputs) - ```""" - # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - pooled_output = vision_outputs.pooler_output - - return pooled_output - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> SiglipOutput: - r""" - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] - >>> # important: we pass `padding=max_length` since the model was trained with this - >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") - - >>> with torch.no_grad(): - ... outputs = model(**inputs) - - >>> logits_per_image = outputs.logits_per_image - >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities - >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") - 31.9% that image 0 is 'a photo of 2 cats' - ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - image_embeds = vision_outputs.pooler_output - text_embeds = text_outputs.pooler_output - - # normalized features - image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) - - # cosine similarity as logits - logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) - - logit_scale, logit_bias = ( - self.logit_scale.to(text_embeds.device), - self.logit_bias.to(text_embeds.device), - ) - logits_per_text = logits_per_text * logit_scale.exp() + logit_bias - - logits_per_image = logits_per_text.t() - - loss = None - if return_loss: - # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 - eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) - m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye - loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) - nll = -torch.sum(loglik, dim=-1) - loss = nll.mean() - - return SiglipOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -@safe_auto_docstring( - custom_intro=""" - SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of - the patch tokens) e.g. for ImageNet. - """ -) -class SiglipForImageClassification(SiglipPreTrainedModel): - main_input_name = "pixel_values" - - def __init__(self, config: SiglipConfig) -> None: - super().__init__(config) - - self.num_labels = config.num_labels - - # Create the vision model with proper attention - # and take only vision_model submodule (for backward compatibility) - vision_model = SiglipVisionModel._from_config(config.vision_config) - self.vision_model = vision_model.vision_model - - # Classifier head - self.classifier = ( - nn.Linear(config.vision_config.hidden_size, config.num_labels) - if config.num_labels > 0 - else nn.Identity() - ) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - pixel_values: torch.Tensor | None = None, - labels: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> ImageClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, SiglipForImageClassification - >>> import torch - >>> from PIL import Image - >>> import requests - - >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> # note: we are loading a `SiglipModel` from the hub here, - >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. - >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") - >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") - - >>> inputs = image_processor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the two classes - >>> predicted_class_idx = logits.argmax(-1).item() - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) - Predicted class: LABEL_1 - ```""" - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - sequence_output = outputs.last_hidden_state - - # average pool the patch tokens - sequence_output = torch.mean(sequence_output, dim=1) - # apply classifier - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "SiglipModel", - "SiglipPreTrainedModel", - "SiglipTextModel", - "SiglipVisionModel", - "SiglipForImageClassification", -] diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index eb6f95934..8fa738f03 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -34,7 +34,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi from lerobot.configs.policies import PreTrainedConfig from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig +from lerobot.policies.pi05.configuration_pi05openpi import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -495,7 +495,7 @@ class PaliGemmaWithExpertModel( class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI05 PyTorch model.""" - def __init__(self, config: PI05OpenPIConfig): + def __init__(self, config: PI05Config): super().__init__() self.config = config @@ -523,10 +523,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) - msg = """transformers_replace is not installed correctly. -Please install it with `pip install transformers==4.53.2` -and `cp -r ./src/lerobot/policies/pi0/transformers_replace/* \ -$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" try: from transformers.models.siglip import check @@ -816,15 +813,15 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_ return self.action_out_proj(suffix_out) -class PI05OpenPIPolicy(PreTrainedPolicy): +class PI05Policy(PreTrainedPolicy): """PI05 OpenPI Policy for LeRobot.""" - config_class = PI05OpenPIConfig + config_class = PI05Config name = "pi05" def __init__( # see lerobot pi0 `__init__` self, - config: PI05OpenPIConfig, + config: PI05Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ diff --git a/src/lerobot/policies/pi05/transformers_replace/models/gemma/configuration_gemma.py b/src/lerobot/policies/pi05/transformers_replace/models/gemma/configuration_gemma.py deleted file mode 100644 index 72eb2a36c..000000000 --- a/src/lerobot/policies/pi05/transformers_replace/models/gemma/configuration_gemma.py +++ /dev/null @@ -1,173 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...configuration_utils import PretrainedConfig - - -class GemmaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Gemma-7B. - e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 256000): - Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`GemmaModel`] - hidden_size (`int`, *optional*, defaults to 3072): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 24576): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 28): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 16): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details, check out [this - paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to - `num_attention_heads`. - head_dim (`int`, *optional*, defaults to 256): - The attention head dimension. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The legacy activation function. It is overwritten by the `hidden_activation`. - hidden_activation (`str` or `function`, *optional*): - The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` - if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. - max_position_embeddings (`int`, *optional*, defaults to 8192): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - bos_token_id (`int`, *optional*, defaults to 2): - Beginning of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - use_adarms (`bool`, *optional*, defaults to `False`): - Whether to use ADARMS. - adarms_cond_dim (`int`, *optional*, defaults to `None`): - The dimension of the ADARMS condition. - ```python - >>> from transformers import GemmaModel, GemmaConfig - >>> # Initializing a Gemma gemma-7b style configuration - >>> configuration = GemmaConfig() - >>> # Initializing a model from the gemma-7b style configuration - >>> model = GemmaModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "gemma" - keys_to_ignore_at_inference = ["past_key_values"] - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - - def __init__( - self, - vocab_size=256000, - hidden_size=3072, - intermediate_size=24576, - num_hidden_layers=28, - num_attention_heads=16, - num_key_value_heads=16, - head_dim=256, - hidden_act="gelu_pytorch_tanh", - hidden_activation=None, - max_position_embeddings=8192, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - eos_token_id=1, - bos_token_id=2, - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - use_adarms: bool = False, - adarms_cond_dim: int | None = None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.head_dim = head_dim - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.hidden_activation = hidden_activation - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.use_adarms = use_adarms - self.adarms_cond_dim = adarms_cond_dim - - # Set default for adarms_cond_dim if use_adarms is True - if self.use_adarms and self.adarms_cond_dim is None: - self.adarms_cond_dim = self.hidden_size - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -__all__ = ["GemmaConfig"] diff --git a/src/lerobot/policies/pi05/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi05/transformers_replace/models/gemma/modeling_gemma.py deleted file mode 100644 index 05066afc5..000000000 --- a/src/lerobot/policies/pi05/transformers_replace/models/gemma/modeling_gemma.py +++ /dev/null @@ -1,895 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_gemma.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# coding=utf-8 -# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Callable - -import torch -from torch import nn - -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...generation import GenerationMixin -from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging -from .configuration_gemma import GemmaConfig - -logger = logging.get_logger(__name__) - - -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - -class GemmaRMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): - super().__init__() - self.eps = eps - self.dim = dim - self.cond_dim = cond_dim - - # Dense layer for adaptive normalization (if cond_dim is provided) - if cond_dim is not None: - # self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16) - self.dense = nn.Linear(cond_dim, dim * 3, bias=True) - # Initialize with zeros (matches source implementation) - nn.init.zeros_(self.dense.weight) - else: - self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16)) - self.dense = None - - def _norm(self, x): - # Compute variance in float32 (like the source implementation) - var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) - # Compute normalization in float32 - normed_inputs = x * torch.rsqrt(var + self.eps) - return normed_inputs - - def forward(self, x, cond=None): - dtype = x.dtype # original dtype, could be half-precision - normed_inputs = self._norm(x) - - if cond is None or self.dense is None: - # regular RMSNorm - # scale by learned parameter in float32 (matches source implementation) - normed_inputs = normed_inputs * (1.0 + self.weight.float()) - return normed_inputs.to(dtype), None # return in original dtype with None gate - - # adaptive RMSNorm (if cond is provided and dense layer exists) - if cond.shape[-1] != self.cond_dim: - raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}") - - # self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32) - modulation = self.dense(cond) - # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features] - if len(x.shape) == 3: # [batch, seq, features] - modulation = modulation.unsqueeze(1) - - scale, shift, gate = torch.chunk(modulation, 3, dim=-1) - - # Apply adaptive normalization: use model weight dtype to ensure compatibility - # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16) - # scale = scale.to(model_dtype) - # shift = shift.to(model_dtype) - # gate = gate.to(model_dtype) - # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype - - normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32) - - return normed_inputs.to(dtype), gate.to(dtype) - - def extra_repr(self): - repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}" - if self.dense is not None: - repr_str += f", adaptive=True, cond_dim={self.cond_dim}" - return repr_str - - -class GemmaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class GemmaRotaryEmbedding(nn.Module): - def __init__(self, config: GemmaConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - ) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def _gated_residual(x, y, gate): - """ - Applies gated residual connection with optional gate parameter. - - Args: - x: Input tensor (residual) - y: Output tensor to be added - gate: Optional gate tensor to modulate the addition - - Returns: - x + y if gate is None, otherwise x + y * gate - """ - if x is None and y is None: - return None - if x is None or y is None: - return x if x is not None else y - if gate is None: - return x + y - return x + y * gate - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class GemmaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_value: Cache | None = None, - cache_position: torch.LongTensor | None = None, - use_cache: bool = False, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - # Use cache if provided - if past_key_value is not None: - if use_cache: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - else: - key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2) - value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class GemmaDecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) - - self.mlp = GemmaMLP(config) - cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) - self.post_attention_layernorm = GemmaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: None - | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC - adarms_cond: torch.Tensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - residual = hidden_states - hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = _gated_residual(residual, hidden_states, gate) - - # Fully Connected - residual = hidden_states - hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond) - hidden_states = self.mlp(hidden_states) - hidden_states = _gated_residual(residual, hidden_states, gate) - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - -@safe_auto_docstring -class GemmaPreTrainedModel(PreTrainedModel): - config_class = GemmaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["GemmaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_3 = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_attention_backend = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, GemmaRMSNorm): - if hasattr(module, "weight"): - module.weight.data.fill_(1.0) - - -@safe_auto_docstring -class GemmaModel(GemmaPreTrainedModel): - def __init__(self, config: GemmaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - - cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) - self.rotary_emb = GemmaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - cache_position: torch.LongTensor | None = None, - adarms_cond: torch.Tensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> BaseModelOutputWithPast: - """ - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - # embed positions - hidden_states = inputs_embeds - # Convert to bfloat16 if the first layer uses bfloat16 - if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.bfloat16) - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # normalized - # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - _normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - # hidden_states = hidden_states * normalizer - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - adarms_cond=adarms_cond, - **kwargs, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states, _ = self.norm(hidden_states, adarms_cond) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -@safe_auto_docstring -class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - cache_position: torch.LongTensor | None = None, - logits_to_keep: int | torch.Tensor = 0, - adarms_cond: torch.Tensor | None = None, - **kwargs: Unpack[KwargsForCausalLM], - ) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - - Example: - - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - adarms_cond=adarms_cond, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@safe_auto_docstring( - custom_intro=""" - The Gemma Model transformer with a sequence classification head on top (linear layer). - - [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """ -) -class GemmaForSequenceClassification(GemmaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = GemmaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - adarms_cond: torch.Tensor | None = None, - ) -> SequenceClassifierOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - - transformer_outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - adarms_cond=adarms_cond, - ) - hidden_states = transformer_outputs.last_hidden_state - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config - ) - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@safe_auto_docstring -class GemmaForTokenClassification(GemmaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = GemmaModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - adarms_cond: torch.Tensor | None = None, - ) -> TokenClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): - Condition for ADARMS. - """ - - outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - adarms_cond=adarms_cond, - ) - sequence_output = outputs.last_hidden_state - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "GemmaModel", - "GemmaForCausalLM", - "GemmaForSequenceClassification", - "GemmaForTokenClassification", - "GemmaPreTrainedModel", -] diff --git a/src/lerobot/policies/pi05/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi05/transformers_replace/models/paligemma/modeling_paligemma.py deleted file mode 100644 index b2a36b5ca..000000000 --- a/src/lerobot/policies/pi05/transformers_replace/models/paligemma/modeling_paligemma.py +++ /dev/null @@ -1,666 +0,0 @@ -# Copyright 2024 the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch PaliGemmamodel.""" - -from dataclasses import dataclass - -import torch -import torch.utils.checkpoint -from torch import nn - -from ...cache_utils import Cache, HybridCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import BaseModelOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...processing_utils import Unpack -from ...utils import ( - LossKwargs, - ModelOutput, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - logging, -) -from ..auto import AutoModel -from .configuration_paligemma import PaliGemmaConfig - -logger = logging.get_logger(__name__) - - -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for Paligemma outputs, with hidden states and attentions. - """ -) -class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. - """ - - image_hidden_states: torch.FloatTensor | None = None - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for PaliGemma causal language model (or autoregressive) outputs. - """ -) -class PaliGemmaCausalLMOutputWithPast(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - image_hidden_states (`torch.FloatTensor`, *optional*): - A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. - image_hidden_states of the model produced by the vision encoder after projecting last hidden state. - """ - - loss: torch.FloatTensor | None = None - logits: torch.FloatTensor | None = None - past_key_values: list[torch.FloatTensor] | Cache | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - image_hidden_states: torch.FloatTensor | None = None - - -class PaliGemmaMultiModalProjector(nn.Module): - def __init__(self, config: PaliGemmaConfig): - super().__init__() - self.linear = nn.Linear( - config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True - ) - - def forward(self, image_features): - hidden_states = self.linear(image_features) - - return hidden_states - - -@safe_auto_docstring -class PaliGemmaPreTrainedModel(PreTrainedModel): - config_class = PaliGemmaConfig - base_model_prefix = "" - supports_gradient_checkpointing = True - _no_split_modules = ["PaliGemmaMultiModalProjector"] - _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_attention_backend = True - - def _init_weights(self, module): - # important: this ported version of PaliGemmaisn't meant for training from scratch - only - # inference and fine-tuning - std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) - - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - - -@safe_auto_docstring( - custom_intro=""" - The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., - """ -) -class PaliGemmaModel(PaliGemmaPreTrainedModel): - _checkpoint_conversion_mapping = {"language_model.model": "language_model"} - # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch - accepts_loss_kwargs = False - - def __init__(self, config: PaliGemmaConfig): - super().__init__(config) - self.vision_tower = AutoModel.from_config(config=config.vision_config) - self.multi_modal_projector = PaliGemmaMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - - language_model = AutoModel.from_config(config=config.text_config) - self.language_model = language_model - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.post_init() - - # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - - def _update_causal_mask( - self, - attention_mask, - token_type_ids=None, - past_key_values=None, - cache_position=None, - input_tensor=None, - is_training: bool | None = None, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - is_training = is_training if is_training is not None else self.training - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(self.dtype).min - if input_tensor is None: - input_tensor = attention_mask - - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=self.dtype, - device=cache_position.device, - ) - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - if is_training: - causal_mask = torch.triu(causal_mask, diagonal=1) - else: - causal_mask[:, :sequence_length] = 0.0 - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # First unmask prefix tokens during training - if is_training: - if token_type_ids is None: - raise ValueError("Token type ids must be provided during training") - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - def get_image_features(self, pixel_values: torch.FloatTensor): - """ - Obtains image last hidden states from the vision tower and apply multimodal projection. - - Args: - pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) - The tensors corresponding to the input images. - Returns: - image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). - """ - image_outputs = self.vision_tower(pixel_values) - selected_image_feature = image_outputs.last_hidden_state - image_features = self.multi_modal_projector(selected_image_feature) - return image_features - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | Cache | None = None, - token_type_ids: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple | PaligemmaModelOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration - - >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") - >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") - - >>> prompt = "Where is the cat standing?" - >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs,) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Where is the cat standing?\nsnow" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - is_training = token_type_ids is not None and labels is not None - - # Replace image id with PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_id >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_id - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - cache_position=cache_position, - **kwargs, - ) - - return PaligemmaModelOutputWithPast( - last_hidden_state=outputs.last_hidden_state, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - -@safe_auto_docstring( - custom_intro=""" - The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., - """ -) -class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "^language_model.model": "model.language_model", - "^vision_tower": "model.vision_tower", - "^multi_modal_projector": "model.multi_modal_projector", - "^language_model.lm_head": "lm_head", - } - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: PaliGemmaConfig): - super().__init__(config) - self.model = PaliGemmaModel(config) - self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) - self.post_init() - - def get_input_embeddings(self): - return self.model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - - def get_image_features(self, pixel_values): - return self.model.get_image_features(pixel_values) - - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | Cache | None = None, - token_type_ids: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> tuple | PaliGemmaCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration - - >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") - >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") - - >>> prompt = "Where is the cat standing?" - >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs,) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Where is the cat standing?\nsnow" - ```""" - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids=input_ids, - pixel_values=pixel_values, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function( - logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs - ) - - return PaliGemmaCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - logits_to_keep=None, - labels=None, - **kwargs, - ): - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = super().prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - logits_to_keep=logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # position_ids in Paligemma are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self.model._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - @staticmethod - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=cache_position.device, - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] diff --git a/src/lerobot/policies/pi05/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi05/transformers_replace/models/siglip/check.py deleted file mode 100644 index d899dc1b9..000000000 --- a/src/lerobot/policies/pi05/transformers_replace/models/siglip/check.py +++ /dev/null @@ -1,5 +0,0 @@ -import transformers - - -def check_whether_transformers_replace_is_installed_correctly(): - return transformers.__version__ == "4.53.2" diff --git a/src/lerobot/policies/pi05/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi05/transformers_replace/models/siglip/modeling_siglip.py deleted file mode 100644 index 0fc0bba0f..000000000 --- a/src/lerobot/policies/pi05/transformers_replace/models/siglip/modeling_siglip.py +++ /dev/null @@ -1,1283 +0,0 @@ -# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Siglip model.""" - -import math -import warnings -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -import numpy as np -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.nn.init import _calculate_fan_in_and_fan_out - -from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int -from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig - -logger = logging.get_logger(__name__) - - -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - -def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) # noqa: E741 - u = norm_cdf((b - mean) / std) # noqa: E741 - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741 - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsequently scaled and shifted by the mean and std args. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - """ -) -# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip -class SiglipVisionModelOutput(ModelOutput): - r""" - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - """ - - image_embeds: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - - -@dataclass -@safe_auto_docstring( - custom_intro=""" - Base class for text model's outputs that also contains a pooling of the last hidden states. - """ -) -# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip -class SiglipTextModelOutput(ModelOutput): - r""" - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The text embeddings obtained by applying the projection layer to the pooler_output. - """ - - text_embeds: torch.FloatTensor | None = None - last_hidden_state: torch.FloatTensor | None = None - hidden_states: tuple[torch.FloatTensor, ...] | None = None - attentions: tuple[torch.FloatTensor, ...] | None = None - - -@dataclass -@safe_auto_docstring -# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip -class SiglipOutput(ModelOutput): - r""" - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): - Contrastive loss for image-text similarity. - logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): - The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text - similarity scores. - logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): - The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image - similarity scores. - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): - The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. - text_model_output (`BaseModelOutputWithPooling`): - The output of the [`SiglipTextModel`]. - vision_model_output (`BaseModelOutputWithPooling`): - The output of the [`SiglipVisionModel`]. - """ - - loss: torch.FloatTensor | None = None - logits_per_image: torch.FloatTensor | None = None - logits_per_text: torch.FloatTensor | None = None - text_embeds: torch.FloatTensor | None = None - image_embeds: torch.FloatTensor | None = None - text_model_output: BaseModelOutputWithPooling = None - vision_model_output: BaseModelOutputWithPooling = None - - def to_tuple(self) -> tuple[Any]: - return tuple( - self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() - for k in self.keys() - ) - - -class SiglipVisionEmbeddings(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - self.register_buffer( - "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False - ) - - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution - images. This method is also adapted to support torch.jit tracing and no class embeddings. - - Adapted from: - - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 - """ - - num_patches = embeddings.shape[1] - num_positions = self.position_embedding.weight.shape[0] - - # always interpolate when tracing to ensure the exported model works for dynamic input shapes - if not torch.jit.is_tracing() and num_patches == num_positions and height == width: - return self.position_embedding(self.position_ids) - - patch_pos_embed = self.position_embedding.weight.unsqueeze(0) - - dim = embeddings.shape[-1] - - new_height = height // self.patch_size - new_width = width // self.patch_size - - sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - size=(new_height, new_width), - mode="bicubic", - align_corners=False, - ) - - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return patch_pos_embed - - def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: - _, _, height, width = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype) - ) # shape = [*, width, grid, grid] - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip -class SiglipTextEmbeddings(nn.Module): - def __init__(self, config: SiglipTextConfig): - super().__init__() - embed_dim = config.hidden_size - - self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) - self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) - - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer( - "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False - ) - - def forward( - self, - input_ids: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - ) -> torch.Tensor: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - max_position_embedding = self.position_embedding.weight.shape[0] - - if seq_length > max_position_embedding: - raise ValueError( - f"Sequence length must be less than max_position_embeddings (got `sequence length`: " - f"{seq_length} and max_position_embeddings: {max_position_embedding}" - ) - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - inputs_embeds = self.token_embedding(input_ids) - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs, -): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class SiglipAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - self.is_causal = False - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - output_attentions: bool | None = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Input shape: Batch x Time x Channel""" - - batch_size, seq_length, embed_dim = hidden_states.shape - - queries = self.q_proj(hidden_states) - keys = self.k_proj(hidden_states) - values = self.v_proj(hidden_states) - - queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - queries, - keys, - values, - attention_mask, - is_causal=self.is_causal, - scaling=self.scale, - dropout=0.0 if not self.training else self.dropout, - ) - - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip -class SiglipMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -class SiglipEncoderLayer(GradientCheckpointingLayer): - def __init__(self, config: SiglipVisionConfig | SiglipTextConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.self_attn = SiglipAttention(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_attentions: bool | None = False, - ) -> tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -@safe_auto_docstring -class SiglipPreTrainedModel(PreTrainedModel): - config_class = SiglipConfig - base_model_prefix = "siglip" - supports_gradient_checkpointing = True - - _no_split_modules = [ - "SiglipTextEmbeddings", - "SiglipEncoderLayer", - "SiglipVisionEmbeddings", - "SiglipEncoderLayer", - "SiglipMultiheadAttentionPoolingHead", - ] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True - _supports_attention_backend = True - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, SiglipVisionEmbeddings): - width = ( - self.config.vision_config.hidden_size - if isinstance(self.config, SiglipConfig) - else self.config.hidden_size - ) - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, SiglipAttention): - nn.init.xavier_uniform_(module.q_proj.weight) - nn.init.xavier_uniform_(module.k_proj.weight) - nn.init.xavier_uniform_(module.v_proj.weight) - nn.init.xavier_uniform_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, SiglipMLP): - nn.init.xavier_uniform_(module.fc1.weight) - nn.init.xavier_uniform_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, SiglipMultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe.data) - nn.init.xavier_uniform_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) - elif isinstance(module, SiglipModel): - logit_scale_init = torch.log(torch.tensor(1.0)) - module.logit_scale.data.fill_(logit_scale_init) - module.logit_bias.data.zero_() - elif isinstance(module, SiglipForImageClassification): - nn.init.normal_( - module.classifier.weight, - std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, - ) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip -class SiglipEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`SiglipEncoderLayer`]. - - Args: - config: SiglipConfig - """ - - def __init__(self, config: SiglipConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - # Ignore copy - @can_return_tuple - def forward( - self, - inputs_embeds, - attention_mask: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> BaseModelOutput: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, - ) - - -class SiglipTextTransformer(nn.Module): - def __init__(self, config: SiglipTextConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - self.embeddings = SiglipTextEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - - self.head = nn.Linear(embed_dim, config.projection_size) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> BaseModelOutputWithPooling: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if input_ids is None: - raise ValueError("You have to specify input_ids") - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) - - # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. - # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: - # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - - encoder_outputs: BaseModelOutput = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - last_hidden_state = encoder_outputs.last_hidden_state - last_hidden_state = self.final_layer_norm(last_hidden_state) - - # Assuming "sticky" EOS tokenization, last token is always EOS. - pooled_output = last_hidden_state[:, -1, :] - pooled_output = self.head(pooled_output) - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -@safe_auto_docstring( - custom_intro=""" - The text model from SigLIP without any head or projection on top. - """ -) -class SiglipTextModel(SiglipPreTrainedModel): - config_class = SiglipTextConfig - - def __init__(self, config: SiglipTextConfig): - super().__init__(config) - self.text_model = SiglipTextTransformer(config) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, value): - self.text_model.embeddings.token_embedding = value - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> BaseModelOutputWithPooling: - r""" - Examples: - - ```python - >>> from transformers import AutoTokenizer, SiglipTextModel - - >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled (EOS token) states - ```""" - - return self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - -class SiglipVisionTransformer(nn.Module): - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - embed_dim = config.hidden_size - - self.embeddings = SiglipVisionEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head - if self.use_head: - self.head = SiglipMultiheadAttentionPoolingHead(config) - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - pixel_values, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool | None = False, - ) -> BaseModelOutputWithPooling: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - # Convert to bfloat16 if the encoder uses bfloat16 - if ( - len(self.encoder.layers) > 0 - and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 - ): - hidden_states = hidden_states.to(torch.bfloat16) - - encoder_outputs: BaseModelOutput = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - last_hidden_state = encoder_outputs.last_hidden_state - last_hidden_state = self.post_layernorm(last_hidden_state) - - pooler_output = self.head(last_hidden_state) if self.use_head else None - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooler_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class SiglipMultiheadAttentionPoolingHead(nn.Module): - """Multihead Attention Pooling.""" - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - - self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.attention = torch.nn.MultiheadAttention( - config.hidden_size, config.num_attention_heads, batch_first=True - ) - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - - def forward(self, hidden_state): - batch_size = hidden_state.shape[0] - probe = self.probe.repeat(batch_size, 1, 1) - - hidden_state = self.attention(probe, hidden_state, hidden_state)[0] - - residual = hidden_state - hidden_state = self.layernorm(hidden_state) - hidden_state = residual + self.mlp(hidden_state) - - return hidden_state[:, 0] - - -@safe_auto_docstring( - custom_intro=""" - The vision model from SigLIP without any head or projection on top. - """ -) -class SiglipVisionModel(SiglipPreTrainedModel): - config_class = SiglipVisionConfig - main_input_name = "pixel_values" - - def __init__(self, config: SiglipVisionConfig): - super().__init__(config) - - self.vision_model = SiglipVisionTransformer(config) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - pixel_values, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> BaseModelOutputWithPooling: - r""" - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, SiglipVisionModel - - >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="pt") - - >>> outputs = model(**inputs) - >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooler_output # pooled features - ```""" - - return self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - -@safe_auto_docstring -class SiglipModel(SiglipPreTrainedModel): - config_class = SiglipConfig - - def __init__(self, config: SiglipConfig): - super().__init__(config) - - if not isinstance(config.text_config, SiglipTextConfig): - raise TypeError( - "config.text_config is expected to be of type SiglipTextConfig but is of type" - f" {type(config.text_config)}." - ) - - if not isinstance(config.vision_config, SiglipVisionConfig): - raise TypeError( - "config.vision_config is expected to be of type SiglipVisionConfig but is of type" - f" {type(config.vision_config)}." - ) - - text_config = config.text_config - vision_config = config.vision_config - - # First, initialize the text and vision models with proper attention implementation - text_model = SiglipTextModel._from_config(text_config) - vision_model = SiglipVisionModel._from_config(vision_config) - - # Second, get the text and vision submodules (for backward compatibility) - self.text_model = text_model.text_model - self.vision_model = vision_model.vision_model - - self.logit_scale = nn.Parameter(torch.randn(1)) - self.logit_bias = nn.Parameter(torch.randn(1)) - - # Initialize weights and apply final processing - self.post_init() - - @safe_auto_docstring - def get_text_features( - self, - input_ids: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - ) -> torch.FloatTensor: - r""" - Returns: - text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by - applying the projection layer to the pooled output of [`SiglipTextModel`]. - - Examples: - - ```python - >>> from transformers import AutoTokenizer, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") - - >>> # important: make sure to set padding="max_length" as that's how the model was trained - >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") - >>> with torch.no_grad(): - ... text_features = model.get_text_features(**inputs) - ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - pooled_output = text_outputs.pooler_output - - return pooled_output - - @safe_auto_docstring - def get_image_features( - self, - pixel_values: torch.FloatTensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> torch.FloatTensor: - r""" - Returns: - image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by - applying the projection layer to the pooled output of [`SiglipVisionModel`]. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, return_tensors="pt") - - >>> with torch.no_grad(): - ... image_features = model.get_image_features(**inputs) - ```""" - # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - pooled_output = vision_outputs.pooler_output - - return pooled_output - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - return_loss: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> SiglipOutput: - r""" - return_loss (`bool`, *optional*): - Whether or not to return the contrastive loss. - - Examples: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, AutoModel - >>> import torch - - >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") - >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] - >>> # important: we pass `padding=max_length` since the model was trained with this - >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") - - >>> with torch.no_grad(): - ... outputs = model(**inputs) - - >>> logits_per_image = outputs.logits_per_image - >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities - >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") - 31.9% that image 0 is 'a photo of 2 cats' - ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - image_embeds = vision_outputs.pooler_output - text_embeds = text_outputs.pooler_output - - # normalized features - image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) - - # cosine similarity as logits - logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) - - logit_scale, logit_bias = ( - self.logit_scale.to(text_embeds.device), - self.logit_bias.to(text_embeds.device), - ) - logits_per_text = logits_per_text * logit_scale.exp() + logit_bias - - logits_per_image = logits_per_text.t() - - loss = None - if return_loss: - # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 - eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) - m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye - loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) - nll = -torch.sum(loglik, dim=-1) - loss = nll.mean() - - return SiglipOutput( - loss=loss, - logits_per_image=logits_per_image, - logits_per_text=logits_per_text, - text_embeds=text_embeds, - image_embeds=image_embeds, - text_model_output=text_outputs, - vision_model_output=vision_outputs, - ) - - -@safe_auto_docstring( - custom_intro=""" - SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of - the patch tokens) e.g. for ImageNet. - """ -) -class SiglipForImageClassification(SiglipPreTrainedModel): - main_input_name = "pixel_values" - - def __init__(self, config: SiglipConfig) -> None: - super().__init__(config) - - self.num_labels = config.num_labels - - # Create the vision model with proper attention - # and take only vision_model submodule (for backward compatibility) - vision_model = SiglipVisionModel._from_config(config.vision_config) - self.vision_model = vision_model.vision_model - - # Classifier head - self.classifier = ( - nn.Linear(config.vision_config.hidden_size, config.num_labels) - if config.num_labels > 0 - else nn.Identity() - ) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @safe_auto_docstring - def forward( - self, - pixel_values: torch.Tensor | None = None, - labels: torch.Tensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - interpolate_pos_encoding: bool = False, - ) -> ImageClassifierOutput: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the image classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Examples: - - ```python - >>> from transformers import AutoImageProcessor, SiglipForImageClassification - >>> import torch - >>> from PIL import Image - >>> import requests - - >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> # note: we are loading a `SiglipModel` from the hub here, - >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. - >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") - >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") - - >>> inputs = image_processor(images=image, return_tensors="pt") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the two classes - >>> predicted_class_idx = logits.argmax(-1).item() - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) - Predicted class: LABEL_1 - ```""" - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - ) - - sequence_output = outputs.last_hidden_state - - # average pool the patch tokens - sequence_output = torch.mean(sequence_output, dim=1) - # apply classifier - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - return ImageClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -__all__ = [ - "SiglipModel", - "SiglipPreTrainedModel", - "SiglipTextModel", - "SiglipVisionModel", - "SiglipForImageClassification", -]