diff --git a/src/lerobot/policies/wall_x/__init__.py b/src/lerobot/policies/wall_x/__init__.py new file mode 100644 index 000000000..16fd2c8ab --- /dev/null +++ b/src/lerobot/policies/wall_x/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_wall_x import WallXConfig +from .modeling_wall_x import WallXPolicy +from .processor_wall_x import make_wall_x_pre_post_processors + +__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"] diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index f17742301..c0936f427 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Any from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @@ -29,8 +30,48 @@ class WallXConfig(PreTrainedConfig): Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching. It supports cross-embodiment robotic control through unified action representations. + + This config supports multi-modal learning with vision, language, and action data. """ - # Input / output structure + + # ==================== Model and Paths Configuration ==================== + # Logging + log_name: str = "wall_x_training" + log_project: str = "vla_training" + model_type: str = "wall-oss" + + # Pretrained model paths + pretrained_wallx_path: str | None = None # Path to pretrained Wall-X model + save_path: str | None = None # Path to save checkpoints + processor_path: str | None = None # Path to processor (defaults to pretrained_wallx_path) + action_tokenizer_path: str | None = None # Path to action tokenizer (for FAST mode) + + # Tokenizer settings + use_fast_tokenizer: bool = False # True: train FAST, False: train Flow + + # ==================== Profiling Configuration ==================== + profile: bool = False + profile_save_path: str | None = None + profile_wait_iters: int = 10 + profile_warmup_iters: int = 5 + profile_active_iters: int = 2 + + # ==================== Training Hyperparameters ==================== + num_warmup_steps: int = 100 + num_training_steps: int = 64000000 + learning_rate: float = 5e-5 + min_lr: float = 5e-5 + num_epoch: int = 100 + gradient_accumulation_steps: int = 32 + batch_size_per_gpu: int = 8 + padding_side: str = "left" + epoch_save_interval: int = 10 + + # Training optimization + fsdp2: bool = False + torch_compile: bool = False + + # ==================== Input / Output Structure ==================== n_obs_steps: int = 1 chunk_size: int = 32 # action_horizon in wall-x n_action_steps: int = 32 @@ -53,7 +94,7 @@ class WallXConfig(PreTrainedConfig): # Tokenizer tokenizer_max_length: int = 256 - # Model architecture + # ==================== Model Architecture ==================== vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct" load_vlm_weights: bool = True @@ -79,6 +120,7 @@ class WallXConfig(PreTrainedConfig): num_key_value_heads: int = 4 # 8 for 7B model vocab_size: int = 152064 + # ==================== Action Prediction ==================== # Action prediction mode: "flow" or "fast" prediction_mode: str = "flow" @@ -93,7 +135,8 @@ class WallXConfig(PreTrainedConfig): num_inference_timesteps: int = 10 # Number of ODE solver steps ode_solver_method: str = "euler" # ODE solver method - # Degrees of freedom configuration - example for bimanual robot + # ==================== Robot Configuration ==================== + # Degrees of freedom configuration - defines action space dof_config: dict = field(default_factory=lambda: { "left_ee_pos": 3, "left_ee_rot": 3, @@ -103,7 +146,7 @@ class WallXConfig(PreTrainedConfig): "right_gripper": 1, }) - # Proprioception configuration (mirrors dof_config) + # Proprioception configuration (typically mirrors dof_config) agent_pos_config: dict = field(default_factory=lambda: { "left_ee_pos": 3, "left_ee_rot": 3, @@ -113,12 +156,23 @@ class WallXConfig(PreTrainedConfig): "right_gripper": 1, }) - # MoE configuration + # Customized robot configuration + enable_customized_robot_config: bool = False + customized_robot_config: dict = field(default_factory=lambda: { + "name": "", + "customized_dof_config": {}, + "customized_agent_pos_config": {}, + }) + + # Normalization statistics path + norm_stats_path: str | None = None + + # ==================== MoE Configuration ==================== num_experts: int = 4 attention_moe: bool = False mlp_moe: bool = False - # Finetuning settings + # ==================== Finetuning Settings ==================== freeze_vision_encoder: bool = True train_expert_only: bool = False # wall-x trains more components train_action_head: bool = True @@ -126,7 +180,7 @@ class WallXConfig(PreTrainedConfig): # Cache use_cache: bool = True - # Training presets + # ==================== Optimizer Presets ==================== optimizer_lr: float = 2e-5 optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_eps: float = 1e-8 @@ -137,14 +191,48 @@ class WallXConfig(PreTrainedConfig): scheduler_decay_steps: int = 100000 scheduler_decay_lr: float = 1e-6 + # ==================== Dataset Configuration ==================== # Dataset-specific normalization statistics - # Maps dataset names to {min, delta} for action normalization action_statistics: dict = field(default_factory=dict) + # Data configuration + data_config: dict = field(default_factory=lambda: { + "use_lerobot": True, + "lerobot_config": { + "repo_id": "", + "root": None, + "episodes": None, + "image_transforms": None, + "delta_timestamps": None, + "tolerance_s": 1e-4, + "revision": None, + "force_cache_sync": False, + "download_videos": True, + "video_backend": None, + }, + "action_horizon": 32, + "train_test_split": 0.95, + "obs_action_keys": [], + "predict_action_keys": [], + "resolution": { + "face_view": 256, + "left_wrist_view": 256, + "right_wrist_view": 256, + "move1_view": 256, + "move2_view": 256, + "top_view": 256, + "wall_view": 256, + "multi_modal": 256, + }, + }) + + # ==================== Resume Configuration ==================== + resume_config: dict | None = field(default_factory=lambda: None) + def __post_init__(self): super().__post_init__() - """Input validation""" + # Input validation if self.n_action_steps > self.chunk_size: raise ValueError( f"The chunk size is the upper bound for the number of action steps per model invocation. Got " @@ -163,6 +251,232 @@ class WallXConfig(PreTrainedConfig): f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})" ) + # Sync prediction_mode with use_fast_tokenizer + if self.use_fast_tokenizer: + self.prediction_mode = "fast" + else: + self.prediction_mode = "flow" + + def get_train_config(self) -> dict: + """ + Extract the complete train_config dictionary matching the YAML training configuration format. + + This method constructs the full train_config from WallXConfig fields, suitable for + training scripts and Qwen2_5_VLMoEForAction.from_pretrained. + + Returns: + dict: Complete training configuration matching YAML structure. + """ + # Build customized_robot_config + if self.enable_customized_robot_config and self.customized_robot_config: + customized_robot_config = { + "name": self.customized_robot_config.get("name", ""), + "customized_dof_config": self.customized_robot_config.get( + "customized_dof_config", self.dof_config + ), + "customized_agent_pos_config": self.customized_robot_config.get( + "customized_agent_pos_config", self.agent_pos_config + ), + } + else: + customized_robot_config = { + "name": self.data_config.get("lerobot_config", {}).get("repo_id", ""), + "customized_dof_config": self.dof_config, + "customized_agent_pos_config": self.agent_pos_config, + } + + train_config = { + # Model and paths configuration + "log_name": self.log_name, + "log_project": self.log_project, + "model_type": self.model_type, + "pretrained_wallx_path": self.pretrained_wallx_path, + "save_path": self.save_path, + "use_fast_tokenizer": self.use_fast_tokenizer, + "action_tokenizer_path": self.action_tokenizer_path, + + # Profiling configuration + "profile": self.profile, + "profile_save_path": self.profile_save_path, + "profile_wait_iters": self.profile_wait_iters, + "profile_warmup_iters": self.profile_warmup_iters, + "profile_active_iters": self.profile_active_iters, + + # Training hyperparameters + "num_warmup_steps": self.num_warmup_steps, + "num_training_steps": self.num_training_steps, + "learning_rate": self.learning_rate, + "min_lr": self.min_lr, + "num_epoch": self.num_epoch, + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "batch_size_per_gpu": self.batch_size_per_gpu, + "padding_side": self.padding_side, + "epoch_save_interval": self.epoch_save_interval, + + # Training optimization + "FSDP2": self.fsdp2, + "torch_compile": self.torch_compile, + + # Robot configuration + "dof_config": self.dof_config, + "agent_pos_config": self.agent_pos_config, + + # Normalization stats + "norm_stats_path": self.norm_stats_path, + + # Customized robot config + "enable_customized_robot_config": self.enable_customized_robot_config, + "customized_robot_config": customized_robot_config, + + # Resume configuration + "resume": self.resume_config, + + # Data configuration + "data": self.data_config, + } + + return train_config + + def get_dataload_config(self) -> dict: + """ + Extract data loading configuration from config. + + Returns: + dict: Data loading configuration for preprocessing. + """ + return { + "action_horizon": self.data_config.get("action_horizon", self.chunk_size), + "train_test_split": self.data_config.get("train_test_split", 0.95), + "split_seed": 42, + "predict_action_keys": self.data_config.get("predict_action_keys", []), + "obs_action_keys": self.data_config.get("obs_action_keys", []), + "resolution": self.data_config.get("resolution", {}), + "priority_order": None, + "max_length": self.tokenizer_max_length, + } + + def get_lerobot_config(self) -> dict: + """ + Extract LeRobot dataset configuration. + + Returns: + dict: LeRobot dataset configuration. + """ + return self.data_config.get("lerobot_config", {}) + + @classmethod + def from_yaml_dict(cls, yaml_dict: dict) -> "WallXConfig": + """ + Create a WallXConfig from a YAML configuration dictionary. + + Args: + yaml_dict: Dictionary loaded from YAML training config file. + + Returns: + WallXConfig instance with values from YAML. + """ + config_kwargs = {} + + # Model and paths + if "log_name" in yaml_dict: + config_kwargs["log_name"] = yaml_dict["log_name"] + if "log_project" in yaml_dict: + config_kwargs["log_project"] = yaml_dict["log_project"] + if "model_type" in yaml_dict: + config_kwargs["model_type"] = yaml_dict["model_type"] + if "pretrained_wallx_path" in yaml_dict: + config_kwargs["pretrained_wallx_path"] = yaml_dict["pretrained_wallx_path"] + if "save_path" in yaml_dict: + config_kwargs["save_path"] = yaml_dict["save_path"] + if "use_fast_tokenizer" in yaml_dict: + config_kwargs["use_fast_tokenizer"] = yaml_dict["use_fast_tokenizer"] + if "action_tokenizer_path" in yaml_dict: + config_kwargs["action_tokenizer_path"] = yaml_dict["action_tokenizer_path"] + + # Profiling + if "profile" in yaml_dict: + config_kwargs["profile"] = yaml_dict["profile"] + if "profile_save_path" in yaml_dict: + config_kwargs["profile_save_path"] = yaml_dict["profile_save_path"] + if "profile_wait_iters" in yaml_dict: + config_kwargs["profile_wait_iters"] = yaml_dict["profile_wait_iters"] + if "profile_warmup_iters" in yaml_dict: + config_kwargs["profile_warmup_iters"] = yaml_dict["profile_warmup_iters"] + if "profile_active_iters" in yaml_dict: + config_kwargs["profile_active_iters"] = yaml_dict["profile_active_iters"] + + # Training hyperparameters + if "num_warmup_steps" in yaml_dict: + config_kwargs["num_warmup_steps"] = yaml_dict["num_warmup_steps"] + config_kwargs["scheduler_warmup_steps"] = yaml_dict["num_warmup_steps"] + if "num_training_steps" in yaml_dict: + config_kwargs["num_training_steps"] = yaml_dict["num_training_steps"] + config_kwargs["scheduler_decay_steps"] = yaml_dict["num_training_steps"] + if "learning_rate" in yaml_dict: + config_kwargs["learning_rate"] = yaml_dict["learning_rate"] + config_kwargs["optimizer_lr"] = yaml_dict["learning_rate"] + if "min_lr" in yaml_dict: + config_kwargs["min_lr"] = yaml_dict["min_lr"] + config_kwargs["scheduler_decay_lr"] = yaml_dict["min_lr"] + if "num_epoch" in yaml_dict: + config_kwargs["num_epoch"] = yaml_dict["num_epoch"] + if "gradient_accumulation_steps" in yaml_dict: + config_kwargs["gradient_accumulation_steps"] = yaml_dict["gradient_accumulation_steps"] + if "batch_size_per_gpu" in yaml_dict: + config_kwargs["batch_size_per_gpu"] = yaml_dict["batch_size_per_gpu"] + if "padding_side" in yaml_dict: + config_kwargs["padding_side"] = yaml_dict["padding_side"] + if "epoch_save_interval" in yaml_dict: + config_kwargs["epoch_save_interval"] = yaml_dict["epoch_save_interval"] + + # Training optimization + if "FSDP2" in yaml_dict: + config_kwargs["fsdp2"] = yaml_dict["FSDP2"] + if "torch_compile" in yaml_dict: + config_kwargs["torch_compile"] = yaml_dict["torch_compile"] + + # Robot configuration + if "dof_config" in yaml_dict: + config_kwargs["dof_config"] = yaml_dict["dof_config"] + if "agent_pos_config" in yaml_dict: + config_kwargs["agent_pos_config"] = yaml_dict["agent_pos_config"] + + # Normalization stats + if "norm_stats_path" in yaml_dict: + config_kwargs["norm_stats_path"] = yaml_dict["norm_stats_path"] + + # Customized robot config + if "enable_customized_robot_config" in yaml_dict: + config_kwargs["enable_customized_robot_config"] = yaml_dict["enable_customized_robot_config"] + if "customized_robot_config" in yaml_dict: + config_kwargs["customized_robot_config"] = yaml_dict["customized_robot_config"] + + # Resume config + if "resume" in yaml_dict: + config_kwargs["resume_config"] = yaml_dict["resume"] + + # Data configuration + if "data" in yaml_dict: + data = yaml_dict["data"] + data_config = { + "use_lerobot": data.get("use_lerobot", True), + "action_horizon": data.get("action_horizon", 32), + "train_test_split": data.get("train_test_split", 0.95), + "obs_action_keys": data.get("obs_action_keys", []), + "predict_action_keys": data.get("predict_action_keys", []), + "resolution": data.get("resolution", {}), + } + if "lerobot_config" in data: + data_config["lerobot_config"] = data["lerobot_config"] + config_kwargs["data_config"] = data_config + + # Set chunk_size from action_horizon + if "action_horizon" in data: + config_kwargs["chunk_size"] = data["action_horizon"] + config_kwargs["n_action_steps"] = data["action_horizon"] + + return cls(**config_kwargs) + def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( lr=self.optimizer_lr, diff --git a/src/lerobot/policies/wall_x/constant.py b/src/lerobot/policies/wall_x/constant.py new file mode 100644 index 000000000..872302ad6 --- /dev/null +++ b/src/lerobot/policies/wall_x/constant.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X Constants and Configuration Data. + +Contains dataset names, key mappings, frequency mappings, and action statistics +for cross-embodiment robotic control. +""" + +from pathlib import Path + +# Add wall-x repo to path if available +WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x") + +CAMERA_NAME_MAPPING = { + "face_view": "front view", + "left_wrist_view": "left wrist view", + "right_wrist_view": "right wrist view", + "move1_view": "move view", + "move2_view": "move view", + "wall_view": "wall view", + "top_view": "top view", +} diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index ebab36e54..458028c3a 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -34,22 +34,64 @@ lerobot-train \ ``` """ +import glob import math +import os import sys from collections import deque from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from PIL import Image +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from peft import LoraConfig, get_peft_model +from safetensors.torch import load_file from torch import Tensor from torch.distributions import Beta +from torch.nn import CrossEntropyLoss +from torchdiffeq import odeint +from transformers import AutoConfig, AutoProcessor +from transformers.activations import ACT2FN +from transformers.cache_utils import ( + Cache, + DynamicCache, + SlidingWindowCache, + StaticCache, +) +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import is_torchdynamo_compiling, logging +from transformers import AutoProcessor, BatchFeature +from qwen_vl_utils.vision_process import smart_resize from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.policies.utils import populate_queues +from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE -from lerobot.utils.utils import get_safe_dtype + +from lerobot.policies.wall_x.utils import * +from lerobot.policies.wall_x.constant import * +from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2RMSNorm, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLRotaryEmbedding, + Qwen2_5_VLPreTrainedModel, + Qwen2_5_VLForConditionalGeneration, +) + +from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLDecoderLayer_with_MoE, + Qwen2_5_VLACausalLMOutputWithPast, +) # Add wall-x repo to path if available WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x") @@ -194,62 +236,2068 @@ class ActionHead(nn.Module): return loss - def project_proprioception(self, proprioception, dof_mask=None): + def proprioception_proj(self, proprioception, dof_mask=None, use_history=False): """Project proprioceptive data to hidden space.""" - proprioception = proprioception.to( - device=self.propri_proj.weight.device, + # Ensure proper device and dtype alignment + proprioception = proprioception.to(device=self.propri_proj.weight.device).to( dtype=self.propri_proj.weight.dtype ) if dof_mask is not None: - proprioception = torch.cat([proprioception, dof_mask], dim=-1) + # Concatenate proprioception with DOF mask + # TODO: Use variable-based dimension checking for better flexibility + if use_history: + proprioception = torch.cat([proprioception, dof_mask], dim=-1) + else: + proprioception = torch.cat([proprioception, dof_mask], dim=-1) + proprioception = proprioception.to(device=self.propri_proj.weight.device).to( + dtype=self.propri_proj.weight.dtype + ) return self.propri_proj(proprioception) -class WallXVLMWrapper(nn.Module): - """ - Wrapper around Qwen2.5-VL model from wall-x. +class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): + """Qwen2.5-VL model with Mixture of Experts (MoE) architecture. - This class attempts to load the wall-x model if available, - otherwise provides a placeholder implementation. + This model extends the base Qwen2.5-VL model by incorporating MoE layers + for improved scalability and specialization across different token types. """ - def __init__(self, config: WallXConfig): - super().__init__() - self.config = config + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + num_experts: Optional[int] = None, + *args, + **kwargs, + ): + """Load a pretrained model with optional MoE configuration. - # Try to import wall-x model - try: - from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor - from qwen_vl_utils import process_vision_info + Args: + pretrained_model_name_or_path: Path or name of the pretrained model + num_experts: Number of experts for MoE layers (if not in config) + *args: Additional arguments passed to parent class + **kwargs: Additional keyword arguments passed to parent class - self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - config.vlm_model_name, - torch_dtype=torch.bfloat16 if config.device != "cpu" else torch.float32, - device_map=config.device if config.device != "cpu" else None, + Returns: + Initialized model instance with MoE configuration + """ + config = kwargs.get("config", None) + if config is None: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + + # Override number of experts if specified + if num_experts is not None: + config.num_experts = num_experts + + kwargs["config"] = config + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + def __init__(self, config: Qwen2_5_VLConfig): + """Initialize the Qwen2.5-VL MoE model. + + Args: + config: Model configuration containing architecture parameters + """ + super().__init__(config) + + # Basic model parameters + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Model components + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + + # Decoder layers with MoE support + self.layers = nn.ModuleList( + [ + Qwen2_5_VLDecoderLayer_with_MoE(config, layer_idx, config.num_experts) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # Model configuration + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + """Get the input embedding layer. + + Returns: + The token embedding layer + """ + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + """Set the input embedding layer. + + Args: + value: New embedding layer to use + """ + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + moe_token_types: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # Set default output options + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Validate inputs + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" ) - self.processor = AutoProcessor.from_pretrained(config.vlm_model_name) - self.process_vision_info = process_vision_info - self.available = True + if moe_token_types is None: + raise ValueError("moe_token_types must be provided for MoE routing") - # Freeze vision encoder if requested - if config.freeze_vision_encoder: - for param in self.model.visual.parameters(): - param.requires_grad = False + # Handle gradient checkpointing compatibility + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False - except ImportError: - print("Warning: Could not import wall-x dependencies. Using placeholder.") - self.available = False - self.model = None - self.processor = None + # Initialize cache if needed + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() - def forward(self, **kwargs): - """Forward pass through VLM.""" - if not self.available: - raise RuntimeError("Wall-X VLM not available. Install required dependencies.") - return self.model(**kwargs) + # Get input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Set up cache position + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # Set up position IDs (hardcoded 3 dimensions for temporal, height, width) + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # Create causal attention mask + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + moe_token_types, + ) + + hidden_states = inputs_embeds + + # Create position embeddings to be shared across decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Process through decoder layers + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + # Use gradient checkpointing during training + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + moe_token_types, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + # Regular forward pass + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + token_types=moe_token_types, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + # Update cache if using it + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + # Collect attention weights if requested + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Apply final layer normalization + hidden_states = self.norm(hidden_states) + + # Add final hidden states if collecting all states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + # Return outputs in requested format + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + moe_token_types: Optional[torch.LongTensor] = None, + ): + """Update causal attention mask with support for bidirectional attention for specific token types. + + This method creates and modifies attention masks to support different attention patterns: + - Standard causal (unidirectional) attention for most tokens + - Bidirectional attention for specific token types (e.g., MoE routing tokens) + + Args: + attention_mask: Input attention mask to avoid attending to padding tokens + input_tensor: Input embeddings tensor for shape and device information + cache_position: Position indices for caching mechanisms + past_key_values: Cached key-value pairs from previous forward passes + output_attentions: Whether attention weights will be returned + moe_token_types: Optional tensor indicating token types for MoE routing + (type 1 tokens will use bidirectional attention) + + Returns: + Updated causal attention mask, or None if using Flash Attention 2 + """ + # Flash Attention 2 handles masking internally + if self.config._attn_implementation == "flash_attention_2": + return None + + # Calculate sequence lengths for cache management + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # For SDPA (Scaled Dot Product Attention), use `is_causal` argument when possible + # instead of explicit attention mask to enable Flash Attention 2 dispatch + # Note: This optimization is not compatible with static cache + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + # Check if we can ignore the causal mask and rely on SDPA's internal handling + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + # Extract tensor properties for mask creation + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + + # Determine target length based on cache type + if using_sliding_window_cache or using_static_cache: + # Use maximum cache shape for sliding window or static caches + target_length = past_key_values.get_max_cache_shape() + else: + # For dynamic cache or no cache, calculate based on attention mask or sequence length + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # Generate 4D causal attention mask from 2D input mask if provided + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + # Modify mask to support bidirectional attention for specific token types + if moe_token_types is not None: + # Identify positions of type 1 tokens (MoE routing tokens) + type1_tokens = ( + (moe_token_types == 1).unsqueeze(1).unsqueeze(2) + ) # Shape: [B, 1, 1, S] + + # Create bidirectional attention region for type 1 tokens + # This allows type 1 tokens to attend to each other bidirectionally + type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S] + type1_region = type1_tokens & type1_tokens.transpose( + -1, -2 + ) # Shape: [B, 1, S, S] + type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool) + + # Apply bidirectional attention: zero out causal constraints in type 1 regions + causal_mask = torch.where( + type1_mask, # Where type 1 tokens interact with each other + torch.zeros_like( + causal_mask + ), # Remove causal masking (allow bidirectional) + causal_mask, # Keep original causal masking for other regions + ) + + # Handle special case for SDPA with CUDA/XPU devices + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Ensure attention to all tokens in fully masked rows for memory-efficient attention + # This is required for F.scaled_dot_product_attention's memory-efficient path + # when using left padding. See: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if ( + not isinstance(past_key_values, SlidingWindowCache) + or sequence_length > target_length + ): + sliding_attend_mask = torch.arange( + target_length, device=device + ) <= (cache_position.reshape(-1, 1) - config.sliding_window) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + return causal_mask + + +class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): + """ + Qwen2.5 Vision-Language Mixture of Experts model for action processing. + + This model extends the base Qwen2.5 VL model with action token processing capabilities + and optional LoRA fine-tuning support. + """ + + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"] + + @classmethod + def from_pretrained( + cls, + pretrained_model_path, + train_config, + config_path=None, + processor_path=None, + action_tokenizer_path=None, + **kwargs, + ): + """ + Load model from pretrained model path. + + Args: + pretrained_model_path (str): Model directory path containing model.safetensors file + config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path + processor_path (str, optional): Processor path, if None will load from default config + action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config + **kwargs: Additional arguments + + Returns: + Qwen2_5_VLMoEForAction: Loaded model instance + """ + # Load model components from pretrained path + config_path = os.path.join(pretrained_model_path, "config.json") + config = cls.config_class.from_pretrained(config_path) + processor = AutoProcessor.from_pretrained(pretrained_model_path, use_fast=True) + if action_tokenizer_path is not None: + processor.action_processor = AutoProcessor.from_pretrained( + action_tokenizer_path, trust_remote_code=True + ) + + # Initialize model with configuration and processor + model = cls(config, processor=processor, **kwargs) + + # Resize token embeddings to match processor tokenizer vocabulary size + model.resize_token_embeddings(len(processor.tokenizer)) + + # Load model state dict from safetensors file + safetensor_files = glob.glob( + os.path.join(pretrained_model_path, "*.safetensors") + ) + state_dict = {} + for file in safetensor_files: + sd = load_file(file, device="cpu") + # filter normalizer statistic params + del_keys = [] + for key in sd.keys(): + if "action_preprocessor.normalizer" in key: + print(f"filter load model weight {key}") + del_keys.append(key) + for key in del_keys: + del sd[key] + state_dict.update(sd) + + model.load_state_dict(state_dict, strict=False) + + return model + + def __init__( + self, + config, + use_fast_tokenizer=False, + processor=None, + action_tokenizer=None, + action_mapper=None, + flow_loss_weight=1.0, + ): + """ + Initialize the Qwen2.5 VLMoE model for action processing. + + Args: + config: Model configuration + use_fast_tokenizer (bool): Whether to use fast tokenizer + processor: Text and image processor + action_tokenizer: Action-specific tokenizer + action_mapper: Action mapping utility + flow_loss_weight (float): Weight for flow loss computation + """ + super().__init__(config) + + # Initialize vision transformer and language model components + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( + config.vision_config + ) + self.model = Qwen2_5_VLMoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize loss function without reduction for channel-wise loss computation + self.loss_fct = CrossEntropyLoss(reduction="none") + self.flow_loss_weight = flow_loss_weight + self.use_fast_tokenizer = use_fast_tokenizer + self.processor = processor + + # Define action token IDs + self.define_action_token_id() + + # Cache for rope deltas + self.rope_deltas = None + + # Initialize action preprocessor + self.action_preprocessor = ActionHead(config) + + # Apply LoRA if specified in configuration + if hasattr(config, "use_lora") and config.use_lora: + self.add_lora( + r=config.lora_r, + lora_alpha=config.lora_alpha, + target_modules=config.lora_target_modules, + lora_dropout=config.lora_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + def to_bfloat16_for_selected_params(self): + self.to(dtype=torch.bfloat16) + + params_to_keep_float32 = [] + + for name, param in self.named_parameters(): + if "input_layernorm" in name or "post_attention_layernorm" in name or "model.norm" in name: + params_to_keep_float32.append(name) + if "action_preprocessor" in name: + params_to_keep_float32.append(name) + + for name, param in self.named_parameters(): + if name in params_to_keep_float32: + param.data = param.data.to(torch.float32) + + def define_action_token_id(self): + """ + Define action token IDs based on tokenizer configuration. + + Creates mappings for fast action tokens, proprioception tokens, and general action tokens. + """ + # Create list of fast action token IDs + fast_action_token_list = [] + if self.use_fast_tokenizer: + for i in range( + self.processor.tokenizer.init_kwargs["action_token_vocab_size"] + ): + action_token_id = self.processor.tokenizer.convert_tokens_to_ids( + f"<|action_token_{i}|>" + ) + fast_action_token_list.append(action_token_id) + + # Get special action token IDs + action_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|action|>") + propri_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|propri|>") + + # Store action token ID mappings + self.action_token_id_set = { + "fast_action_token_list": fast_action_token_list, + "propri_token_id": propri_token_id, + "action_token_id": action_token_id, + } + + def add_lora( + self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1 + ): + """ + Add LoRA (Low-Rank Adaptation) adapters to the model. + + Args: + r (int): Rank of adaptation + lora_alpha (int): LoRA scaling parameter + target_modules (list): List of module names to apply LoRA to + lora_dropout (float): Dropout probability for LoRA layers + """ + config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + self.model = get_peft_model(self.model, config) + + # Print information about trainable parameters + self.model.print_trainable_parameters() + + def get_input_embeddings(self): + """Get input embeddings layer.""" + return self.model.embed_tokens + + def set_input_embeddings(self, value): + """Set input embeddings layer.""" + self.model.embed_tokens = value + + def get_output_embeddings(self): + """Get output embeddings layer.""" + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + """Set output embeddings layer.""" + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + """Set the decoder model.""" + self.model = decoder + + def get_decoder(self): + """Get the decoder model.""" + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens. + + This method computes position embeddings that account for the temporal, height, and width + dimensions of vision tokens (images/videos) while maintaining standard 1D position embeddings + for text tokens. + + For vision tokens, 3D position embeddings are calculated based on: + - Temporal dimension: Time patches in videos + - Height dimension: Vertical patches in images/video frames + - Width dimension: Horizontal patches in images/video frames + + For text tokens, standard 1D position embeddings are used, continuing from the maximum + vision position ID plus 1. + + Args: + input_ids (torch.LongTensor, optional): Input token IDs of shape (batch_size, sequence_length) + image_grid_thw (torch.LongTensor, optional): Image grid dimensions (num_images, 3) for [temporal, height, width] + video_grid_thw (torch.LongTensor, optional): Video grid dimensions (num_videos, 3) for [temporal, height, width] + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid (num_videos,) + attention_mask (torch.Tensor, optional): Attention mask (batch_size, sequence_length) + + Returns: + tuple: + - position_ids (torch.LongTensor): 3D position IDs of shape (3, batch_size, sequence_length) + - mrope_position_deltas (torch.Tensor): Position deltas for mRoPE of shape (batch_size, 1) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + # Initialize 3D position IDs tensor + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + # Process each sequence in the batch + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + + # Find vision tokens and count images/videos + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + # Process each vision token (image or video) + for _ in range(image_nums + video_nums): + # Find next image or video token + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + # Determine if processing image or video token + if ed_image < ed_video: + # Process image token + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + # Process video token + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + + # Calculate grid dimensions after spatial merging + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + # Add position IDs for text tokens before vision token + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + # Calculate 3D position embeddings for vision tokens + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + # Calculate temporal position IDs with time scaling + time_tensor = ( + expanded_range + * second_per_grid_t + * self.config.vision_config.tokens_per_second + ) + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + # Calculate spatial position IDs + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + + # Add 3D position IDs for vision tokens + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + # Add position IDs for remaining text tokens + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + # Concatenate all position IDs for this sequence + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + # Handle case without vision tokens - use standard 1D position embeddings + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def train_step_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + moe_token_types: Optional[ + torch.LongTensor + ] = None, # MoE token type assignments + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + action_chunk: Optional[torch.FloatTensor] = None, # Action trajectory chunks + proprioception: Optional[ + torch.FloatTensor + ] = None, # Joint position/orientation data + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + dataset_names: Optional[str] = None, + dof_mask: Optional[torch.FloatTensor] = None, + agent_pos_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]: + """ + Forward pass for training with multi-modal inputs including vision, text, and action data. + + This method handles the complete forward pass during training, processing various input modalities + including images, videos, text, proprioceptive data, and action sequences. It computes losses + for both language modeling and action prediction using flow matching. + + Args: + input_ids (torch.LongTensor, optional): Input token IDs + attention_mask (torch.Tensor, optional): Attention mask for input tokens + position_ids (torch.LongTensor, optional): Position IDs for tokens + past_key_values (List[torch.FloatTensor], optional): Cached key-value pairs for generation + inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings + moe_token_types (torch.LongTensor, optional): Token type assignments for MoE routing + labels (torch.LongTensor, optional): Target labels for loss computation + use_cache (bool, optional): Whether to use key-value caching + output_attentions (bool, optional): Whether to return attention weights + output_hidden_states (bool, optional): Whether to return hidden states + return_dict (bool, optional): Whether to return structured output + pixel_values (torch.Tensor, optional): Image pixel values + pixel_values_videos (torch.FloatTensor, optional): Video pixel values + image_grid_thw (torch.LongTensor, optional): Image grid dimensions (temporal, height, width) + video_grid_thw (torch.LongTensor, optional): Video grid dimensions (temporal, height, width) + action_chunk (torch.FloatTensor, optional): Action trajectory data chunks + proprioception (torch.FloatTensor, optional): Proprioceptive sensor data (joint positions, etc.) + rope_deltas (torch.LongTensor, optional): RoPE position deltas + cache_position (torch.LongTensor, optional): Cache position indices + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid + dataset_names (str, optional): Names of datasets in the current batch + dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens + agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data + **kwargs: Additional keyword arguments + + Returns: + Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]: Model outputs including losses, logits, + and auxiliary information, or tuple if return_dict=False + """ + batch_size, seq_length = input_ids.shape + + # Set output configuration from model config if not specified + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Calculate RoPE position IDs if not provided + # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation + if position_ids is None and ( + attention_mask is None or attention_mask.ndim == 2 + ): + # Calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # Use previously calculated rope deltas to get correct position IDs + else: + delta = ( + (cache_position[0] + self.rope_deltas).to(self.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=self.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # Process input embeddings with multi-modal data + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + + # Process image embeddings + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # Process video embeddings + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + + # Validate video token and feature count match + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # Process proprioceptive data (joint positions, orientations, etc.) + if proprioception is not None: + proprioception = proprioception.to(inputs_embeds.device).to( + inputs_embeds.dtype + ) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to( + inputs_embeds.dtype + ) + proprioception = self.action_preprocessor.proprioception_proj( + proprioception, + dataset_names, + agent_pos_mask, + use_history=proprioception.shape[1] > 1, + ) + mask = input_ids == self.action_token_id_set["propri_token_id"] + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + proprioception_mask = mask_expanded.to(inputs_embeds.device) + + proprioception = proprioception.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + proprioception_mask, proprioception + ) + elif self.training: + # Dummy forward pass to ensure gradient registration in DDP + # This handles cases where one process has proprioception data while another doesn't + # Without this, DDP would hang waiting for a gradient that will never be computed + dummy_input = torch.randn( + 2, + self.action_preprocessor.propri_dim * 2, + device=inputs_embeds.device, + ) + dummy_forward = self.action_preprocessor.proprioception_proj( + dummy_input + ) + dummy_loss = sum(p.sum() for p in dummy_forward) + inputs_embeds = inputs_embeds + 0 * dummy_loss + + # Process action chunk data + if action_chunk is not None: + action_chunk = action_chunk.to(inputs_embeds.device).to( + inputs_embeds.dtype + ) + dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) + noisy_action_emb, flow = self.action_preprocessor( + action_chunk, dataset_names, dof_mask + ) + mask = input_ids == self.action_token_id_set["action_token_id"] + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + action_mask = mask_expanded.to(inputs_embeds.device) + + noisy_action_emb = noisy_action_emb.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + action_mask, noisy_action_emb + ) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # Forward pass through the main model + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + moe_token_types=moe_token_types, # Pass token types for MoE routing + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = hidden_states.to(self.lm_head.weight.dtype) + logits = self.lm_head(hidden_states) + + # Initialize loss computation variables + loss = None + cross_entropy_loss, flow_loss = None, None + channel_loss_dict = None + channel_loss_count_dict = None + + # Compute losses if labels are provided + if labels is not None: + loss = 0 + action_accuracy = 0 + unique_datasets_name = list(set(dataset_names)) + + # Initialize per-dataset loss tracking dictionaries + channel_loss_dict = { + dataset_name: torch.tensor(0.0, device=logits.device) + for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES + } + channel_loss_count_dict = { + dataset_name: torch.tensor(0, device=logits.device) + for dataset_name in ACTION_DATASET_NAMES + MULTIMODAL_DATASET_NAMES + } + + # Compute standard cross-entropy loss for language modeling + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + + # Enable model parallelism by moving labels to correct device + shift_labels = shift_labels.to(shift_logits.device) + non_ignored_mask = shift_labels != -100 + _cross_entropy_loss = self.loss_fct(shift_logits, shift_labels) + cross_entropy_loss = ( + _cross_entropy_loss[non_ignored_mask].mean() + if non_ignored_mask.any() + else torch.tensor(0.0, device=shift_logits.device) + ) + + # Compute per-dataset channel losses + _cross_entropy_loss = _cross_entropy_loss.view(batch_size, seq_length - 1) + non_ignored_mask = non_ignored_mask.view(batch_size, seq_length - 1) + for dataset_name_i in unique_datasets_name: + dataset_mask = torch.tensor( + [name == dataset_name_i for name in dataset_names], + device=logits.device, + ) + combined_mask = dataset_mask.unsqueeze(1) & non_ignored_mask + channel_loss_dict[dataset_name_i] = ( + _cross_entropy_loss[combined_mask].sum() + if combined_mask.any() + else torch.tensor(0.0, device=shift_logits.device) + ) + channel_loss_count_dict[dataset_name_i] += combined_mask.sum() + + # Add cross-entropy loss to total loss if valid + if not torch.isnan(cross_entropy_loss): + loss += cross_entropy_loss + else: + with torch.no_grad(): + cross_entropy_loss.detach() + + # Compute action token prediction accuracy + shift_logits = logits[..., :-1, :].contiguous() + action_preds = shift_logits.argmax(dim=-1) + shift_labels = labels[..., 1:].contiguous() + if self.use_fast_tokenizer: + action_mask = ( + shift_labels > self.action_token_id_set["fast_action_token_list"][0] + ) + correct_preds = (action_preds == shift_labels) & action_mask + action_accuracy = ( + correct_preds.sum().float() / action_mask.sum().float() + ) + channel_loss_dict["action_accuracy"] = action_accuracy + + if action_chunk is not None: + action_mask = input_ids == self.action_token_id_set["action_token_id"] + if action_mask.any(): + action_hidden_states = hidden_states[action_mask].to(torch.float32) + flow = flow.reshape(-1, flow.shape[-1]) + _flow_loss = self.action_preprocessor.flow_loss( + action_hidden_states, flow, dof_mask + ) + if isinstance(_flow_loss, torch.Tensor): + flow_loss = _flow_loss.mean() + if loss is not None: + loss += self.flow_loss_weight * flow_loss + else: + loss = self.flow_loss_weight * flow_loss + _flow_loss = _flow_loss.view( + dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2] + ) + + # Return outputs based on return_dict setting + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLACausalLMOutputWithPast( + loss=loss, + cross_entropy_loss=( + cross_entropy_loss.clone() if cross_entropy_loss is not None else None + ), + flow_loss=flow_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + channel_loss_dict=channel_loss_dict, + channel_loss_count_dict=channel_loss_count_dict, + ) + + def predict_action(self, predict_mode: str, **kwargs): + """ + Predict actions using specified prediction mode. + + Args: + predict_mode (str): Prediction mode, either "fast" or "diffusion" + **kwargs: Additional arguments passed to the predict method + + Returns: + tuple: (predicted_action, ground_truth_action) where ground_truth_action may be None + """ + assert predict_mode in ["fast", "diffusion"] + + output = self.predict(predict_mode=predict_mode, **kwargs) + + return output["predict_action"], output.get("gt_action", None) + + @torch.no_grad() + def predict( + self, + predict_mode: str, + pred_horizon: Optional[int] = None, + action_dim: Optional[int] = None, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + moe_token_types: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + action_chunk: Optional[torch.FloatTensor] = None, + proprioception: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + num_inference_timesteps: Optional[int] = 10, + dataset_names: Optional[str] = None, + dof_mask: Optional[torch.FloatTensor] = None, + agent_pos_mask: Optional[torch.FloatTensor] = None, + re_generate: bool = False, + **kwargs, + ): + """ + Multi-modal prediction method supporting text generation, fast action prediction, and diffusion-based action prediction. + + This method handles three prediction modes: + 1. "text": Pure text generation using autoregressive decoding + 2. "fast": Fast action prediction using discrete action tokens + 3. "diffusion": Continuous action prediction using diffusion/flow matching + + Args: + predict_mode (str): Prediction mode ("text", "fast", or "diffusion") + pred_horizon (int, optional): Prediction horizon for action sequences + action_dim (int, optional): Dimensionality of action space + input_ids (torch.LongTensor, optional): Input token IDs + attention_mask (torch.Tensor, optional): Attention mask for input tokens + position_ids (torch.LongTensor, optional): Position IDs for tokens + past_key_values (List[torch.FloatTensor], optional): Cached key-value pairs + inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings + moe_token_types (torch.LongTensor, optional): Token type assignments for MoE routing + labels (torch.LongTensor, optional): Target labels for evaluation + use_cache (bool, optional): Whether to use key-value caching + output_attentions (bool, optional): Whether to return attention weights + output_hidden_states (bool, optional): Whether to return hidden states + return_dict (bool, optional): Whether to return structured output + pixel_values (torch.Tensor, optional): Image pixel values + pixel_values_videos (torch.FloatTensor, optional): Video pixel values + image_grid_thw (torch.LongTensor, optional): Image grid dimensions + video_grid_thw (torch.LongTensor, optional): Video grid dimensions + action_chunk (torch.FloatTensor, optional): Ground truth action sequences + proprioception (torch.FloatTensor, optional): Proprioceptive sensor data + rope_deltas (torch.LongTensor, optional): RoPE position deltas + cache_position (torch.LongTensor, optional): Cache position indices + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid + num_inference_timesteps (int, optional): Number of diffusion inference steps + dataset_names (str, optional): Dataset names for normalization + dof_mask (torch.FloatTensor, optional): Degrees of freedom mask + agent_pos_mask (torch.FloatTensor, optional): Agent position mask + re_generate (bool, optional): Whether to use sampling for regeneration + **kwargs: Additional keyword arguments + + Returns: + dict: Dictionary containing prediction results with keys like: + - 'predict_action': Predicted action sequences + - 'gt_action': Ground truth actions (if available) + - 'input_text': Input text (for text/fast modes) + - 'predict_output_text': Generated text (for text/fast modes) + - 'gt_output_text': Ground truth text (for text/fast modes) + """ + batch_size = ( + input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + ) + + # Text and fast modes require batch size 1 for autoregressive generation + if predict_mode in ["text", "fast"]: + assert ( + batch_size == 1 + ), "predict only support batch size 1 for ar generation" + + # Set output configuration from model config if not specified + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Process input embeddings with multi-modal data + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + + # Process image embeddings + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + + # Validate image token and feature count match + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # Process video embeddings + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + + # Validate video token and feature count match + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # Process proprioceptive data + if proprioception is not None: + proprioception = proprioception.to(inputs_embeds.device).to( + inputs_embeds.dtype + ) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to( + inputs_embeds.dtype + ) + proprio_embed = self.action_preprocessor.proprioception_proj( + proprioception, + dataset_names, + agent_pos_mask, + use_history=proprioception.shape[1] > 1, + ) + proprioception_mask = ( + input_ids == self.action_token_id_set["propri_token_id"] + ) + proprio_embed = proprio_embed.to(torch.bfloat16) + inputs_embeds[proprioception_mask] = proprio_embed.reshape( + -1, inputs_embeds.shape[-1] + ) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # Calculate RoPE position IDs if not provided + # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation + if position_ids is None and ( + attention_mask is None or attention_mask.ndim == 2 + ): + # Calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # Use previously calculated rope deltas to get correct position IDs + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # Prepare action chunk data if provided + if action_chunk is not None: + action_chunk = action_chunk.to(inputs_embeds.device).to(torch.float32) + + output = {} + + # Split input sequence for text and fast modes (not needed for diffusion) + if predict_mode == "text" or predict_mode == "fast": + # Look for generation prompt tokens: <|im_start|>assistant + generation_prompt_ids = torch.tensor( + [151644, 77091], device=input_ids.device, dtype=input_ids.dtype + ) + matches = (input_ids[0, :-1] == generation_prompt_ids[0]) & ( + input_ids[0, 1:] == generation_prompt_ids[1] + ) + + if matches.any(): + split_pos = torch.nonzero(matches, as_tuple=True)[0][0].item() + # Extract ground truth output tokens (including newline) + gt_output_ids = input_ids[:, split_pos + 3 :] + # Remove output part from input, keeping prompt + input_ids = input_ids[:, : split_pos + 3] + inputs_embeds = inputs_embeds[:, : split_pos + 3, :] + if attention_mask is not None: + attention_mask = attention_mask[:, : split_pos + 3] + if labels is not None: + labels = labels[:, split_pos + 3 :] + else: + raise Warning( + "input_ids does not contain the generation prompt tokens <|im_start|>assistant" + ) + + # Decode input text for output + input_text = self.processor.batch_decode( + input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True + ) + output["input_text"] = input_text + + # Handle text and fast prediction modes using autoregressive generation + if predict_mode == "text" or predict_mode == "fast": + # Initialize MoE token types for generation + moe_token_types = torch.zeros_like(input_ids) + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "moe_token_types": moe_token_types, + "image_grid_thw": image_grid_thw, + "dof_mask": dof_mask, + "agent_pos_mask": agent_pos_mask, + "proprioception": proprioception, + "dataset_names": dataset_names, + } + + # Generate output tokens + predict_output_ids = self.generate( + **batch, + max_new_tokens=100, + eos_token_id=[self.processor.tokenizer.eos_token_id], + use_cache=True, + pad_token_id=self.processor.tokenizer.pad_token_id, + temperature=( + 1.0 if not re_generate else 0.7 + ), # Higher temperature for regeneration + do_sample=( + False if not re_generate else True + ), # Enable sampling for regeneration + ) + + # Decode generated and ground truth text + gt_output_text = self.processor.batch_decode( + gt_output_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + predict_output_text = self.processor.batch_decode( + predict_output_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + output["gt_output_text"] = gt_output_text + output["predict_output_text"] = predict_output_text + + # Convert tokens to actions for fast prediction mode + if predict_mode == "fast": + action_id = [] + # Extract action tokens from generated sequence + for token_id_i in predict_output_ids[0]: + if ( + token_id_i.item() + >= self.processor.tokenizer.init_kwargs["action_token_start_index"] + ): + action_id.append( + token_id_i.item() + - self.processor.tokenizer.init_kwargs[ + "action_token_start_index" + ] + ) + + predict_action = self.processor.action_processor.decode( + [action_id], time_horizon=pred_horizon, action_dim=action_dim + ) + # Handle action decoding errors + if np.sum(predict_action) == 0: + print("Error in decoding action, predict_action is None") + output["predict_action"] = None + else: + # Convert discrete tokens to continuous actions + predict_action = torch.tensor(predict_action, device=self.device) + dof_mask = dof_mask.to(self.device).to(pixel_values.dtype) + # removed unnormalization step for now + predict_action = ( + predict_action[:, :, dof_mask[0, 0, :].bool()] + ) + output["predict_action"] = predict_action + + # Process ground truth actions if available + if action_chunk is not None: + # Apply DOF mask to get ground truth actions + # removed unnormalization step for now + action_chunk = action_chunk[:, :, dof_mask[0, 0, :].bool()] + output["gt_action"] = action_chunk + else: + output["gt_action"] = None + + # Handle diffusion-based action prediction + if predict_mode == "diffusion": + # Initialize with random noise + noisy_action = torch.randn( + size=(batch_size, pred_horizon, action_dim), + dtype=torch.float32, + device=inputs_embeds.device, + ) + dof_mask = dof_mask.to(inputs_embeds.device).to(torch.float32) + + def step(timestep, noisy_action): + """ + Single denoising step for diffusion process. + + Args: + timestep: Current diffusion timestep + noisy_action: Current noisy action estimate + + Returns: + torch.Tensor: Predicted clean action + """ + action_mask = input_ids == self.action_token_id_set["action_token_id"] + assert action_mask.any(), "No action token found in input_ids" + + # Prepare timestep for batch processing + timestep = timestep.unsqueeze(0).repeat(noisy_action.shape[0]) + action_embed = self.action_preprocessor.step( + timestep=timestep, noisy_action=noisy_action, dof_mask=dof_mask + ) + action_embed = action_embed.reshape(-1, inputs_embeds.shape[-1]) + + # Ensure action_embed has the correct dtype and device before assignment + action_embed = action_embed.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + # Create temporary copy of embeddings (clone preserves dtype) + temp_inputs_embeds = inputs_embeds.clone() + temp_inputs_embeds[action_mask] = action_embed + + # Forward pass through transformer + transformer_outputs = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=temp_inputs_embeds, + moe_token_types=moe_token_types, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ) + + # Extract action predictions from hidden states + hidden_states = transformer_outputs.last_hidden_state + action_mask = input_ids == self.action_token_id_set["action_token_id"] + action_hidden_states = hidden_states[action_mask].to(torch.float32) + pred = self.action_preprocessor.action_proj_back(action_hidden_states) + return pred.reshape(batch_size, pred_horizon, action_dim) + + # Perform ODE integration for diffusion sampling + times = torch.linspace( + 0, + 1, + num_inference_timesteps+1, + device=inputs_embeds.device, + dtype=torch.float32, + ) + action_trajectory = odeint(step, noisy_action, times, method="euler") + + # Extract final predicted action + # Removed unnormalization step for now + predict_action = action_trajectory[-1] + output["predict_action"] = predict_action + + # Process ground truth actions if available + # removed unnormalization step for now + if action_chunk is not None: + output["gt_action"] = action_chunk[:, :, dof_mask[0, 0, :].bool()] + + return output + + def forward( + self, mode: Optional[str] = None, predict_mode: Optional[str] = "text", **kwargs + ): + """ + Main forward pass dispatcher for different execution modes. + + This method routes execution to appropriate forward functions based on the specified mode: + - No mode (None): Training step with gradient disabled + - 'predict': Prediction/inference mode + - 'train': Training mode with gradients enabled + - 'validate': Validation mode with gradients disabled + + Args: + mode (str, optional): Execution mode. If None, defaults to training step without gradients + predict_mode (str, optional): Prediction mode for 'predict' mode ("text", "fast", or "diffusion") + **kwargs: Additional arguments passed to the selected forward function + + Returns: + Model outputs appropriate for the selected mode + + Todo: + - Add support for distinguishing multi-modal data types in prediction mode + """ + if not mode: + with torch.no_grad(): + return self.train_step_forward(**kwargs) + elif mode == "predict": + return self.predict(predict_mode=predict_mode, **kwargs) + elif mode == "train": + return self.train_step_forward(use_cache=False, **kwargs) + elif mode == "validate": + with torch.no_grad(): + return self.train_step_forward(use_cache=False, **kwargs) + else: + raise NotImplementedError("invalid key") + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + moe_token_types=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + dataset_names=None, + proprioception=None, + dof_mask=None, + agent_pos_mask=None, + **kwargs, + ): + """ + Prepare inputs for autoregressive generation with multi-modal support. + + This method handles input preparation for generation, including proper slicing of inputs + based on cache position, MoE token type management, and multi-modal data handling. + Vision inputs are selectively forwarded only when needed during generation. + + Args: + input_ids: Input token IDs + past_key_values: Cached key-value pairs from previous generation steps + attention_mask: Attention mask for input tokens + inputs_embeds: Pre-computed input embeddings + moe_token_types: Token type assignments for MoE routing + cache_position: Current cache position for generation + position_ids: Position IDs for tokens + use_cache: Whether to use key-value caching + pixel_values: Image pixel values + pixel_values_videos: Video pixel values + image_grid_thw: Image grid dimensions + video_grid_thw: Video grid dimensions + second_per_grid_ts: Time interval per temporal grid + dataset_names: Dataset names for processing + proprioception: Proprioceptive sensor data + dof_mask: Degrees of freedom mask + agent_pos_mask: Agent position mask + **kwargs: Additional arguments + + Returns: + dict: Prepared model inputs for generation step + + Todo: + - Test this function thoroughly with various input configurations + + Note: + This is an overridden method that handles specific cases for multi-modal generation: + - Slices input_ids through cache_position to keep only unprocessed tokens + - Handles special cases for input_embeds, generation methods, and GPU synchronization + - Manages vision inputs to avoid unnecessary forward passes + """ + # Initialize MoE token types if not provided + if moe_token_types is None: + moe_token_types = torch.zeros_like( + input_ids + ) # FIXME: Handle case when input_embeds is used instead + else: + # Ensure moe_token_types length matches input_ids + if moe_token_types.shape[1] < input_ids.shape[1]: + # Calculate required padding length + pad_length = input_ids.shape[1] - moe_token_types.shape[1] + # Create padding tensor with default token type (0) + pad_tensor = torch.zeros( + (moe_token_types.shape[0], pad_length), + dtype=moe_token_types.dtype, + device=moe_token_types.device, + ) + # Concatenate padding to existing moe_token_types + moe_token_types = torch.cat([moe_token_types, pad_tensor], dim=1) + + # Handle input slicing based on cache state and special cases + if past_key_values is not None: + if ( + inputs_embeds is not None and input_ids.shape[1] == 0 + ): # Exception 4: input_embeds case + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + moe_token_types = moe_token_types[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1: input_embeds provided + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3: GPU sync edge case + input_ids = input_ids[:, -cache_position.shape[0] :] + moe_token_types = moe_token_types[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (Exception 2 is no-op) + cache_pos = cache_position.clone() + input_ids = input_ids[:, cache_pos] + moe_token_types = moe_token_types[:, cache_pos] + + # Skip vision inputs for continuation steps (not initial generation) + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # Determine whether to use inputs_embeds or input_ids for this generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + # Prepare 4D causal attention mask for static cache + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = ( + self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + ) + + # Assemble all model inputs for generation + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "moe_token_types": moe_token_types, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + "proprioception": proprioception, + "dataset_names": dataset_names, + "dof_mask": dof_mask, + "agent_pos_mask": agent_pos_mask, + } + ) + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate tensor separation lengths. + + These parameters are computed directly from input_ids rather than being passed through + the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (torch.LongTensor): Input token IDs of shape (batch_size, sequence_length) + + Returns: + tuple: + - image_nums (torch.LongTensor): Number of images per sample + - video_nums (torch.LongTensor): Number of videos per sample + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Find vision start tokens and their following tokens + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + # Count images and videos following vision start tokens + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + """ + Expand inputs for generation with support for multi-modal tensors. + + This is an overridden method that supports expanding tensors without a standard batch + size dimension, specifically for vision-related tensors: + - pixel_values.shape[0] = sum(sequence_lengths for all image samples) + - image_grid_thw.shape[0] = sum(num_images for all samples) + - Similar patterns for video tensors + + Args: + expand_size (int): Factor by which to expand inputs (for beam search, etc.) + is_encoder_decoder (bool): Whether using encoder-decoder architecture + input_ids (torch.LongTensor, optional): Input token IDs + **model_kwargs: Additional model arguments to expand + + Returns: + tuple: (expanded_input_ids, expanded_model_kwargs) + """ + if expand_size == 1: + return input_ids, model_kwargs + + # Define keys for vision-related tensors that need special handling + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + """Expand vision-related tensors based on image/video counts per sample.""" + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + """Split tensor by lengths and repeat each sample.""" + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat( + [sample.repeat(*repeat_args) for sample in samples], dim=0 + ) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # Split images into samples and compute sequence lengths + samples = torch.split(image_grid_thw, list(image_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # Expand based on number of images per sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + # Split videos into samples and compute sequence lengths + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + # Expand based on number of videos per sample + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + # Handle list-type temporal grid data + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples( + tensor, lengths=lengths, repeat_times=expand_size + ) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + """Expand standard tensors using repeat_interleave.""" + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave( + expand_size, dim=0 + ) + return dict_to_expand + + # Expand visual inputs only if input_ids is available for counting images/videos + # If input_ids is unavailable, visual inputs won't be used, so no expansion needed + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + # Expand input_ids using standard repeat_interleave + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + # Expand all other model arguments + model_kwargs = _expand_dict_for_generation(model_kwargs) + + # Handle encoder-decoder specific expansion + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation( + model_kwargs["encoder_outputs"] + ) + + return input_ids, model_kwargs class WallXPolicy(PreTrainedPolicy): @@ -269,10 +2317,7 @@ class WallXPolicy(PreTrainedPolicy): self.config = config # Initialize VLM wrapper - self.vlm = WallXVLMWrapper(config) - - # Initialize action head - self.action_head = ActionHead(config) + self.model = Qwen2_5_VLMoEForAction(config) self.reset() @@ -286,7 +2331,7 @@ class WallXPolicy(PreTrainedPolicy): """Get parameters for optimization.""" params = [] - if self.vlm.available: + if self.model.visual.available: # Add VLM parameters if not self.config.train_expert_only: params.extend(self.vlm.model.parameters()) @@ -297,59 +2342,224 @@ class WallXPolicy(PreTrainedPolicy): return params - def prepare_images(self, batch): - """Prepare images for VLM processing.""" - images = [] - present_img_keys = [key for key in self.config.image_features if key in batch] - - if len(present_img_keys) == 0: - raise ValueError("No image features found in batch") - - for key in present_img_keys: - img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key] - images.append(img) - - return images - - def prepare_state(self, batch): - """Prepare proprioceptive state.""" - state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE] - # Pad to expected dimension - if state.shape[-1] < self.config.max_state_dim: - padding = torch.zeros( - *state.shape[:-1], - self.config.max_state_dim - state.shape[-1], - device=state.device, - dtype=state.dtype - ) - state = torch.cat([state, padding], dim=-1) - return state - - def prepare_action(self, batch): - """Prepare action chunk.""" - actions = batch[ACTION] - # Pad to expected dimension - if actions.shape[-1] < self.config.max_action_dim: - padding = torch.zeros( - *actions.shape[:-1], - self.config.max_action_dim - actions.shape[-1], - device=actions.device, - dtype=actions.dtype - ) - actions = torch.cat([actions, padding], dim=-1) - return actions - - def _create_dof_mask(self, batch_size, device, dtype): - """Create DOF mask for action dimensions.""" - # Create mask showing which dimensions are active - mask = torch.ones( - batch_size, - self.config.chunk_size, - sum(self.config.dof_config.values()), - device=device, - dtype=dtype + def preprocess_inputs( + self, + batch: List[Dict[str, Any]], + config: Dict[str, Any], + dataload_config: Dict[str, Any], + norm_stats: Dict[str, Any], + lerobot_config: Dict[str, Any], + processor: Any, + action_tokenizer: Optional[Any] = None, + camera_keys: Optional[List[str]] = None, + ) -> BatchFeature: + """ + Convert a batch of LeRobot dataset items to Wall-X model input format. + + This is the batch version of convert_lerobot_to_wallx_format, processing multiple + samples together for efficient batched inference or training. + + Args: + batch: List of items from LeRobot dataset + config: Model configuration dict + dataload_config: Data loading configuration + norm_stats: Normalization statistics + lerobot_config: LeRobot config containing 'repo_id' + processor: Hugging Face processor for tokenization + action_tokenizer: Optional action tokenizer + camera_keys: List of camera keys in the dataset + + Returns: + BatchFeature containing batched model inputs + + Example: + >>> batch = [dataset[i] for i in range(8)] + >>> result = preprocess_inputs( + ... batch=batch, + ... config=config, + ... dataload_config=dataload_config, + ... norm_stats=norm_stats, + ... lerobot_config={'repo_id': 'lerobot/aloha_mobile_cabinet'}, + ... processor=processor, + ... ) + """ + repo_id = lerobot_config["repo_id"] + use_fast_tokenizer = config.get("use_fast_tokenizer", False) + + # Get key mappings + cam_key_mapping = KEY_MAPPINGS[repo_id]["camera"] + state_key = KEY_MAPPINGS[repo_id]["state"] + action_key = KEY_MAPPINGS[repo_id]["action"] + + # Build data config + data_config = X2RDataProcessingConfig().update( + train_test_split=dataload_config.get("train_test_split", 0.95), + split_seed=dataload_config.get("split_seed", 42), + predict_action_keys=dataload_config.get("predict_action_keys", []), + obs_action_keys=dataload_config.get("obs_action_keys", []), + resolution=dataload_config.get("resolution", None), + priority_order=dataload_config.get("priority_order", None), ) - return mask + + if camera_keys is None: + camera_keys = list(cam_key_mapping.keys()) + + # ==================== PROCESS ALL SAMPLES ==================== + all_image_inputs = [] + all_texts = [] + all_agent_pos = [] + all_actions = [] + all_frame_indices = [] + + for data in batch: + # Vision preprocessing per sample + processed_frames = [] + orig_height, orig_width = None, None + resized_height, resized_width = None, None + + for key in camera_keys: + current_obs = data[key].clone() + if current_obs.dim() == 3: + current_obs = current_obs.permute(1, 2, 0) + + img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy()) + orig_width, orig_height = img_pil.size + + cam_name = cam_key_mapping.get(key, key) + target_size = data_config.resolution.get(cam_name, -1) + if target_size != -1: + if orig_width > orig_height: + new_width = target_size + new_height = int(target_size * orig_height / orig_width) + else: + new_height = target_size + new_width = int(target_size * orig_width / orig_height) + img_pil = img_pil.resize((new_width, new_height)) + + current_width, current_height = img_pil.size + resized_height, resized_width = smart_resize( + current_height, + current_width, + factor=data_config.image_factor, + min_pixels=data_config.min_pixels, + max_pixels=data_config.max_pixels, + ) + resized_img = img_pil.resize((resized_width, resized_height)) + processed_frames.append(resized_img) + + all_image_inputs.append(processed_frames) + + # Text preprocessing + frame_index = data["frame_index"] + instruction_info = {"instruction": data["task"]} + + complete_text, _ = get_wallx_normal_text( + instruction_info, + dataload_config.get("action_horizon", 33) - 1, + frame_index, + data_config.priority_order, + cam_key_mapping, + generate_subtask_ratio=data_config.generate_subtask_ratio, + ) + + text = process_grounding_points( + complete_text, orig_height, orig_width, resized_height, resized_width, + data_config.model_type + ) + all_texts.append(text) + + # Collect raw values + all_agent_pos.append(data[state_key]) + all_actions.append(data[action_key]) + all_frame_indices.append(frame_index) + + # ==================== BATCH NORMALIZATION ==================== + action_min_stat = norm_stats["action"].min + action_delta = norm_stats["action"].delta + state_min_stat = norm_stats["state"].min + state_delta = norm_stats["state"].delta + + def normalize(x, min_stat, delta): + delta = torch.where(delta == 0, torch.ones_like(delta), delta) + x = (x - min_stat) / delta + x = x * 2 - 1 + return torch.clamp(x, -1, 1) + + # Stack and normalize agent_pos + agent_pos = torch.stack(all_agent_pos) + if agent_pos.dim() == 2: + agent_pos = agent_pos.unsqueeze(1) + agent_pos_mask = (~torch.isnan(agent_pos)).float() + agent_pos = agent_pos.nan_to_num(nan=0.0) + agent_pos = normalize(agent_pos, state_min_stat, state_delta) + + if agent_pos.shape[-1] != 20: + pad_size = 20 - agent_pos.shape[-1] + agent_pos = torch.cat([ + agent_pos, + torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size) + ], dim=-1) + agent_pos_mask = torch.cat([ + agent_pos_mask, + torch.zeros(agent_pos_mask.shape[0], agent_pos_mask.shape[1], pad_size) + ], dim=-1) + + # Stack and normalize actions + action = torch.stack(all_actions) + if action.dim() == 2: + action = action.unsqueeze(1) + dof_mask = (~torch.isnan(action)).float() + action = action.nan_to_num(nan=0.0) + action = normalize(action, action_min_stat, action_delta) + + if action.shape[-1] != 20: + pad_size = 20 - action.shape[-1] + action = torch.cat([ + action, + torch.zeros(action.shape[0], action.shape[1], pad_size) + ], dim=-1) + dof_mask = torch.cat([ + dof_mask, + torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size) + ], dim=-1) + + # ==================== ACTION TOKEN REPLACEMENT ==================== + all_texts = replace_action_token( + all_texts, + action, + action_tokenizer if use_fast_tokenizer else None, + [repo_id] * len(batch), + dof_mask, + ) + + # ==================== TOKENIZATION ==================== + inputs = preprocesser_call( + processor=processor, + text=all_texts, + images=all_image_inputs, + videos=None, + padding=True, + truncation=True, + return_tensors="pt", + max_length=dataload_config.get("max_length", 768), + ) + + # ==================== ADDITIONAL INPUTS ==================== + action_token_id = processor.tokenizer.convert_tokens_to_ids("<|action|>") + moe_token_types = inputs.input_ids == action_token_id + + inputs["proprioception"] = agent_pos + inputs["agent_pos_mask"] = agent_pos_mask + inputs["action_chunk"] = action + inputs["dof_mask"] = dof_mask + inputs["moe_token_types"] = moe_token_types + inputs["dataset_names"] = [repo_id] * len(batch) + inputs["frame_index"] = torch.stack([ + torch.tensor(fi) if not isinstance(fi, torch.Tensor) else fi + for fi in all_frame_indices + ]) + + return inputs def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """ diff --git a/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py new file mode 100644 index 000000000..439a36923 --- /dev/null +++ b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py @@ -0,0 +1,248 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2_5_VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2_5_VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + num_experts=4, + experts=None, + dof_config=None, + noise_scheduler=None, + dim_inputs=(1536, 1536), + attention_moe=False, + mlp_moe=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + self.num_experts = num_experts + self.experts = experts + self.dof_config = dof_config + self.noise_scheduler = noise_scheduler + self.dim_inputs = tuple(dim_inputs) + self.attention_moe = attention_moe + self.mlp_moe = mlp_moe + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations + # one can set it to "linear"/"dynamic" etc. to have scaled RoPE + # TODO: @raushan update config in the hub + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["Qwen2_5_VLConfig"] diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py new file mode 100644 index 000000000..9e8352ee6 --- /dev/null +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -0,0 +1,2548 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from torch.nn import CrossEntropyLoss +from typing import Any, Dict, List, Optional, Tuple, Union + +from transformers.activations import ACT2FN +from transformers.cache_utils import ( + Cache, + DynamicCache, + SlidingWindowCache, + StaticCache, +) +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb + from flash_attn import flash_attn_func +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + flash_attn_func = None + + +if is_flash_attn_2_available(): + pass +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim + ) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) + q = q.squeeze(0) + k = k.squeeze(0) + + if max_seqlen is None: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen + ).reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], + torch.finfo(q.dtype).min, + device=q.device, + dtype=q.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros( + [1, seq_length, seq_length], device=q.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention( + q, k, v, attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: Optional[int] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +Qwen2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen2_5_VLVisionBlock(config, config._attn_implementation) + for _ in range(config.depth) + ] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + window_index = window_index.to(hidden_states.device) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen_full = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen_window = ( + (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() + ) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen_full + else: + cu_seqlens_now = cu_window_seqlens + max_seqlen_now = max_seqlen_window + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, + hidden_states, + cu_seqlens_now, + None, + position_embeddings, + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + max_seqlen=max_seqlen_now, + position_embeddings=position_embeddings, + ) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len + and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for thw grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[ + :, :, None, : + ].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat( + [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + sin = torch.cat( + [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 + ).unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=True + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where( + torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): + """ + Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_rate, + softmax_scale=None, + causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if ( + config.use_sliding_window + and config._attn_implementation != "flash_attention_2" + ): + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx + ) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + def __init__(self, config: Qwen2_5_VLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen2_5_VLDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = ( + attention_mask[:, -1].sum().item() != input_tensor.size()[0] + ) + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype + ) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange( + target_length, device=device + ) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if ( + not isinstance(past_key_values, SlidingWindowCache) + or sequence_length > target_length + ): + sliding_attend_mask = torch.arange( + target_length, device=device + ) <= (cache_position.reshape(-1, 1) - config.sliding_window) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + return causal_mask + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +QWEN2_5_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( + config.vision_config + ) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = ( + expanded_range + * second_per_grid_t + * self.config.vision_config.tokens_per_second + ) + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and ( + attention_mask is None or attention_mask.ndim == 2 + ): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1 + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = ( + self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + } + ) + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat( + [sample.repeat(*repeat_args) for sample in samples], dim=0 + ) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples( + tensor, lengths=lengths, repeat_times=expand_size + ) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave( + expand_size, dim=0 + ) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation( + model_kwargs["encoder_outputs"] + ) + + return input_ids, model_kwargs + +@dataclass +class Qwen2_5_VLACausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + flow_loss: Optional[torch.FloatTensor] = None + cross_entropy_loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + channel_loss_dict: Optional[dict[torch.FloatTensor]] = None + channel_loss_count_dict: Optional[dict[torch.FloatTensor]] = None + + +class BlockSparseMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.hidden_size = config["hidden_size"] + self.intermediate_size = config["intermediate_size"] + self.hidden_act = config["hidden_act"] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[self.hidden_act] + + def forward(self, hidden_state): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + + +class SparseMoeBlock(nn.Module): + def __init__(self, config, num_experts: int): + super().__init__() + self.num_experts = num_experts + self.experts = nn.ModuleList([BlockSparseMLP(config.experts[i]) for i in range(num_experts)]) + + if not hasattr(config, 'dim_inputs') or not config.dim_inputs: + raise ValueError("Config must contain valid dim_inputs") + + self.dim_inputs = config.dim_inputs + + def forward(self, hidden_states: torch.Tensor, experts_indices: torch.Tensor) -> torch.Tensor: + """ + Route different hidden_states to corresponding experts for processing. + + Args: + hidden_states (torch.Tensor): Tensor of shape (batch_size, seq_length, hidden_dim). + experts_indices (torch.Tensor): Tensor of shape (batch_size, seq_length), + indicating the expert index assigned to each token. + + Returns: + output (torch.Tensor): Tensor of shape (batch_size, seq_length, hidden_dim). + """ + batch_size, seq_length, hidden_dim = hidden_states.size() + output = torch.zeros_like(hidden_states) + + for expert_idx, expert in enumerate(self.experts): + mask = (experts_indices == expert_idx) + if mask.sum() == 0: + continue + dim_input = self.dim_inputs[expert_idx] + + selected_hidden = hidden_states[mask] + processed_hidden = expert(selected_hidden[:, :dim_input]) + + batch_indices, seq_indices = torch.where(mask) + output[batch_indices, seq_indices, :dim_input] = processed_hidden + + return output + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int, num_experts: int): + super().__init__() + self.hidden_size = config.hidden_size + + if ( + config.use_sliding_window + and config._attn_implementation != "flash_attention_2" + ): + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx + ) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + if config.mlp_moe: + self.moe = SparseMoeBlock(config, num_experts=num_experts) + self.mlp = None + else: + self.mlp = Qwen2_5_VLMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + token_types=None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = hidden_states.to(self.input_layernorm.weight.dtype) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states.to(self.self_attn.q_proj.weight.dtype) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = hidden_states.to(self.post_attention_layernorm.weight.dtype) + hidden_states = self.post_attention_layernorm(hidden_states) + if self.mlp is None: # using moe mlp + hidden_states = hidden_states.to(self.moe.experts[0].down_proj.weight.dtype) + hidden_states = self.moe( + hidden_states, token_types + ) + else: + hidden_states = hidden_states.to(self.mlp.down_proj.weight.dtype) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs + +__all__ = [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + "Qwen2_5_VLDecoderLayer_with_MoE", +] \ No newline at end of file diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py new file mode 100644 index 000000000..3a908317d --- /dev/null +++ b/src/lerobot/policies/wall_x/utils.py @@ -0,0 +1,653 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X Utility Functions. + +Contains data processing utilities, text formatting functions, and helper classes +for the Wall-X cross-embodiment robotic control model. +""" + +import json +import random +import re +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import BatchFeature + +from lerobot.policies.wall_x.constant import ( + CAMERA_NAME_MAPPING, + FREQUENCY_MAPPING, + KEY_MAPPINGS, + MULTIMODAL_DATASET_NAMES, +) + + +@dataclass +class X2RDataProcessingConfig: + """Configuration class for X2R data processing pipeline. + + This class contains all the necessary parameters for processing robotic data + including camera mappings, tactile sensor configurations, action predictions, + and various processing options. + """ + + # Action prediction configuration + predict_action_keys: List[str] = field(default_factory=list) + obs_action_keys: List[str] = field(default_factory=list) + + # Image resolution settings for different views + resolution: Dict[str, int] = field( + default_factory=lambda: { + "face_view": -1, + "left_wrist_view": 128, + "right_wrist_view": 128, + } + ) + + # Dataset splitting + train_test_split: float = 0.9 + split_seed: int = 42 + + # Instruction handling + priority_order: Optional[Dict[str, float]] = None + + # Vision model parameters + model_type: str = "qwen2_5" + max_pixels: int = 16384 * 28 * 28 + min_pixels: int = 4 * 28 * 28 + image_factor: int = 28 + + generate_subtask_ratio: float = 0.0 + + def __post_init__(self): + """Post-initialization validation and setup.""" + # Validate train/test split + if not 0 < self.train_test_split < 1: + raise ValueError( + f"train_test_split must be between 0 and 1, got {self.train_test_split}" + ) + + def as_dict(self) -> Dict: + """Convert configuration to dictionary format. + + Returns: + Dict: Configuration as dictionary + """ + return self.__dict__ + + def update(self, **kwargs) -> "X2RDataProcessingConfig": + """Update configuration parameters. + + Args: + **kwargs: Key-value pairs to update + + Returns: + X2RDataProcessingConfig: Updated configuration instance + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise ValueError(f"Unknown configuration parameter: {key}") + return self + +def preprocesser_call( + processor, + images: Optional[Union[List, Any]] = None, + text: Optional[Union[str, List[str]]] = None, + videos: Optional[Union[List, Any]] = None, + padding: Union[bool, str] = False, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, + return_tensors: str = "pt", +) -> BatchFeature: + """Unified preprocessing function for Wall-X model handling text, image and video inputs. + + Processes inputs into format suitable for multimodal transformer models, including: + - Text tokenization and special token handling + - Image/video processing through image processor + - Attention mask and label generation + - Padding and truncation handling + + Args: + processor: Multimodal processor containing tokenizer and image processor + images: Input images (PIL, numpy arrays, or torch tensors) + text: Text or list of texts to tokenize + videos: Input videos (numpy arrays or torch tensors) + padding: Whether to pad sequences to same length + truncation: Whether to truncate sequences longer than max_length + max_length: Maximum length for truncation/padding + return_tensors: Format for returned tensors ('pt', 'np', etc.) + + Returns: + BatchFeature containing processed inputs with keys: + - input_ids: Tokenized text + - attention_mask: Attention mask for text + - pixel_values: Processed image pixels + - pixel_values_videos: Processed video frames + - image_grid_thw: Image grid dimensions for LLM + - video_grid_thw: Video grid dimensions for LLM + - labels: Training labels with masking + """ + # Process image inputs + if images is not None and len(images) > 0: + image_inputs = processor.image_processor( + images=images, videos=None, return_tensors=return_tensors + ) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + # Process video inputs + if videos is not None: + videos_inputs = processor.image_processor( + images=None, videos=videos, return_tensors=return_tensors + ) + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + # Ensure text input is in list format + if not isinstance(text, list): + text = [text] + + # Process image placeholder tokens in text + if image_grid_thw is not None: + merge_length = processor.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while "<|image_pad|>" in text[i]: + # Add bounds checking to avoid index overflow + if index >= len(image_grid_thw): + print( + f"Warning: Number of image placeholders ({index + 1}) " + f"exceeds actual images ({len(image_grid_thw)}), " + f"skipping remaining placeholder processing" + ) + break + # Replace image placeholder with actual token count + token_count = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace( + "<|image_pad|>", "<|placeholder|>" * token_count, 1 + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") + + # Process video placeholder tokens in text + if video_grid_thw is not None: + merge_length = processor.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while "<|video_pad|>" in text[i]: + # Replace video placeholder with actual token count + token_count = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace( + "<|video_pad|>", "<|placeholder|>" * token_count, 1 + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>") + + # Tokenize complete input text + text_inputs = processor.tokenizer( + text, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + + # Get pad token ID for label generation + pad_token_id = processor.tokenizer.pad_token_id + if pad_token_id is None: + pad_token_id = processor.tokenizer.eos_token_id + + # Generate labels for multi-turn dialogue, keeping only assistant response loss + labels = torch.full_like(text_inputs.input_ids, -100) + assistant_marker = "<|im_start|>assistant\n" + im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") + assistant_tokens = processor.tokenizer( + "<|im_start|>assistant\n", add_special_tokens=False + ).input_ids + + for i in range(len(text)): + assistant_regions = [] + parts = text[i].split(assistant_marker) + + # Process each part to determine which tokens belong to assistant responses + # Count left padding tokens + num_left_pads = 0 + for token_id in text_inputs.input_ids[i]: + if token_id == pad_token_id: + num_left_pads += 1 + else: + break + current_pos = num_left_pads + + for j, part in enumerate(parts): + part_tokens = processor.tokenizer(part, add_special_tokens=False).input_ids + if j == 0: + # First part is system prompt or user question, all labels are -100 + current_pos += len(part_tokens) + continue + + # From second part onwards, each part starts with assistant response + for k in range(current_pos + 1, len(text_inputs.input_ids[i])): + if text_inputs.input_ids[i][k] == im_end_token_id: + assistant_regions.append( + (current_pos + len(assistant_tokens), k + 2) + ) + break + current_pos += len(part_tokens) + 3 + + # Set labels for assistant response regions + for start, end in assistant_regions: + labels[i][start:end] = text_inputs.input_ids[i][start:end] + + # Mask special action tokens in labels + action_token_id = processor.tokenizer.encode("<|action|>")[0] + propri_token_id = processor.tokenizer.encode("<|propri|>")[0] + labels[labels == action_token_id] = -100 + labels[labels == propri_token_id] = -100 + labels[labels == processor.tokenizer.pad_token_id] = -100 + + # Set labels to None if all are invalid to skip cross entropy loss + if (labels != -100).any().item(): + text_inputs["labels"] = labels + else: + text_inputs["labels"] = None + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + +def process_grounding_points( + text: str, + orig_height: int, + orig_width: int, + resized_height: int, + resized_width: int, + model_type: str, +) -> str: + """Process grounding point coordinates in text based on image resizing. + + Adjusts coordinate values in tags to match resized image dimensions + for different model types (qwen2, qwen2_5). + + Args: + text: Input text containing tags with coordinates + orig_height: Original image height + orig_width: Original image width + resized_height: Resized image height + resized_width: Resized image width + model_type: Model type for coordinate processing ('qwen2' or 'qwen2_5') + + Returns: + Text with adjusted coordinate values + """ + # Regex pattern to match tags and their contents + point_pattern = re.compile(r"(.*?)") + + def process_match(match): + """Process a single point match and adjust coordinates.""" + coords_str = match.group(1) + try: + # Extract coordinates from string + coords = list(map(int, re.findall(r"\d+", coords_str))) + + # Calculate resize scale factors + scale_w = resized_width / orig_width + scale_h = resized_height / orig_height + + if len(coords) == 2: + x, y = coords + if model_type == "qwen2_5": + # Qwen2.5 uses pixel coordinates + new_x = max(0, min(round(x * scale_w), resized_width - 1)) + new_y = max(0, min(round(y * scale_h), resized_height - 1)) + elif model_type == "qwen2": + # Qwen2 normalizes to [0, 1000) range + new_x = max(0, min(999.999, (x / orig_width) * 1000)) + new_y = max(0, min(999.999, (y / orig_height) * 1000)) + else: + raise ValueError(f"Unsupported model type: {model_type}") + coords = [new_x, new_y] + + elif len(coords) == 4: + x1, y1, x2, y2 = coords + if model_type == "qwen2_5": + new_x1 = max(0, min(round(x1 * scale_w), resized_width - 1)) + new_y1 = max(0, min(round(y1 * scale_h), resized_height - 1)) + new_x2 = max(0, min(round(x2 * scale_w), resized_width - 1)) + new_y2 = max(0, min(round(y2 * scale_h), resized_height - 1)) + elif model_type == "qwen2": + new_x1 = max(0, min(999.999, (x1 / orig_width) * 1000)) + new_y1 = max(0, min(999.999, (y1 / orig_height) * 1000)) + new_x2 = max(0, min(999.999, (x2 / orig_width) * 1000)) + new_y2 = max(0, min(999.999, (y2 / orig_height) * 1000)) + else: + raise ValueError(f"Unsupported model type: {model_type}") + coords = [new_x1, new_y1, new_x2, new_y2] + + # Return processed point tag + return f'[{", ".join(map(str, coords))}]' + + except (ValueError, TypeError): + # Return original content if processing fails + return match.group(0) + + # Replace all matching point tags + processed_text = point_pattern.sub(process_match, text) + return processed_text + + +def get_frame_instruction( + instruction_info: Dict[str, Any], + frame_idx: Optional[int] = None, + truncate_keys: Optional[List[str]] = None, +) -> Tuple[Dict[str, Any], Optional[int]]: + """Extract frame-specific instruction from instruction dictionary. + + Args: + instruction_info: Dictionary containing instruction components + frame_idx: Current frame index + truncate_keys: Keys that trigger truncation when found + + Returns: + Tuple of (frame_instruction_dict, split_end_frame) + """ + if truncate_keys is None: + truncate_keys = [ + "subtask_generation", + "distribute", + "subtask_generation_zh", + "distribute_zh", + ] + + instruction_for_frame = {} + split_end = None + + for key, value in instruction_info.items(): + if isinstance(value, dict): + # Handle frame-range specific instructions + for frame_range, frame_instruction in value.items(): + start_frame, end_frame = map(int, frame_range.split(" ")) + if start_frame <= frame_idx < end_frame or (start_frame == frame_idx): + instruction_for_frame[key] = frame_instruction + if ( + truncate_keys is not None + and split_end is None + and key in truncate_keys + ): + split_end = end_frame + 1 + break + else: + instruction_for_frame[key] = value + + return instruction_for_frame, split_end + + +def get_task_instruction( + frame_instruction_info: Dict[str, Any], priority_order: Optional[OrderedDict] = None +) -> str: + """Construct task instruction from available instruction fields using priority sampling. + + Args: + frame_instruction_info: Dictionary containing instruction fields + priority_order: OrderedDict specifying sampling probability for each field + + Returns: + Combined instruction string with priority components + """ + # Default priority settings + default_priority_order = OrderedDict( + { + "subtask_generation": 0.25, + "subtask_generation_zh": 0.25, + "distribute": 0.25, + "distribute_zh": 0.25, + } + ) + + if priority_order is not None: + priority_order = OrderedDict(priority_order) + else: + priority_order = default_priority_order + + got_instruction = False + task_instruction = "" + + # Sample instruction components based on priority probabilities + for key, prob in priority_order.items(): + if key in frame_instruction_info and frame_instruction_info[key] != "": + if got_instruction: + if random.random() >= prob: + continue + + task_instruction += f"\n{frame_instruction_info[key]}" + got_instruction = True + break + + # Fall back to base instruction if no priority components found + if not got_instruction: + task_instruction = frame_instruction_info.get("instruction", "") + + return task_instruction + + +def get_wallx_normal_text( + instruction_info: Dict[str, Any], + action_chunk_size: int, + frame_idx: int, + priority_order: Optional[OrderedDict] = None, + cam_mapping: Optional[Dict[str, str]] = None, + generate_subtask_ratio: float = 0.0, +) -> Tuple[str, bool]: + """Construct complete multimodal prompt text for Wall-X model. + + Formats input using special tokens including: + - System message + - User observations (with image placeholders) + - Task instructions + - Proprioception prompts + - Assistant responses (with action tokens) + + Args: + instruction_info: Dictionary containing instruction components + action_chunk_size: Number of action tokens to generate + frame_idx: Current frame index + priority_order: Priority order for instruction sampling + cam_mapping: Camera name mapping dictionary + generate_subtask_ratio: Probability of generating subtask instead of actions + + Returns: + Tuple of (formatted_prompt_text, is_subtask_generation) + """ + # Special tokens for formatting + role_start_symbol = "<|im_start|>" + role_end_symbol = "<|im_end|>" + vision_start_symbol = "<|vision_start|>" + vision_end_symbol = "<|vision_end|>" + image_pad_symbol = "<|image_pad|>" + propri_symbol = "<|propri|>" + action_symbol = "<|action|>" + action_fast_symbol = "<|action_fast|>" + + # System prologue + prologue = ( + f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n" + ) + + # User request with observation + user_request = f"{role_start_symbol}user\nObservation:" + if cam_mapping: + for _, cam_name in cam_mapping.items(): + view_name = CAMERA_NAME_MAPPING.get(cam_name, cam_name) + user_request += f" {view_name}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}" + user_request += "\nInstruction:" + + # Get frame-specific instruction + frame_instruction_info, _ = get_frame_instruction( + instruction_info, frame_idx=frame_idx + ) + + generate_subtask = False + priority_keys = ["subtask_generation", "distribute"] + + # Decide whether to generate subtask or actions + if ( + bool(set(frame_instruction_info.keys()) & set(priority_keys)) + and random.random() < generate_subtask_ratio + ): + # Generate subtask (equivalent to VQA task) + instruction = frame_instruction_info.get("instruction", "") + text_prompt = "\nPredict the next action in language.\n" + user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" + + # Find output instruction from priority keys + for key in priority_keys: + if key in frame_instruction_info: + output_instruction = frame_instruction_info[key] + break + + assistant_output = ( + f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}" + ) + generate_subtask = True + else: + # Generate actions + instruction = get_task_instruction( + frame_instruction_info, priority_order=priority_order + ) + text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n" + user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" + assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}" + + complete_text = prologue + user_message + assistant_output + return complete_text, generate_subtask + + +def get_action_tokens( + normalized_actions: Union[torch.Tensor, List], action_tokenizer +) -> List[List[str]]: + """Convert normalized actions to action token strings. + + Args: + normalized_actions: Normalized action arrays/tensors + action_tokenizer: Tokenizer for converting actions to tokens + + Returns: + List of action token string lists for each sample + """ + if isinstance(normalized_actions, torch.Tensor): + normalized_actions = normalized_actions.cpu().numpy() + + all_action_tokens = [] + for i in range(len(normalized_actions)): + if isinstance(normalized_actions[i], torch.Tensor): + normalized_actions[i] = normalized_actions[i].cpu().numpy() + + token_id = action_tokenizer(normalized_actions[i]) + action_tokens = [f"<|action_token_{j}|>" for j in token_id[0]] + all_action_tokens.append(action_tokens) + + return all_action_tokens + + +def pad_action_token_strs( + actions_token_lists: List[List[str]], pad_token: str = "<|endoftext|>" +) -> List[str]: + """Pad action token lists to same length and join as strings. + + Args: + actions_token_lists: List of action token lists for each sample + pad_token: Token used for padding + + Returns: + List of padded action token strings + """ + max_len = max(len(tokens) for tokens in actions_token_lists) + padded_action_strs = [] + + for tokens in actions_token_lists: + padded_tokens = ( + tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens)) + ) + padded_action_strs.append("".join(padded_tokens)) + + return padded_action_strs + + +def replace_action_token( + text: List[str], + norm_action: Optional[torch.Tensor], + action_tokenizer, + dataset_names: List[str], + dof_masks: Optional[torch.Tensor] = None, +) -> List[str]: + """Replace action placeholders in text with actual action tokens. + + Args: + text: List of text strings with action placeholders + norm_action: Normalized action tensors + action_tokenizer: Tokenizer for converting actions to tokens + dataset_names: Names of datasets for each sample + dof_masks: Masks for degrees of freedom + + Returns: + List of text strings with action tokens replaced + """ + # Filter out multimodal dataset names + dataset_names = [ + name for name in dataset_names if name not in MULTIMODAL_DATASET_NAMES + ] + + # Get required action chunk sizes + required_chunk_sizes = [32 for name in dataset_names] + + if action_tokenizer is not None and norm_action is not None: + # Extract actions based on chunk sizes and DOF masks + norm_action = [ + action[: required_chunk_sizes[i], dof_masks[i, 0].bool()] + for i, action in enumerate(norm_action) + ] + + # Convert to action tokens and pad + actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer) + actions_fast_token_strs = pad_action_token_strs(actions_fast_tokens) + + # Replace action placeholders with actual tokens + actions_fast_token_idx = 0 + for i in range(len(text)): + if "<|action_fast|>" in text[i]: + text[i] = text[i].replace( + "<|action_fast|><|im_end|>\n", + actions_fast_token_strs[actions_fast_token_idx], + ) + actions_fast_token_idx += 1 + + # Remove remaining action placeholders + text = [t.replace("<|action|>", "") for t in text] + else: + # Remove action placeholders when no tokenizer available + text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text] + + return text +