From bd5c264c498857b218a78dff52f0c95895befc8f Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 10 Sep 2025 19:44:41 +0200 Subject: [PATCH] initial commit --- push_pi0_to_hub.py | 228 +++ src/lerobot/policies/factory.py | 7 + src/lerobot/policies/pi0_openpi/__init__.py | 20 + .../pi0_openpi/configuration_pi0openpi.py | 176 +++ .../policies/pi0_openpi/modeling_pi0openpi.py | 1087 ++++++++++++++ .../models/gemma/configuration_gemma.py | 173 +++ .../models/gemma/modeling_gemma.py | 876 ++++++++++++ .../models/paligemma/modeling_paligemma.py | 647 +++++++++ .../models/siglip/check.py | 5 + .../models/siglip/modeling_siglip.py | 1264 +++++++++++++++++ test_pi0_hub.py | 136 ++ test_pi0_openpi.py | 109 ++ 12 files changed, 4728 insertions(+) create mode 100644 push_pi0_to_hub.py create mode 100644 src/lerobot/policies/pi0_openpi/__init__.py create mode 100644 src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py create mode 100644 src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py create mode 100644 src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/configuration_gemma.py create mode 100644 src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py create mode 100644 src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py create mode 100644 src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py create mode 100644 src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py create mode 100644 test_pi0_hub.py create mode 100644 test_pi0_openpi.py diff --git a/push_pi0_to_hub.py b/push_pi0_to_hub.py new file mode 100644 index 000000000..f6b4b1b09 --- /dev/null +++ b/push_pi0_to_hub.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python + +"""Script to create and push a PI0OpenPI model to HuggingFace hub with proper config format.""" + +import tempfile +from pathlib import Path + +import torch +from huggingface_hub import HfApi, create_repo + +from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy + + +def create_and_push_model( + repo_id: str, + private: bool = False, + token: str = None, +): + """Create a PI0OpenPI model with proper config and push to HuggingFace hub. + + Args: + repo_id: HuggingFace repository ID (e.g., "username/model-name") + private: Whether to create a private repository + token: HuggingFace API token (optional, will use cached token if not provided) + """ + print("=" * 60) + print("PI0OpenPI Model Hub Upload") + print("=" * 60) + + # Create configuration + print("\nCreating PI0OpenPI configuration...") + config = PI0OpenPIConfig( + # Model architecture + paligemma_variant="gemma_2b", + action_expert_variant="gemma_300m", + pi05=False, # Use PI0 (not PI0.5) + dtype="float32", # Use float32 for compatibility + # Input/output dimensions + action_dim=32, # see openpi `Pi0Config` + state_dim=32, + action_horizon=50, + n_action_steps=50, + # Image inputs, see openpi `model.py, IMAGE_KEYS` + image_keys=( + "observation.images.base_0_rgb", + "observation.images.left_wrist_0_rgb", + "observation.images.right_wrist_0_rgb", + ), + # Training settings + gradient_checkpointing=False, + compile_model=False, + device=None, # Auto-detect + # Tokenizer settings + tokenizer_max_length=200, # see openpi `__post_init__`, use pi0=200 and pi05=48 + ) + + print(f" - Config type: {config.__class__.__name__}") + print(f" - PaliGemma variant: {config.paligemma_variant}") + print(f" - Action expert variant: {config.action_expert_variant}") + print(f" - Action dim: {config.action_dim}") + print(f" - State dim: {config.state_dim}") + + # Create dummy dataset stats for normalization + print("\nCreating dataset statistics...") + dataset_stats = { + "observation.state": { + "mean": torch.zeros(config.state_dim), + "std": torch.ones(config.state_dim), + "min": torch.full((config.state_dim,), -5.0), + "max": torch.full((config.state_dim,), 5.0), + }, + "action": { + "mean": torch.zeros(config.action_dim), + "std": torch.ones(config.action_dim), + "min": torch.full((config.action_dim,), -1.0), + "max": torch.full((config.action_dim,), 1.0), + }, + } + + # Add image stats + for key in config.image_keys: + dataset_stats[key] = { + "mean": torch.tensor([0.485, 0.456, 0.406]), # TODO(pepijn): fix this, now its ImageNet mean + "std": torch.tensor([0.229, 0.224, 0.225]), # TODO(pepijn): fix this, now its ImageNet std + "min": torch.tensor([0.0, 0.0, 0.0]), + "max": torch.tensor([1.0, 1.0, 1.0]), + } + + # Create the policy + print("\nInitializing PI0OpenPI policy...") + print(" (This may take a moment as it loads the tokenizer and initializes the model)") + policy = PI0OpenPIPolicy(config, dataset_stats) + + # Initialize with small random weights (optional - for testing) + # Note: In practice, you would load your trained weights here + print("\nInitializing model weights...") + for name, param in policy.named_parameters(): + if "weight" in name: + if "norm" in name.lower() or "layernorm" in name.lower(): + torch.nn.init.ones_(param) + elif len(param.shape) >= 2: + torch.nn.init.xavier_uniform_(param, gain=0.01) + else: + torch.nn.init.normal_(param, mean=0.0, std=0.01) + elif "bias" in name: + torch.nn.init.zeros_(param) + + print(f" - Total parameters: {sum(p.numel() for p in policy.parameters()):,}") + print(f" - Trainable parameters: {sum(p.numel() for p in policy.parameters() if p.requires_grad):,}") + + # Create temporary directory for saving + with tempfile.TemporaryDirectory() as tmpdir: + save_path = Path(tmpdir) / "model" + save_path.mkdir(exist_ok=True) + + print(f"\nSaving model to temporary directory: {save_path}") + + # Save the model using LeRobot's save_pretrained method + # This ensures the config is saved in the correct format + policy.save_pretrained(save_path) + + # List saved files + saved_files = list(save_path.glob("*")) + print("\nSaved files:") + for file in saved_files: + size = file.stat().st_size + print(f" - {file.name}: {size:,} bytes") + + # Create or get repository + print(f"\nCreating/accessing repository: {repo_id}") + api = HfApi(token=token) + + try: + # Create repo if it doesn't exist + create_repo( + repo_id, + private=private, + token=token, + exist_ok=True, + ) + print(f" ✓ Repository ready: https://huggingface.co/{repo_id}") + except Exception as e: + print(f" ⚠️ Note: {e}") + + # Upload to hub + print("\nUploading to HuggingFace hub...") + api.upload_folder( + folder_path=str(save_path), + repo_id=repo_id, + repo_type="model", + token=token, + commit_message="Upload PI0OpenPI model with proper LeRobot config format", + ) + + print(f"\n✓ Model successfully uploaded to: https://huggingface.co/{repo_id}") + + # Test loading the model back + print("\n" + "-" * 60) + print("Testing model loading from hub...") + + try: + loaded_policy = PI0OpenPIPolicy.from_pretrained( + repo_id, + token=token, + ) + print("✓ Model loaded successfully from hub") + + # Quick validation + batch_size = 1 + device = next(loaded_policy.parameters()).device + test_batch = { + "observation.state": torch.randn(batch_size, config.state_dim, device=device), + "action": torch.randn(batch_size, config.action_horizon, config.action_dim, device=device), + "task": ["Test task"], + } + + # Add images + for key in config.image_keys: + test_batch[key] = torch.rand(batch_size, 3, 224, 224, device=device) + + # Test forward pass + loaded_policy.train() + loss, loss_dict = loaded_policy.forward(test_batch) + print(f"✓ Forward pass successful - Loss: {loss_dict['loss']:.4f}") + + except Exception as e: + print(f"✗ Failed to load model: {e}") + import traceback + + traceback.print_exc() + + print("\n" + "=" * 60) + print("✓ Process complete!") + print("=" * 60) + + return policy + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Push PI0OpenPI model to HuggingFace hub") + parser.add_argument( + "--repo-id", + type=str, + default="test-user/pi0-openpi-test", + help="HuggingFace repository ID (e.g., 'username/model-name')", + ) + parser.add_argument( + "--private", + action="store_true", + help="Create a private repository", + ) + parser.add_argument( + "--token", + type=str, + default=None, + help="HuggingFace API token (optional, uses cached token if not provided)", + ) + + args = parser.parse_args() + + # Run the upload + create_and_push_model( + repo_id=args.repo_id, + private=args.private, + token=args.token, + ) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ef56bdb61..a19ee4737 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -27,6 +27,7 @@ from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig @@ -62,6 +63,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "pi0_openpi": + from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy + + return PI0OpenPIPolicy elif name == "sac": from lerobot.policies.sac.modeling_sac import SACPolicy @@ -91,6 +96,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) + elif policy_type == "pi0_openpi": + return PI0OpenPIConfig(**kwargs) elif policy_type == "sac": return SACConfig(**kwargs) elif policy_type == "smolvla": diff --git a/src/lerobot/policies/pi0_openpi/__init__.py b/src/lerobot/policies/pi0_openpi/__init__.py new file mode 100644 index 000000000..a8d916f96 --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_pi0openpi import PI0OpenPIConfig +from .modeling_pi0openpi import PI0OpenPIPolicy + +__all__ = ["PI0OpenPIConfig", "PI0OpenPIPolicy"] diff --git a/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py b/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py new file mode 100644 index 000000000..4d4d70071 --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + +# ### ⚠️ WARNING ⚠️ ### +# This project requires patching the Hugging Face `transformers` library. +# +# 1. Make sure you have the exact version installed: +# pip show transformers +# It must be version 4.53.2. +# +# 2. Apply the custom patches by copying the modified files into your conda environment (make sure your environment is activated!) +# cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))") +# +# These patches overwrite parts of `transformers` to: +# (a) support AdaRMS optimizer, +# (b) correctly control the precision of activations, +# (c) allow the KV cache to be used without updates. +# +# IMPORTANT: +# - This permanently modifies the `transformers` installation in your conda environment. +# - The changes will survive reinstalls of `transformers` unless you explicitly remove +# the patched files or recreate the environment. +# +# To undo the operation and restore a clean state, run: +# pip uninstall transformers +# pip install transformers==4.53.2 + + +@PreTrainedConfig.register_subclass("pi0_openpi") +@dataclass +class PI0OpenPIConfig(PreTrainedConfig): + # Model architecture + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + pi05: bool = False # Whether to use PI0.5 variant with AdaRMS + dtype: str = "float32" # Options: "bfloat16", "float32" + + # Input / output structure + n_obs_steps: int = 1 + action_horizon: int = 50 # Number of action steps to predict + n_action_steps: int = 50 # Number of action steps to execute + action_dim: int = 32 # Action dimension (will be padded to 32) + state_dim: int = 32 # State dimension (will be padded to 32) + + # Flow matching parameters: see openpi `PI0Pytorch` + num_inference_steps: int = 10 # Number of denoising steps during inference + time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling + time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling + min_period: float = 4e-3 # Min period for sinusoidal positional encoding + max_period: float = 4.0 # Max period for sinusoidal positional encodingis my + + # Image preprocessing + image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py` + image_keys: tuple[str, ...] = ( + "observation.images.base_0_rgb", + "observation.images.left_wrist_0_rgb", + "observation.images.right_wrist_0_rgb", + ) + + # Normalization + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, # Images are normalized to [-1, 1] in preprocessing + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) + + # Optimizer settings: see openpi `AdamW` and + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + # Scheduler settings: see openpi `CosineDecaySchedule` + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + tokenizer_max_length: int = 200 # pi0=200 and pi05=48, see openpi `__post_init__` + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.n_action_steps > self.action_horizon: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})" + ) + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + # Add image features + for key in self.image_keys: + if key not in self.input_features: + self.input_features[key] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), # Default shape, will be resized + ) + + # Ensure state and action features exist + if "observation.state" not in self.input_features: + self.input_features["observation.state"] = PolicyFeature( + type=FeatureType.STATE, + shape=(self.state_dim,), + ) + + if "action" not in self.output_features: + self.output_features["action"] = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.action_dim,), + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py new file mode 100644 index 000000000..6f1ed5173 --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -0,0 +1,1087 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +from collections import deque +from typing import Literal + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from transformers import AutoTokenizer, GemmaForCausalLM, PaliGemmaForConditionalGeneration +from transformers.models.auto import CONFIG_MAPPING +from transformers.models.gemma import modeling_gemma + +from lerobot.constants import ACTION, OBS_STATE +from lerobot.policies.normalize import Normalize, Unnormalize +from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig +from lerobot.policies.pretrained import PreTrainedPolicy + + +# Helper functions +def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` + """Get a safe dtype for the given device type.""" + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + if batch_size == 1 and images.shape[0] == 1: + padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added + + return padded_images + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PaliGemmaWithExpertModel(nn.Module): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` + """PaliGemma model with action expert for PI0.""" + + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Define the complete layer computation function for gradient checkpointing + def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond): + models = [self.paligemma.language_model, self.gemma_expert.model] + + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + + batch_size = query_states.shape[0] + scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling + + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + self.paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + + return outputs_embeds + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond + ) + + # final norm + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI0 PyTorch model.""" + + def __init__(self, config: PI0OpenPIConfig): + super().__init__() + self.config = config + self.pi05 = config.pi05 + + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True] if self.pi05 else [False, False], + precision=config.dtype, + ) + + self.action_in_proj = nn.Linear(32, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, 32) + + if self.pi05: + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + else: + self.state_proj = nn.Linear(32, action_expert_config.width) + self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) + self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38) + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta( + self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device + ) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) + embs.append(lang_emb) + pad_masks.append(lang_masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, state, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + if not self.pi05: + if self.state_proj.weight.dtype == torch.float32: + state = state.to(torch.float32) + + def state_proj_func(state): + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) + embs.append(state_emb[:, None, :]) + bsize = state_emb.shape[0] + device = state_emb.device + + state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + att_masks += [1] + + # Embed timestep using sine-cosine positional encoding + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.action_in_proj.out_features, + min_period=self.config.min_period, + max_period=self.config.max_period, + device=timestep.device, + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + if not self.pi05: + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + def mlp_func(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) + return self.action_time_mlp_out(x) + + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) + adarms_cond = None + else: + + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + embs.append(action_time_emb) + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + att_masks += [1] + ([0] * (self.config.action_horizon - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward( + self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None + ) -> Tensor: + """Do a full training forward pass and compute the loss.""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) + + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() # see openpi `sample_actions` (slightly adapted) + def sample_actions( + self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None + ) -> Tensor: + """Do a full inference forward and compute the action.""" + if num_steps is None: + num_steps = self.config.num_inference_steps + + bsize = state.shape[0] + device = state.device + + if noise is None: + # Sample noise with padded dimension (32) as expected by action_in_proj + actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + x_t = x_t + dt * v_t + time += dt + + # Truncate to actual action dimension before returning + if self.config.action_dim < 32: + x_t = x_t[:, :, : self.config.action_dim] + + return x_t + + def denoise_step( + self, + state, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) + + +class PI0OpenPIPolicy(PreTrainedPolicy): + """PI0 OpenPI Policy for LeRobot.""" + + config_class = PI0OpenPIConfig + name = "pi0_openpi" + + def __init__( # see lerobot pi0 `__init__` + self, + config: PI0OpenPIConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance. + dataset_stats: Dataset statistics to be used for normalization. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + # Create tokenizer for language input + self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + + # Set max token length for tokenizer (from OpenPI) + self.max_token_len = config.tokenizer_max_length + + # Initialize the core PI0 model + self.model = PI0Pytorch(config) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.reset() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" + " This implementation follows the original OpenPI structure for compatibility. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + + def get_optim_params(self) -> dict: # see lerobot pi0 `get_optim_params` + return self.parameters() + + def reset(self): # see lerobot pi0 `reset` + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def _preprocess_images( + self, batch: dict[str, Tensor] + ) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images` + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + for key in self.config.image_keys: + if key in batch: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # Check if image is in [B, C, H, W] format (channels first) + if img.dim() == 4 and img.shape[1] in [1, 3]: # Grayscale or RGB + # Already in correct format + pass + elif img.dim() == 4 and img.shape[-1] in [1, 3]: # [B, H, W, C] format + # Convert to [B, C, H, W] + img = img.permute(0, 3, 1, 2) + else: + raise ValueError(f"Unexpected image shape {img.shape} for key {key}") + + # Resize with padding if needed + if img.shape[-2:] != self.config.image_resolution: + # resize_with_pad_torch handles both [B, C, H, W] and [B, H, W, C] formats + # But we need to ensure we pass it in the right format + img = resize_with_pad_torch( + img.permute(0, 2, 3, 1), # Convert to [B, H, W, C] for resize function + *self.config.image_resolution, + ).permute(0, 3, 1, 2) # Convert back to [B, C, H, W] + + # Normalize from [0, 1] to [-1, 1] for SigLIP/PaliGemma + # Check if normalization is needed + if img.min() >= 0 and img.max() <= 1: + img = img * 2.0 - 1.0 + elif img.min() >= -1 and img.max() <= 1: + # Already normalized to [-1, 1] + pass + else: + # Assume it's in [0, 255] range and normalize + img = (img / 255.0) * 2.0 - 1.0 + + images.append(img) + # Create mask (all ones for real images) + img_masks.append(torch.ones(img.shape[0], dtype=torch.bool, device=device)) + + return images, img_masks + + def _tokenize_language( + self, batch: dict[str, Tensor] + ) -> tuple[Tensor, Tensor]: # see lerobot pi0 `prepare_language` + """Tokenize language input using PaliGemma tokenizer.""" + device = next(self.parameters()).device + + # Get task description + if "task" in batch: + tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + elif isinstance(tasks, list) and len(tasks) == 1: + # Expand to batch size + batch_size = batch[next(iter(batch.keys()))].shape[0] + tasks = tasks * batch_size + else: + # Default task if not provided + batch_size = batch[next(iter(batch.keys()))].shape[0] + tasks = ["Pick up the object"] * batch_size + + # Tokenize with max_length padding to match OpenPI's expected format + tokenized = self.tokenizer( + tasks, + padding="max_length", # Use max_length padding as per OpenPI + padding_side="right", # from lerobot pi0 `prepare_language` + truncation=True, + max_length=self.max_token_len, # Use the max token length from config + return_tensors="pt", + ) + + lang_tokens = tokenized["input_ids"].to(device) + lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) + + return lang_tokens, lang_masks + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action` + """Select a single action given environment observations.""" + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: # see lerobot pi0 `select_action` + """Predict a chunk of actions given environment observations.""" + self.eval() + + batch = self.normalize_inputs(batch) + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + lang_tokens, lang_masks = self._tokenize_language(batch) + state = batch[OBS_STATE] + + # Validate state dimension + if state.shape[-1] > 32: + raise ValueError( + f"State dimension {state.shape[-1]} exceeds maximum of 32. " + f"Please reduce state dimension or modify the model." + ) + + # Pad state to 32 dimensions if needed (PI0 expects fixed 32-dim); works similar to lerobot pi0 `prepare_state` + if state.shape[-1] < 32: + padding = torch.zeros( + state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype + ) + state = torch.cat([state, padding], dim=-1) + + # Sample actions using the model + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) + + # Truncate to actual action dimension, works similar to lerobot pi0 `prepare_action` + if self.config.action_dim < 32: + actions = actions[:, :, : self.config.action_dim] + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # see lerobot pi0 `forward` + """Run the batch through the model and compute the loss for training.""" + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + lang_tokens, lang_masks = self._tokenize_language(batch) + state = batch[OBS_STATE] + actions = batch[ACTION] + + # Validate state and action dimensions + if state.shape[-1] > 32: + raise ValueError( + f"State dimension {state.shape[-1]} exceeds maximum of 32. " + f"Please reduce state dimension or modify the model." + ) + if actions.shape[-1] > 32: + raise ValueError( + f"Action dimension {actions.shape[-1]} exceeds maximum of 32. " + f"Please reduce action dimension or modify the model." + ) + + # Pad state and actions to 32 dimensions if needed (PI0 expects fixed 32-dim) + if state.shape[-1] < 32: + padding = torch.zeros( + state.shape[0], 32 - state.shape[-1], device=state.device, dtype=state.dtype + ) + state = torch.cat([state, padding], dim=-1) + + if actions.shape[-1] < 32: + padding = torch.zeros( + *actions.shape[:-1], 32 - actions.shape[-1], device=actions.device, dtype=actions.dtype + ) + actions = torch.cat([actions, padding], dim=-1) + + # Compute loss + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) + + # Truncate losses to actual action dimensions + if self.config.action_dim < 32: + losses = losses[:, :, : self.config.action_dim] + + loss = losses.mean() + + loss_dict = { + "loss": loss.item(), + "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + } + + return loss, loss_dict diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/configuration_gemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/configuration_gemma.py new file mode 100644 index 000000000..72eb2a36c --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/configuration_gemma.py @@ -0,0 +1,173 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig + + +class GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_adarms (`bool`, *optional*, defaults to `False`): + Whether to use ADARMS. + adarms_cond_dim (`int`, *optional*, defaults to `None`): + The dimension of the ADARMS condition. + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation=None, + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + use_adarms: bool = False, + adarms_cond_dim: int | None = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_adarms = use_adarms + self.adarms_cond_dim = adarms_cond_dim + + # Set default for adarms_cond_dim if use_adarms is True + if self.use_adarms and self.adarms_cond_dim is None: + self.adarms_cond_dim = self.hidden_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["GemmaConfig"] diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py new file mode 100644 index 000000000..c1b277abe --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -0,0 +1,876 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from .configuration_gemma import GemmaConfig + +logger = logging.get_logger(__name__) + + +class GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): + super().__init__() + self.eps = eps + self.dim = dim + self.cond_dim = cond_dim + + # Dense layer for adaptive normalization (if cond_dim is provided) + if cond_dim is not None: + # self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16) + self.dense = nn.Linear(cond_dim, dim * 3, bias=True) + # Initialize with zeros (matches source implementation) + nn.init.zeros_(self.dense.weight) + else: + self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16)) + self.dense = None + + def _norm(self, x): + # Compute variance in float32 (like the source implementation) + var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) + # Compute normalization in float32 + normed_inputs = x * torch.rsqrt(var + self.eps) + return normed_inputs + + def forward(self, x, cond=None): + dtype = x.dtype # original dtype, could be half-precision + normed_inputs = self._norm(x) + + if cond is None or self.dense is None: + # regular RMSNorm + # scale by learned parameter in float32 (matches source implementation) + normed_inputs = normed_inputs * (1.0 + self.weight.float()) + return normed_inputs.to(dtype), None # return in original dtype with None gate + + # adaptive RMSNorm (if cond is provided and dense layer exists) + if cond.shape[-1] != self.cond_dim: + raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}") + + # self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32) + modulation = self.dense(cond) + # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features] + if len(x.shape) == 3: # [batch, seq, features] + modulation = modulation.unsqueeze(1) + + scale, shift, gate = torch.chunk(modulation, 3, dim=-1) + + # Apply adaptive normalization: use model weight dtype to ensure compatibility + # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16) + # scale = scale.to(model_dtype) + # shift = shift.to(model_dtype) + # gate = gate.to(model_dtype) + # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype + + normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32) + + return normed_inputs.to(dtype), gate.to(dtype) + + def extra_repr(self): + repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}" + if self.dense is not None: + repr_str += f", adaptive=True, cond_dim={self.cond_dim}" + return repr_str + + +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class GemmaRotaryEmbedding(nn.Module): + def __init__(self, config: GemmaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _gated_residual(x, y, gate): + """ + Applies gated residual connection with optional gate parameter. + + Args: + x: Input tensor (residual) + y: Output tensor to be added + gate: Optional gate tensor to modulate the addition + + Returns: + x + y if gate is None, otherwise x + y * gate + """ + if x is None and y is None: + return None + if x is None or y is None: + return x if x is not None else y + if gate is None: + return x + y + return x + y * gate + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class GemmaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_value: Cache | None = None, + cache_position: torch.LongTensor | None = None, + use_cache: bool = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Use cache if provided + if past_key_value is not None: + if use_cache: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2) + value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class GemmaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + + self.mlp = GemmaMLP(config) + cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: None + | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC + adarms_cond: torch.Tensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = _gated_residual(residual, hidden_states, gate) + + # Fully Connected + residual = hidden_states + hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond) + hidden_states = self.mlp(hidden_states) + hidden_states = _gated_residual(residual, hidden_states, gate) + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class GemmaPreTrainedModel(PreTrainedModel): + config_class = GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GemmaRMSNorm): + if hasattr(module, "weight"): + module.weight.data.fill_(1.0) + + +@auto_docstring +class GemmaModel(GemmaPreTrainedModel): + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + cond_dim = getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim) + self.rotary_emb = GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + adarms_cond: torch.Tensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + """ + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + # embed positions + hidden_states = inputs_embeds + # Convert to bfloat16 if the first layer uses bfloat16 + if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.bfloat16) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841 + # hidden_states = hidden_states * normalizer + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states, _ = self.norm(hidden_states, adarms_cond) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring +class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = GemmaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + adarms_cond: torch.Tensor | None = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The Gemma Model transformer with a sequence classification head on top (linear layer). + + [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class GemmaForSequenceClassification(GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = GemmaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + adarms_cond: torch.Tensor | None = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + adarms_cond=adarms_cond, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config + ) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring +class GemmaForTokenClassification(GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = GemmaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + adarms_cond: torch.Tensor | None = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + adarms_cond=adarms_cond, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "GemmaModel", + "GemmaForCausalLM", + "GemmaForSequenceClassification", + "GemmaForTokenClassification", + "GemmaPreTrainedModel", +] diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py new file mode 100644 index 000000000..e5ef5567e --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py @@ -0,0 +1,647 @@ +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PaliGemmamodel.""" + +from dataclasses import dataclass + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache, HybridCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + ModelOutput, + auto_docstring, + can_return_tuple, + is_torchdynamo_compiling, + logging, +) +from ..auto import AutoModel +from .configuration_paligemma import PaliGemmaConfig + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Paligemma outputs, with hidden states and attentions. + """ +) +class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for PaliGemma causal language model (or autoregressive) outputs. + """ +) +class PaliGemmaCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder after projecting last hidden state. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: list[torch.FloatTensor] | Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None + + +class PaliGemmaMultiModalProjector(nn.Module): + def __init__(self, config: PaliGemmaConfig): + super().__init__() + self.linear = nn.Linear( + config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True + ) + + def forward(self, image_features): + hidden_states = self.linear(image_features) + + return hidden_states + + +@auto_docstring +class PaliGemmaPreTrainedModel(PreTrainedModel): + config_class = PaliGemmaConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["PaliGemmaMultiModalProjector"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + # important: this ported version of PaliGemmaisn't meant for training from scratch - only + # inference and fine-tuning + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., + """ +) +class PaliGemmaModel(PaliGemmaPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch + accepts_loss_kwargs = False + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModel.from_config(config=config.text_config) + self.language_model = language_model + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def _update_causal_mask( + self, + attention_mask, + token_type_ids=None, + past_key_values=None, + cache_position=None, + input_tensor=None, + is_training: bool | None = None, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + is_training = is_training if is_training is not None else self.training + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(self.dtype).min + if input_tensor is None: + input_tensor = attention_mask + + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=self.dtype, + device=cache_position.device, + ) + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + if is_training: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + causal_mask[:, :sequence_length] = 0.0 + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # First unmask prefix tokens during training + if is_training: + if token_type_ids is None: + raise ValueError("Token type ids must be provided during training") + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + return image_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | Cache | None = None, + token_type_ids: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple | PaligemmaModelOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id worth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if ( + not is_torchdynamo_compiling() + and inputs_embeds[special_image_mask].numel() != image_features.numel() + ): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return PaligemmaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring( + custom_intro=""" + The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., + """ +) +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: PaliGemmaConfig): + super().__init__(config) + self.model = PaliGemmaModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features(self, pixel_values): + return self.model.get_image_features(pixel_values) + + # Make modules available conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | Cache | None = None, + token_type_ids: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> tuple | PaliGemmaCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224") + >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224") + + >>> prompt = "Where is the cat standing?" + >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs,) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Where is the cat standing?\nsnow" + ```""" + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return PaliGemmaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + pixel_values=None, + attention_mask=None, + token_type_ids=None, + use_cache=True, + logits_to_keep=None, + labels=None, + **kwargs, + ): + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self.model._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=cache_position.device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"] diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py new file mode 100644 index 000000000..d899dc1b9 --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py @@ -0,0 +1,5 @@ +import transformers + + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py new file mode 100644 index 000000000..98b280977 --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py @@ -0,0 +1,1264 @@ +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Siglip model.""" + +import math +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int +from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + +logger = logging.get_logger(__name__) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa: E741 + u = norm_cdf((b - mean) / std) # noqa: E741 + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + r""" + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + """ + + image_embeds: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model's outputs that also contains a pooling of the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + r""" + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + """ + + text_embeds: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +@auto_docstring +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: torch.FloatTensor | None = None + logits_per_image: torch.FloatTensor | None = None + logits_per_text: torch.FloatTensor | None = None + text_embeds: torch.FloatTensor | None = None + image_embeds: torch.FloatTensor | None = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False + ) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(GradientCheckpointingLayer): + def __init__(self, config: SiglipVisionConfig | SiglipTextConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = SiglipAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool | None = False, + ) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +@auto_docstring +class SiglipPreTrainedModel(PreTrainedModel): + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipEncoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, SiglipForImageClassification): + nn.init.normal_( + module.classifier.weight, + std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor, + ) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + @can_return_tuple + def forward( + self, + inputs_embeds, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> BaseModelOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> BaseModelOutputWithPooling: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The text model from SigLIP without any head or projection on top. + """ +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool | None = False, + ) -> BaseModelOutputWithPooling: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + # Convert to bfloat16 if the encoder uses bfloat16 + if ( + len(self.encoder.layers) > 0 + and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 + ): + hidden_states = hidden_states.to(torch.bfloat16) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@auto_docstring( + custom_intro=""" + The vision model from SigLIP without any head or projection on top. + """ +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool = False, + ) -> BaseModelOutputWithPooling: + r""" + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + +@auto_docstring +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise TypeError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention implementation + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = text_outputs.pooler_output + + return pooled_output + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs.pooler_output + + return pooled_output + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + return_loss: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool = False, + ) -> SiglipOutput: + r""" + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) + + logit_scale, logit_bias = ( + self.logit_scale.to(text_embeds.device), + self.logit_bias.to(text_embeds.device), + ) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@auto_docstring( + custom_intro=""" + SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of + the patch tokens) e.g. for ImageNet. + """ +) +class SiglipForImageClassification(SiglipPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: SiglipConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + + # Create the vision model with proper attention + # and take only vision_model submodule (for backward compatibility) + vision_model = SiglipVisionModel._from_config(config.vision_config) + self.vision_model = vision_model.vision_model + + # Classifier head + self.classifier = ( + nn.Linear(config.vision_config.hidden_size, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool = False, + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, SiglipForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a `SiglipModel` from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above. + >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the two classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + Predicted class: LABEL_1 + ```""" + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + sequence_output = outputs.last_hidden_state + + # average pool the patch tokens + sequence_output = torch.mean(sequence_output, dim=1) + # apply classifier + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "SiglipModel", + "SiglipPreTrainedModel", + "SiglipTextModel", + "SiglipVisionModel", + "SiglipForImageClassification", +] diff --git a/test_pi0_hub.py b/test_pi0_hub.py new file mode 100644 index 000000000..f4729fbb0 --- /dev/null +++ b/test_pi0_hub.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python + +"""Test script to load PI0OpenPI model from HuggingFace hub and run inference.""" + +import torch + +from lerobot.policies.pi0_openpi import PI0OpenPIPolicy + + +def test_hub_loading(): + """Test loading model from HuggingFace hub.""" + print("=" * 60) + print("PI0OpenPI HuggingFace Hub Loading Test") + print("=" * 60) + + # Model ID on HuggingFace hub + model_id = "pepijn223/pi0_base_fp32" # We made sure this config matches our code and `PI0OpenPIConfig` by uploading a model with push_pi0_to_hub.py and copying that config. + + print(f"\nLoading model from: {model_id}") + print("-" * 60) + + try: + # Load the model from HuggingFace hub with strict mode + policy = PI0OpenPIPolicy.from_pretrained( + model_id, + strict=True, # Ensure all weights are loaded correctly + ) + print("✓ Model loaded successfully from HuggingFace hub") + + # Get model info + print("\nModel configuration:") + print(f" - PaliGemma variant: {policy.config.paligemma_variant}") + print(f" - Action expert variant: {policy.config.action_expert_variant}") + print(f" - Action dimension: {policy.config.action_dim}") + print(f" - State dimension: {policy.config.state_dim}") + print(f" - Action horizon: {policy.config.action_horizon}") + print(f" - Device: {next(policy.parameters()).device}") + print(f" - Dtype: {next(policy.parameters()).dtype}") + + except Exception as e: + print(f"✗ Failed to load model: {e}") + return False + + print("\n" + "-" * 60) + print("Testing forward pass with loaded model...") + + # Create dummy batch for testing + batch_size = 1 + device = next(policy.parameters()).device + + # Create dummy dataset stats if not loaded with the model + if not hasattr(policy, "normalize_inputs") or policy.normalize_inputs is None: + from lerobot.policies.normalize import Normalize, Unnormalize + + dataset_stats = { + "observation.state": { + "mean": torch.zeros(policy.config.state_dim, device=device), + "std": torch.ones(policy.config.state_dim, device=device), + }, + "action": { + "mean": torch.zeros(policy.config.action_dim, device=device), + "std": torch.ones(policy.config.action_dim, device=device), + }, + } + policy.normalize_inputs = Normalize( + policy.config.input_features, policy.config.normalization_mapping, dataset_stats + ) + policy.normalize_targets = Normalize( + policy.config.output_features, policy.config.normalization_mapping, dataset_stats + ) + policy.unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, dataset_stats + ) + + # Create test batch + batch = { + "observation.state": torch.randn( + batch_size, policy.config.state_dim, dtype=torch.float32, device=device + ), + "action": torch.randn( + batch_size, + policy.config.action_horizon, + policy.config.action_dim, + dtype=torch.float32, + device=device, + ), + "task": ["Pick up the object"] * batch_size, + } + + # Add images if they're in the config + for key in policy.config.image_keys: + batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device) + + try: + # Test forward pass + policy.train() # Set to training mode for forward pass with loss + loss, loss_dict = policy.forward(batch) + print("✓ Forward pass successful") + print(f" - Loss: {loss_dict['loss']:.4f}") + print(f" - Loss shape: {loss.shape if hasattr(loss, 'shape') else 'scalar'}") + + except Exception as e: + print(f"✗ Forward pass failed: {e}") + import traceback + + traceback.print_exc() + return False + + print("\n" + "-" * 60) + print("Testing inference with loaded model...") + + try: + # Test action prediction + policy.eval() # Set to evaluation mode for inference + with torch.no_grad(): + action = policy.select_action(batch) + print("✓ Action prediction successful") + print(f" - Action shape: {action.shape}") + print(f" - Action range: [{action.min().item():.3f}, {action.max().item():.3f}]") + + except Exception as e: + print(f"✗ Action prediction failed: {e}") + import traceback + + traceback.print_exc() + return False + + print("\n" + "=" * 60) + print("✓ All tests passed!") + print("=" * 60) + return True + + +if __name__ == "__main__": + success = test_hub_loading() + exit(0 if success else 1) diff --git a/test_pi0_openpi.py b/test_pi0_openpi.py new file mode 100644 index 000000000..d7e5ffc0b --- /dev/null +++ b/test_pi0_openpi.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +"""Test script to verify PI0OpenPI policy integration with LeRobot.""" + +import torch + +from lerobot.policies.factory import make_policy_config +from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy + + +def test_policy_instantiation(): + """Test basic policy instantiation.""" + print("Testing PI0OpenPI policy instantiation...") + + # Create config + config = PI0OpenPIConfig(action_dim=7, state_dim=14, device="cpu", dtype="float32") + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + } + + # Instantiate policy + policy = PI0OpenPIPolicy(config, dataset_stats) + print(f"Policy created successfully: {policy.name}") + + # Test forward pass with dummy data + batch_size = 1 + device = policy.device if hasattr(policy, "device") else "cpu" + batch = { + "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + + print("\nTesting forward pass...") + try: + loss, loss_dict = policy.forward(batch) + print(f"✓ Forward pass successful. Loss: {loss_dict['loss']:.4f}") + except Exception as e: + print(f"✗ Forward pass failed: {e}") + return False + + print("\nTesting action prediction...") + try: + with torch.no_grad(): + action = policy.select_action(batch) + print(f"✓ Action prediction successful. Action shape: {action.shape}") + except Exception as e: + print(f"✗ Action prediction failed: {e}") + return False + + return True + + +def test_config_creation(): + """Test policy config creation through factory.""" + print("\nTesting config creation through factory...") + + try: + config = make_policy_config( + policy_type="pi0_openpi", + action_dim=7, + state_dim=14, + ) + print("✓ Config created successfully through factory") + print(f" Config type: {type(config).__name__}") + print(f" PaliGemma variant: {config.paligemma_variant}") + print(f" Action expert variant: {config.action_expert_variant}") + return True + except Exception as e: + print(f"✗ Config creation failed: {e}") + return False + + +def main(): + """Run all tests.""" + print("=" * 60) + print("PI0OpenPI Policy Integration Test") + print("=" * 60) + + # Test config creation + config_test = test_config_creation() + + print("\n" + "-" * 60) + + # Test policy instantiation + policy_test = test_policy_instantiation() + + print("\n" + "=" * 60) + if config_test and policy_test: + print("✓ All tests passed!") + else: + print("✗ Some tests failed.") + print("=" * 60) + + +if __name__ == "__main__": + main()