mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
initial commit
This commit is contained in:
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Script to create and push a PI0OpenPI model to HuggingFace hub with proper config format."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, create_repo
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
|
||||
|
||||
def create_and_push_model(
|
||||
repo_id: str,
|
||||
private: bool = False,
|
||||
token: str = None,
|
||||
):
|
||||
"""Create a PI0OpenPI model with proper config and push to HuggingFace hub.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace repository ID (e.g., "username/model-name")
|
||||
private: Whether to create a private repository
|
||||
token: HuggingFace API token (optional, will use cached token if not provided)
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("PI0OpenPI Model Hub Upload")
|
||||
print("=" * 60)
|
||||
|
||||
# Create configuration
|
||||
print("\nCreating PI0OpenPI configuration...")
|
||||
config = PI0OpenPIConfig(
|
||||
# Model architecture
|
||||
paligemma_variant="gemma_2b",
|
||||
action_expert_variant="gemma_300m",
|
||||
pi05=False, # Use PI0 (not PI0.5)
|
||||
dtype="float32", # Use float32 for compatibility
|
||||
# Input/output dimensions
|
||||
action_dim=32, # see openpi `Pi0Config`
|
||||
state_dim=32,
|
||||
action_horizon=50,
|
||||
n_action_steps=50,
|
||||
# Image inputs, see openpi `model.py, IMAGE_KEYS`
|
||||
image_keys=(
|
||||
"observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist_0_rgb",
|
||||
),
|
||||
# Training settings
|
||||
gradient_checkpointing=False,
|
||||
compile_model=False,
|
||||
device=None, # Auto-detect
|
||||
# Tokenizer settings
|
||||
tokenizer_max_length=200, # see openpi `__post_init__`, use pi0=200 and pi05=48
|
||||
)
|
||||
|
||||
print(f" - Config type: {config.__class__.__name__}")
|
||||
print(f" - PaliGemma variant: {config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {config.action_expert_variant}")
|
||||
print(f" - Action dim: {config.action_dim}")
|
||||
print(f" - State dim: {config.state_dim}")
|
||||
|
||||
# Create dummy dataset stats for normalization
|
||||
print("\nCreating dataset statistics...")
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(config.state_dim),
|
||||
"std": torch.ones(config.state_dim),
|
||||
"min": torch.full((config.state_dim,), -5.0),
|
||||
"max": torch.full((config.state_dim,), 5.0),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(config.action_dim),
|
||||
"std": torch.ones(config.action_dim),
|
||||
"min": torch.full((config.action_dim,), -1.0),
|
||||
"max": torch.full((config.action_dim,), 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
# Add image stats
|
||||
for key in config.image_keys:
|
||||
dataset_stats[key] = {
|
||||
"mean": torch.tensor([0.485, 0.456, 0.406]), # TODO(pepijn): fix this, now its ImageNet mean
|
||||
"std": torch.tensor([0.229, 0.224, 0.225]), # TODO(pepijn): fix this, now its ImageNet std
|
||||
"min": torch.tensor([0.0, 0.0, 0.0]),
|
||||
"max": torch.tensor([1.0, 1.0, 1.0]),
|
||||
}
|
||||
|
||||
# Create the policy
|
||||
print("\nInitializing PI0OpenPI policy...")
|
||||
print(" (This may take a moment as it loads the tokenizer and initializes the model)")
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
# Initialize with small random weights (optional - for testing)
|
||||
# Note: In practice, you would load your trained weights here
|
||||
print("\nInitializing model weights...")
|
||||
for name, param in policy.named_parameters():
|
||||
if "weight" in name:
|
||||
if "norm" in name.lower() or "layernorm" in name.lower():
|
||||
torch.nn.init.ones_(param)
|
||||
elif len(param.shape) >= 2:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.01)
|
||||
else:
|
||||
torch.nn.init.normal_(param, mean=0.0, std=0.01)
|
||||
elif "bias" in name:
|
||||
torch.nn.init.zeros_(param)
|
||||
|
||||
print(f" - Total parameters: {sum(p.numel() for p in policy.parameters()):,}")
|
||||
print(f" - Trainable parameters: {sum(p.numel() for p in policy.parameters() if p.requires_grad):,}")
|
||||
|
||||
# Create temporary directory for saving
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
save_path = Path(tmpdir) / "model"
|
||||
save_path.mkdir(exist_ok=True)
|
||||
|
||||
print(f"\nSaving model to temporary directory: {save_path}")
|
||||
|
||||
# Save the model using LeRobot's save_pretrained method
|
||||
# This ensures the config is saved in the correct format
|
||||
policy.save_pretrained(save_path)
|
||||
|
||||
# List saved files
|
||||
saved_files = list(save_path.glob("*"))
|
||||
print("\nSaved files:")
|
||||
for file in saved_files:
|
||||
size = file.stat().st_size
|
||||
print(f" - {file.name}: {size:,} bytes")
|
||||
|
||||
# Create or get repository
|
||||
print(f"\nCreating/accessing repository: {repo_id}")
|
||||
api = HfApi(token=token)
|
||||
|
||||
try:
|
||||
# Create repo if it doesn't exist
|
||||
create_repo(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f" ✓ Repository ready: https://huggingface.co/{repo_id}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Note: {e}")
|
||||
|
||||
# Upload to hub
|
||||
print("\nUploading to HuggingFace hub...")
|
||||
api.upload_folder(
|
||||
folder_path=str(save_path),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
token=token,
|
||||
commit_message="Upload PI0OpenPI model with proper LeRobot config format",
|
||||
)
|
||||
|
||||
print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}")
|
||||
|
||||
# Test loading the model back
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing model loading from hub...")
|
||||
|
||||
try:
|
||||
loaded_policy = PI0OpenPIPolicy.from_pretrained(
|
||||
repo_id,
|
||||
token=token,
|
||||
)
|
||||
print("✓ Model loaded successfully from hub")
|
||||
|
||||
# Quick validation
|
||||
batch_size = 1
|
||||
device = next(loaded_policy.parameters()).device
|
||||
test_batch = {
|
||||
"observation.state": torch.randn(batch_size, config.state_dim, device=device),
|
||||
"action": torch.randn(batch_size, config.action_horizon, config.action_dim, device=device),
|
||||
"task": ["Test task"],
|
||||
}
|
||||
|
||||
# Add images
|
||||
for key in config.image_keys:
|
||||
test_batch[key] = torch.rand(batch_size, 3, 224, 224, device=device)
|
||||
|
||||
# Test forward pass
|
||||
loaded_policy.train()
|
||||
loss, loss_dict = loaded_policy.forward(test_batch)
|
||||
print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ Process complete!")
|
||||
print("=" * 60)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Push PI0OpenPI model to HuggingFace hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="test-user/pi0-openpi-test",
|
||||
help="HuggingFace repository ID (e.g., 'username/model-name')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Create a private repository",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="HuggingFace API token (optional, uses cached token if not provided)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run the upload
|
||||
create_and_push_model(
|
||||
repo_id=args.repo_id,
|
||||
private=args.private,
|
||||
token=args.token,
|
||||
)
|
||||
@@ -27,6 +27,7 @@ from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
@@ -62,6 +63,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "pi0_openpi":
|
||||
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
|
||||
|
||||
return PI0OpenPIPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -91,6 +96,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "pi0_openpi":
|
||||
return PI0OpenPIConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_pi0openpi import PI0OpenPIConfig
|
||||
from .modeling_pi0openpi import PI0OpenPIPolicy
|
||||
|
||||
__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"]
|
||||
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
# ### ⚠️ WARNING ⚠️ ###
|
||||
# This project requires patching the Hugging Face `transformers` library.
|
||||
#
|
||||
# 1. Make sure you have the exact version installed:
|
||||
# pip show transformers
|
||||
# It must be version 4.53.2.
|
||||
#
|
||||
# 2. Apply the custom patches by copying the modified files into your conda environment (make sure your environment is activated!)
|
||||
# cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")
|
||||
#
|
||||
# These patches overwrite parts of `transformers` to:
|
||||
# (a) support AdaRMS optimizer,
|
||||
# (b) correctly control the precision of activations,
|
||||
# (c) allow the KV cache to be used without updates.
|
||||
#
|
||||
# IMPORTANT:
|
||||
# - This permanently modifies the `transformers` installation in your conda environment.
|
||||
# - The changes will survive reinstalls of `transformers` unless you explicitly remove
|
||||
# the patched files or recreate the environment.
|
||||
#
|
||||
# To undo the operation and restore a clean state, run:
|
||||
# pip uninstall transformers
|
||||
# pip install transformers==4.53.2
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0_openpi")
|
||||
@dataclass
|
||||
class PI0OpenPIConfig(PreTrainedConfig):
|
||||
# Model architecture
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
pi05: bool = False # Whether to use PI0.5 variant with AdaRMS
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
action_horizon: int = 50 # Number of action steps to predict
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
action_dim: int = 32 # Action dimension (will be padded to 32)
|
||||
state_dim: int = 32 # State dimension (will be padded to 32)
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10 # Number of denoising steps during inference
|
||||
time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling
|
||||
time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling
|
||||
min_period: float = 4e-3 # Min period for sinusoidal positional encoding
|
||||
max_period: float = 4.0 # Max period for sinusoidal positional encodingis my
|
||||
|
||||
# Image preprocessing
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
image_keys: tuple[str, ...] = (
|
||||
"observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY, # Images are normalized to [-1, 1] in preprocessing
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW` and
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 200 # pi0=200 and pi05=48, see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.action_horizon:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
# Add image features
|
||||
for key in self.image_keys:
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Default shape, will be resized
|
||||
)
|
||||
|
||||
# Ensure state and action features exist
|
||||
if "observation.state" not in self.input_features:
|
||||
self.input_features["observation.state"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.state_dim,),
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
self.output_features["action"] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.action_dim,),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
+173
@@ -0,0 +1,173 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class GemmaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GemmaModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 24576):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 28):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The legacy activation function. It is overwritten by the `hidden_activation`.
|
||||
hidden_activation (`str` or `function`, *optional*):
|
||||
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
||||
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 2):
|
||||
Beginning of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
use_adarms (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use ADARMS.
|
||||
adarms_cond_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the ADARMS condition.
|
||||
```python
|
||||
>>> from transformers import GemmaModel, GemmaConfig
|
||||
>>> # Initializing a Gemma gemma-7b style configuration
|
||||
>>> configuration = GemmaConfig()
|
||||
>>> # Initializing a model from the gemma-7b style configuration
|
||||
>>> model = GemmaModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "gemma"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=3072,
|
||||
intermediate_size=24576,
|
||||
num_hidden_layers=28,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=256,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
hidden_activation=None,
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
use_adarms: bool = False,
|
||||
adarms_cond_dim: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_activation = hidden_activation
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.use_adarms = use_adarms
|
||||
self.adarms_cond_dim = adarms_cond_dim
|
||||
|
||||
# Set default for adarms_cond_dim if use_adarms is True
|
||||
if self.use_adarms and self.adarms_cond_dim is None:
|
||||
self.adarms_cond_dim = self.hidden_size
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["GemmaConfig"]
|
||||
@@ -0,0 +1,876 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_gemma.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_gemma import GemmaConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GemmaRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dim = dim
|
||||
self.cond_dim = cond_dim
|
||||
|
||||
# Dense layer for adaptive normalization (if cond_dim is provided)
|
||||
if cond_dim is not None:
|
||||
# self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
|
||||
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
|
||||
# Initialize with zeros (matches source implementation)
|
||||
nn.init.zeros_(self.dense.weight)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
|
||||
self.dense = None
|
||||
|
||||
def _norm(self, x):
|
||||
# Compute variance in float32 (like the source implementation)
|
||||
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
|
||||
# Compute normalization in float32
|
||||
normed_inputs = x * torch.rsqrt(var + self.eps)
|
||||
return normed_inputs
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
normed_inputs = self._norm(x)
|
||||
|
||||
if cond is None or self.dense is None:
|
||||
# regular RMSNorm
|
||||
# scale by learned parameter in float32 (matches source implementation)
|
||||
normed_inputs = normed_inputs * (1.0 + self.weight.float())
|
||||
return normed_inputs.to(dtype), None # return in original dtype with None gate
|
||||
|
||||
# adaptive RMSNorm (if cond is provided and dense layer exists)
|
||||
if cond.shape[-1] != self.cond_dim:
|
||||
raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
|
||||
|
||||
# self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
|
||||
modulation = self.dense(cond)
|
||||
# Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
|
||||
if len(x.shape) == 3: # [batch, seq, features]
|
||||
modulation = modulation.unsqueeze(1)
|
||||
|
||||
scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
|
||||
|
||||
# Apply adaptive normalization: use model weight dtype to ensure compatibility
|
||||
# model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
|
||||
# scale = scale.to(model_dtype)
|
||||
# shift = shift.to(model_dtype)
|
||||
# gate = gate.to(model_dtype)
|
||||
# normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
|
||||
|
||||
normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
|
||||
|
||||
return normed_inputs.to(dtype), gate.to(dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
if self.dense is not None:
|
||||
repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
|
||||
return repr_str
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: GemmaConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _gated_residual(x, y, gate):
|
||||
"""
|
||||
Applies gated residual connection with optional gate parameter.
|
||||
|
||||
Args:
|
||||
x: Input tensor (residual)
|
||||
y: Output tensor to be added
|
||||
gate: Optional gate tensor to modulate the addition
|
||||
|
||||
Returns:
|
||||
x + y if gate is None, otherwise x + y * gate
|
||||
"""
|
||||
if x is None and y is None:
|
||||
return None
|
||||
if x is None or y is None:
|
||||
return x if x is not None else y
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_value: Cache | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Use cache if provided
|
||||
if past_key_value is not None:
|
||||
if use_cache:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
else:
|
||||
key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GemmaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: GemmaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = GemmaMLP(config)
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: None
|
||||
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = _gated_residual(residual, hidden_states, gate)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = GemmaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GemmaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, GemmaRMSNorm):
|
||||
if hasattr(module, "weight"):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaModel(GemmaPreTrainedModel):
|
||||
def __init__(self, config: GemmaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
"""
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# normalized
|
||||
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841
|
||||
# hidden_states = hidden_states * normalizer
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, adarms_cond)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = GemmaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
adarms_cond=adarms_cond,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Gemma Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
"""
|
||||
)
|
||||
class GemmaForSequenceClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
transformer_outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config
|
||||
)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class GemmaForTokenClassification(GemmaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = GemmaModel(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
adarms_cond: torch.Tensor | None = None,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
|
||||
Condition for ADARMS.
|
||||
"""
|
||||
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GemmaModel",
|
||||
"GemmaForCausalLM",
|
||||
"GemmaForSequenceClassification",
|
||||
"GemmaForTokenClassification",
|
||||
"GemmaPreTrainedModel",
|
||||
]
|
||||
+647
@@ -0,0 +1,647 @@
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch PaliGemmamodel."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel
|
||||
from .configuration_paligemma import PaliGemmaConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for Paligemma outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PaliGemma causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
||||
"""
|
||||
|
||||
loss: torch.FloatTensor | None = None
|
||||
logits: torch.FloatTensor | None = None
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None
|
||||
hidden_states: tuple[torch.FloatTensor] | None = None
|
||||
attentions: tuple[torch.FloatTensor] | None = None
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(
|
||||
config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear(image_features)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
config_class = PaliGemmaConfig
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PaliGemmaMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
||||
# inference and fine-tuning
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
language_model = AutoModel.from_config(config=config.text_config)
|
||||
self.language_model = language_model
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
self.post_init()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids=None,
|
||||
past_key_values=None,
|
||||
cache_position=None,
|
||||
input_tensor=None,
|
||||
is_training: bool | None = None,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
is_training = is_training if is_training is not None else self.training
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
if input_tensor is None:
|
||||
input_tensor = attention_mask
|
||||
|
||||
inputs_lead_dim, sequence_length = input_tensor.shape[:2]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else cache_position[0] + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=self.dtype,
|
||||
device=cache_position.device,
|
||||
)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
|
||||
# First unmask prefix tokens during training
|
||||
if is_training:
|
||||
if token_type_ids is None:
|
||||
raise ValueError("Token type ids must be provided during training")
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
|
||||
# Then apply padding mask (will mask pad tokens)
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple | PaligemmaModelOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id worth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return PaligemmaModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
self.model = PaliGemmaModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def get_image_features(self, pixel_values):
|
||||
return self.model.get_image_features(pixel_values)
|
||||
|
||||
# Make modules available conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
logits_to_keep: int | torch.Tensor = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> tuple | PaliGemmaCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
||||
)
|
||||
|
||||
return PaliGemmaCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# position_ids in Paligemma are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self.model._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length),
|
||||
fill_value=min_dtype,
|
||||
dtype=dtype,
|
||||
device=cache_position.device,
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]
|
||||
@@ -0,0 +1,5 @@
|
||||
import transformers
|
||||
|
||||
|
||||
def check_whether_transformers_replace_is_installed_correctly():
|
||||
return transformers.__version__ == "4.53.2"
|
||||
+1264
File diff suppressed because it is too large
Load Diff
+136
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to load PI0OpenPI model from HuggingFace hub and run inference."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy
|
||||
|
||||
|
||||
def test_hub_loading():
|
||||
"""Test loading model from HuggingFace hub."""
|
||||
print("=" * 60)
|
||||
print("PI0OpenPI HuggingFace Hub Loading Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Model ID on HuggingFace hub
|
||||
model_id = "pepijn223/pi0_base_fp32" # We made sure this config matches our code and `PI0OpenPIConfig` by uploading a model with push_pi0_to_hub.py and copying that config.
|
||||
|
||||
print(f"\nLoading model from: {model_id}")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
# Load the model from HuggingFace hub with strict mode
|
||||
policy = PI0OpenPIPolicy.from_pretrained(
|
||||
model_id,
|
||||
strict=True, # Ensure all weights are loaded correctly
|
||||
)
|
||||
print("✓ Model loaded successfully from HuggingFace hub")
|
||||
|
||||
# Get model info
|
||||
print("\nModel configuration:")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Action horizon: {policy.config.action_horizon}")
|
||||
print(f" - Device: {next(policy.parameters()).device}")
|
||||
print(f" - Dtype: {next(policy.parameters()).dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to load model: {e}")
|
||||
return False
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing forward pass with loaded model...")
|
||||
|
||||
# Create dummy batch for testing
|
||||
batch_size = 1
|
||||
device = next(policy.parameters()).device
|
||||
|
||||
# Create dummy dataset stats if not loaded with the model
|
||||
if not hasattr(policy, "normalize_inputs") or policy.normalize_inputs is None:
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(policy.config.state_dim, device=device),
|
||||
"std": torch.ones(policy.config.state_dim, device=device),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(policy.config.action_dim, device=device),
|
||||
"std": torch.ones(policy.config.action_dim, device=device),
|
||||
},
|
||||
}
|
||||
policy.normalize_inputs = Normalize(
|
||||
policy.config.input_features, policy.config.normalization_mapping, dataset_stats
|
||||
)
|
||||
policy.normalize_targets = Normalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
||||
)
|
||||
policy.unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Create test batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(
|
||||
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
||||
),
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
policy.config.action_horizon,
|
||||
policy.config.action_dim,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
# Add images if they're in the config
|
||||
for key in policy.config.image_keys:
|
||||
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
||||
|
||||
try:
|
||||
# Test forward pass
|
||||
policy.train() # Set to training mode for forward pass with loss
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print("✓ Forward pass successful")
|
||||
print(f" - Loss: {loss_dict['loss']:.4f}")
|
||||
print(f" - Loss shape: {loss.shape if hasattr(loss, 'shape') else 'scalar'}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Forward pass failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("Testing inference with loaded model...")
|
||||
|
||||
try:
|
||||
# Test action prediction
|
||||
policy.eval() # Set to evaluation mode for inference
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print("✓ Action prediction successful")
|
||||
print(f" - Action shape: {action.shape}")
|
||||
print(f" - Action range: [{action.min().item():.3f}, {action.max().item():.3f}]")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ All tests passed!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_hub_loading()
|
||||
exit(0 if success else 1)
|
||||
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
|
||||
|
||||
|
||||
def test_policy_instantiation():
|
||||
"""Test basic policy instantiation."""
|
||||
print("Testing PI0OpenPI policy instantiation...")
|
||||
|
||||
# Create config
|
||||
config = PI0OpenPIConfig(action_dim=7, state_dim=14, device="cpu", dtype="float32")
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
print(f"Policy created successfully: {policy.name}")
|
||||
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = policy.device if hasattr(policy, "device") else "cpu"
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
print("\nTesting forward pass...")
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
except Exception as e:
|
||||
print(f"✗ Forward pass failed: {e}")
|
||||
return False
|
||||
|
||||
print("\nTesting action prediction...")
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
print(f"✓ Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"✗ Action prediction failed: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
print("\nTesting config creation through factory...")
|
||||
|
||||
try:
|
||||
config = make_policy_config(
|
||||
policy_type="pi0_openpi",
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
)
|
||||
print("✓ Config created successfully through factory")
|
||||
print(f" Config type: {type(config).__name__}")
|
||||
print(f" PaliGemma variant: {config.paligemma_variant}")
|
||||
print(f" Action expert variant: {config.action_expert_variant}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Config creation failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("PI0OpenPI Policy Integration Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Test config creation
|
||||
config_test = test_config_creation()
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
|
||||
# Test policy instantiation
|
||||
policy_test = test_policy_instantiation()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if config_test and policy_test:
|
||||
print("✓ All tests passed!")
|
||||
else:
|
||||
print("✗ Some tests failed.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user