diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 562da14f4..c5ad8dfd8 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -39,6 +39,8 @@
title: π₀.₅ (Pi05)
- local: groot
title: NVIDIA GR00T N1.5
+ - local: xvla
+ title: X-VLA
title: "Policies"
- sections:
- local: async
diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx
index 14f51ef3b..3617f3b25 100644
--- a/docs/source/libero.mdx
+++ b/docs/source/libero.mdx
@@ -62,6 +62,11 @@ lerobot-eval \
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
+### Control Mode
+
+LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
+You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
+
### Policy inputs and outputs
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
diff --git a/docs/source/xvla.mdx b/docs/source/xvla.mdx
new file mode 100644
index 000000000..aa974477f
--- /dev/null
+++ b/docs/source/xvla.mdx
@@ -0,0 +1,570 @@
+# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
+
+## Overview
+
+For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
+
+But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
+
+Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
+
+**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
+
+
+
+
+
+Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
+
+
+
+
+
+With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
+
+
+
+
+
+X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
+
+## Installation
+
+After installing LeRobot, install the X-VLA dependencies:
+
+```bash
+pip install -e .[xvla]
+```
+
+After the new release, you'll be able to do:
+
+```bash
+pip install lerobot[xvla]
+```
+
+## Quick Start
+
+### Basic Usage
+
+To use X-VLA in your LeRobot configuration, specify the policy type as:
+
+```bash
+policy.type=xvla
+```
+
+### Evaluating Pre-trained Checkpoints
+
+Example evaluation with LIBERO:
+
+```bash
+lerobot-eval \
+ --policy.path="lerobot/xvla-libero" \
+ --env.type=libero \
+ --env.task=libero_spatial,libero_goal,libero_10 \
+ --env.control_mode=absolute \
+ --eval.batch_size=1 \
+ --eval.n_episodes=1 \
+ --env.episode_length=800 \
+ --seed=142
+```
+
+## Available Checkpoints
+
+### 🎯 Base Model
+
+**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
+
+A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
+
+- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
+
+- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
+
+### Simulation Checkpoints
+
+**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
+
+Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
+
+**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
+
+Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
+
+### 🤖 Real-World Checkpoints
+
+**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
+
+A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
+
+**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
+
+Optimized for AgileX robot dexterous manipulation tasks.
+
+**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
+
+Adapted for Google Robot platforms.
+
+## Training X-VLA
+
+### Recommended Training Configuration
+
+When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy:
+
+```bash
+lerobot-train \
+ --dataset.repo_id=YOUR_DATASET \
+ --output_dir=./outputs/xvla_training \
+ --job_name=xvla_training \
+ --policy.path="lerobot/xvla-base" \
+ --policy.repo_id="HF_USER/xvla-your-robot" \
+ --steps=3000 \
+ --policy.device=cuda \
+ --policy.freeze_vision_encoder=True \
+ --policy.freeze_language_encoder=True \
+ --policy.train_policy_transformer=True \
+ --policy.train_soft_prompts=True \
+ --policy.action_mode=YOUR_ACTION_MODE
+```
+
+### Training Parameters Explained
+
+| Parameter | Default | Description |
+| -------------------------- | ------- | ---------------------------------------- |
+| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights |
+| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights |
+| `train_policy_transformer` | `True` | Allow policy transformer layers to train |
+| `train_soft_prompts` | `True` | Allow soft prompts to train |
+
+**💡 Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute.
+
+### Example: Training on Bimanual Robot
+
+```bash
+lerobot-train \
+ --dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
+ --output_dir=./outputs/xvla_bimanual \
+ --job_name=xvla_so101_training \
+ --policy.path="lerobot/xvla-base" \
+ --policy.repo_id="YOUR_USERNAME/xvla-biso101" \
+ --steps=3000 \
+ --policy.device=cuda \
+ --policy.action_mode=so101_bimanual \
+ --policy.freeze_vision_encoder=True \
+ --policy.freeze_language_encoder=True \
+ --policy.train_policy_transformer=True \
+ --policy.train_soft_prompts=True
+```
+
+💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
+
+**🔥 Full-finetune all components with a custom learning-rate scheme**
+
+To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
+This LR ratio is crucial for achieving strong and stable finetuning performance.
+To enable this behavior, you must:
+
+1. Implement a custom optimizer and register it in your training config
+
+```
+from dataclasses import dataclass, asdict
+from lerobot.optim.optimizers import OptimizerConfig
+import torch
+
+@OptimizerConfig.register_subclass("xvla-adamw")
+@dataclass
+class XVLAAdamW(OptimizerConfig):
+ lr: float = 1e-4
+ betas: tuple[float, float] = (0.9, 0.99)
+ eps: float = 1e-8
+ weight_decay: float = 0.0
+ grad_clip_norm: float = 10.0
+
+ def build(self, params: dict) -> torch.optim.Optimizer:
+ """
+ Expect `named_parameters()` as input.
+ Apply lr = lr / 10 for all VLM-related parameters.
+ """
+ assert isinstance(params, dict), \
+ "Custom LR optimizer requires `named_parameters()` as inputs."
+ kwargs = asdict(self)
+ kwargs.pop("grad_clip_norm")
+ vlm_group, other_group = [], []
+ for name, p in params.items():
+ if not p.requires_grad:
+ continue
+ if "vlm" in name.lower():
+ vlm_group.append(p)
+ else:
+ other_group.append(p)
+
+ param_groups = [
+ {"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1},
+ {"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay},
+ ]
+
+ return torch.optim.AdamW(param_groups, **kwargs)
+```
+
+2. Modify X-VLA’s get_optim_params to return named parameters
+
+Replace:
+
+```
+def get_optim_params(self) -> dict:
+ """Return only trainable parameters for optimization."""
+ return filter(lambda p: p.requires_grad, self.parameters())
+```
+
+with:
+
+```
+def get_optim_params(self):
+ """Return trainable named parameters."""
+ return filter(lambda kv: kv[1].requires_grad, self.named_parameters())
+```
+
+This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule.
+
+❕Note
+
+Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
+We encourage implementing this in your customized training pipeline for optimal results.
+
+## Core Concepts
+
+### 1. Action Modes
+
+X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
+
+#### Available Action Modes
+
+| Action Mode | Action Dim | Description | Use Case |
+| ---------------- | ----------------------- | ------------------------------------------- | ------------------------------------ |
+| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
+| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
+| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
+| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
+| `auto` | 20 (model), auto (real) | Auto-detects action dim from dataset | **Recommended** for new robots |
+
+#### Why Action Modes Matter
+
+When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
+
+1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
+2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
+3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
+
+#### Example: BimanualSO101 Action Space
+
+The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
+
+```python
+# Model outputs 20 dimensions for compatibility
+dim_action = 20
+
+# Real robot only needs 12 dimensions
+# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
+REAL_DIM = 12
+
+# Preprocessing: Pad 12D actions to 20D for training
+# Postprocessing: Trim 20D predictions to 12D for deployment
+```
+
+See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
+
+#### Auto Action Mode (Recommended)
+
+The `auto` action mode is the easiest way to use X-VLA with any robot. It automatically detects your dataset's action dimension and handles padding/trimming:
+
+```bash
+lerobot-train \
+ --policy.path="lerobot/xvla-base" \
+ --policy.action_mode=auto \
+ --policy.max_action_dim=20 \
+ ...
+```
+
+**How it works:**
+
+- Reads `action_feature.shape[-1]` from your dataset (e.g., 7 for Franka)
+- Model outputs `max_action_dim` (default 20) for pretrained compatibility
+- Loss is computed **only on the real dimensions**: `MSE(pred[:,:,:real_dim], target[:,:,:real_dim])`
+- Postprocess trims output back to `real_dim` for robot control
+
+This eliminates the need to create custom action modes for most robots.
+
+### 2. Domain IDs
+
+Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
+
+- Different robots (Robot 1 vs Robot 2)
+- Different camera configurations (cam1 vs cam2)
+- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
+
+#### Setting Domain IDs
+
+**During Training**: By default, domain_id is set to 0 for general training.
+
+**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
+
+```python
+# Example: LIBERO checkpoint uses domain_id=3
+domain_id = 3
+```
+
+The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
+
+### 3. Processor Steps
+
+X-VLA requires specific preprocessing and postprocessing steps for proper operation.
+
+#### Required Preprocessing Steps
+
+1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
+2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
+3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
+
+#### Example Custom Processor
+
+For LIBERO environments, a custom processor handles the specific observation format:
+
+```python
+from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
+
+processor = LiberoProcessorStep()
+# Handles robot_state dictionary, converts rotation matrices to 6D representation
+# Applies 180° image rotation for camera convention
+```
+
+### 4. Configuration Parameters
+
+Key configuration parameters for X-VLA:
+
+```python
+# Observation and action
+n_obs_steps: int = 1 # Number of observation timesteps
+chunk_size: int = 32 # Action sequence length
+n_action_steps: int = 32 # Number of action steps to execute
+
+# Model architecture
+hidden_size: int = 1024 # Transformer hidden dimension
+depth: int = 24 # Number of transformer layers
+num_heads: int = 16 # Number of attention heads
+num_domains: int = 30 # Maximum number of domain IDs
+len_soft_prompts: int = 32 # Length of soft prompt embeddings
+
+# Action space
+action_mode: str = "ee6d" # Action space type (use "auto" for auto-detection)
+use_proprio: bool = True # Use proprioceptive state
+max_state_dim: int = 32 # Maximum state dimension
+max_action_dim: int = 20 # Max action dim for padding (used by "auto" mode)
+
+# Vision
+num_image_views: int | None # Number of camera views
+resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
+
+# Training
+num_denoising_steps: int = 10 # Flow matching denoising steps
+```
+
+## Creating Custom Action Modes
+
+If your robot has a unique action space, you can create a custom action mode:
+
+### Step 1: Define Your Action Space
+
+```python
+from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
+import torch.nn as nn
+
+@register_action("my_custom_robot")
+class MyCustomActionSpace(BaseActionSpace):
+ """Custom action space for my robot."""
+
+ dim_action = 15 # Your robot's action dimension
+ gripper_idx = (7, 14) # Gripper channel indices
+
+ def __init__(self):
+ super().__init__()
+ self.mse = nn.MSELoss()
+ self.bce = nn.BCEWithLogitsLoss()
+
+ def compute_loss(self, pred, target):
+ """Define your loss computation."""
+ # Example: MSE for joints, BCE for grippers
+ joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
+ gripper_loss = self.bce(pred[:, :, self.gripper_idx],
+ target[:, :, self.gripper_idx])
+
+ return {
+ "joints_loss": joints_loss,
+ "gripper_loss": gripper_loss,
+ }
+
+ def preprocess(self, proprio, action, mode="train"):
+ """Preprocess actions before training."""
+ # Example: Zero out grippers in proprioception
+ proprio_m = proprio.clone()
+ action_m = action.clone() if action is not None else None
+ proprio_m[..., self.gripper_idx] = 0.0
+ if action_m is not None:
+ action_m[..., self.gripper_idx] = 0.0
+ return proprio_m, action_m
+
+ def postprocess(self, action):
+ """Post-process predictions for deployment."""
+ # Example: Apply sigmoid to gripper logits
+ action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
+ return action
+```
+
+### Step 2: Use Your Custom Action Mode
+
+```bash
+lerobot-train \
+ --policy.action_mode=my_custom_robot \
+ --dataset.repo_id=YOUR_DATASET \
+ --policy.path="lerobot/xvla-base" \
+ ...
+```
+
+## Advanced Topics
+
+### Multi-Camera Support
+
+X-VLA supports multiple camera views through the `num_image_views` parameter:
+
+```python
+# Configure for 3 camera views
+policy.num_image_views=3
+
+# Add empty cameras if you have fewer physical cameras
+policy.empty_cameras=1 # Adds 1 zero-padded camera view
+```
+
+### Custom Preprocessing Pipeline
+
+Create a custom preprocessing pipeline for your environment:
+
+```python
+from lerobot.processor import PolicyProcessorPipeline
+from lerobot.policies.xvla.processor_xvla import (
+ XVLAImageToFloatProcessorStep,
+ XVLAImageNetNormalizeProcessorStep,
+ XVLAAddDomainIdProcessorStep,
+)
+
+# Build custom pipeline
+preprocessor = PolicyProcessorPipeline(
+ steps=[
+ YourCustomProcessorStep(), # Your custom processing
+ XVLAImageToFloatProcessorStep(), # Required: convert to float
+ XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
+ XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
+ ]
+)
+```
+
+### Handling Different Action Dimensions
+
+When your dataset has fewer action dimensions than the pretrained model:
+
+**Option 1 (Recommended)**: Use `auto` action mode
+
+```bash
+# Automatically detects your dataset's action dimension
+# Works with any robot without custom code
+policy.action_mode=auto
+policy.max_action_dim=20 # Match pretrained model
+```
+
+**Option 2**: Use a predefined action mode with built-in padding
+
+```python
+# Model expects 20D, dataset has 12D
+# Action mode handles padding internally
+action_mode = "so101_bimanual" # Pads 12 → 20
+```
+
+**Option 2**: Create a custom action mode that maps dimensions explicitly
+
+```python
+@register_action("my_mapped_action")
+class MappedActionSpace(BaseActionSpace):
+ dim_action = 20
+ REAL_DIM = 12
+
+ def _pad_to_model_dim(self, x):
+ # Custom padding logic
+ ...
+```
+
+## Troubleshooting
+
+### Common Issues
+
+**Issue**: "Action dimension mismatch"
+
+- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
+
+**Issue**: "Image values outside [0, 1] range"
+
+- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
+
+**Issue**: "Domain ID not found"
+
+- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
+
+**Issue**: "Low success rate on new embodiment"
+
+- **Solution**:
+ 1. Verify your action_mode is correct
+ 2. Check that soft prompts are being trained (`train_soft_prompts=True`)
+ 3. Ensure proper preprocessing (ImageNet normalization, domain_id)
+ 4. Consider increasing training steps
+
+**Issue**: "Out of memory during training"
+
+- **Solution**:
+ 1. Reduce `chunk_size` (e.g., from 32 to 16)
+ 2. Enable gradient checkpointing
+ 3. Reduce batch size
+ 4. Freeze more components
+
+## Citation
+
+If you use X-VLA in your research, please cite:
+
+```bibtex
+@article{zheng2025x,
+ title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
+ author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
+ and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
+ journal = {arXiv preprint arXiv:2510.10274},
+ year = {2025}
+}
+```
+
+## Additional Resources
+
+- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
+- [LeRobot Documentation](https://github.com/huggingface/lerobot)
+- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
+- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
+- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
+
+## Contributing
+
+We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.
diff --git a/pyproject.toml b/pyproject.toml
index 638b2326f..050b604e8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -133,6 +133,7 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
+xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -161,6 +162,7 @@ all = [
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
+ "lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py
index e696f683e..4323f3316 100644
--- a/src/lerobot/envs/configs.py
+++ b/src/lerobot/envs/configs.py
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
fps: int = 30
- episode_length: int = 520
+ episode_length: int | None = None
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
@@ -272,6 +272,7 @@ class LiberoEnv(EnvConfig):
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
}
)
+ control_mode: str = "relative" # or "absolute"
def __post_init__(self):
if self.obs_type == "pixels":
diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py
index bda84fd78..b39cfee71 100644
--- a/src/lerobot/envs/factory.py
+++ b/src/lerobot/envs/factory.py
@@ -19,8 +19,10 @@ from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
+from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
+from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
def make_env_pre_post_processors(
env_cfg: EnvConfig,
+ policy_cfg: PreTrainedConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -61,6 +64,10 @@ def make_env_pre_post_processors(
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
+ if isinstance(policy_cfg, XVLAConfig):
+ from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
+
+ return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
@@ -136,6 +143,8 @@ def make_env(
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
+ control_mode=cfg.control_mode,
+ episode_length=cfg.episode_length,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs
diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py
index 35bc58e07..b1eb37377 100644
--- a/src/lerobot/envs/libero.py
+++ b/src/lerobot/envs/libero.py
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
return [0, 0, 0, 0, 0, 0, -1]
-OBS_STATE_DIM = 8
ACTION_DIM = 7
-AGENT_POS_LOW = -1000.0
-AGENT_POS_HIGH = 1000.0
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
TASK_SUITE_MAX_STEPS: dict[str, int] = {
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
task_suite: Any,
task_id: int,
task_suite_name: str,
+ episode_length: int | None = None,
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
obs_type: str = "pixels",
render_mode: str = "rgb_array",
@@ -114,6 +112,7 @@ class LiberoEnv(gym.Env):
episode_index: int = 0,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
+ control_mode: str = "relative",
):
super().__init__()
self.task_id = task_id
@@ -141,14 +140,19 @@ class LiberoEnv(gym.Env):
self.camera_name_mapping = camera_name_mapping
self.num_steps_wait = num_steps_wait
self.episode_index = episode_index
+ self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
- self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
-
+ self._max_episode_steps = (
+ TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
+ if self.episode_length is None
+ else self.episode_length
+ )
+ self.control_mode = control_mode
images = {}
for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box(
@@ -296,6 +300,15 @@ class LiberoEnv(gym.Env):
# Increasing this value can improve determinism and reproducibility across resets.
for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
+
+ if self.control_mode == "absolute":
+ for robot in self._env.robots:
+ robot.controller.use_delta = False
+ elif self.control_mode == "relative":
+ for robot in self._env.robots:
+ robot.controller.use_delta = True
+ else:
+ raise ValueError(f"Invalid control mode: {self.control_mode}")
observation = self._format_raw_obs(raw_obs)
info = {"is_success": False}
return observation, info
@@ -341,8 +354,10 @@ def _make_env_fns(
task_id: int,
n_envs: int,
camera_names: list[str],
+ episode_length: int | None,
init_states: bool,
gym_kwargs: Mapping[str, Any],
+ control_mode: str,
) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id)."""
@@ -354,7 +369,9 @@ def _make_env_fns(
task_suite_name=suite_name,
camera_name=camera_names,
init_states=init_states,
+ episode_length=episode_length,
episode_index=episode_index,
+ control_mode=control_mode,
**local_kwargs,
)
@@ -374,6 +391,8 @@ def create_libero_envs(
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
+ control_mode: str = "relative",
+ episode_length: int | None = None,
) -> dict[str, dict[int, Any]]:
"""
Create vectorized LIBERO environments with a consistent return shape.
@@ -415,12 +434,14 @@ def create_libero_envs(
for tid in selected:
fns = _make_env_fns(
suite=suite,
+ episode_length=episode_length,
suite_name=suite_name,
task_id=tid,
n_envs=n_envs,
camera_names=camera_names,
init_states=init_states,
gym_kwargs=gym_kwargs,
+ control_mode=control_mode,
)
out[suite_name][tid] = env_cls(fns)
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py
index f2bd0df42..5120f828c 100644
--- a/src/lerobot/optim/optimizers.py
+++ b/src/lerobot/optim/optimizers.py
@@ -104,6 +104,107 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs)
+@OptimizerConfig.register_subclass("xvla-adamw")
+@dataclass
+class XVLAAdamWConfig(OptimizerConfig):
+ """Custom AdamW optimizer for XVLA with differential learning rates.
+
+ The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
+ for stable optimization, while all other components use the full LR.
+
+ This LR ratio is crucial for achieving strong and stable finetuning performance.
+
+ Soft-prompts can optionally use a separate learning rate with warm-up support.
+ Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
+ at a lower LR. Combine with a warmup scheduler for optimal results.
+
+ Note:
+ Completely matching official reported performance may require an additional
+ warm-up LR schedule for soft-prompts, which can bring minor improvements.
+ When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
+ `lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
+
+ Parameter Groups:
+ - Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
+ - Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
+ - Group 2 (other): All other parameters at full lr
+ """
+
+ lr: float = 1e-4
+ betas: tuple[float, float] = (0.9, 0.99)
+ eps: float = 1e-8
+ weight_decay: float = 0.0
+ grad_clip_norm: float = 10.0
+ # Soft-prompt specific settings
+ soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
+ soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
+
+ def build(self, params: dict) -> torch.optim.Optimizer:
+ """
+ Build AdamW optimizer with differential learning rates.
+
+ Expects `named_parameters()` as input (dict of name -> param).
+ Applies:
+ - lr * 0.1 for all VLM-related parameters
+ - lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
+ - full lr for all other parameters
+
+ Args:
+ params: Dictionary of parameter names to parameters (from named_parameters())
+
+ Returns:
+ AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
+ """
+ assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
+
+ vlm_group, soft_prompt_group, other_group = [], [], []
+ for name, p in params.items():
+ if not p.requires_grad:
+ continue
+ if "vlm" in name.lower():
+ vlm_group.append(p)
+ elif "soft_prompt" in name.lower():
+ soft_prompt_group.append(p)
+ else:
+ other_group.append(p)
+
+ # Determine soft-prompt LR
+ soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
+ if self.soft_prompt_warmup_lr_scale is not None:
+ # Start at warmup scale, scheduler will warm up to soft_prompt_lr
+ soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
+
+ param_groups = [
+ {
+ "params": vlm_group,
+ "lr": self.lr * 0.1,
+ "weight_decay": self.weight_decay * 0.1,
+ "name": "vlm",
+ },
+ {
+ "params": soft_prompt_group,
+ "lr": soft_prompt_lr,
+ "weight_decay": self.weight_decay,
+ "name": "soft_prompts",
+ },
+ {
+ "params": other_group,
+ "lr": self.lr,
+ "weight_decay": self.weight_decay,
+ "name": "other",
+ },
+ ]
+
+ # Filter out empty groups
+ param_groups = [g for g in param_groups if len(g["params"]) > 0]
+
+ return torch.optim.AdamW(
+ param_groups,
+ betas=self.betas,
+ eps=self.eps,
+ )
+
+
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py
index 4cdc89ea9..788542d49 100644
--- a/src/lerobot/policies/__init__.py
+++ b/src/lerobot/policies/__init__.py
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
+from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
@@ -31,4 +32,5 @@ __all__ = [
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
+ "XVLAConfig",
]
diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py
index a80bf2ee6..3d17fa7dc 100644
--- a/src/lerobot/policies/factory.py
+++ b/src/lerobot/policies/factory.py
@@ -41,6 +41,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
+from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
@@ -108,6 +109,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
+ elif name == "xvla":
+ from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
+
+ return XVLAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -154,6 +159,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
+ elif policy_type == "xvla":
+ return XVLAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -337,6 +344,15 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
+ elif isinstance(policy_cfg, XVLAConfig):
+ from lerobot.policies.xvla.processor_xvla import (
+ make_xvla_pre_post_processors,
+ )
+
+ processors = make_xvla_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ )
else:
try:
@@ -414,8 +430,7 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
- if not cfg.output_features:
- cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
+ cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
diff --git a/src/lerobot/policies/xvla/__init__.py b/src/lerobot/policies/xvla/__init__.py
new file mode 100644
index 000000000..71b04e76f
--- /dev/null
+++ b/src/lerobot/policies/xvla/__init__.py
@@ -0,0 +1,6 @@
+# register the processor steps
+from lerobot.policies.xvla.processor_xvla import (
+ XVLAAddDomainIdProcessorStep,
+ XVLAImageNetNormalizeProcessorStep,
+ XVLAImageToFloatProcessorStep,
+)
diff --git a/src/lerobot/policies/xvla/action_hub.py b/src/lerobot/policies/xvla/action_hub.py
new file mode 100644
index 000000000..e8411de9d
--- /dev/null
+++ b/src/lerobot/policies/xvla/action_hub.py
@@ -0,0 +1,588 @@
+# ------------------------------------------------------------------------------
+# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ------------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+
+# =============================================================================
+# Registry
+# =============================================================================
+ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
+
+
+def register_action(name: str):
+ """Decorator for registering a new action space."""
+
+ def _wrap(cls):
+ key = name.lower()
+ if key in ACTION_REGISTRY:
+ raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
+ ACTION_REGISTRY[key] = cls
+ cls.name = key
+ return cls
+
+ return _wrap
+
+
+def build_action_space(name: str, **kwargs) -> BaseActionSpace:
+ """Instantiate a registered action space by name."""
+ key = name.lower()
+ if key not in ACTION_REGISTRY:
+ raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
+ return ACTION_REGISTRY[key](**kwargs)
+
+
+# =============================================================================
+# Base class
+# =============================================================================
+class BaseActionSpace(nn.Module):
+ """
+ Abstract base class for all action-space definitions.
+
+ Each subclass defines:
+ - `dim_action`: dimension of the action vector.
+ - `gripper_idx`: indices of gripper channels.
+ - `compute_loss(pred, target)`: supervised loss for this space.
+ - `preprocess(proprio, action, mode)`: pre-step modifications.
+ - `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
+ """
+
+ name: str = "base"
+ dim_action: int = 0
+ gripper_idx: tuple[int, ...] = ()
+
+ def __init__(self):
+ super().__init__()
+
+ # ---------------------------------------------------------------------
+ # Core supervised loss
+ # ---------------------------------------------------------------------
+ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
+ raise NotImplementedError
+
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
+ """Alias for compute_loss."""
+ return self.compute_loss(pred, target)
+
+ # ---------------------------------------------------------------------
+ # Space-level hooks
+ # ---------------------------------------------------------------------
+ def preprocess(
+ self,
+ proprio: torch.Tensor,
+ action: torch.Tensor,
+ mode: str = "train",
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Default: return unchanged."""
+ return proprio, action
+
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
+ """Default: return unchanged."""
+ return action
+
+
+# =============================================================================
+# Utilities
+# =============================================================================
+def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
+ bad = [i for i in idx if i < 0 or i >= dim_action]
+ if bad:
+ raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
+
+
+# =============================================================================
+# Implementations
+# =============================================================================
+@register_action("ee6d")
+class EE6DActionSpace(BaseActionSpace):
+ """End-effector layout with xyz, 6D rotation, and gripper channels."""
+
+ dim_action = 20
+ gripper_idx = (9, 19)
+ GRIPPER_SCALE = 1.0
+ XYZ_SCALE = 500.0
+ ROT_SCALE = 10.0
+
+ POS_IDX_1 = (0, 1, 2)
+ POS_IDX_2 = (10, 11, 12)
+ ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
+ ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
+
+ def __init__(self):
+ super().__init__()
+ self.mse = nn.MSELoss()
+ self.bce = nn.BCEWithLogitsLoss()
+
+ def compute_loss(self, pred, target):
+ assert pred.shape == target.shape, "pred/target shapes must match"
+ batch_size, seq_len, action_dim = pred.shape
+ _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
+
+ # Gripper BCE
+ g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
+ gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
+
+ # XYZ position
+ pos_loss = (
+ self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ + self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
+ ) * self.XYZ_SCALE
+
+ # Rotation 6D
+ rot_loss = (
+ self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ + self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
+ ) * self.ROT_SCALE
+
+ return {
+ "position_loss": pos_loss,
+ "rotate6D_loss": rot_loss,
+ "gripper_loss": gripper_loss,
+ }
+
+ def preprocess(self, proprio, action, mode="train"):
+ """Zero-out gripper channels in proprio/action."""
+ proprio_m = proprio.clone()
+ action_m = action.clone()
+ proprio_m[..., self.gripper_idx] = 0.0
+ action_m[..., self.gripper_idx] = 0.0
+ return proprio_m, action_m
+
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
+ """Apply sigmoid to gripper logits."""
+ if action.size(-1) > max(self.gripper_idx):
+ action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
+ return action
+
+
+@register_action("joint")
+class JointActionSpace(BaseActionSpace):
+ """Joint-space layout with joints + gripper only."""
+
+ dim_action = 14
+ gripper_idx = (6, 13)
+ GRIPPER_SCALE = 0.1
+ JOINTS_SCALE = 1.0
+
+ def __init__(self):
+ super().__init__()
+ self.mse = nn.MSELoss()
+ self.bce = nn.BCEWithLogitsLoss()
+
+ def compute_loss(self, pred, target):
+ assert pred.shape == target.shape
+ batch_size, seq_len, action_dim = pred.shape
+ _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
+
+ g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
+ gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
+
+ joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
+ joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
+
+ return {
+ "joints_loss": joints_loss,
+ "gripper_loss": gripper_loss,
+ }
+
+ def preprocess(self, proprio, action, mode="train"):
+ """Zero-out gripper channels in proprio/action."""
+ proprio_m = proprio.clone()
+ action_m = action.clone()
+ proprio_m[..., self.gripper_idx] = 0.0
+ action_m[..., self.gripper_idx] = 0.0
+ return proprio_m, action_m
+
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
+ """Apply sigmoid to gripper logits."""
+ if action.size(-1) > max(self.gripper_idx):
+ action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
+ return action
+
+
+@register_action("agibot_ee6d")
+class AGIBOTEE6DActionSpace(BaseActionSpace):
+ """AGI-bot variant of EE6DActionSpace using MSE for all components."""
+
+ dim_action = 20
+ gripper_idx = (9, 19)
+ GRIPPER_SCALE = 10.0
+ XYZ_SCALE = 500.0
+ ROT_SCALE = 10.0
+ POS_IDX_1 = (0, 1, 2)
+ POS_IDX_2 = (10, 11, 12)
+ ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
+ ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
+
+ def __init__(self):
+ super().__init__()
+ self.mse = nn.MSELoss()
+
+ def compute_loss(self, pred, target):
+ assert pred.shape == target.shape
+ batch_size, seq_len, action_dim = pred.shape
+ _ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
+
+ gripper_loss = (
+ self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
+ )
+ pos_loss = (
+ self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ + self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
+ ) * self.XYZ_SCALE
+ rot_loss = (
+ self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ + self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
+ ) * self.ROT_SCALE
+
+ return {
+ "position_loss": pos_loss,
+ "rotate6D_loss": rot_loss,
+ "gripper_loss": gripper_loss,
+ }
+
+ def preprocess(self, proprio, action, mode="train"):
+ """No preprocessing applied in AGIBOT variant."""
+ return proprio, action
+
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
+ """AGIBOT does not postprocess."""
+ return action
+
+
+@register_action("franka_joint7")
+class FrankaJoint7ActionSpace(BaseActionSpace):
+ """
+ Franka Panda joint-space: 7 joints, with gripper.
+
+ - Real robot action dim: 7
+ - Model-facing dim: 20 (padded with zeros)
+ compatible with pretrained VLA models expecting 20D.
+ """
+
+ dim_action = 20 # model dimension
+ REAL_DIM = 7 # actual Franka joints
+
+ JOINTS_SCALE = 1.0
+
+ def __init__(self):
+ super().__init__()
+ self.mse = nn.MSELoss()
+
+ def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
+ """Pad 7 → 20 dims (zeros for the dummy channels)."""
+ if x is None:
+ return None
+ if x.size(-1) == self.dim_action:
+ return x
+ if x.size(-1) != self.REAL_DIM:
+ raise ValueError(
+ f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
+ )
+
+ pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM] # 13 zeros
+ pad = x.new_zeros(pad_shape)
+ return torch.cat([x, pad], dim=-1)
+
+ def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
+ """Trim model output 20 → 7 dims."""
+ return x[..., : self.REAL_DIM]
+
+ def compute_loss(self, pred, target):
+ """
+ pred : [B, T, 20]
+ target : [B, T, 7] or [B, T, 20]
+
+ Only compute MSE on the first 7 dims.
+ """
+ pred = self._pad_to_model_dim(pred)
+ target = self._pad_to_model_dim(target)
+
+ assert pred.shape == target.shape
+
+ joints_loss = (
+ self.mse(
+ pred[:, :, : self.REAL_DIM], # use only the first 7 joints
+ target[:, :, : self.REAL_DIM],
+ )
+ * self.JOINTS_SCALE
+ )
+
+ return {"joints_loss": joints_loss}
+
+ def preprocess(self, proprio, action, mode="train"):
+ """
+ During training:
+ - Pad [7] → [20]
+ """
+ return proprio, self._pad_to_model_dim(action)
+
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
+ """
+ After model prediction:
+ - Trim [20] → [7] for real robot control.
+ """
+ return self._trim_to_real_dim(action)
+
+
+@register_action("auto")
+class AutoActionSpace(BaseActionSpace):
+ """
+ Auto-detecting action space that adapts to any action dimension.
+
+ - Auto-detects the real action dimension from the policy feature
+ - Model outputs max_dim for compatibility with pretrained models
+ - Loss is computed only on the first real_dim dimensions
+ - Postprocess trims output back to real_dim
+
+ Args:
+ real_dim: The actual action dimension from the dataset/policy feature
+ max_dim: The model's output dimension for pretrained VLA compatibility
+ """
+
+ JOINTS_SCALE = 1.0
+
+ def __init__(self, real_dim: int, max_dim: int):
+ super().__init__()
+ self.real_dim = real_dim
+ self.dim_action = max_dim # Model-facing dimension
+ self.mse = nn.MSELoss()
+
+ def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
+ """Pad real_dim → max_dim (zeros for the dummy channels)."""
+ if x is None:
+ return None
+ if x.size(-1) == self.dim_action:
+ return x
+ if x.size(-1) != self.real_dim:
+ # If dimension doesn't match either, pad/trim to real_dim first
+ if x.size(-1) < self.real_dim:
+ pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
+ pad = x.new_zeros(pad_shape)
+ x = torch.cat([x, pad], dim=-1)
+ else:
+ x = x[..., : self.real_dim]
+
+ pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
+ pad = x.new_zeros(pad_shape)
+ return torch.cat([x, pad], dim=-1)
+
+ def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
+ """Trim model output max_dim → real_dim."""
+ return x[..., : self.real_dim]
+
+ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
+ """
+ Compute loss only on the first real_dim dimensions.
+
+ pred: [B, T, max_dim] from the model
+ target: [B, T, real_dim] or [B, T, max_dim]
+
+ Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
+ """
+ pred = self._pad_to_model_dim(pred)
+ target = self._pad_to_model_dim(target)
+ assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
+
+ # only compute loss on the real dimensions
+ joints_loss = (
+ self.mse(
+ pred[:, :, : self.real_dim],
+ target[:, :, : self.real_dim],
+ )
+ * self.JOINTS_SCALE
+ )
+
+ return {"joints_loss": joints_loss}
+
+ def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
+ """
+ Pad action from real_dim to max_dim for the model.
+ """
+ return proprio, self._pad_to_model_dim(action)
+
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
+ """
+ Trim model output from max_dim to real_dim for real robot control.
+ """
+ return self._trim_to_real_dim(action)
+
+
+@register_action("so101_bimanual")
+class BimanualSO101ActionSpace(BaseActionSpace):
+ """
+ Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
+
+ Layout (real robot):
+ [left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
+ - Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
+ - Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
+
+ Real action dim: 12
+ Model-facing dim: 20 (extra 8 dummy dims at the end)
+ """
+
+ # Model output / training dimension (to match pretrained policy)
+ dim_action = 20
+
+ # Real robot action dimension
+ REAL_DIM = 12
+
+ # Indices of real vs dummy channels
+ REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
+ DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
+
+ # Grippers live in the real part
+ gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
+ GRIPPER_SCALE = 1.0
+ JOINTS_SCALE = 1.0
+
+ # Indices for left and right arm joints (excluding grippers)
+ LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
+ RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
+
+ def __init__(self):
+ super().__init__()
+ self.mse = nn.MSELoss()
+ self.bce = nn.BCEWithLogitsLoss()
+
+ # ---------- helpers ----------
+
+ def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
+ """If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
+ if x is None:
+ return None
+ if x.size(-1) == self.dim_action:
+ return x
+ if x.size(-1) != self.REAL_DIM:
+ raise ValueError(
+ f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
+ )
+ pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
+ pad = x.new_zeros(pad_shape)
+ return torch.cat([x, pad], dim=-1)
+
+ def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
+ """Keep only the first REAL_DIM (12) dims for the real robot."""
+ return x[..., : self.REAL_DIM]
+
+ # ---------- loss ----------
+
+ def compute_loss(self, pred, target):
+ """
+ pred: [B, T, 20] from the model
+ target: [B, T, 12] or [B, T, 20]
+ We pad target → 20 and compute loss only on the real dims.
+ """
+ # Ensure both are [B, T, 20]
+ pred = self._pad_to_model_dim(pred)
+ target = self._pad_to_model_dim(target)
+ assert pred.shape == target.shape
+
+ # ---- MSE for all real dims (0–11) ----
+ real_dims = 12
+
+ joints_loss = (
+ self.mse(
+ pred[:, :, :real_dims],
+ target[:, :, :real_dims],
+ )
+ * self.JOINTS_SCALE
+ )
+
+ left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
+ right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
+
+ gripper_loss = (
+ self.mse(
+ pred[:, :, [5, 11]],
+ target[:, :, [5, 11]],
+ )
+ * self.GRIPPER_SCALE
+ )
+
+ return {
+ "joints_loss": joints_loss,
+ "gripper_loss": gripper_loss,
+ "left_arm_loss": left_arm_loss,
+ "right_arm_loss": right_arm_loss,
+ }
+
+ # ---------- preprocess / postprocess ----------
+
+ def preprocess(self, proprio, action, mode="train"):
+ """
+ - If proprio/action are 12-dim, pad them to 20 for the model.
+ - Zero-out gripper channels in proprio/action to focus learning on joints.
+ """
+ proprio_m = self._pad_to_model_dim(proprio.clone())
+ action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
+
+ proprio_m[..., self.gripper_idx] = 0.0
+ if action_m is not None:
+ action_m[..., self.gripper_idx] = 0.0
+
+ return proprio_m, action_m
+
+ def postprocess(self, action: torch.Tensor) -> torch.Tensor:
+ """
+ - Model outputs [*, 20]
+ - Apply sigmoid to gripper logits
+ - Return only the first 12 dims for the real robot:
+ ["left_shoulder_pan.pos",
+ "left_shoulder_lift.pos",
+ "left_elbow_flex.pos",
+ "left_wrist_flex.pos",
+ "left_wrist_roll.pos",
+ "left_gripper.pos",
+ "right_shoulder_pan.pos",
+ "right_shoulder_lift.pos",
+ "right_elbow_flex.pos",
+ "right_wrist_flex.pos",
+ "right_wrist_roll.pos",
+ "right_gripper.pos"]
+ """
+ # Ensure we at least have the real dims + grippers
+ if action.size(-1) < self.REAL_DIM:
+ raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
+
+ # Apply sigmoid on gripper channels in model space (indices 5 and 11)
+ if action.size(-1) > max(self.gripper_idx):
+ action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
+
+ # Return only the real 12-dim control vector for the env
+ return self._trim_to_real_dim(action)
+
+
+# =============================================================================
+# Exports
+# =============================================================================
+__all__ = [
+ "BaseActionSpace",
+ "build_action_space",
+ "register_action",
+ "EE6DActionSpace",
+ "JointActionSpace",
+ "AGIBOTEE6DActionSpace",
+ "FrankaJoint7ActionSpace",
+ "AutoActionSpace",
+ "BimanualSO101ActionSpace",
+ "ACTION_REGISTRY",
+]
diff --git a/src/lerobot/policies/xvla/configuration_florence2.py b/src/lerobot/policies/xvla/configuration_florence2.py
new file mode 100644
index 000000000..35c006ee0
--- /dev/null
+++ b/src/lerobot/policies/xvla/configuration_florence2.py
@@ -0,0 +1,353 @@
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+""" Florence-2 configuration"""
+
+logger = logging.get_logger(__name__)
+
+
+class Florence2VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ The dropout rate of the drop path layer.
+ patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
+ The patch size of the image.
+ patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
+ The patch stride of the image.
+ patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
+ The patch padding of the image.
+ patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
+ Whether to apply layer normalization before the patch embedding layer.
+ enable_checkpoint (`bool`, *optional*, defaults to False):
+ Whether to enable checkpointing.
+ dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
+ The dimension of the embedding layer.
+ num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
+ The number of attention heads.
+ num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
+ The number of groups.
+ depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
+ The depth of the model.
+ window_size (`int`, *optional*, defaults to 12):
+ The window size of the model.
+ projection_dim (`int`, *optional*, defaults to 1024):
+ The dimension of the projection layer.
+ visual_temporal_embedding (`dict`, *optional*):
+ The configuration of the visual temporal embedding.
+ image_pos_embed (`dict`, *optional*):
+ The configuration of the image position embedding.
+ image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
+ The source of the image feature.
+ Example:
+
+ ```python
+ >>> from transformers import Florence2VisionConfig, Florence2VisionModel
+
+ >>> # Initializing a Florence2 Vision style configuration
+ >>> configuration = Florence2VisionConfig()
+
+ >>> # Initializing a model (with random weights)
+ >>> model = Florence2VisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "davit"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ drop_path_rate=0.1,
+ patch_size=None,
+ patch_stride=None,
+ patch_padding=None,
+ patch_prenorm=None,
+ enable_checkpoint=False,
+ dim_embed=None,
+ num_heads=None,
+ num_groups=None,
+ depths=None,
+ window_size=12,
+ projection_dim=1024,
+ visual_temporal_embedding=None,
+ image_pos_embed=None,
+ image_feature_source=None,
+ **kwargs,
+ ):
+ self.drop_path_rate = drop_path_rate
+ self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
+ self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
+ self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
+ self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
+ self.enable_checkpoint = enable_checkpoint
+ self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
+ self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
+ self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
+ self.depths = depths if depths is not None else [1, 1, 9, 1]
+ self.window_size = window_size
+ self.projection_dim = projection_dim
+
+ if visual_temporal_embedding is None:
+ visual_temporal_embedding = {
+ "type": "COSINE",
+ "max_temporal_embeddings": 100,
+ }
+ self.visual_temporal_embedding = visual_temporal_embedding
+
+ if image_pos_embed is None:
+ image_pos_embed = {
+ "type": "learned_abs_2d",
+ "max_pos_embeddings": 1000,
+ }
+ self.image_pos_embed = image_pos_embed
+
+ self.image_feature_source = (
+ image_feature_source
+ if image_feature_source is not None
+ else ["spatial_avg_pool", "temporal_avg_pool"]
+ )
+
+ super().__init__(**kwargs)
+
+
+class Florence2LanguageConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the BART
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 51289):
+ Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Florence2LanguageModel`].
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for classifier.
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ scale_embedding (`bool`, *optional*, defaults to `False`):
+ Scale embeddings by diving by sqrt(d_model).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ num_labels (`int`, *optional*, defaults to 3):
+ The number of labels to use in [`Florence2LanguageForSequenceClassification`].
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+ `eos_token_id`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
+
+ >>> # Initializing a Florence2 Language style configuration
+ >>> configuration = Florence2LanguageConfig()
+
+ >>> # Initializing a model (with random weights)
+ >>> model = Florence2LanguageModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "florence2_language"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=51289,
+ max_position_embeddings=1024,
+ encoder_layers=12,
+ encoder_ffn_dim=4096,
+ encoder_attention_heads=16,
+ decoder_layers=12,
+ decoder_ffn_dim=4096,
+ decoder_attention_heads=16,
+ encoder_layerdrop=0.0,
+ decoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1024,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ num_labels=3,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ is_encoder_decoder=True,
+ decoder_start_token_id=2,
+ forced_eos_token_id=2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.decoder_layerdrop = decoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ super().__init__(
+ num_labels=num_labels,
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ decoder_start_token_id=decoder_start_token_id,
+ forced_eos_token_id=forced_eos_token_id,
+ **kwargs,
+ )
+
+ # ensure backward compatibility for BART CNN models
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
+ self.forced_bos_token_id = self.bos_token_id
+ warnings.warn(
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
+ "The config can simply be saved and uploaded again to be fixed.",
+ stacklevel=2,
+ )
+
+
+class Florence2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
+ Florence-2 model according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Florence2VisionConfig`, *optional*):
+ Custom vision config or dict
+ text_config (`Union[AutoConfig, dict]`, *optional*):
+ The config object of the text backbone.
+ ignore_index (`int`, *optional*, defaults to -100):
+ The ignore index for the loss function.
+ vocab_size (`int`, *optional*, defaults to 51289):
+ Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
+ projection_dim (`int`, *optional*, defaults to 1024):
+ Dimension of the multimodal projection space.
+
+ Example:
+
+ ```python
+ >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
+
+ >>> # Initializing a clip-like vision config
+ >>> vision_config = CLIPVisionConfig()
+
+ >>> # Initializing a Bart config
+ >>> text_config = BartConfig()
+
+ >>> # Initializing a Florence-2 configuration
+ >>> configuration = Florence2Config(vision_config, text_config)
+
+ >>> # Initializing a model from the florence-2 configuration
+ >>> model = Florence2ForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "florence2"
+ is_composition = False
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ ignore_index=-100,
+ vocab_size=51289,
+ projection_dim=1024,
+ **kwargs,
+ ):
+ self.ignore_index = ignore_index
+ self.vocab_size = vocab_size
+ self.projection_dim = projection_dim
+ if vision_config is not None:
+ vision_config = Florence2VisionConfig(**vision_config)
+ self.vision_config = vision_config
+
+ self.text_config = text_config
+ if text_config is not None:
+ self.text_config = Florence2LanguageConfig(**text_config)
+
+ super().__init__(**kwargs)
diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py
new file mode 100644
index 000000000..30700b042
--- /dev/null
+++ b/src/lerobot/policies/xvla/configuration_xvla.py
@@ -0,0 +1,203 @@
+#!/usr/bin/env python
+
+# ------------------------------------------------------------------------------
+# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ------------------------------------------------------------------------------
+
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any
+
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
+from lerobot.optim.optimizers import XVLAAdamWConfig
+from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+from lerobot.utils.constants import OBS_IMAGES
+
+# Conditional import for type checking and lazy loading
+from lerobot.utils.import_utils import _transformers_available
+
+if TYPE_CHECKING or _transformers_available:
+ from .configuration_florence2 import Florence2Config
+else:
+ Florence2Config = None
+
+
+@PreTrainedConfig.register_subclass("xvla")
+@dataclass
+class XVLAConfig(PreTrainedConfig):
+ """
+ Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
+ plug into the LeRobot training stack.
+
+ The config mirrors the knobs exposed in the original XVLA repository but also
+ declares the input/output feature contract required by LeRobot.
+ """
+
+ # Input / output structure
+ n_obs_steps: int = 1
+ chunk_size: int = 32
+ n_action_steps: int = 32
+ dtype: str = "float32" # Options: "bfloat16", "float32"
+
+ normalization_mapping: dict[str, NormalizationMode] = field(
+ default_factory=lambda: {
+ "VISUAL": NormalizationMode.IDENTITY,
+ "STATE": NormalizationMode.IDENTITY,
+ "ACTION": NormalizationMode.IDENTITY,
+ }
+ )
+
+ # Florence2 backbone and tokenizer configuration
+ florence_config: dict[str, Any] = field(default_factory=dict)
+ tokenizer_name: str = "facebook/bart-large"
+ tokenizer_max_length: int = 64
+ tokenizer_padding_side: str = "right"
+ pad_language_to: str = "max_length"
+
+ # Transformer head
+ hidden_size: int = 1024
+ depth: int = 24
+ num_heads: int = 16
+ mlp_ratio: float = 4.0
+ num_domains: int = 30
+ len_soft_prompts: int = 32
+ dim_time: int = 32
+ max_len_seq: int = 512
+ use_hetero_proj: bool = False
+
+ # Action & proprioception
+ action_mode: str = "ee6d"
+ num_denoising_steps: int = 10
+ use_proprio: bool = True
+ max_state_dim: int = 32
+ max_action_dim: int = 20 # Maximum action dimension for padding (used by "auto" action mode)
+ domain_feature_key: str | None = None
+
+ # Vision preprocessing
+ resize_imgs_with_padding: tuple[int, int] | None = None
+ num_image_views: int | None = None
+ empty_cameras: int = 0
+
+ # Freezing options for VLM components
+ # By default, VLM encoders are frozen and only policy transformer + soft prompts train
+ freeze_vision_encoder: bool = False # Freeze VLM vision encoder weights
+ freeze_language_encoder: bool = False # Freeze VLM language encoder weights
+ train_policy_transformer: bool = True # Allow policy transformer to train
+ train_soft_prompts: bool = True # Allow soft prompts to train
+
+ # Training presets
+ optimizer_lr: float = 1e-4
+ optimizer_betas: tuple[float, float] = (0.9, 0.99)
+ optimizer_eps: float = 1e-8
+ optimizer_weight_decay: float = 0.0
+ optimizer_grad_clip_norm: float = 10.0
+ # Soft-prompt LR settings (for optional warm-up)
+ optimizer_soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR
+ optimizer_soft_prompt_warmup_lr_scale: float | None = None # Start scale for warmup (e.g., 0.01)
+
+ scheduler_warmup_steps: int = 1_000
+ scheduler_decay_steps: int = 30_000
+ scheduler_decay_lr: float = 2.5e-6
+
+ def __post_init__(self) -> None:
+ super().__post_init__()
+
+ if self.chunk_size <= 0:
+ raise ValueError("`chunk_size` must be strictly positive.")
+ if self.n_action_steps > self.chunk_size:
+ raise ValueError(
+ f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
+ )
+ if self.num_image_views is not None and self.num_image_views <= 0:
+ raise ValueError("`num_image_views` must be > 0 when specified.")
+ if self.dtype not in ["bfloat16", "float32"]:
+ raise ValueError(f"Invalid dtype: {self.dtype}")
+ self._florence_config_obj: Florence2Config | None = None
+
+ def get_florence_config(self) -> Florence2Config:
+ """
+ Build (and cache) the Florence2 transformer config that should back the VLM.
+ """
+ if self._florence_config_obj is None:
+ config_dict = dict(self.florence_config)
+ if "vision_config" not in config_dict or config_dict["vision_config"] is None:
+ raise ValueError("vision_config is required")
+
+ if "text_config" not in config_dict or config_dict["text_config"] is None:
+ raise ValueError("text_config is required")
+ self._florence_config_obj = Florence2Config(**config_dict)
+ return self._florence_config_obj
+
+ def validate_features(self) -> None:
+ if not self.image_features:
+ raise ValueError("XVLA requires at least one visual feature in the inputs.")
+ if self.use_proprio and self.robot_state_feature is None:
+ raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
+ if self.num_image_views is None:
+ self.num_image_views = len(self.image_features) + self.empty_cameras
+ else:
+ self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
+
+ if self.empty_cameras > 0:
+ height, width = (480, 640)
+ if self.resize_imgs_with_padding is not None:
+ height, width = self.resize_imgs_with_padding
+ for idx in range(self.empty_cameras):
+ key = f"{OBS_IMAGES}.empty_camera_{idx}"
+ if key not in self.input_features:
+ self.input_features[key] = PolicyFeature(
+ type=FeatureType.VISUAL,
+ shape=(3, height, width),
+ )
+
+ def get_optimizer_preset(self) -> XVLAAdamWConfig:
+ """Return the XVLA-specific optimizer with differential learning rates.
+
+ This optimizer applies:
+ - 1/10 LR for VLM parameters (stable optimization)
+ - Full LR for transformer/action head
+ - Configurable LR for soft-prompts (with optional warm-up)
+ """
+ return XVLAAdamWConfig(
+ lr=self.optimizer_lr,
+ betas=self.optimizer_betas,
+ eps=self.optimizer_eps,
+ weight_decay=self.optimizer_weight_decay,
+ grad_clip_norm=self.optimizer_grad_clip_norm,
+ soft_prompt_lr_scale=self.optimizer_soft_prompt_lr_scale,
+ soft_prompt_warmup_lr_scale=self.optimizer_soft_prompt_warmup_lr_scale,
+ )
+
+ def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
+ return CosineDecayWithWarmupSchedulerConfig(
+ peak_lr=self.optimizer_lr,
+ decay_lr=self.scheduler_decay_lr,
+ num_warmup_steps=self.scheduler_warmup_steps,
+ num_decay_steps=self.scheduler_decay_steps,
+ )
+
+ @property
+ def observation_delta_indices(self) -> list[int] | None:
+ return None
+
+ @property
+ def action_delta_indices(self) -> list[int]:
+ return list(range(self.chunk_size))
+
+ @property
+ def reward_delta_indices(self) -> list[int] | None:
+ return None
diff --git a/src/lerobot/policies/xvla/modeling_florence2.py b/src/lerobot/policies/xvla/modeling_florence2.py
new file mode 100644
index 000000000..2b5316fae
--- /dev/null
+++ b/src/lerobot/policies/xvla/modeling_florence2.py
@@ -0,0 +1,2757 @@
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""PyTorch Florence-2 model."""
+
+import math
+from collections import OrderedDict
+from dataclasses import dataclass
+
+import torch
+import torch.nn.functional as functional
+import torch.utils.checkpoint
+import torch.utils.checkpoint as checkpoint
+from einops import rearrange
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.activations import ACT2FN
+from transformers.generation.utils import GenerationMixin
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+from .configuration_florence2 import Florence2Config, Florence2LanguageConfig
+from .utils import drop_path
+
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "Florence2Config"
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
+
+
+class LearnedAbsolutePositionEmbedding2D(nn.Module):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, embedding_dim=256, num_pos=50):
+ super().__init__()
+ self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
+ self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2))
+
+ def forward(self, pixel_values):
+ """
+ pixel_values: (batch_size, height, width, num_channels)
+ returns: (batch_size, height, width, embedding_dim * 2)
+ """
+ if len(pixel_values.shape) != 4:
+ raise ValueError("pixel_values must be a 4D tensor")
+ height, width = pixel_values.shape[1:3]
+ width_values = torch.arange(width, device=pixel_values.device)
+ height_values = torch.arange(height, device=pixel_values.device)
+ x_emb = self.column_embeddings(width_values)
+ y_emb = self.row_embeddings(height_values)
+ # (height, width, embedding_dim * 2)
+ pos = torch.cat(
+ [x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1
+ )
+ # (embedding_dim * 2, height, width)
+ pos = pos.permute(2, 0, 1)
+ pos = pos.unsqueeze(0)
+ # (batch_size, embedding_dim * 2, height, width)
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
+ # (batch_size, height, width, embedding_dim * 2)
+ pos = pos.permute(0, 2, 3, 1)
+ return pos
+
+
+class PositionalEmbeddingCosine1D(nn.Module):
+ """
+ This class implements a very simple positional encoding. It follows closely
+ the encoder from the link below:
+ https://pytorch.org/tutorials/beginner/translation_transformer.html
+
+ Args:
+ embed_dim: The dimension of the embeddings.
+ dropout_prob: The dropout probability.
+ max_seq_len: The maximum length to precompute the positional encodings.
+ """
+
+ def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.max_seq_len = max_seq_len
+ # Generate the sinusoidal arrays.
+ factor = math.log(10000)
+ denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim)
+ # Matrix where rows correspond to a positional embedding as a function
+ # of the position index (i.e., the row index).
+ frequencies = torch.arange(0, self.max_seq_len).reshape(self.max_seq_len, 1) * denominator
+ pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
+ # Populate uneven entries.
+ pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
+ pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
+ # Save the positional embeddings in a constant buffer.
+ self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
+
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ seq_embeds: The sequence embeddings in order. Allowed size:
+ 1. [T, D], where T is the length of the sequence, and D is the
+ frame embedding dimension.
+ 2. [B, T, D], where B is the batch size and T and D are the
+ same as above.
+
+ Returns a tensor of with the same dimensions as the input: i.e.,
+ [1, T, D] or [T, D].
+ """
+ shape_len = len(seq_embeds.shape)
+ assert 2 <= shape_len <= 3
+ len_seq = seq_embeds.size(-2)
+ assert len_seq <= self.max_seq_len
+ pos_embeds = self.pos_idx_to_embed[0 : seq_embeds.size(-2), :]
+ # Adapt pre-computed positional embeddings to the input.
+ if shape_len == 3:
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
+ return pos_embeds
+
+
+class LearnedAbsolutePositionEmbedding1D(nn.Module):
+ """
+ Learnable absolute positional embeddings for 1D sequences.
+
+ Args:
+ embed_dim: The dimension of the embeddings.
+ max_seq_len: The maximum length to precompute the positional encodings.
+ """
+
+ def __init__(self, embedding_dim: int = 512, num_pos: int = 1024) -> None:
+ super().__init__()
+ self.embeddings = nn.Embedding(num_pos, embedding_dim)
+ self.num_pos = num_pos
+
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ seq_embeds: The sequence embeddings in order. Allowed size:
+ 1. [T, D], where T is the length of the sequence, and D is the
+ frame embedding dimension.
+ 2. [B, T, D], where B is the batch size and T and D are the
+ same as above.
+
+ Returns a tensor of with the same dimensions as the input: i.e.,
+ [1, T, D] or [T, D].
+ """
+ shape_len = len(seq_embeds.shape)
+ assert 2 <= shape_len <= 3
+ len_seq = seq_embeds.size(-2)
+ assert len_seq <= self.num_pos
+ # [T, D]
+ pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
+ # Adapt pre-computed positional embeddings to the input.
+ if shape_len == 3:
+ pos_embeds = pos_embeds.view((1, pos_embeds.size(0), pos_embeds.size(1)))
+ return pos_embeds
+
+
+class MySequential(nn.Sequential):
+ def forward(self, *inputs):
+ for module in self._modules.values():
+ inputs = module(*inputs) if isinstance(inputs, tuple) else module(inputs)
+ return inputs
+
+
+class PreNorm(nn.Module):
+ def __init__(self, norm, fn, drop_path=None):
+ super().__init__()
+ self.norm = norm
+ self.fn = fn
+ self.drop_path = drop_path
+
+ def forward(self, x, *args, **kwargs):
+ shortcut = x
+ if self.norm is not None:
+ x, size = self.fn(self.norm(x), *args, **kwargs)
+ else:
+ x, size = self.fn(x, *args, **kwargs)
+
+ if self.drop_path:
+ x = self.drop_path(x)
+
+ x = shortcut + x
+
+ return x, size
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.net = nn.Sequential(
+ OrderedDict(
+ [
+ ("fc1", nn.Linear(in_features, hidden_features)),
+ ("act", act_layer()),
+ ("fc2", nn.Linear(hidden_features, out_features)),
+ ]
+ )
+ )
+
+ def forward(self, x, size):
+ return self.net(x), size
+
+
+class DepthWiseConv2d(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ kernel_size,
+ padding,
+ stride,
+ bias=True,
+ ):
+ super().__init__()
+ self.dw = nn.Conv2d(
+ dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias
+ )
+
+ def forward(self, x, size):
+ batch_size, num_tokens, channels = x.shape
+ height, width = size
+ assert num_tokens == height * width
+
+ x = self.dw(x.transpose(1, 2).view(batch_size, channels, height, width))
+ size = (x.size(-2), x.size(-1))
+ x = x.flatten(2).transpose(1, 2)
+ return x, size
+
+
+class ConvEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(
+ self, patch_size=7, in_chans=3, embed_dim=64, stride=4, padding=2, norm_layer=None, pre_norm=True
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
+
+ dim_norm = in_chans if pre_norm else embed_dim
+ self.norm = norm_layer(dim_norm) if norm_layer else None
+
+ self.pre_norm = pre_norm
+
+ def forward(self, x, size):
+ height, width = size
+ if len(x.size()) == 3:
+ if self.norm and self.pre_norm:
+ x = self.norm(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=height, w=width)
+
+ x = self.proj(x)
+
+ _, _, height, width = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ if self.norm and not self.pre_norm:
+ x = self.norm(x)
+
+ return x, (height, width)
+
+
+class ChannelAttention(nn.Module):
+ def __init__(self, dim, groups=8, qkv_bias=True):
+ super().__init__()
+
+ self.groups = groups
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x, size):
+ batch_size, num_tokens, channels = x.shape
+
+ qkv = (
+ self.qkv(x)
+ .reshape(batch_size, num_tokens, 3, self.groups, channels // self.groups)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * (float(num_tokens) ** -0.5)
+ attention = q.transpose(-1, -2) @ k
+ attention = attention.softmax(dim=-1)
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
+ x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
+ x = self.proj(x)
+ return x, size
+
+
+class ChannelBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ groups,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path_rate=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ conv_at_attn=True,
+ conv_at_ffn=True,
+ ):
+ super().__init__()
+
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
+ self.channel_attn = PreNorm(
+ norm_layer(dim), ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), drop_path
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
+ drop_path,
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.channel_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+
+ return x, size
+
+
+def window_partition(x, window_size: int):
+ batch_size, height, width, channels = x.shape
+ x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels)
+ return windows
+
+
+def window_reverse(windows, batch_size: int, window_size: int, height: int, width: int):
+ # this will cause onnx conversion failed for dynamic axis, because treated as constant
+ # int(windows.shape[0] / (height * width / window_size / window_size))
+ x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ def __init__(self, dim, num_heads, window_size, qkv_bias=True):
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = float(head_dim) ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, size):
+ height, width = size
+ batch_size, seq_len, channels = x.shape
+ assert seq_len == height * width, "input feature has wrong size"
+
+ x = x.view(batch_size, height, width, channels)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - width % self.window_size) % self.window_size
+ pad_b = (self.window_size - height % self.window_size) % self.window_size
+ x = functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, height_padded, width_padded, _ = x.shape
+
+ x = window_partition(x, self.window_size)
+ x = x.view(-1, self.window_size * self.window_size, channels)
+
+ # W-MSA/SW-MSA
+ # attn_windows = self.attn(x_windows)
+
+ batch_windows, num_tokens, channels = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(batch_windows, num_tokens, 3, self.num_heads, channels // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = self.softmax(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(batch_windows, num_tokens, channels)
+ x = self.proj(x)
+
+ # merge windows
+ x = x.view(-1, self.window_size, self.window_size, channels)
+ x = window_reverse(x, batch_size, self.window_size, height_padded, width_padded)
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :height, :width, :].contiguous()
+
+ x = x.view(batch_size, height * width, channels)
+
+ return x, size
+
+
+class SpatialBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ window_size,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path_rate=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ conv_at_attn=True,
+ conv_at_ffn=True,
+ ):
+ super().__init__()
+
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
+ self.window_attn = PreNorm(
+ norm_layer(dim), WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), drop_path
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
+ drop_path,
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.window_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+ return x, size
+
+
+class DaViT(nn.Module):
+ """DaViT: Dual-Attention Transformer
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3.
+ num_classes (int): Number of classes for classification head. Default: 1000.
+ patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2).
+ patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2).
+ patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0).
+ patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False).
+ embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256).
+ num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16).
+ num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16).
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ enable_checkpoint (bool): If True, enable checkpointing. Default: False.
+ conv_at_attn (bool): If True, perform depthwise convolution before attention layer. Default: True.
+ conv_at_ffn (bool): If True, perform depthwise convolution before ffn layer. Default: True.
+ """
+
+ def __init__(
+ self,
+ in_chans=3,
+ num_classes=1000,
+ depths=(1, 1, 3, 1),
+ patch_size=(7, 2, 2, 2),
+ patch_stride=(4, 2, 2, 2),
+ patch_padding=(3, 0, 0, 0),
+ patch_prenorm=(False, False, False, False),
+ embed_dims=(64, 128, 192, 256),
+ num_heads=(3, 6, 12, 24),
+ num_groups=(3, 6, 12, 24),
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ enable_checkpoint=False,
+ conv_at_attn=True,
+ conv_at_ffn=True,
+ ):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.num_groups = num_groups
+ self.num_stages = len(self.embed_dims)
+ self.enable_checkpoint = enable_checkpoint
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
+
+ num_stages = len(embed_dims)
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)]
+
+ depth_offset = 0
+ convs = []
+ blocks = []
+ for i in range(num_stages):
+ conv_embed = ConvEmbed(
+ patch_size=patch_size[i],
+ stride=patch_stride[i],
+ padding=patch_padding[i],
+ in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
+ embed_dim=self.embed_dims[i],
+ norm_layer=norm_layer,
+ pre_norm=patch_prenorm[i],
+ )
+ convs.append(conv_embed)
+
+ block = MySequential(
+ *[
+ MySequential(
+ OrderedDict(
+ [
+ (
+ "spatial_block",
+ SpatialBlock(
+ embed_dims[i],
+ num_heads[i],
+ window_size,
+ drop_path_rate=dpr[depth_offset + j * 2],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ ),
+ ),
+ (
+ "channel_block",
+ ChannelBlock(
+ embed_dims[i],
+ num_groups[i],
+ drop_path_rate=dpr[depth_offset + j * 2 + 1],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ ),
+ ),
+ ]
+ )
+ )
+ for j in range(depths[i])
+ ]
+ )
+ blocks.append(block)
+ depth_offset += depths[i] * 2
+
+ self.convs = nn.ModuleList(convs)
+ self.blocks = nn.ModuleList(blocks)
+
+ self.norms = norm_layer(self.embed_dims[-1])
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
+
+ @property
+ def dim_out(self):
+ return self.embed_dims[-1]
+
+ def forward_features_unpool(self, x):
+ """
+ forward until avg pooling
+ Args:
+ x (_type_): input image tensor
+ """
+ input_size = (x.size(2), x.size(3))
+ for conv, block in zip(self.convs, self.blocks, strict=False):
+ x, input_size = conv(x, input_size)
+ if self.enable_checkpoint:
+ x, input_size = checkpoint.checkpoint(block, x, input_size)
+ else:
+ x, input_size = block(x, input_size)
+ return x
+
+ def forward_features(self, x):
+ x = self.forward_features_unpool(x)
+
+ # (batch_size, num_tokens, token_dim)
+ x = self.avgpool(x.transpose(1, 2))
+ # (batch_size, 1, num_tokens)
+ x = torch.flatten(x, 1)
+ x = self.norms(x)
+
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(
+ depths=config.depths,
+ embed_dims=config.dim_embed,
+ num_heads=config.num_heads,
+ num_groups=config.num_groups,
+ patch_size=config.patch_size,
+ patch_stride=config.patch_stride,
+ patch_padding=config.patch_padding,
+ patch_prenorm=config.patch_prenorm,
+ drop_path_rate=config.drop_path_rate,
+ window_size=config.window_size,
+ )
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class Florence2LearnedPositionalEmbedding(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # Florence2 is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+ bsz, seq_len = input_ids.shape[:2]
+ positions = torch.arange(
+ past_key_values_length,
+ past_key_values_length + seq_len,
+ dtype=torch.long,
+ device=self.weight.device,
+ ).expand(bsz, -1)
+
+ return super().forward(positions + self.offset)
+
+
+class Florence2ScaledWordEmbedding(nn.Embedding):
+ """
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+ """
+
+ def __init__(
+ self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0
+ ):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.embed_scale = embed_scale
+
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale
+
+
+class Florence2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Florence2LanguageConfig | None = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: torch.Tensor | None = None,
+ past_key_value: tuple[torch.Tensor] | None = None,
+ attention_mask: torch.Tensor | None = None,
+ layer_head_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
+ bsz, self.num_heads, tgt_len, src_len
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class Florence2FlashAttention2(Florence2Attention):
+ """
+ Florence2 flash attention module. This module inherits from `Florence2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: torch.Tensor | None = None,
+ past_key_value: tuple[torch.Tensor] | None = None,
+ attention_mask: torch.Tensor | None = None,
+ layer_head_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
+ # Florence2FlashAttention2 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("Florence2FlashAttention2 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
+ query_layer, attention_mask
+ )
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class Florence2SdpaAttention(Florence2Attention):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: torch.Tensor | None = None,
+ past_key_value: tuple[torch.Tensor] | None = None,
+ attention_mask: torch.Tensor | None = None,
+ layer_head_mask: torch.Tensor | None = None,
+ output_attentions: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
+ """Input shape: Batch x Time x Channel"""
+ if output_attentions or layer_head_mask is not None:
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Florence2Model is using Florence2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states,
+ key_value_states=key_value_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ query_states = self._shape(query_states, tgt_len, bsz)
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
+ is_causal = bool(self.is_causal and attention_mask is None and tgt_len > 1)
+
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+FLORENCE2_ATTENTION_CLASSES = {
+ "eager": Florence2Attention,
+ "sdpa": Florence2SdpaAttention,
+ "flash_attention_2": Florence2FlashAttention2,
+}
+
+
+class Florence2EncoderLayer(nn.Module):
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ output_attentions: bool | None = False,
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(
+ hidden_states, p=self.activation_dropout, training=self.training
+ )
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class Florence2DecoderLayer(nn.Module):
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation](
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ encoder_hidden_states: torch.Tensor | None = None,
+ encoder_attention_mask: torch.Tensor | None = None,
+ layer_head_mask: torch.Tensor | None = None,
+ cross_attn_layer_head_mask: torch.Tensor | None = None,
+ past_key_value: tuple[torch.Tensor] | None = None,
+ output_attentions: bool | None = False,
+ use_cache: bool | None = True,
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(
+ hidden_states, p=self.activation_dropout, training=self.training
+ )
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class Florence2LanguagePreTrainedModel(PreTrainedModel):
+ config_class = Florence2LanguageConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
+ _no_split_modules = [r"Florence2EncoderLayer", r"Florence2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.Conv2d):
+ nn.init.normal_(module.weight, std=0.02)
+ for name, _ in module.named_parameters():
+ if name == "bias":
+ nn.init.constant_(module.bias, 0)
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
+ nn.init.constant_(module.weight, 1.0)
+ nn.init.constant_(module.bias, 0)
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+class Florence2Encoder(Florence2LanguagePreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`Florence2EncoderLayer`].
+
+ Args:
+ config: Florence2LanguageConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: Florence2LanguageConfig, embed_tokens: nn.Embedding | None = None):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ self.embed_tokens = Florence2ScaledWordEmbedding(
+ config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
+ )
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = Florence2LearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ )
+ self.layers = nn.ModuleList([Florence2EncoderLayer(config) for _ in range(config.encoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ head_mask: torch.Tensor | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | BaseModelOutput:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif inputs_embeds is not None:
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ embed_pos = self.embed_positions(input)
+ embed_pos = embed_pos.to(inputs_embeds.device)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ if self._use_flash_attention_2:
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self._use_sdpa and head_mask is None and not output_attentions:
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None and head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class Florence2Decoder(Florence2LanguagePreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Florence2DecoderLayer`]
+
+ Args:
+ config: Florence2LanguageConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: Florence2LanguageConfig, embed_tokens: nn.Embedding | None = None):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ self.embed_tokens = Florence2ScaledWordEmbedding(
+ config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
+ )
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = Florence2LearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ )
+ self.layers = nn.ModuleList([Florence2DecoderLayer(config) for _ in range(config.decoder_layers)])
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ encoder_hidden_states: torch.FloatTensor | None = None,
+ encoder_attention_mask: torch.LongTensor | None = None,
+ head_mask: torch.Tensor | None = None,
+ cross_attn_head_mask: torch.Tensor | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input)
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self._use_flash_attention_2:
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # embed positions
+ positions = self.embed_positions(input, past_key_values_length)
+ positions = positions.to(inputs_embeds.device)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = self.layernorm_embedding(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip(
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"], strict=False
+ ):
+ if attn_mask is not None and attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__(config)
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+ self.encoder = Florence2Encoder(config, self.shared)
+ self.decoder = Florence2Decoder(config, self.shared)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
+ # self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ decoder_input_ids: torch.LongTensor | None = None,
+ decoder_attention_mask: torch.LongTensor | None = None,
+ head_mask: torch.Tensor | None = None,
+ decoder_head_mask: torch.Tensor | None = None,
+ cross_attn_head_mask: torch.Tensor | None = None,
+ encoder_outputs: list[torch.FloatTensor] | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | Seq2SeqModelOutput:
+ # different to other models, Florence2 automatically creates decoder_input_ids from
+ # input_ids if no decoder_input_ids are provided
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ if input_ids is None:
+ raise ValueError(
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
+ "passed, `input_ids` cannot be `None`. Please pass either "
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
+ )
+
+ decoder_input_ids = shift_tokens_right(
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
+ base_model_prefix = "model"
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+
+ def __init__(self, config: Florence2LanguageConfig):
+ super().__init__(config)
+ self.model = Florence2LanguageModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
+ # self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
+ # self._tie_or_clone_weights(self.lm_head, self.model.shared)
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_token_embeddings(
+ self, new_num_tokens: int, pad_to_multiple_of: int | None = None, **kwargs
+ ) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs)
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros(
+ (1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device
+ )
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ decoder_input_ids: torch.LongTensor | None = None,
+ decoder_attention_mask: torch.LongTensor | None = None,
+ head_mask: torch.Tensor | None = None,
+ decoder_head_mask: torch.Tensor | None = None,
+ cross_attn_head_mask: torch.Tensor | None = None,
+ encoder_outputs: list[torch.FloatTensor] | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | Seq2SeqLMOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if use_cache:
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+ use_cache = False
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ lm_logits = self.lm_head(outputs[0])
+ lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
+
+ masked_lm_loss = None
+ if labels is not None:
+ labels = labels.to(lm_logits.device)
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx.to(past_state.device))
+ for past_state in layer_past[:2]
+ )
+ + layer_past[2:],
+ )
+ return reordered_past
+
+
+@dataclass
+class Florence2Seq2SeqLMOutput(ModelOutput):
+ """
+ Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+ weighted average in the cross-attention heads.
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size,
+ num_image_tokens, hidden_size)`.
+
+ image_hidden_states of the model produced by the vision encoder
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor = None
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: tuple[tuple[torch.FloatTensor]] | None = None
+ decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
+ decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
+ cross_attentions: tuple[torch.FloatTensor, ...] | None = None
+ encoder_last_hidden_state: torch.FloatTensor | None = None
+ encoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
+ encoder_attentions: tuple[torch.FloatTensor, ...] | None = None
+ image_hidden_states: tuple[torch.FloatTensor, ...] | None = None
+
+
+FLORENCE2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Florence2Config`] or [`Florence2VisionConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Florence-2 Model outputting raw hidden-states without any specific head on top.",
+ FLORENCE2_START_DOCSTRING,
+)
+class Florence2PreTrainedModel(PreTrainedModel):
+ config_class = Florence2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+
+FLORENCE2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Florence2Processor`] uses
+ [`CLIPImageProcessor`] for processing images).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ """The FLORENCE2 model which consists of a vision backbone and a language model.""",
+ FLORENCE2_START_DOCSTRING,
+)
+class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
+ _tied_weights_keys = [
+ "language_model.encoder.embed_tokens.weight",
+ "language_model.decoder.embed_tokens.weight",
+ "language_model.lm_head.weight",
+ ]
+
+ def __init__(self, config: Florence2Config):
+ super().__init__(config)
+ assert config.vision_config.model_type == "davit", "only DaViT is supported for now"
+ self.vision_tower = DaViT.from_config(config=config.vision_config)
+ # remove unused layers
+ del self.vision_tower.head
+ del self.vision_tower.norms
+
+ self.vocab_size = config.vocab_size
+ self._attn_implementation = config._attn_implementation
+ self._build_image_projection_layers(config)
+
+ language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
+
+ self.language_model = language_model
+
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ def _build_image_projection_layers(self, config):
+ image_dim_out = config.vision_config.dim_embed[-1]
+ dim_projection = config.vision_config.projection_dim
+ self.image_projection = nn.Parameter(torch.empty(image_dim_out, dim_projection))
+ self.image_proj_norm = nn.LayerNorm(dim_projection)
+ image_pos_embed_config = config.vision_config.image_pos_embed
+ if image_pos_embed_config["type"] == "learned_abs_2d":
+ self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
+ embedding_dim=image_dim_out, num_pos=image_pos_embed_config["max_pos_embeddings"]
+ )
+ else:
+ raise NotImplementedError("Not implemented yet")
+
+ self.image_feature_source = config.vision_config.image_feature_source
+
+ # temporal embedding
+ visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding
+ if visual_temporal_embedding_config["type"] == "COSINE":
+ self.visual_temporal_embed = PositionalEmbeddingCosine1D(
+ embed_dim=image_dim_out,
+ max_seq_len=visual_temporal_embedding_config["max_temporal_embeddings"],
+ )
+ else:
+ raise NotImplementedError("Not implemented yet")
+
+ def get_encoder(self):
+ return self.language_model.get_encoder()
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def resize_token_embeddings(
+ self, new_num_tokens: int | None = None, pad_to_multiple_of=None, **kwargs
+ ) -> nn.Embedding:
+ model_embeds = self.language_model.resize_token_embeddings(
+ new_num_tokens, pad_to_multiple_of, **kwargs
+ )
+ # update vocab size
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
+ self.config.vocab_size = model_embeds.num_embeddings
+ self.vocab_size = model_embeds.num_embeddings
+ return model_embeds
+
+ def _encode_image(self, pixel_values):
+ # Cast pixel_values to model's dtype
+ pixel_values = pixel_values.to(dtype=self.vision_tower.convs[0].proj.weight.dtype)
+
+ if len(pixel_values.shape) == 4:
+ batch_size, channels, height, width = pixel_values.shape
+ num_frames = 1
+ x = self.vision_tower.forward_features_unpool(pixel_values)
+ else:
+ raise ValueError(f"invalid image shape {pixel_values.shape}")
+
+ if self.image_pos_embed is not None:
+ x = x.view(batch_size * num_frames, -1, x.shape[-1])
+ num_tokens = x.shape[-2]
+ h, w = int(num_tokens**0.5), int(num_tokens**0.5)
+ assert h * w == num_tokens, "only support square feature maps for now"
+ x = x.view(batch_size * num_frames, h, w, x.shape[-1])
+ pos_embed = self.image_pos_embed(x)
+ x = x + pos_embed
+ x = x.view(batch_size, num_frames * h * w, x.shape[-1])
+
+ if self.visual_temporal_embed is not None:
+ visual_temporal_embed = self.visual_temporal_embed(
+ x.view(batch_size, num_frames, -1, x.shape[-1])[:, :, 0]
+ )
+ x = x.view(batch_size, num_frames, -1, x.shape[-1]) + visual_temporal_embed.view(
+ 1, num_frames, 1, x.shape[-1]
+ )
+
+ x_feat_dict = {}
+
+ spatial_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=2)
+ x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
+
+ temporal_avg_pool_x = x.view(batch_size, num_frames, -1, x.shape[-1]).mean(dim=1)
+ x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
+
+ x = x.view(batch_size, num_frames, -1, x.shape[-1])[:, -1]
+ x_feat_dict["last_frame"] = x
+
+ new_x = []
+ for _image_feature_source in self.image_feature_source:
+ if _image_feature_source not in x_feat_dict:
+ raise ValueError(f"invalid image feature source: {_image_feature_source}")
+ new_x.append(x_feat_dict[_image_feature_source])
+
+ x = torch.cat(new_x, dim=1)
+
+ x = x @ self.image_projection
+ x = self.image_proj_norm(x)
+
+ return x
+
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds):
+ batch_size, image_token_length = image_features.size()[:-1]
+ device = image_features.device
+ image_attention_mask = torch.ones(batch_size, image_token_length, device=device)
+
+ # task_prefix_embeds: [batch_size, padded_context_length, hidden_size]
+ # task_prefix_attention_mask: [batch_size, context_length]
+ if inputs_embeds is None:
+ return image_features, image_attention_mask
+
+ task_prefix_embeds = inputs_embeds
+ task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
+
+ if len(task_prefix_attention_mask.shape) == 3:
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
+
+ # concat [image embeds, task prefix embeds]
+ inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
+ attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1)
+
+ return inputs_embeds, attention_mask
+
+ @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: torch.Tensor | None = None,
+ decoder_input_ids: torch.LongTensor | None = None,
+ decoder_attention_mask: torch.LongTensor | None = None,
+ head_mask: torch.Tensor | None = None,
+ decoder_head_mask: torch.Tensor | None = None,
+ cross_attn_head_mask: torch.Tensor | None = None,
+ encoder_outputs: list[torch.FloatTensor] | None = None,
+ past_key_values: list[torch.FloatTensor] | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ decoder_inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ use_cache: bool | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ ) -> tuple | Florence2Seq2SeqLMOutput:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+ >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large")
+ >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large")
+
+ >>> prompt = ""
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_length=100)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "A green car parked in front of a yellow building."
+ ```"""
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ image_features = None
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ # 2. Merge text and images
+ if pixel_values is not None:
+ # (batch_size, num_image_tokens, hidden_size)
+ image_features = self._encode_image(pixel_values)
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds
+ )
+
+ if inputs_embeds is not None:
+ attention_mask = attention_mask.to(inputs_embeds.dtype)
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ labels=labels,
+ inputs_embeds=inputs_embeds,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs.logits
+ logits = logits.float()
+ loss = outputs.loss
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Florence2Seq2SeqLMOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ image_hidden_states=image_features,
+ )
+
+ def generate(self, input_ids, inputs_embeds=None, pixel_values=None, **kwargs):
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ if input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ # 2. Merge text and images
+ if pixel_values is not None:
+ image_features = self._encode_image(pixel_values)
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds
+ )
+
+ return self.language_model.generate(input_ids=None, inputs_embeds=inputs_embeds, **kwargs)
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ pixel_values=None,
+ decoder_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if decoder_input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "decoder_attention_mask": decoder_attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return self.language_model.shift_tokens_right(labels)
+
+ def _reorder_cache(self, *args, **kwargs):
+ return self.language_model._reorder_cache(*args, **kwargs)
diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py
new file mode 100644
index 000000000..27c7c6e1b
--- /dev/null
+++ b/src/lerobot/policies/xvla/modeling_xvla.py
@@ -0,0 +1,548 @@
+#!/usr/bin/env python
+
+# ------------------------------------------------------------------------------
+# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ------------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import builtins
+import logging
+import os
+from collections import deque
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F # noqa: N812
+from torch import Tensor, nn
+
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.policies.pretrained import PreTrainedPolicy, T
+from lerobot.policies.utils import populate_queues
+from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
+
+from .action_hub import build_action_space
+from .configuration_florence2 import Florence2Config
+from .configuration_xvla import XVLAConfig
+from .modeling_florence2 import Florence2ForConditionalGeneration
+from .soft_transformer import SoftPromptedTransformer
+
+
+class XVLAModel(nn.Module):
+ """
+ XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
+ """
+
+ def __init__(
+ self,
+ config: XVLAConfig,
+ florence_config: Florence2Config,
+ proprio_dim: int,
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.chunk_size: int = config.chunk_size
+ self.use_proprio: bool = config.use_proprio
+
+ # Build action space with auto-detection for "auto" mode
+ if config.action_mode.lower() == "auto":
+ # Auto-detect real action dim from config.action_feature
+ real_dim = (
+ config.action_feature.shape[-1]
+ if config.action_feature is not None
+ else config.max_action_dim
+ )
+ self.action_space = build_action_space(
+ config.action_mode.lower(),
+ real_dim=real_dim,
+ max_dim=config.max_action_dim,
+ )
+ else:
+ self.action_space = build_action_space(config.action_mode.lower())
+
+ self.dim_action = self.action_space.dim_action
+ self.dim_proprio = proprio_dim
+
+ self.vlm = Florence2ForConditionalGeneration(florence_config)
+ if hasattr(self.vlm, "language_model"):
+ lm = self.vlm.language_model
+ if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
+ del lm.model.decoder
+ if hasattr(lm, "lm_head"):
+ del lm.lm_head
+
+ projection_dim = getattr(self.vlm.config, "projection_dim", None)
+ if projection_dim is None:
+ raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
+
+ self.transformer = SoftPromptedTransformer(
+ hidden_size=config.hidden_size,
+ multi_modal_input_size=projection_dim,
+ depth=config.depth,
+ num_heads=config.num_heads,
+ mlp_ratio=config.mlp_ratio,
+ num_domains=config.num_domains,
+ dim_action=self.dim_action,
+ dim_propio=self.dim_proprio,
+ len_soft_prompts=config.len_soft_prompts,
+ dim_time=config.dim_time,
+ max_len_seq=config.max_len_seq,
+ use_hetero_proj=config.use_hetero_proj,
+ )
+
+ # Apply freezing based on config
+ self._apply_freezing()
+
+ # Apply dtype casting based on config
+ self._apply_dtype()
+
+ def _get_target_dtype(self) -> torch.dtype:
+ """Get the target dtype based on config."""
+ if self.config.dtype == "bfloat16":
+ return torch.bfloat16
+ return torch.float32
+
+ def _apply_dtype(self) -> None:
+ """
+ Apply dtype casting to model components based on config.
+ """
+ target_dtype = self._get_target_dtype()
+ self.to(dtype=target_dtype)
+
+ def _apply_freezing(self) -> None:
+ """
+ Freeze VLM vision and language encoders based on config options.
+ Keep only policy transformer and soft prompts trainable.
+ """
+ # Freeze vision encoder
+ if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
+ for param in self.vlm.vision_tower.parameters():
+ param.requires_grad = False
+
+ # Freeze language encoder
+ if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
+ lm = self.vlm.language_model
+ # Freeze encoder
+ if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
+ for param in lm.model.encoder.parameters():
+ param.requires_grad = False
+ # Freeze shared embeddings
+ if hasattr(lm, "model") and hasattr(lm.model, "shared"):
+ for param in lm.model.shared.parameters():
+ param.requires_grad = False
+
+ # Freeze or unfreeze policy transformer
+ if not self.config.train_policy_transformer:
+ for name, param in self.transformer.named_parameters():
+ if "soft_prompts" not in name:
+ param.requires_grad = False
+
+ # Freeze or unfreeze soft prompts
+ if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
+ for param in self.transformer.soft_prompt_hub.parameters():
+ param.requires_grad = False
+
+ def forward_vlm(
+ self,
+ input_ids: torch.LongTensor,
+ pixel_values: torch.FloatTensor,
+ image_mask: torch.Tensor,
+ ) -> dict[str, torch.Tensor]:
+ """
+ Encode text and multi-view images via Florence2 encoder.
+ """
+ batch_size, num_views = pixel_values.shape[:2]
+ flat_mask = image_mask.view(-1).to(dtype=torch.bool)
+ flat_images = pixel_values.flatten(0, 1)
+ num_valid = int(flat_mask.sum().item())
+ if num_valid == 0:
+ raise ValueError("At least one image view must be valid per batch.")
+
+ valid_images = flat_images[flat_mask]
+ valid_feats = self.vlm._encode_image(valid_images)
+ tokens_per_view, hidden_dim = valid_feats.shape[1:]
+
+ image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
+ image_features[flat_mask] = valid_feats
+ image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
+ inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
+ merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
+ image_features[:, 0],
+ inputs_embeds,
+ )
+
+ enc_out = self.vlm.language_model.model.encoder(
+ attention_mask=attention_mask,
+ inputs_embeds=merged_embeds,
+ )[0]
+
+ aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
+ return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ image_input: torch.FloatTensor,
+ image_mask: torch.Tensor,
+ domain_id: torch.LongTensor,
+ proprio: torch.Tensor,
+ action: torch.Tensor,
+ ) -> dict[str, torch.Tensor]:
+ """
+ Forward pass for the XVLA model.
+ """
+ target_dtype = self._get_target_dtype()
+ image_input = image_input.to(dtype=target_dtype)
+ proprio = proprio.to(dtype=target_dtype)
+ action = action.to(dtype=target_dtype)
+
+ enc = self.forward_vlm(input_ids, image_input, image_mask)
+
+ batch_size = input_ids.shape[0]
+ t = (
+ torch.rand(1, device=input_ids.device, dtype=target_dtype)
+ + torch.arange(batch_size, device=input_ids.device, dtype=target_dtype) / batch_size
+ ) % (1 - 1e-5)
+
+ action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
+ proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
+
+ pred_action = self.transformer(
+ domain_id=domain_id,
+ action_with_noise=action_noisy_m,
+ t=t,
+ proprio=proprio_m,
+ **enc,
+ )
+ return self.action_space.compute_loss(pred_action, action)
+
+ @torch.no_grad()
+ def generate_actions(
+ self,
+ input_ids: torch.LongTensor,
+ image_input: torch.FloatTensor,
+ image_mask: torch.Tensor,
+ domain_id: torch.LongTensor,
+ proprio: torch.Tensor,
+ steps: int,
+ ) -> torch.Tensor:
+ self.eval()
+
+ target_dtype = self._get_target_dtype()
+ image_input = image_input.to(dtype=target_dtype)
+ proprio = proprio.to(dtype=target_dtype)
+
+ enc = self.forward_vlm(input_ids, image_input, image_mask)
+
+ batch_size = input_ids.shape[0]
+ action_dim = self.dim_action
+
+ x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=target_dtype)
+ action = torch.zeros_like(x1)
+
+ steps = max(1, int(steps))
+ for i in range(steps, 0, -1):
+ t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=target_dtype)
+ x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
+ proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
+ action = self.transformer(
+ domain_id=domain_id,
+ action_with_noise=x_t_m,
+ proprio=proprio_m,
+ t=t,
+ **enc,
+ )
+ return self.action_space.postprocess(action)
+
+
+class XVLAPolicy(PreTrainedPolicy):
+ """LeRobot-compliant wrapper built around the XVLA model."""
+
+ config_class = XVLAConfig
+ name = "xvla"
+
+ def __init__(self, config: XVLAConfig):
+ super().__init__(config)
+ config.validate_features()
+ florence_config = config.get_florence_config()
+ proprio_dim = config.max_state_dim if config.use_proprio else 0
+ self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
+ self.reset()
+
+ def reset(self) -> None:
+ self._queues = {
+ ACTION: deque(maxlen=self.config.n_action_steps),
+ }
+
+ def get_optim_params(self) -> dict:
+ """Return trainable named parameters for optimization.
+
+ Returns a dict of name -> param for all trainable parameters.
+ This enables the xvla-adamw optimizer to apply differential learning rates
+ based on parameter names (e.g., 1/10 LR for VLM components).
+ """
+ return dict(filter(lambda kv: kv[1].requires_grad, self.named_parameters()))
+
+ def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
+ if not self.config.use_proprio or OBS_STATE not in batch:
+ return torch.zeros(batch_size, 0, device=device)
+ state = batch[OBS_STATE]
+ if state.ndim > 2:
+ state = state[:, -1, :]
+ return pad_vector(state, self.model.dim_proprio)
+
+ def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
+ present_img_keys = [key for key in self.config.image_features if key in batch]
+ if len(present_img_keys) == 0:
+ raise ValueError(
+ "All image features are missing from the batch. "
+ f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
+ )
+
+ images = []
+ masks = []
+ for key in present_img_keys:
+ img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
+ if self.config.resize_imgs_with_padding is not None:
+ img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
+ images.append(img)
+ masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
+
+ stacked_imgs = torch.stack(images, dim=1)
+ stacked_masks = torch.stack(masks, dim=1)
+
+ total_views = self.config.num_image_views or stacked_imgs.size(1)
+ total_views = max(total_views, stacked_imgs.size(1))
+ num_pad = total_views - stacked_imgs.size(1)
+ if num_pad > 0:
+ pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
+ pad_imgs = stacked_imgs.new_zeros(pad_shape)
+ pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
+ stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
+ stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
+
+ return stacked_imgs, stacked_masks
+
+ def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
+ candidate = None
+ if self.config.domain_feature_key and self.config.domain_feature_key in batch:
+ candidate = batch[self.config.domain_feature_key]
+ elif "domain_id" in batch:
+ candidate = batch["domain_id"]
+
+ if candidate is None:
+ return torch.zeros(batch_size, dtype=torch.long, device=device)
+
+ if not isinstance(candidate, torch.Tensor):
+ candidate = torch.as_tensor(candidate, device=device)
+ else:
+ candidate = candidate.to(device=device)
+
+ if candidate.ndim == 0:
+ candidate = candidate.expand(batch_size)
+ if candidate.ndim > 1:
+ candidate = candidate.view(candidate.shape[0], -1)[:, 0]
+ if candidate.shape[0] != batch_size:
+ candidate = candidate.expand(batch_size)
+ return candidate.to(dtype=torch.long)
+
+ def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
+ if ACTION not in batch:
+ raise ValueError("Batch is missing action targets required for training.")
+ actions = batch[ACTION]
+ if actions.ndim == 2:
+ actions = actions.unsqueeze(1)
+ actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
+ if actions.shape[-1] != self.model.dim_action:
+ actions = pad_vector(actions, self.model.dim_action)
+ return actions
+
+ def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
+ input_ids = batch[OBS_LANGUAGE_TOKENS]
+ batch_size = input_ids.shape[0]
+ images, image_mask = self._prepare_images(batch)
+ domain_id = self._get_domain_id(batch, batch_size, images.device)
+ proprio = self._prepare_state(batch, batch_size, images.device)
+ return {
+ "input_ids": input_ids,
+ "image_input": images,
+ "image_mask": image_mask,
+ "domain_id": domain_id,
+ "proprio": proprio,
+ }
+
+ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
+ inputs = self._build_model_inputs(batch)
+ targets = self._prepare_action_targets(batch)
+ losses = self.model(action=targets, **inputs)
+ total_loss = sum(losses.values())
+
+ log_dict = {k: v.detach().item() for k, v in losses.items()}
+ log_dict["loss"] = total_loss.detach().item()
+ return total_loss, log_dict
+
+ def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
+ inputs = self._build_model_inputs(batch)
+ actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
+ return actions
+
+ @torch.no_grad()
+ def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
+ self.eval()
+ self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
+ return self._get_action_chunk(batch)
+
+ @torch.no_grad()
+ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
+ self.eval()
+ self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
+
+ if len(self._queues[ACTION]) == 0:
+ actions = self._get_action_chunk(batch)
+ self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
+
+ return self._queues[ACTION].popleft()
+
+ @classmethod
+ def from_pretrained(
+ cls: builtins.type[T],
+ pretrained_name_or_path: str | Path,
+ *,
+ config: PreTrainedConfig | None = None,
+ force_download: bool = False,
+ resume_download: bool | None = None,
+ proxies: dict | None = None,
+ token: str | bool | None = None,
+ cache_dir: str | Path | None = None,
+ local_files_only: bool = False,
+ revision: str | None = None,
+ strict: bool = False,
+ **kwargs,
+ ):
+ """
+ Loads XVLA model weights with:
+ - automatic prefix 'model.' added to all keys
+ - skip list for layers that should remain randomly initialized
+ """
+ import safetensors.torch
+
+ # step 1: load config
+ # TODO: jadechoghari, fix this
+ if config is None:
+ config = PreTrainedConfig.from_pretrained(
+ pretrained_name_or_path=pretrained_name_or_path,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ token=token,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ revision=revision,
+ **kwargs,
+ )
+
+ model_id = str(pretrained_name_or_path)
+ instance = cls(config, **kwargs)
+ # step 2: locate model.safetensors
+ if os.path.isdir(model_id):
+ logging.info("Loading weights from local directory")
+ model_file = os.path.join(model_id, "model.safetensors")
+ else:
+ try:
+ from huggingface_hub import hf_hub_download
+ from huggingface_hub.utils import HfHubHTTPError
+
+ model_file = hf_hub_download(
+ repo_id=model_id,
+ filename="model.safetensors",
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ token=token,
+ local_files_only=local_files_only,
+ )
+ except HfHubHTTPError as e:
+ raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
+
+ logging.info(f"Loading checkpoint from {model_file}")
+ # step 3: load state dict
+ state_dict = safetensors.torch.load_file(model_file)
+ encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
+ shared_key = "model.vlm.language_model.model.shared.weight"
+ if encoder_key in state_dict:
+ state_dict[shared_key] = state_dict[encoder_key]
+ # or deepcopy
+ # step 4: load into instance
+ instance.load_state_dict(state_dict, strict=True)
+ logging.info("Loaded XVLA checkpoint")
+ # step 5: finalize
+ # Reapply dtype after loading state dict
+ instance.model._apply_dtype()
+ instance.to(config.device)
+ instance.eval()
+ return instance
+
+
+def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
+ if img.ndim != 4:
+ raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
+
+ current_height, current_width = img.shape[2:]
+ if current_height == height and current_width == width:
+ return img
+
+ ratio = max(current_width / width, current_height / height)
+ resized_height = int(current_height / ratio)
+ resized_width = int(current_width / ratio)
+ resized_img = F.interpolate(
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ pad_height = max(0, height - resized_height)
+ pad_width = max(0, width - resized_width)
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
+ return padded_img
+
+
+def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
+ if vector.shape[-1] == new_dim:
+ return vector
+ if new_dim == 0:
+ shape = list(vector.shape)
+ shape[-1] = 0
+ return vector.new_zeros(*shape)
+ shape = list(vector.shape)
+ current_dim = shape[-1]
+ shape[-1] = new_dim
+ new_vector = vector.new_zeros(*shape)
+ length = min(current_dim, new_dim)
+ new_vector[..., :length] = vector[..., :length]
+ return new_vector
+
+
+def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
+ current_len = tensor.size(dim)
+ if current_len == target_len:
+ return tensor
+ if current_len > target_len:
+ slices = [slice(None)] * tensor.dim()
+ slices[dim] = slice(0, target_len)
+ return tensor[tuple(slices)]
+ pad_shape = list(tensor.shape)
+ pad_shape[dim] = target_len - current_len
+ pad_tensor = tensor.new_zeros(pad_shape)
+ return torch.cat([tensor, pad_tensor], dim=dim)
diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py
new file mode 100644
index 000000000..7f7297b9a
--- /dev/null
+++ b/src/lerobot/policies/xvla/processor_xvla.py
@@ -0,0 +1,554 @@
+# ------------------------------------------------------------------------------
+# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ------------------------------------------------------------------------------
+
+from dataclasses import dataclass
+from typing import Any
+
+import numpy as np
+import torch
+
+from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.datasets.factory import IMAGENET_STATS
+from lerobot.policies.xvla.configuration_xvla import XVLAConfig
+from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ ObservationProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ ProcessorStep,
+ ProcessorStepRegistry,
+ RenameObservationsProcessorStep,
+ TokenizerProcessorStep,
+ UnnormalizerProcessorStep,
+)
+from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+from lerobot.processor.core import EnvTransition, TransitionKey
+from lerobot.utils.constants import (
+ OBS_IMAGES,
+ OBS_STATE,
+ POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ POLICY_PREPROCESSOR_DEFAULT_NAME,
+)
+
+
+def make_xvla_pre_post_processors(
+ config: XVLAConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Build the LeRobot processor pipelines for XVLA.
+ """
+
+ features = {**config.input_features, **config.output_features}
+ input_steps = [
+ RenameObservationsProcessorStep(rename_map={}),
+ AddBatchDimensionProcessorStep(),
+ TokenizerProcessorStep(
+ tokenizer_name=config.tokenizer_name,
+ max_length=config.tokenizer_max_length,
+ padding=config.pad_language_to,
+ padding_side=config.tokenizer_padding_side,
+ ),
+ XVLAImageToFloatProcessorStep(),
+ XVLAImageNetNormalizeProcessorStep(),
+ XVLAAddDomainIdProcessorStep(),
+ DeviceProcessorStep(device=config.device),
+ NormalizerProcessorStep(
+ features=features, norm_map=config.normalization_mapping, stats=dataset_stats
+ ),
+ ]
+ output_steps = [
+ UnnormalizerProcessorStep(
+ features=config.output_features,
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ DeviceProcessorStep(device="cpu"),
+ ]
+
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=input_steps,
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=output_steps,
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
+
+
+# Custom XVLA processor steps
+@dataclass
+class LiberoProcessorStep(ObservationProcessorStep):
+ """
+ Processes LIBERO observations into the LeRobot format.
+
+ This step handles the specific observation structure from LIBERO environments,
+ which includes nested robot_state dictionaries and image observations.
+
+ **State Processing:**
+ - Processes the `robot_state` dictionary which contains nested end-effector,
+ gripper, and joint information.
+ - Extracts and concatenates:
+ - End-effector position (3D)
+ - End-effector quaternion converted to axis-angle (3D)
+ - Gripper joint positions (2D)
+ - Maps the concatenated state to `"observation.state"`.
+
+ **Image Processing:**
+ - Rotates images by 180 degrees by flipping both height and width dimensions.
+ - This accounts for the HuggingFaceVLA/libero camera orientation convention.
+ """
+
+ def _process_observation(self, observation):
+ """
+ Processes both image and robot_state observations from LIBERO.
+ """
+ processed_obs = observation.copy()
+ for key in list(processed_obs.keys()):
+ if key.startswith(f"{OBS_IMAGES}."):
+ img = processed_obs[key]
+
+ if key == f"{OBS_IMAGES}.image":
+ # Flip both H and W
+ img = torch.flip(img, dims=[2, 3])
+
+ processed_obs[key] = img
+ # Process robot_state into a flat state vector
+ if "observation.robot_state" in processed_obs:
+ robot_state = processed_obs.pop("observation.robot_state")
+
+ # Extract components
+ eef_pos = robot_state["eef"]["pos"] # (B, 3,)
+ eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
+ eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
+
+ extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
+
+ proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
+ state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
+ # ensure float32
+ state = state.float()
+ if state.dim() == 1:
+ state = state.unsqueeze(0)
+
+ processed_obs[OBS_STATE] = state
+ return processed_obs
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """
+ Transforms feature keys from the LIBERO format to the LeRobot standard.
+ """
+ new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
+
+ # copy over non-STATE features
+ for ft, feats in features.items():
+ if ft != PipelineFeatureType.STATE:
+ new_features[ft] = feats.copy()
+
+ # rebuild STATE features
+ state_feats = {}
+
+ # add our new flattened state
+ state_feats["observation.state"] = PolicyFeature(
+ key="observation.state",
+ shape=(20,),
+ dtype="float32",
+ )
+
+ new_features[PipelineFeatureType.STATE] = state_feats
+
+ return new_features
+
+ def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
+ """
+ Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
+
+ Args:
+ rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
+
+ Returns:
+ Tensor: 6D rotation representation, shape (B, 6)
+
+ Raises:
+ TypeError: if input is not a torch tensor
+ ValueError: if shape is not (B, 3, 3)
+ """
+
+ if not isinstance(rot_mats, torch.Tensor):
+ raise TypeError(f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}")
+
+ if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
+ raise ValueError(f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}")
+
+ rot_mats = rot_mats.to(torch.float32)
+
+ col1 = rot_mats[:, :3, 0] # (B, 3)
+ col2 = rot_mats[:, :3, 1] # (B, 3)
+
+ rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
+
+ return rot6d
+
+ def observation(self, observation):
+ return self._process_observation(observation)
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="xvla_image_scale")
+class XVLAImageScaleProcessorStep(ProcessorStep):
+ """Scale image observations by 255 to convert from [0, 1] to [0, 255] range.
+
+ This processor step multiplies all image observations by 255, which is required
+ for XVLA models that expect images in uint8-like range.
+
+ Args:
+ image_keys: List of observation keys that contain images to scale.
+ If None, will automatically detect keys starting with "observation.images."
+ """
+
+ image_keys: list[str] | None = None
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Scale image observations by 255."""
+ new_transition = transition.copy()
+ obs = new_transition.get(TransitionKey.OBSERVATION, {})
+ if obs is None:
+ return new_transition
+
+ # Make a copy of observations to avoid modifying the original
+ obs = obs.copy()
+
+ # Determine which keys to scale
+ keys_to_scale = self.image_keys
+ if keys_to_scale is None:
+ # Auto-detect image keys
+ keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
+
+ # Scale each image
+ for key in keys_to_scale:
+ if key in obs and isinstance(obs[key], torch.Tensor):
+ obs[key] = obs[key] * 255
+
+ new_transition[TransitionKey.OBSERVATION] = obs
+ return new_transition
+
+ def transform_features(self, features):
+ """Image scaling doesn't change feature structure."""
+ return features
+
+ def get_config(self) -> dict[str, Any]:
+ """Return serializable configuration."""
+ return {
+ "image_keys": self.image_keys,
+ }
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="xvla_image_to_float")
+class XVLAImageToFloatProcessorStep(ProcessorStep):
+ """Convert image observations from [0, 255] to [0, 1] range.
+
+ This processor step divides image observations by 255 to convert from uint8-like
+ range [0, 255] to float range [0, 1]. This is typically used when loading images
+ that are stored as uint8 values.
+
+ Args:
+ image_keys: List of observation keys that contain images to convert.
+ If None, will automatically detect keys starting with "observation.images."
+ validate_range: If True, validates that input values are in [0, 255] range (default: True)
+
+ Raises:
+ ValueError: If validate_range is True and image values are not in [0, 255] range.
+ """
+
+ image_keys: list[str] | None = None
+ validate_range: bool = True
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Convert image observations from [0, 255] to [0, 1]."""
+ new_transition = transition.copy()
+ obs = new_transition.get(TransitionKey.OBSERVATION, {})
+ if obs is None:
+ return new_transition
+
+ # Make a copy of observations to avoid modifying the original
+ obs = obs.copy()
+
+ # Determine which keys to convert
+ keys_to_convert = self.image_keys
+ if keys_to_convert is None:
+ # Auto-detect image keys
+ keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
+
+ # Convert each image
+ for key in keys_to_convert:
+ if key in obs and isinstance(obs[key], torch.Tensor):
+ tensor = obs[key]
+
+ min_val = tensor.min().item()
+ max_val = tensor.max().item()
+
+ if max_val <= 1.0:
+ obs[key] = tensor.float() # ensure float dtype, but no division
+ continue
+ # Validate that values are in [0, 255] range if requested
+ if self.validate_range and (min_val < 0.0 or max_val > 255.0):
+ raise ValueError(
+ f"Image '{key}' has values outside [0, 255] range: "
+ f"min={min_val:.4f}, max={max_val:.4f}. "
+ f"Cannot convert to [0, 1] range."
+ )
+
+ # Convert to float and divide by 255
+ obs[key] = tensor.float() / 255.0
+
+ new_transition[TransitionKey.OBSERVATION] = obs
+ return new_transition
+
+ def transform_features(self, features):
+ """Image conversion doesn't change feature structure."""
+ return features
+
+ def get_config(self) -> dict[str, Any]:
+ """Return serializable configuration."""
+ return {
+ "image_keys": self.image_keys,
+ "validate_range": self.validate_range,
+ }
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
+class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
+ """Normalize image observations using ImageNet statistics.
+
+ This processor step applies ImageNet normalization (mean and std) to image observations.
+ It validates that input values are in the [0, 1] range before normalizing.
+
+ The normalization formula is: (image - mean) / std
+
+ Args:
+ image_keys: List of observation keys that contain images to normalize.
+ If None, will automatically detect keys starting with "observation.images."
+
+ Raises:
+ ValueError: If image values are not in the [0, 1] range.
+ """
+
+ image_keys: list[str] | None = None
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Normalize image observations using ImageNet statistics."""
+ new_transition = transition.copy()
+ obs = new_transition.get(TransitionKey.OBSERVATION, {})
+ if obs is None:
+ return new_transition
+
+ # Make a copy of observations to avoid modifying the original
+ obs = obs.copy()
+
+ # Determine which keys to normalize
+ keys_to_normalize = self.image_keys
+ if keys_to_normalize is None:
+ # Auto-detect image keys
+ keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
+
+ # Normalize each image
+ for key in keys_to_normalize:
+ if key in obs and isinstance(obs[key], torch.Tensor):
+ tensor = obs[key]
+
+ # Validate that values are in [0, 1] range
+ min_val = tensor.min().item()
+ max_val = tensor.max().item()
+ if min_val < 0.0 or max_val > 1.0:
+ raise ValueError(
+ f"Image '{key}' has values outside [0, 1] range: "
+ f"min={min_val:.4f}, max={max_val:.4f}. "
+ f"ImageNet normalization requires input values in [0, 1]."
+ )
+
+ # Apply ImageNet normalization
+ mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
+ std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
+
+ # Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
+ while mean.dim() < tensor.dim():
+ mean = mean.unsqueeze(0)
+ std = std.unsqueeze(0)
+
+ # Normalize: (image - mean) / std
+ obs[key] = (tensor - mean) / std
+
+ new_transition[TransitionKey.OBSERVATION] = obs
+ return new_transition
+
+ def transform_features(self, features):
+ """ImageNet normalization doesn't change feature structure."""
+ return features
+
+ def get_config(self) -> dict[str, Any]:
+ """Return serializable configuration."""
+ return {
+ "image_keys": self.image_keys,
+ }
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="xvla_add_domain_id")
+class XVLAAddDomainIdProcessorStep(ProcessorStep):
+ """Add domain_id to complementary data.
+
+ This processor step adds a domain_id tensor to the complementary data,
+ which is used by XVLA to identify different robot embodiments or task domains.
+
+ Args:
+ domain_id: The domain ID to add (default: 3)
+ """
+
+ domain_id: int = 0
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Add domain_id to complementary data."""
+ new_transition = transition.copy()
+ comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
+ comp = {} if comp is None else comp.copy()
+
+ # Infer batch size from observation tensors
+ obs = new_transition.get(TransitionKey.OBSERVATION, {})
+ batch_size = 1
+ if obs:
+ for v in obs.values():
+ if isinstance(v, torch.Tensor):
+ batch_size = v.shape[0]
+ break
+
+ # Add domain_id tensor
+ comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
+
+ new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
+ return new_transition
+
+ def transform_features(self, features):
+ """Domain ID addition doesn't change feature structure."""
+ return features
+
+ def get_config(self) -> dict[str, Any]:
+ """Return serializable configuration."""
+ return {
+ "domain_id": self.domain_id,
+ }
+
+
+@dataclass
+@ProcessorStepRegistry.register(name="xvla_rotation_6d_to_axis_angle")
+class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
+ """Convert 6D rotation representation to axis-angle and reorganize action dimensions.
+
+ This processor step takes actions with 6D rotation representation and converts them to
+ axis-angle representation, reorganizing the action dimensions as:
+ - action[:, :3] -> target_eef (end-effector position)
+ - action[:, 3:9] -> 6D rotation (converted to axis-angle, 3D)
+ - action[:, 9:10] -> gripper action
+
+ Final output: [target_eef (3), axis_angle (3), gripper (1)] = 7D action
+
+ Args:
+ expected_action_dim: Expected input action dimension (default: 10, supports 6D rotation + extras)
+ """
+
+ expected_action_dim: int = 10
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """Convert 6D rotation to axis-angle in action."""
+ new_transition = transition.copy()
+ action = new_transition.get(TransitionKey.ACTION)
+
+ if action is None or not isinstance(action, torch.Tensor):
+ return new_transition
+
+ # Convert to numpy for processing
+ device = action.device
+ dtype = action.dtype
+ action_np = action.cpu().numpy()
+
+ # Extract components
+ # action shape: (B, D) where D >= 10
+ target_eef = action_np[:, :3] # (B, 3)
+ rotation_6d = action_np[:, 3:9] # (B, 6)
+ target_act = action_np[:, 9:10] # (B, 1)
+
+ # Convert 6D rotation to axis-angle
+ target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
+
+ # Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
+ action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
+
+ # Convert gripper action to -1 or 1
+ action_np[:, -1] = np.where(action_np[:, -1] > 0.5, 1.0, -1.0)
+
+ # Convert back to tensor
+ action = torch.from_numpy(action_np).to(device=device, dtype=dtype)
+
+ new_transition[TransitionKey.ACTION] = action
+ return new_transition
+
+ def transform_features(self, features):
+ """Rotation conversion changes action dimension from 10 to 7."""
+ # Note: This is a simplified version. In practice, you might want to
+ # update the action feature shape in the features dict.
+ return features
+
+ def get_config(self) -> dict[str, Any]:
+ """Return serializable configuration."""
+ return {
+ "expected_action_dim": self.expected_action_dim,
+ }
+
+
+def make_xvla_libero_pre_post_processors() -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """
+ Build the LeRobot processor pipelines for XVLA with LIBERO environment.
+ """
+ pre_processor_steps: list[ProcessorStep] = []
+ post_processor_steps: list[ProcessorStep] = []
+ pre_processor_steps.extend(
+ [LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]
+ )
+ post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=pre_processor_steps,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=post_processor_steps,
+ ),
+ )
diff --git a/src/lerobot/policies/xvla/soft_transformer.py b/src/lerobot/policies/xvla/soft_transformer.py
new file mode 100644
index 000000000..77ceb6e26
--- /dev/null
+++ b/src/lerobot/policies/xvla/soft_transformer.py
@@ -0,0 +1,415 @@
+# ------------------------------------------------------------------------------
+# Copyright 2025 2toINF (https://github.com/2toINF)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ------------------------------------------------------------------------------
+
+from __future__ import annotations
+
+import math
+from collections.abc import Iterable
+from functools import partial
+from typing import Final
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as functional
+
+# ------------------------------- Small utils ----------------------------------
+
+
+def _to_2tuple(x) -> tuple:
+ """Minimal replacement for timm.layers.to_2tuple."""
+ if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
+ t = tuple(x)
+ return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
+ return (x, x)
+
+
+def _has_sdp_attention() -> bool:
+ """Check if we can use PyTorch fused scaled_dot_product_attention."""
+ return hasattr(functional, "scaled_dot_product_attention")
+
+
+# ---------------------------------- MLP --------------------------------------
+
+
+class Mlp(nn.Module):
+ """
+ MLP used in ViT-style blocks.
+
+ Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int | None = None,
+ out_features: int | None = None,
+ norm_layer: type[nn.Module] | None = None,
+ bias: bool | tuple[bool, bool] = True,
+ drop: float | tuple[float, float] = 0.0,
+ use_conv: bool = False,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = _to_2tuple(bias)
+ drop_probs = _to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = nn.GELU(approximate="tanh")
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Expect [B, T, C] for Linear variant; caller is responsible for shapes.
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.norm(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+# -------------------------------- Attention ----------------------------------
+
+
+class Attention(nn.Module):
+ """
+ Multi-Head Self-Attention with optional fused SDPA fallback.
+
+ If PyTorch provides `scaled_dot_product_attention`, it will be used
+ (usually faster and more stable); otherwise we use a manual implementation.
+ """
+
+ fused_attn: Final[bool]
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: type[nn.Module] = nn.LayerNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.fused_attn = _has_sdp_attention()
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Parameters
+ ----------
+ x : Tensor, shape [batch_size, seq_len, channels]
+ Input sequence.
+
+ Returns
+ -------
+ Tensor, shape [batch_size, seq_len, channels]
+ Output sequence after MHSA + projection.
+ """
+ batch_size, seq_len, channels = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
+ )
+ q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.fused_attn:
+ x = functional.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ ) # [batch_size, num_heads, seq_len, head_dim]
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
+
+ x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+# ------------------------------- Utilities -----------------------------------
+
+
+def basic_init(module: nn.Module) -> None:
+ """
+ Apply a basic initialization scheme to Linear layers.
+
+ - Weight: Xavier uniform initialization.
+ - Bias: Set to zero.
+ """
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0.0)
+
+
+def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
+ """
+ Create sinusoidal timestep embeddings.
+
+ Parameters
+ ----------
+ t : torch.Tensor
+ Shape [B]. Each element is a timestep index, may be fractional.
+ dim : int
+ Dimensionality of the output embedding.
+ max_period : int, default=100
+ Controls the minimum frequency of the sinusoids.
+
+ Returns
+ -------
+ torch.Tensor
+ Shape [B, dim]. Sinusoidal embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
+ )
+ args = t[:, None] * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2 == 1:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+# ------------------------------- Core Layers ----------------------------------
+
+
+class DomainAwareLinear(nn.Module):
+ """
+ Linear layer with domain-conditioned parameters (per-sample).
+
+ Each domain has its own weight and bias vectors, stored in embeddings.
+ """
+
+ def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
+ super().__init__()
+ self.input_size = input_size
+ self.output_size = output_size
+ self.fc = nn.Embedding(num_domains, output_size * input_size)
+ self.bias = nn.Embedding(num_domains, output_size)
+ nn.init.xavier_uniform_(self.fc.weight)
+ nn.init.zeros_(self.bias.weight)
+
+ def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
+ """
+ Parameters
+ ----------
+ x : Tensor
+ [B, I] or [B, T, I]
+ domain_id : LongTensor
+ [B], domain indices.
+
+ Returns
+ -------
+ Tensor
+ [batch_size, output_size] or [batch_size, seq_len, output_size]
+ """
+ batch_size = domain_id.shape[0]
+ squeeze_seq = False
+ if x.dim() == 2:
+ x = x.unsqueeze(1)
+ squeeze_seq = True
+ weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
+ bias = self.bias(domain_id).view(batch_size, self.output_size)
+ y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
+ if squeeze_seq:
+ y = y.squeeze(1)
+ return y
+
+
+class TransformerBlock(nn.Module):
+ """
+ Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
+ """
+
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=int(hidden_size * mlp_ratio),
+ drop=0.1,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Parameters
+ ----------
+ x : Tensor, [B, T, H]
+
+ Returns
+ -------
+ Tensor, [B, T, H]
+ """
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+# --------------------------- Main Model ---------------------------------------
+
+
+class SoftPromptedTransformer(nn.Module):
+ """
+ Multi-modal, domain-aware Transformer with optional soft prompts.
+
+ See parameter and forward I/O descriptions inside the docstrings.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ multi_modal_input_size: int = 768,
+ depth: int = 24,
+ num_heads: int = 16,
+ mlp_ratio: float = 4.0,
+ num_domains: int = 20,
+ dim_action: int = 20,
+ dim_propio: int = 20,
+ dim_time: int = 32,
+ len_soft_prompts: int = 32,
+ max_len_seq: int = 512,
+ use_hetero_proj: bool = False,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dim_action = dim_action
+ self.dim_time = dim_time
+ self.len_soft_prompts = len_soft_prompts
+ self.use_hetero_proj = use_hetero_proj
+
+ self.blocks = nn.ModuleList(
+ [TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
+ )
+
+ if use_hetero_proj:
+ self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
+ self.aux_visual_proj = DomainAwareLinear(
+ multi_modal_input_size, hidden_size, num_domains=num_domains
+ )
+ else:
+ self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
+ self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
+
+ self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
+ nn.init.normal_(self.pos_emb, std=0.02)
+
+ self.norm = nn.LayerNorm(hidden_size)
+ self.action_encoder = DomainAwareLinear(
+ dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
+ )
+ self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
+
+ if len_soft_prompts > 0:
+ self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
+ nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
+
+ self.apply(basic_init)
+
+ def forward(
+ self,
+ domain_id: torch.LongTensor,
+ vlm_features: torch.Tensor,
+ aux_visual_inputs: torch.Tensor,
+ action_with_noise: torch.Tensor,
+ proprio: torch.Tensor,
+ t: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Forward pass.
+
+ Inputs
+ ------
+ domain_id : [B]
+ vlm_features : [B, T_vlm, D]
+ aux_visual_inputs : [B, T_aux, D]
+ action_with_noise : [B, T_action, dim_action]
+ proprio : [B, dim_propio]
+ t : [B]
+
+ Returns
+ -------
+ Tensor
+ Predicted actions, [batch_size, num_actions, dim_action]
+ """
+ batch_size, num_actions = action_with_noise.shape[:2]
+
+ # Encode (action + proprio + time) → tokens
+ time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
+ time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
+ proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
+ action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
+ x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
+
+ # Project visual streams and concatenate
+ if self.use_hetero_proj:
+ x = torch.cat(
+ [
+ x,
+ self.vlm_proj(vlm_features, domain_id),
+ self.aux_visual_proj(aux_visual_inputs, domain_id),
+ ],
+ dim=1,
+ )
+ else:
+ x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
+
+ # Add positional embeddings (truncate if needed)
+ seq_len = x.shape[1]
+ if seq_len > self.pos_emb.shape[1]:
+ raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
+ x = x + self.pos_emb[:, :seq_len, :]
+
+ # Append soft prompts
+ if self.len_soft_prompts > 0:
+ soft_prompts = self.soft_prompt_hub(domain_id).view(
+ batch_size, self.len_soft_prompts, self.hidden_size
+ )
+ x = torch.cat([x, soft_prompts], dim=1)
+
+ # Transformer backbone
+ for block in self.blocks:
+ x = block(x)
+
+ # Decode only the action segment
+ return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
diff --git a/src/lerobot/policies/xvla/utils.py b/src/lerobot/policies/xvla/utils.py
new file mode 100644
index 000000000..bf31ffd82
--- /dev/null
+++ b/src/lerobot/policies/xvla/utils.py
@@ -0,0 +1,138 @@
+import math
+
+import numpy as np
+
+
+def mat2quat(rmat):
+ """
+ Converts given rotation matrix to quaternion.
+
+ Args:
+ rmat (np.array): 3x3 rotation matrix
+
+ Returns:
+ np.array: (x,y,z,w) float quaternion angles
+ """
+ mat = np.asarray(rmat).astype(np.float32)[:3, :3]
+
+ m00 = mat[0, 0]
+ m01 = mat[0, 1]
+ m02 = mat[0, 2]
+ m10 = mat[1, 0]
+ m11 = mat[1, 1]
+ m12 = mat[1, 2]
+ m20 = mat[2, 0]
+ m21 = mat[2, 1]
+ m22 = mat[2, 2]
+ # symmetric matrix k
+ k = np.array(
+ [
+ [m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
+ [m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
+ [m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
+ [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
+ ]
+ )
+ k /= 3.0
+ # quaternion is Eigen vector of k that corresponds to largest eigenvalue
+ w, v = np.linalg.eigh(k)
+ inds = np.array([3, 0, 1, 2])
+ q1 = v[inds, np.argmax(w)]
+ if q1[0] < 0.0:
+ np.negative(q1, q1)
+ inds = np.array([1, 2, 3, 0])
+ return q1[inds]
+
+
+def quat2axisangle(quat):
+ """
+ Converts quaternion to axis-angle format.
+ Returns a unit vector direction scaled by its angle in radians.
+
+ Args:
+ quat (np.array): (x,y,z,w) vec4 float angles
+
+ Returns:
+ np.array: (ax,ay,az) axis-angle exponential coordinates
+ """
+ # clip quaternion
+ if quat[3] > 1.0:
+ quat[3] = 1.0
+ elif quat[3] < -1.0:
+ quat[3] = -1.0
+
+ den = np.sqrt(1.0 - quat[3] * quat[3])
+ if math.isclose(den, 0.0):
+ # This is (close to) a zero degree rotation, immediately return
+ return np.zeros(3)
+
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
+
+
+def rotate6d_to_axis_angle(r6d):
+ """
+ r6d: np.ndarray, shape (N, 6)
+ return: np.ndarray, shape (N, 3), axis-angle vectors
+ """
+ flag = 0
+ if len(r6d.shape) == 1:
+ r6d = r6d[None, ...]
+ flag = 1
+
+ a1 = r6d[:, 0:3]
+ a2 = r6d[:, 3:6]
+
+ # b1
+ b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
+
+ # b2
+ dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
+ b2_orth = a2 - dot_prod * b1
+ b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
+
+ # b3
+ b3 = np.cross(b1, b2, axis=-1)
+
+ rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
+
+ axis_angle_list = []
+ for i in range(rotation_matrix.shape[0]):
+ quat = mat2quat(rotation_matrix[i])
+ axis_angle = quat2axisangle(quat)
+ axis_angle_list.append(axis_angle)
+
+ axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
+
+ if flag == 1:
+ axis_angle_array = axis_angle_array[0]
+
+ return axis_angle_array
+
+
+def mat_to_rotate6d(abs_action):
+ if len(abs_action.shape) == 2:
+ return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
+ elif len(abs_action.shape) == 3:
+ return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
+ else:
+ raise NotImplementedError
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py
index 84cebb86c..d23b9d083 100644
--- a/src/lerobot/scripts/lerobot_eval.py
+++ b/src/lerobot/scripts/lerobot_eval.py
@@ -534,7 +534,7 @@ def eval_main(cfg: EvalPipelineConfig):
)
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
- env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
+ env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py
index 3ee5faf37..1ebdee600 100644
--- a/src/lerobot/scripts/lerobot_train.py
+++ b/src/lerobot/scripts/lerobot_train.py
@@ -261,7 +261,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors")
- env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
+ env_preprocessor, env_postprocessor = make_env_pre_post_processors(
+ env_cfg=cfg.env, policy_cfg=cfg.policy
+ )
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py
new file mode 100644
index 000000000..a9603fdb0
--- /dev/null
+++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py
@@ -0,0 +1,318 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
+# ruff: noqa: E402
+
+import random
+from copy import deepcopy
+from typing import Any
+
+import numpy as np
+import pytest
+import torch
+
+pytest.importorskip("transformers")
+
+from lerobot.policies.xvla.configuration_xvla import XVLAConfig
+from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
+from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
+from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
+from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
+from tests.utils import require_cuda # noqa: E402
+
+# Constants
+DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
+DUMMY_STATE_DIM = 20 # Proprioceptive state dimension
+IMAGE_HEIGHT = 224
+IMAGE_WIDTH = 224
+NUM_VIEWS = 2 # Number of camera views
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+MODEL_PATH_LEROBOT = "lerobot/xvla-widowx"
+LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
+
+# Expected values from original XVLA implementation (reference values)
+EXPECTED_ACTIONS_SHAPE = (30, 20)
+EXPECTED_ACTIONS_MEAN = 0.117606
+EXPECTED_ACTIONS_STD = 0.245411
+EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653])
+
+
+def set_seed_all(seed: int):
+ """Set random seed for all RNG sources to ensure reproducibility."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ # Set deterministic behavior
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.use_deterministic_algorithms(True, warn_only=True)
+
+
+def instantiate_lerobot_xvla(
+ from_pretrained: bool = False,
+ model_path: str = MODEL_PATH_LEROBOT,
+) -> tuple[
+ Any, # Policy
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """Instantiate LeRobot XVLA policy with preprocessor and postprocessor."""
+ if from_pretrained:
+ policy = XVLAPolicy.from_pretrained(
+ pretrained_name_or_path=model_path,
+ strict=False,
+ )
+ else:
+ config = XVLAConfig(
+ base_model_path=model_path,
+ n_action_steps=DUMMY_ACTION_DIM,
+ chunk_size=DUMMY_ACTION_DIM,
+ device=DEVICE,
+ num_image_views=NUM_VIEWS,
+ ) # add resize_imgs_with_padding=IMAGE_SIZE, IMAGE_SIZE?
+ policy = XVLAPolicy(config)
+
+ policy.to(DEVICE)
+ policy.config.device = DEVICE
+ preprocessor, postprocessor = make_xvla_pre_post_processors(
+ config=policy.config,
+ dataset_stats=None, # Pass None for dataset_stats to disable normalization (original XVLA doesn't normalize)
+ )
+
+ return policy, preprocessor, postprocessor
+
+
+def create_dummy_data(device=DEVICE):
+ """Create dummy data for testing both implementations."""
+ batch_size = 1
+ prompt = "Pick up the red block and place it in the bin"
+
+ # Create random RGB images in [0, 255] uint8 range (as PIL images would be)
+ # Then convert to [0, 1] float32 range for LeRobot
+ def fake_rgb(h, w):
+ arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
+ t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
+ return t
+
+ batch = {
+ f"{OBS_IMAGES}.image": torch.stack(
+ [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
+ ).to(device),
+ f"{OBS_IMAGES}.image2": torch.stack(
+ [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
+ ).to(device),
+ OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
+ "task": [prompt for _ in range(batch_size)],
+ }
+
+ return batch
+
+
+# Pytest fixtures
+@pytest.fixture(scope="module")
+def xvla_components():
+ """Fixture to instantiate and provide all XVLA components for tests."""
+ print(f"\nTesting with DEVICE='{DEVICE}'")
+ print("\n[Setup] Instantiating LeRobot XVLA policy...")
+ policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True)
+ print("✔️ Model loaded successfully")
+ yield policy_obj, preprocessor_obj, postprocessor_obj
+
+
+@pytest.fixture(scope="module")
+def policy(xvla_components):
+ """Fixture to provide the XVLA policy for tests."""
+ return xvla_components[0]
+
+
+@pytest.fixture(scope="module")
+def preprocessor(xvla_components):
+ """Fixture to provide the XVLA preprocessor for tests."""
+ return xvla_components[1]
+
+
+@require_cuda
+def test_xvla_preprocessor_alignment(policy, preprocessor):
+ """Test that LeRobot XVLA preprocessor produces expected outputs."""
+ print("\n" + "=" * 80)
+ print("Test: XVLA Preprocessor Outputs")
+ print("=" * 80)
+
+ set_seed_all(42)
+
+ print("\nCreating dummy data...")
+ batch = create_dummy_data()
+
+ print("\n[LeRobot] Preprocessing...")
+ lerobot_observation = preprocessor(deepcopy(batch))
+ lerobot_inputs = policy._build_model_inputs(lerobot_observation)
+
+ print("\nVerifying preprocessor outputs:")
+ print("-" * 80)
+
+ # Expected shapes from tester.txt
+ expected_shapes = {
+ "domain_id": (1,),
+ "input_ids": (1, 50),
+ "proprio": (1, 20),
+ "image_mask": (1, 2),
+ "image_input": (1, 2, 3, 224, 224),
+ }
+
+ for key, expected_shape in expected_shapes.items():
+ if key in lerobot_inputs:
+ actual_shape = tuple(lerobot_inputs[key].shape)
+ print(f"\nKey: {key}")
+ print(f"Expected shape: {expected_shape}")
+ print(f"Actual shape: {actual_shape}")
+
+ if actual_shape == expected_shape:
+ print("Shape matches!")
+ else:
+ print("Shape mismatch!")
+
+ assert actual_shape == expected_shape, f"Shape mismatch for {key}"
+ else:
+ print(f"\nKey '{key}' not found in inputs!")
+
+ print("\nAll preprocessor outputs have correct shapes!")
+
+
+@require_cuda
+def test_xvla_action_generation(policy, preprocessor):
+ """Test XVLA LeRobot implementation generates expected actions."""
+ print("\n" + "=" * 80)
+ print("Test: XVLA Action Generation Against Expected Values")
+ print("=" * 80)
+
+ set_seed_all(42)
+
+ print("\nCreating dummy data...")
+ batch = create_dummy_data()
+
+ print("\n[LeRobot] Running inference...")
+ lerobot_observation = preprocessor(deepcopy(batch))
+ lerobot_inputs = policy._build_model_inputs(lerobot_observation)
+
+ # Reset seed for inference
+ torch.manual_seed(42)
+ with torch.no_grad():
+ lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10)
+ lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
+
+ print(f"LeRobot actions shape: {lerobot_actions.shape}")
+ print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
+ print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
+ print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}")
+
+ print("\nExpected values (from original XVLA):")
+ print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}")
+ print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}")
+ print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}")
+ print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}")
+
+ print("\nAction Comparison:")
+ print("-" * 80)
+
+ # Compare shapes
+ actual_shape = tuple(lerobot_actions.shape)
+ assert actual_shape == EXPECTED_ACTIONS_SHAPE, (
+ f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}"
+ )
+ print(f"✔️ Shape matches: {actual_shape}")
+
+ # Compare statistics
+ actual_mean = lerobot_actions.mean().item()
+ actual_std = lerobot_actions.std().item()
+
+ mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN)
+ std_diff = abs(actual_std - EXPECTED_ACTIONS_STD)
+
+ print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})")
+ print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})")
+
+ # Compare first 5 actions
+ actual_first_5 = lerobot_actions[0, :5]
+ first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5)
+
+ print("\nFirst 5 actions comparison:")
+ print(f" Actual: {actual_first_5}")
+ print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}")
+ print(f" Max diff: {first_5_diff.max().item():.6e}")
+ print(f" Mean diff: {first_5_diff.mean().item():.6e}")
+
+ # Check with different tolerances
+ tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
+ for tol in tolerances:
+ is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
+ status = "Success" if is_close else "Failure"
+ print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
+
+ # Assert with reasonable tolerance
+ tolerance = 1e-3
+ assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
+ f"First 5 actions differ by more than tolerance ({tolerance})"
+ )
+ print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
+
+
+@require_cuda
+def test_xvla_inference_reproducibility(policy, preprocessor):
+ """Test that XVLA inference is reproducible with the same seed."""
+ print("\n" + "=" * 80)
+ print("Test: XVLA Inference Reproducibility")
+ print("=" * 80)
+
+ print("\nCreating dummy data...")
+ batch = create_dummy_data()
+
+ # First inference
+ print("\n[Run 1] Running inference...")
+ set_seed_all(42)
+ lerobot_observation = preprocessor(deepcopy(batch))
+ lerobot_inputs = policy._build_model_inputs(lerobot_observation)
+ with torch.no_grad():
+ actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10)
+ actions_1 = actions_1.squeeze(0).float().cpu()
+
+ # Second inference with same seed
+ print("\n[Run 2] Running inference with same seed...")
+ set_seed_all(42)
+ lerobot_observation = preprocessor(deepcopy(batch))
+ lerobot_inputs = policy._build_model_inputs(lerobot_observation)
+ with torch.no_grad():
+ actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10)
+ actions_2 = actions_2.squeeze(0).float().cpu()
+
+ print("\nComparing two runs:")
+ print("-" * 80)
+ if torch.allclose(actions_1, actions_2, atol=1e-8):
+ print("Inference is perfectly reproducible!")
+ else:
+ diff = torch.abs(actions_1 - actions_2)
+ print("Small differences detected:")
+ print(f" Max diff: {diff.max().item():.6e}")
+ print(f" Mean diff: {diff.mean().item():.6e}")
+
+ assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
+
+ print("\nInference is reproducible!")