diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 0dba5e1db..ad222b04f 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -85,7 +85,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install lerobot with all extras - run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional + run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional - name: Run pytest (all extras) run: uv run pytest tests -vv --maxfail=10 diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index b3be9ccdf..95562d0dd 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -78,7 +78,7 @@ jobs: echo "Dependencies unbound:" && cat pyproject.toml - name: Install lerobot with all extras - run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional + run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional - name: Run pytest (all extras) run: uv run pytest tests -vv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 896a0c10b..bfa3340d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -87,7 +87,7 @@ repos: # TODO(Steven): Uncomment when ready to use ##### Static Analysis & Typing ##### - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.18.2 + rev: v1.19.1 hooks: - id: mypy args: [--config-file=pyproject.toml] diff --git a/README.md b/README.md index f4c2a8406..02652d1c9 100644 --- a/README.md +++ b/README.md @@ -99,11 +99,11 @@ lerobot-train \ --dataset.repo_id=lerobot/aloha_mobile_cabinet ``` -| Category | Models | -| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [TDMPC](./docs/source/policy_tdmpc_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | -| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx) & QC-FQL (coming soon) | -| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | +| Category | Models | +| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | +| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | +| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub @@ -129,7 +129,7 @@ Learn how to implement your own simulation environment or benchmark and distribu - **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API. - **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. -- **[Robotics Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. +- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. ## Citation diff --git a/pyproject.toml b/pyproject.toml index 574ab97df..71efc2acc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,13 @@ intelrealsense = [ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"] # Policies +wallx = [ + "transformers==4.49.0", + "peft==0.17.1", + "scipy==1.15.3", + "torchdiffeq==0.2.5", + "qwen_vl_utils==0.0.11" +] pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] groot = [ @@ -142,7 +149,7 @@ async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] peft = ["lerobot[transformers-dep]", "peft>=0.18.0"] # Development -dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] +dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] @@ -161,6 +168,7 @@ all = [ "lerobot[reachy2]", "lerobot[kinematics]", "lerobot[intelrealsense]", + # "lerobot[wallx]", "lerobot[pi]", "lerobot[smolvla]", # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn @@ -231,6 +239,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401", "F403"] +"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original [tool.ruff.lint.isort] combine-as-imports = true @@ -322,9 +331,9 @@ disallow_untyped_defs = true disallow_incomplete_defs = true check_untyped_defs = true -# [[tool.mypy.overrides]] -# module = "lerobot.optim.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.optim.*" +ignore_errors = false [[tool.mypy.overrides]] module = "lerobot.model.*" @@ -374,3 +383,40 @@ ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.scripts.*" # ignore_errors = false + +[tool.uv] +# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0 +conflicts = [ + [ + { extra = "wallx" }, + { extra = "transformers-dep" }, + ], + [ + { extra = "wallx" }, + { extra = "pi" }, + ], + [ + { extra = "wallx" }, + { extra = "smolvla" }, + ], + [ + { extra = "wallx" }, + { extra = "groot" }, + ], + [ + { extra = "wallx" }, + { extra = "xvla" }, + ], + [ + { extra = "wallx" }, + { extra = "hilserl" }, + ], + [ + { extra = "wallx" }, + { extra = "libero" }, + ], + [ + { extra = "wallx" }, + { extra = "all" }, + ], +] diff --git a/src/lerobot/optim/factory.py b/src/lerobot/optim/factory.py index bab95d0ce..699289993 100644 --- a/src/lerobot/optim/factory.py +++ b/src/lerobot/optim/factory.py @@ -35,6 +35,8 @@ def make_optimizer_and_scheduler( tuple[Optimizer, LRScheduler | None]: The couple (Optimizer, Scheduler). Scheduler can be `None`. """ params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters() + if cfg.optimizer is None: + raise ValueError("Optimizer config is required but not provided in TrainPipelineConfig") optimizer = cfg.optimizer.build(params) lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None return optimizer, lr_scheduler diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py index 5120f828c..2b75353d9 100644 --- a/src/lerobot/optim/optimizers.py +++ b/src/lerobot/optim/optimizers.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +from collections.abc import Iterable from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any @@ -29,6 +30,17 @@ from lerobot.utils.constants import ( ) from lerobot.utils.io_utils import deserialize_json_into_object +# Type alias for parameters accepted by optimizer build() methods. +# This matches PyTorch's optimizer signature while also supporting: +# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA) +# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC) +OptimizerParams = ( + Iterable[torch.nn.Parameter] # From model.parameters() + | Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides + | dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR + | dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups +) + @dataclass class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): @@ -45,13 +57,24 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): return "adam" @abc.abstractmethod - def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: """ Build the optimizer. It can be a single optimizer or a dictionary of optimizers. + NOTE: Multiple optimizers are useful when you have different models to optimize. For example, you can have one optimizer for the policy and another one for the value function in reinforcement learning settings. + Args: + params: Parameters to optimize. Accepts multiple formats depending on the optimizer: + - Iterable[Parameter]: From model.parameters() - standard PyTorch usage + - Iterable[dict]: List of param groups with 'params' key and optional + 'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies) + - dict[str, Parameter]: From dict(model.named_parameters()) for optimizers + that apply differential learning rates by parameter name (e.g., XVLA) + - dict[str, Iterable]: For multi-optimizer configs where each key maps to + a separate optimizer's parameters (e.g., SAC with actor/critic/temperature) + Returns: The optimizer or a dictionary of optimizers. """ @@ -67,7 +90,7 @@ class AdamConfig(OptimizerConfig): weight_decay: float = 0.0 grad_clip_norm: float = 10.0 - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.Adam(params, **kwargs) @@ -82,7 +105,7 @@ class AdamWConfig(OptimizerConfig): weight_decay: float = 1e-2 grad_clip_norm: float = 10.0 - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.AdamW(params, **kwargs) @@ -98,7 +121,7 @@ class SGDConfig(OptimizerConfig): weight_decay: float = 0.0 grad_clip_norm: float = 10.0 - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.SGD(params, **kwargs) @@ -139,21 +162,19 @@ class XVLAAdamWConfig(OptimizerConfig): soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR) soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01) - def build(self, params: dict) -> torch.optim.Optimizer: + def build(self, params: OptimizerParams) -> torch.optim.Optimizer: """ Build AdamW optimizer with differential learning rates. - Expects `named_parameters()` as input (dict of name -> param). - Applies: - - lr * 0.1 for all VLM-related parameters - - lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup) - - full lr for all other parameters - Args: - params: Dictionary of parameter names to parameters (from named_parameters()) + params: Must be a dict[str, Parameter] from dict(model.named_parameters()) + or equivalent. Returns: AdamW optimizer with parameter groups for VLM, soft-prompts, and other components + + Raises: + AssertionError: If params is not a dict (e.g., from model.parameters()) """ assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs." @@ -174,7 +195,7 @@ class XVLAAdamWConfig(OptimizerConfig): # Start at warmup scale, scheduler will warm up to soft_prompt_lr soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale - param_groups = [ + param_groups: list[dict[str, Any]] = [ { "params": vlm_group, "lr": self.lr * 0.1, @@ -224,19 +245,25 @@ class MultiAdamConfig(OptimizerConfig): grad_clip_norm: float = 10.0 optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) - def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: + def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]: """Build multiple Adam optimizers. Args: - params_dict: Dictionary mapping parameter group names to lists of parameters - The keys should match the keys in optimizer_groups + params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names + to iterables of parameters. The keys should match the keys in optimizer_groups. + Typically from policies that need separate optimizers (e.g., SAC with + actor/critic/temperature). Returns: Dictionary mapping parameter group names to their optimizers + + Raises: + AssertionError: If params is not a dict """ + assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs." optimizers = {} - for name, params in params_dict.items(): + for name, group_params in params.items(): # Get group-specific hyperparameters or use defaults group_config = self.optimizer_groups.get(name, {}) @@ -248,7 +275,7 @@ class MultiAdamConfig(OptimizerConfig): "weight_decay": group_config.get("weight_decay", self.weight_decay), } - optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) + optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs) return optimizers diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index b5d54b396..4af7f0802 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -30,7 +30,7 @@ from lerobot.utils.io_utils import deserialize_json_into_object @dataclass class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC): - num_warmup_steps: int + num_warmup_steps: int | None @property def type(self) -> str: diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index ceefb0d56..99275e787 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig +from .wall_x.configuration_wall_x import WallXConfig as WallXConfig from .xvla.configuration_xvla import XVLAConfig as XVLAConfig __all__ = [ @@ -34,4 +35,5 @@ __all__ = [ "VQBeTConfig", "GrootConfig", "XVLAConfig", + "WallXConfig", ] diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index a9b1280bf..fb43eacdb 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -42,6 +42,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import validate_visual_features_consistency from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.processor.converters import ( @@ -62,7 +63,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla". + "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x". Returns: The policy class corresponding to the given name. @@ -118,6 +119,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.xvla.modeling_xvla import XVLAPolicy return XVLAPolicy + elif name == "wall_x": + from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy + + return WallXPolicy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -135,7 +140,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", - "reward_classifier". + "reward_classifier", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -166,6 +171,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return GrootConfig(**kwargs) elif policy_type == "xvla": return XVLAConfig(**kwargs) + elif policy_type == "wall_x": + return WallXConfig(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -357,6 +364,7 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, XVLAConfig): from lerobot.policies.xvla.processor_xvla import ( make_xvla_pre_post_processors, @@ -367,6 +375,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, WallXConfig): + from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors + + processors = make_wall_x_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + else: try: processors = _make_processors_from_policy_config( diff --git a/src/lerobot/policies/sarm/README.md b/src/lerobot/policies/sarm/README.md new file mode 100644 index 000000000..e0e49834b --- /dev/null +++ b/src/lerobot/policies/sarm/README.md @@ -0,0 +1,14 @@ +## Paper + +https://arxiv.org/abs/2509.25358 + +## Citation + +```bibtex +@article{chen2025sarm, + title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation}, + author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp}, + journal={arXiv preprint arXiv:2509.25358}, + year={2025} +} +``` diff --git a/src/lerobot/policies/wall_x/README.md b/src/lerobot/policies/wall_x/README.md new file mode 100644 index 000000000..78548bd8d --- /dev/null +++ b/src/lerobot/policies/wall_x/README.md @@ -0,0 +1,35 @@ +# WALL-OSS + +This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. + +--- + +## Model Overview + +| Feature | Description | +| ------------------ | ----------------------------------------------------- | --- | +| Base Model | Qwen2.5-VL (Vision-Language Model) | +| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | +| Architecture | Mixture of Experts (MoE) with action-specific routing | | +| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | + +--- + +## Citation + +If you use this work, please cite: + +```bibtex +@article{zhai2025igniting, + title = {Igniting VLMs Toward the Embodied Space}, + author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach}, + journal = {arXiv preprint arXiv:2509.11766}, + year = {2025} +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**. diff --git a/src/lerobot/policies/wall_x/__init__.py b/src/lerobot/policies/wall_x/__init__.py new file mode 100644 index 000000000..d80c27bda --- /dev/null +++ b/src/lerobot/policies/wall_x/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_wall_x import WallXConfig + +__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"] diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py new file mode 100644 index 000000000..0d10a8f98 --- /dev/null +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -0,0 +1,165 @@ +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("wall_x") +@dataclass +class WallXConfig(PreTrainedConfig): + """ + Configuration class for Wall-X policy. + + Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching. + It supports cross-embodiment robotic control through unified action representations. + + This config supports multi-modal learning with vision, language, and action data. + """ + + # ==================== Input / Output Structure ==================== + n_obs_steps: int = 1 + chunk_size: int = 32 # action_horizon in wall-x + n_action_steps: int = 32 + + # Action dimension - wall-x uses 20 + max_action_dim: int = 20 + max_state_dim: int = 20 # For proprioception + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # ==================== Action Prediction ==================== + # Pretrained model paths + pretrained_name_or_path: str = "x-square-robot/wall-oss-flow" + + # Tokenizer settings + action_tokenizer_path: str | None = "physical-intelligence/fast" + + # Action prediction mode: "diffusion" or "fast" + prediction_mode: str = "diffusion" + + # Attention Implementation, options: "eager", "flash_attention_2", "sdpa" + # NOTE: flash-attn==2.7.4.post1 is required for flash_attention_2 implementation + attn_implementation: str = "eager" + + # ==================== Optimizer Presets ==================== + optimizer_lr: float = 2e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + scheduler_warmup_steps: int = 1000 + scheduler_decay_steps: int = 100000 + scheduler_decay_lr: float = 1e-6 + + def __post_init__(self): + super().__post_init__() + + # Input validation + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + + if self.prediction_mode not in ["diffusion", "fast"]: + raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}") + + # Assign use_fast_tokenizer based on prediction_mode + if self.prediction_mode == "fast": + self.use_fast_tokenizer = True + elif self.prediction_mode == "diffusion": + self.use_fast_tokenizer = False + self.action_tokenizer_path = None # disable action tokenizer for diffusion mode + else: + raise ValueError(f"prediction_mode must be 'diffusion' or 'fast', got {self.prediction_mode}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL] + if not image_features: + raise ValueError( + "Wall-X policy requires at least one visual input feature. " + "No features of type FeatureType.VISUAL found in input_features." + ) + + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature + else: + state_shape = self.input_features["observation.state"].shape + state_dim = state_shape[0] if state_shape else 0 + if state_dim > self.max_state_dim: + raise ValueError( + f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. " + f"Either reduce state dimension or increase max_state_dim in config." + ) + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + else: + action_shape = self.output_features["action"].shape + action_dim = action_shape[0] if action_shape else 0 + if action_dim > self.max_action_dim: + raise ValueError( + f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. " + f"Either reduce action dimension or increase max_action_dim in config." + ) + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> list: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/wall_x/constant.py b/src/lerobot/policies/wall_x/constant.py new file mode 100644 index 000000000..43e5e7fb6 --- /dev/null +++ b/src/lerobot/policies/wall_x/constant.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X Constants and Configuration Data. +""" + +CAMERA_NAME_MAPPING = { + "face_view": "front view", + "left_wrist_view": "left wrist view", + "right_wrist_view": "right wrist view", + "move1_view": "move view", + "move2_view": "move view", + "wall_view": "wall view", + "top_view": "top view", +} + +RESOLUTION = 256 + +# Parameters for preprocessing +MAX_PIXELS = 16384 * 28 * 28 +MIN_PIXELS = 4 * 28 * 28 +IMAGE_FACTOR = 28 +PRIORITY_ORDER = None +GENERATE_SUBTASK_RATIO = 0.0 +MODEL_TYPE = "qwen2_5" + +TOKENIZER_MAX_LENGTH = 768 diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py new file mode 100644 index 000000000..c401c8d60 --- /dev/null +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -0,0 +1,2008 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X: Cross-embodiment robotic control using Qwen2.5-VL with flow matching. + +[Paper](https://github.com/x2-robot/wall-x) + +Install wall-x extra dependencies: +```bash +pip install -e ".[wall_x]" +``` + +Example of finetuning a wall-x model: +```bash +lerobot-train \ +--policy.type=wall_x \ +--dataset.repo_id=your/dataset \ +--batch_size=32 \ +--steps=100000 +``` +""" + +import math +from collections import deque +from os import PathLike +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from peft import LoraConfig, get_peft_model +from PIL import Image +from qwen_vl_utils.vision_process import smart_resize +from torch import Tensor +from torch.distributions import Beta +from torch.nn import CrossEntropyLoss +from torchdiffeq import odeint +from transformers import AutoProcessor, BatchFeature +from transformers.cache_utils import ( + StaticCache, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, +) +from transformers.utils import is_torchdynamo_compiling, logging + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import populate_queues +from lerobot.policies.wall_x.configuration_wall_x import WallXConfig +from lerobot.policies.wall_x.constant import ( + GENERATE_SUBTASK_RATIO, + IMAGE_FACTOR, + MAX_PIXELS, + MIN_PIXELS, + MODEL_TYPE, + PRIORITY_ORDER, + RESOLUTION, + TOKENIZER_MAX_LENGTH, +) +from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLACausalLMOutputWithPast, + Qwen2_5_VLMoEModel, +) +from lerobot.policies.wall_x.utils import ( + get_wallx_normal_text, + preprocesser_call, + process_grounding_points, + replace_action_token, +) +from lerobot.utils.constants import ACTION, OBS_STATE + +logger = logging.get_logger(__name__) + + +class SinusoidalPosEmb(nn.Module): + """Sinusoidal positional embedding for diffusion timesteps.""" + + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class ActionHead(nn.Module): + """ + Action prediction head with flow matching. + + Implements Beta-distributed noise scheduling and temporal embeddings + for action sequence prediction. + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.action_dim = sum(config.dof_config.values()) + self.propri_dim = sum(config.agent_pos_config.values()) + self.hidden_size = config.hidden_size + + # Beta distribution for noise scheduling + self.beta_alpha = 1.5 + self.beta_beta = 1.0 + self.s = 0.999 + + # Sinusoidal timestep embedding + self.time_embed = SinusoidalPosEmb(config.hidden_size) + + # Action embedding network + # *2 for action + DOF mask concatenation + self.w1 = nn.Linear(self.action_dim * 2, self.hidden_size, bias=False) + self.w2 = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=False) # *2 for action + time + self.w3 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.act_fn = nn.SiLU() + + # Project back to action space + self.action_proj_back = nn.Linear(self.hidden_size, self.action_dim, bias=False) + + # Proprioception projection + self.propri_proj = nn.Linear(self.propri_dim * 2, self.hidden_size, bias=False) + + def sample_time(self, batch_size, device): + """Sample timesteps using Beta distribution (always in float32 for numerical stability).""" + beta_dist = Beta( + torch.tensor(self.beta_alpha, dtype=torch.float32, device=device), + torch.tensor(self.beta_beta, dtype=torch.float32, device=device), + ) + sample = beta_dist.sample([batch_size]) + time = (1 - sample) * self.s + return time + + def forward(self, action_chunk, dof_mask=None): + """ + Process action sequences with noise injection for training. + + Args: + action_chunk: Action sequences [batch, seq_len, action_dim] + dof_mask: DOF mask [batch, seq_len, action_dim] + + Returns: + tuple: (action_embeddings, flow_target) + """ + batch_size = action_chunk.shape[0] + device = action_chunk.device + weight_dtype = self.w1.weight.dtype + + # Sample time outside of autocast (Beta distribution needs float32) + time = self.sample_time(batch_size, device) + t = time.unsqueeze(-1).unsqueeze(-1) + + # Noise and flow computation in float32 + noise = torch.randn_like(action_chunk, dtype=torch.float32) + action_chunk_f32 = action_chunk.to(torch.float32) + noisy_action = (1 - t) * noise + t * action_chunk_f32 + flow = action_chunk_f32 - noise + + # Project noisy actions + if dof_mask is not None: + noisy_action = torch.cat([noisy_action, dof_mask.to(torch.float32)], dim=-1) + + # Convert to weight dtype for linear layers + noisy_action = noisy_action.to(dtype=weight_dtype) + action_embed = self.w1(noisy_action) + + # Generate time embeddings and combine + time_embed = self.time_embed(time) + time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1) + time_embed = time_embed.to(dtype=weight_dtype) + + concat_embed = torch.cat([action_embed, time_embed], dim=-1) + concat_embed = self.w2(concat_embed) + embed = self.w3(self.act_fn(concat_embed)) + + return embed, flow + + def step(self, timestep, noisy_action, dof_mask=None): + """Single denoising step for inference.""" + weight_dtype = self.w1.weight.dtype + + if dof_mask is not None: + noisy_action = torch.cat([noisy_action, dof_mask], dim=-1) + noisy_action = noisy_action.to(dtype=weight_dtype) + + time_embed = self.time_embed(timestep) + action_embed = self.w1(noisy_action) + + time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1) + time_embed = time_embed.to(device=noisy_action.device, dtype=weight_dtype) + + concat_embed = torch.cat([action_embed, time_embed], dim=-1) + concat_embed = self.w2(concat_embed) + embed = self.w3(self.act_fn(concat_embed)) + + return embed + + def flow_loss(self, action_hidden_states, flow, dof_mask=None): + """Compute flow matching loss (all computations in float32 for stability).""" + # Ensure all inputs are float32 + action_hidden_states = action_hidden_states.to(torch.float32) + flow = flow.to(torch.float32) + + action_pred = self.action_proj_back(action_hidden_states) + loss = F.mse_loss(action_pred, flow, reduction="none") + + if dof_mask is not None: + dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1]).to(torch.float32) + loss = loss * dof_mask + + return loss + + def proprioception_proj(self, proprioception, dof_mask=None, use_history=False): + """Project proprioceptive data to hidden space.""" + # Ensure proper device and dtype alignment + proprioception = proprioception.to(device=self.propri_proj.weight.device).to( + dtype=self.propri_proj.weight.dtype + ) + + if dof_mask is not None: + # Concatenate proprioception with DOF mask + # TODO: Use variable-based dimension checking for better flexibility + if use_history: + proprioception = torch.cat([proprioception, dof_mask], dim=-1) + else: + proprioception = torch.cat([proprioception, dof_mask], dim=-1) + + proprioception = proprioception.to(device=self.propri_proj.weight.device).to( + dtype=self.propri_proj.weight.dtype + ) + return self.propri_proj(proprioception) + + +class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): + """ + Qwen2.5 Vision-Language Mixture of Experts model for action processing. + + This model extends the base Qwen2.5 VL model with action token processing capabilities + and optional LoRA fine-tuning support. + """ + + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"] + + @classmethod + def from_pretrained( + cls, + pretrained_name_or_path, + config=None, + action_tokenizer_path=None, + attn_implementation: str = "eager", + cache_dir: str | PathLike | None = None, + force_download: bool = False, + local_files_only: bool = False, + token: str | bool | None = None, + revision: str = "main", + strict: bool = False, + **kwargs: Any, + ): + """ + Load model from pretrained model path. + + Args: + pretrained_model_path (str): Model directory path containing model.safetensors file + config_path (str, optional): Configuration file path, if None will look for qwen25_config.json in pretrained_model_path + action_tokenizer_path (str, optional): Action tokenizer path, if None will load from default config + attn_implementation (str, optional): Attention implementation, if None will load from default config + **kwargs: Additional arguments + + Returns: + Qwen2_5_VLMoEForAction: Loaded model instance + """ + if config is None: + config = cls.config_class.from_pretrained( + pretrained_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + strict=strict, + **kwargs, + ) + if attn_implementation is not None: + config._attn_implementation = attn_implementation + processor = AutoProcessor.from_pretrained(pretrained_name_or_path, use_fast=True) + if action_tokenizer_path is not None: + action_tokenizer = AutoProcessor.from_pretrained(action_tokenizer_path, trust_remote_code=True) + processor.action_processor = action_tokenizer + else: + action_tokenizer = None + # Initialize model with configuration and processor + model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs) + + # Resize token embeddings to match processor tokenizer vocabulary size + model.resize_token_embeddings(len(processor.tokenizer)) + + # Try to load the model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + sd = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + state_dict = {} + # filter normalizer statistic params + del_keys = [] + for key in sd.keys(): + if "action_preprocessor.normalizer" in key: + del_keys.append(key) + for key in del_keys: + del sd[key] + state_dict.update(sd) + + model.load_state_dict(state_dict, strict=False) + + return model + + def __init__( + self, + config, + use_fast_tokenizer=False, + processor=None, + action_tokenizer=None, + action_mapper=None, + flow_loss_weight=1.0, + ): + """ + Initialize the Qwen2.5 VLMoE model for action processing. + + Args: + config: Model configuration + use_fast_tokenizer (bool): Whether to use fast tokenizer + processor: Text and image processor + action_tokenizer: Action-specific tokenizer + action_mapper: Action mapping utility + flow_loss_weight (float): Weight for flow loss computation + """ + super().__init__(config) + + # Initialize vision transformer and language model components + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLMoEModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize loss function without reduction for channel-wise loss computation + self.loss_fct = CrossEntropyLoss(reduction="none") + self.flow_loss_weight = flow_loss_weight + self.use_fast_tokenizer = use_fast_tokenizer + self.processor = processor + self.action_tokenizer = action_tokenizer + + # Define action token IDs + self.define_action_token_id() + + # Cache for rope deltas + self.rope_deltas = None + + # Initialize action preprocessor + self.action_preprocessor = ActionHead(config) + + # Apply LoRA if specified in configuration + if hasattr(config, "use_lora") and config.use_lora: + self.add_lora( + r=config.lora_r, + lora_alpha=config.lora_alpha, + target_modules=config.lora_target_modules, + lora_dropout=config.lora_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + def to_bfloat16_for_selected_params(self): + self.to(dtype=torch.bfloat16) + + params_to_keep_float32 = [] + + for name, param in self.named_parameters(): + if "input_layernorm" in name or "post_attention_layernorm" in name or "model.norm" in name: + params_to_keep_float32.append(name) + if "action_preprocessor" in name: + params_to_keep_float32.append(name) + + for name, param in self.named_parameters(): + if name in params_to_keep_float32: + param.data = param.data.to(torch.float32) + + def define_action_token_id(self): + """ + Define action token IDs based on tokenizer configuration. + + Creates mappings for fast action tokens, proprioception tokens, and general action tokens. + """ + # Create list of fast action token IDs + fast_action_token_list = [] + if self.use_fast_tokenizer: + for i in range(self.processor.tokenizer.init_kwargs["action_token_vocab_size"]): + action_token_id = self.processor.tokenizer.convert_tokens_to_ids(f"<|action_token_{i}|>") + fast_action_token_list.append(action_token_id) + + # Get special action token IDs + action_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|action|>") + propri_token_id = self.processor.tokenizer.convert_tokens_to_ids("<|propri|>") + + # Store action token ID mappings + self.action_token_id_set = { + "fast_action_token_list": fast_action_token_list, + "propri_token_id": propri_token_id, + "action_token_id": action_token_id, + } + + def add_lora(self, r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.1): + """ + Add LoRA (Low-Rank Adaptation) adapters to the model. + + Args: + r (int): Rank of adaptation + lora_alpha (int): LoRA scaling parameter + target_modules (list): List of module names to apply LoRA to + lora_dropout (float): Dropout probability for LoRA layers + """ + config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + self.model = get_peft_model(self.model, config) + + # Print information about trainable parameters + self.model.print_trainable_parameters() + + def get_input_embeddings(self): + """Get input embeddings layer.""" + return self.model.embed_tokens + + def set_input_embeddings(self, value): + """Set input embeddings layer.""" + self.model.embed_tokens = value + + def get_output_embeddings(self): + """Get output embeddings layer.""" + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + """Set output embeddings layer.""" + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + """Set the decoder model.""" + self.model = decoder + + def get_decoder(self): + """Get the decoder model.""" + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate 3D RoPE (Rotary Position Embedding) indices for vision and text tokens. + + This method computes position embeddings that account for the temporal, height, and width + dimensions of vision tokens (images/videos) while maintaining standard 1D position embeddings + for text tokens. + + For vision tokens, 3D position embeddings are calculated based on: + - Temporal dimension: Time patches in videos + - Height dimension: Vertical patches in images/video frames + - Width dimension: Horizontal patches in images/video frames + + For text tokens, standard 1D position embeddings are used, continuing from the maximum + vision position ID plus 1. + + Args: + input_ids (torch.LongTensor, optional): Input token IDs of shape (batch_size, sequence_length) + image_grid_thw (torch.LongTensor, optional): Image grid dimensions (num_images, 3) for [temporal, height, width] + video_grid_thw (torch.LongTensor, optional): Video grid dimensions (num_videos, 3) for [temporal, height, width] + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid (num_videos,) + attention_mask (torch.Tensor, optional): Attention mask (batch_size, sequence_length) + + Returns: + tuple: + - position_ids (torch.LongTensor): 3D position IDs of shape (3, batch_size, sequence_length) + - mrope_position_deltas (torch.Tensor): Position deltas for mRoPE of shape (batch_size, 1) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + # Initialize 3D position IDs tensor + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + # Process each sequence in the batch + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + + # Find vision tokens and count images/videos + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + # Process each vision token (image or video) + for _ in range(image_nums + video_nums): + # Find next image or video token + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + # Determine if processing image or video token + if ed_image < ed_video: + # Process image token + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + # Process video token + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + + # Calculate grid dimensions after spatial merging + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + # Add position IDs for text tokens before vision token + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # Calculate 3D position embeddings for vision tokens + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + # Calculate temporal position IDs with time scaling + time_tensor = ( + expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + ) + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + # Calculate spatial position IDs + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + ) + + # Add 3D position IDs for vision tokens + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + # Add position IDs for remaining text tokens + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # Concatenate all position IDs for this sequence + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + # Handle case without vision tokens - use standard 1D position embeddings + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def train_step_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, # MoE token type assignments + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + action_chunk: torch.FloatTensor | None = None, # Action trajectory chunks + proprioception: torch.FloatTensor | None = None, # Joint position/orientation data + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + dof_mask: torch.FloatTensor | None = None, + agent_pos_mask: torch.FloatTensor | None = None, + **kwargs, + ) -> tuple | Qwen2_5_VLACausalLMOutputWithPast: + """ + Forward pass for training with multi-modal inputs including vision, text, and action data. + + This method handles the complete forward pass during training, processing various input modalities + including images, videos, text, proprioceptive data, and action sequences. It computes losses + for both language modeling and action prediction using flow matching. + + Args: + input_ids (torch.LongTensor, optional): Input token IDs + attention_mask (torch.Tensor, optional): Attention mask for input tokens + position_ids (torch.LongTensor, optional): Position IDs for tokens + past_key_values (List[torch.FloatTensor], optional): Cached key-value pairs for generation + inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings + moe_token_types (torch.LongTensor, optional): Token type assignments for MoE routing + labels (torch.LongTensor, optional): Target labels for loss computation + use_cache (bool, optional): Whether to use key-value caching + output_attentions (bool, optional): Whether to return attention weights + output_hidden_states (bool, optional): Whether to return hidden states + return_dict (bool, optional): Whether to return structured output + pixel_values (torch.Tensor, optional): Image pixel values + pixel_values_videos (torch.FloatTensor, optional): Video pixel values + image_grid_thw (torch.LongTensor, optional): Image grid dimensions (temporal, height, width) + video_grid_thw (torch.LongTensor, optional): Video grid dimensions (temporal, height, width) + action_chunk (torch.FloatTensor, optional): Action trajectory data chunks + proprioception (torch.FloatTensor, optional): Proprioceptive sensor data (joint positions, etc.) + rope_deltas (torch.LongTensor, optional): RoPE position deltas + cache_position (torch.LongTensor, optional): Cache position indices + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid + dof_mask (torch.FloatTensor, optional): Degrees of freedom mask for action tokens + agent_pos_mask (torch.FloatTensor, optional): Agent position mask for proprioceptive data + **kwargs: Additional keyword arguments + + Returns: + Union[Tuple, Qwen2_5_VLACausalLMOutputWithPast]: Model outputs including losses, logits, + and auxiliary information, or tuple if return_dict=False + """ + batch_size, seq_length = input_ids.shape + + # Set output configuration from model config if not specified + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Calculate RoPE position IDs if not provided + # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # Calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # Use previously calculated rope deltas to get correct position IDs + else: + delta = ( + (cache_position[0] + self.rope_deltas).to(self.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=self.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # Process input embeddings with multi-modal data + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + + # Process image embeddings + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # Process video embeddings + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + + # Validate video token and feature count match + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # Process proprioceptive data (joint positions, orientations, etc.) + if proprioception is not None: + proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) + proprioception = self.action_preprocessor.proprioception_proj( + proprioception, + agent_pos_mask, + use_history=proprioception.shape[1] > 1, + ) + mask = input_ids == self.action_token_id_set["propri_token_id"] + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + proprioception_mask = mask_expanded.to(inputs_embeds.device) + + proprioception = proprioception.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(proprioception_mask, proprioception) + elif self.training: + # Dummy forward pass to ensure gradient registration in DDP + # This handles cases where one process has proprioception data while another doesn't + # Without this, DDP would hang waiting for a gradient that will never be computed + dummy_input = torch.randn( + 2, + self.action_preprocessor.propri_dim * 2, + device=inputs_embeds.device, + ) + dummy_forward = self.action_preprocessor.proprioception_proj(dummy_input) + dummy_loss = sum(p.sum() for p in dummy_forward) + inputs_embeds = inputs_embeds + 0 * dummy_loss + + # Process action chunk data + if action_chunk is not None: + action_chunk = action_chunk.to(inputs_embeds.device).to(inputs_embeds.dtype) + dof_mask = dof_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) + noisy_action_emb, flow = self.action_preprocessor(action_chunk, dof_mask) + mask = input_ids == self.action_token_id_set["action_token_id"] + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + action_mask = mask_expanded.to(inputs_embeds.device) + + noisy_action_emb = noisy_action_emb.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(action_mask, noisy_action_emb) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # Forward pass through the main model + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + moe_token_types=moe_token_types, # Pass token types for MoE routing + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = hidden_states.to(self.lm_head.weight.dtype) + logits = self.lm_head(hidden_states) + + # Initialize loss computation variables + loss = None + cross_entropy_loss, flow_loss = None, None + channel_loss_dict = None + channel_loss_count_dict = None + + # Compute losses if labels are provided + if labels is not None: + loss = torch.tensor(0.0, device=hidden_states.device, dtype=torch.float32) + + # Compute standard cross-entropy loss for language modeling + shift_logits = logits[..., :-1, :].contiguous().to(torch.float32) + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + + # Enable model parallelism by moving labels to correct device + shift_labels = shift_labels.to(shift_logits.device) + non_ignored_mask = shift_labels != -100 + _cross_entropy_loss = self.loss_fct(shift_logits, shift_labels) + cross_entropy_loss = ( + _cross_entropy_loss[non_ignored_mask].mean() + if non_ignored_mask.any() + else torch.tensor(0.0, device=shift_logits.device, dtype=torch.float32) + ) + + # Add cross-entropy loss to total loss if valid + if not torch.isnan(cross_entropy_loss): + loss = loss + cross_entropy_loss.to(torch.float32) + else: + with torch.no_grad(): + cross_entropy_loss.detach() + + if action_chunk is not None: + action_mask = input_ids == self.action_token_id_set["action_token_id"] + if action_mask.any(): + action_hidden_states = hidden_states[action_mask].to(torch.float32) + flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32) + _flow_loss = self.action_preprocessor.flow_loss(action_hidden_states, flow, dof_mask) + if isinstance(_flow_loss, torch.Tensor): + flow_loss = _flow_loss.mean() + if loss is not None: + loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32) + else: + loss = self.flow_loss_weight * flow_loss.to(torch.float32) + _flow_loss = _flow_loss.view(dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]) + + # Return outputs based on return_dict setting + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLACausalLMOutputWithPast( + loss=loss, + cross_entropy_loss=(cross_entropy_loss.clone() if cross_entropy_loss is not None else None), + flow_loss=flow_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + channel_loss_dict=channel_loss_dict, + channel_loss_count_dict=channel_loss_count_dict, + ) + + def predict_action(self, predict_mode: str, **kwargs): + """ + Predict actions using specified prediction mode. + + Args: + predict_mode (str): Prediction mode, either "fast" or "diffusion" + **kwargs: Additional arguments passed to the predict method + + Returns: + tuple: (predicted_action, ground_truth_action) where ground_truth_action may be None + """ + assert predict_mode in ["fast", "diffusion"] + + output = self.predict(predict_mode=predict_mode, **kwargs) + + return output["predict_action"], output.get("gt_action", None) + + @torch.no_grad() + def predict( + self, + predict_mode: str, + pred_horizon: int | None = None, + action_dim: int | None = None, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + action_chunk: torch.FloatTensor | None = None, + proprioception: torch.FloatTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + num_inference_timesteps: int | None = 10, + dof_mask: torch.FloatTensor | None = None, + agent_pos_mask: torch.FloatTensor | None = None, + re_generate: bool = False, + **kwargs, + ): + """ + Multi-modal prediction method supporting text generation, fast action prediction, and diffusion-based action prediction. + + This method handles three prediction modes: + 1. "text": Pure text generation using autoregressive decoding + 2. "fast": Fast action prediction using discrete action tokens + 3. "diffusion": Continuous action prediction using diffusion/flow matching + + Args: + predict_mode (str): Prediction mode ("text", "fast", or "diffusion") + pred_horizon (int, optional): Prediction horizon for action sequences + action_dim (int, optional): Dimensionality of action space + input_ids (torch.LongTensor, optional): Input token IDs + attention_mask (torch.Tensor, optional): Attention mask for input tokens + position_ids (torch.LongTensor, optional): Position IDs for tokens + past_key_values (List[torch.FloatTensor], optional): Cached key-value pairs + inputs_embeds (torch.FloatTensor, optional): Pre-computed input embeddings + moe_token_types (torch.LongTensor, optional): Token type assignments for MoE routing + labels (torch.LongTensor, optional): Target labels for evaluation + use_cache (bool, optional): Whether to use key-value caching + output_attentions (bool, optional): Whether to return attention weights + output_hidden_states (bool, optional): Whether to return hidden states + return_dict (bool, optional): Whether to return structured output + pixel_values (torch.Tensor, optional): Image pixel values + pixel_values_videos (torch.FloatTensor, optional): Video pixel values + image_grid_thw (torch.LongTensor, optional): Image grid dimensions + video_grid_thw (torch.LongTensor, optional): Video grid dimensions + action_chunk (torch.FloatTensor, optional): Ground truth action sequences + proprioception (torch.FloatTensor, optional): Proprioceptive sensor data + rope_deltas (torch.LongTensor, optional): RoPE position deltas + cache_position (torch.LongTensor, optional): Cache position indices + second_per_grid_ts (torch.Tensor, optional): Time interval per temporal grid + num_inference_timesteps (int, optional): Number of diffusion inference steps + dof_mask (torch.FloatTensor, optional): Degrees of freedom mask + agent_pos_mask (torch.FloatTensor, optional): Agent position mask + re_generate (bool, optional): Whether to use sampling for regeneration + **kwargs: Additional keyword arguments + + Returns: + dict: Dictionary containing prediction results with keys like: + - 'predict_action': Predicted action sequences + - 'gt_action': Ground truth actions (if available) + - 'input_text': Input text (for text/fast modes) + - 'predict_output_text': Generated text (for text/fast modes) + - 'gt_output_text': Ground truth text (for text/fast modes) + """ + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + + # Text and fast modes require batch size 1 for autoregressive generation + if predict_mode in ["text", "fast"]: + assert batch_size == 1, "predict only support batch size 1 for ar generation" + + # Set output configuration from model config if not specified + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Process input embeddings with multi-modal data + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + + # Process image embeddings + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + + # Validate image token and feature count match + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # Process video embeddings + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + + # Validate video token and feature count match + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # Process proprioceptive data + if proprioception is not None: + proprioception = proprioception.to(inputs_embeds.device).to(inputs_embeds.dtype) + agent_pos_mask = agent_pos_mask.to(inputs_embeds.device).to(inputs_embeds.dtype) + proprio_embed = self.action_preprocessor.proprioception_proj( + proprioception, + agent_pos_mask, + use_history=proprioception.shape[1] > 1, + ) + proprioception_mask = input_ids == self.action_token_id_set["propri_token_id"] + proprio_embed = proprio_embed.to(torch.bfloat16) + inputs_embeds[proprioception_mask] = proprio_embed.reshape(-1, inputs_embeds.shape[-1]) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # Calculate RoPE position IDs if not provided + # Note: Cannot calculate rope deltas with 4D attention mask. TODO: Fix this limitation + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # Calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # Use previously calculated rope deltas to get correct position IDs + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # Prepare action chunk data if provided + if action_chunk is not None: + action_chunk = action_chunk.to(inputs_embeds.device).to(torch.float32) + + output = {} + + # Split input sequence for text and fast modes (not needed for diffusion) + if predict_mode == "text" or predict_mode == "fast": + # Look for generation prompt tokens: <|im_start|>assistant + generation_prompt_ids = torch.tensor( + [151644, 77091], device=input_ids.device, dtype=input_ids.dtype + ) + matches = (input_ids[0, :-1] == generation_prompt_ids[0]) & ( + input_ids[0, 1:] == generation_prompt_ids[1] + ) + + if matches.any(): + split_pos = torch.nonzero(matches, as_tuple=True)[0][0].item() + # Extract ground truth output tokens (including newline) + gt_output_ids = input_ids[:, split_pos + 3 :] + # Remove output part from input, keeping prompt + input_ids = input_ids[:, : split_pos + 3] + inputs_embeds = inputs_embeds[:, : split_pos + 3, :] + if attention_mask is not None: + attention_mask = attention_mask[:, : split_pos + 3] + if labels is not None: + labels = labels[:, split_pos + 3 :] + else: + raise ValueError( + "input_ids does not contain the generation prompt tokens <|im_start|>assistant" + ) + + # Decode input text for output + input_text = self.processor.batch_decode( + input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True + ) + output["input_text"] = input_text + + # Handle text and fast prediction modes using autoregressive generation + if predict_mode == "text" or predict_mode == "fast": + # Initialize MoE token types for generation + moe_token_types = torch.zeros_like(input_ids) + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "moe_token_types": moe_token_types, + "image_grid_thw": image_grid_thw, + "dof_mask": dof_mask, + "agent_pos_mask": agent_pos_mask, + "proprioception": proprioception, + } + + # Generate output tokens + predict_output_ids = self.generate( + **batch, + max_new_tokens=100, + eos_token_id=[self.processor.tokenizer.eos_token_id], + use_cache=True, + pad_token_id=self.processor.tokenizer.pad_token_id, + temperature=(1.0 if not re_generate else 0.7), # Higher temperature for regeneration + do_sample=(False if not re_generate else True), # Enable sampling for regeneration + ) + + # Decode generated and ground truth text + gt_output_text = self.processor.batch_decode( + gt_output_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + predict_output_text = self.processor.batch_decode( + predict_output_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + output["gt_output_text"] = gt_output_text + output["predict_output_text"] = predict_output_text + + # Convert tokens to actions for fast prediction mode + if predict_mode == "fast": + action_id = [] + # Extract action tokens from generated sequence + for token_id_i in predict_output_ids[0]: + if token_id_i.item() >= self.processor.tokenizer.init_kwargs["action_token_start_index"]: + action_id.append( + token_id_i.item() - self.processor.tokenizer.init_kwargs["action_token_start_index"] + ) + + predict_action = self.processor.action_processor.decode( + [action_id], time_horizon=pred_horizon, action_dim=action_dim + ) + # Handle action decoding errors + if np.sum(predict_action) == 0: + print("Error in decoding action, predict_action is None") + output["predict_action"] = None + else: + # Convert discrete tokens to continuous actions + predict_action = torch.tensor(predict_action, device=self.device) + dof_mask = dof_mask.to(self.device).to(pixel_values.dtype) + # removed unnormalization step for now + predict_action = predict_action[:, :, dof_mask[0, 0, :].bool()] + output["predict_action"] = predict_action + + # Process ground truth actions if available + if action_chunk is not None: + # Apply DOF mask to get ground truth actions + # removed unnormalization step for now + action_chunk = action_chunk[:, :, dof_mask[0, 0, :].bool()] + output["gt_action"] = action_chunk + else: + output["gt_action"] = None + + # Handle diffusion-based action prediction + if predict_mode == "diffusion": + # Initialize with random noise + noisy_action = torch.randn( + size=(batch_size, pred_horizon, action_dim), + dtype=torch.float32, + device=inputs_embeds.device, + ) + dof_mask = dof_mask.to(inputs_embeds.device).to(torch.float32) + + def step(timestep, noisy_action): + """ + Single denoising step for diffusion process. + + Args: + timestep: Current diffusion timestep + noisy_action: Current noisy action estimate + + Returns: + torch.Tensor: Predicted clean action + """ + action_mask = input_ids == self.action_token_id_set["action_token_id"] + assert action_mask.any(), "No action token found in input_ids" + + # Prepare timestep for batch processing + timestep = timestep.unsqueeze(0).repeat(noisy_action.shape[0]) + action_embed = self.action_preprocessor.step( + timestep=timestep, noisy_action=noisy_action, dof_mask=dof_mask + ) + action_embed = action_embed.reshape(-1, inputs_embeds.shape[-1]) + + # Ensure action_embed has the correct dtype and device before assignment + action_embed = action_embed.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) + + # Create temporary copy of embeddings (clone preserves dtype) + temp_inputs_embeds = inputs_embeds.clone() + temp_inputs_embeds[action_mask] = action_embed + + # Forward pass through transformer + transformer_outputs = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=temp_inputs_embeds, + moe_token_types=moe_token_types, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ) + + # Extract action predictions from hidden states + hidden_states = transformer_outputs.last_hidden_state + action_mask = input_ids == self.action_token_id_set["action_token_id"] + action_hidden_states = hidden_states[action_mask].to(torch.float32) + pred = self.action_preprocessor.action_proj_back(action_hidden_states) + return pred.reshape(batch_size, pred_horizon, action_dim) + + # Perform ODE integration for diffusion sampling + times = torch.linspace( + 0, + 1, + num_inference_timesteps + 1, + device=inputs_embeds.device, + dtype=torch.float32, + ) + action_trajectory = odeint(step, noisy_action, times, method="euler") + + # Extract final predicted action + # Removed unnormalization step for now + predict_action = action_trajectory[-1] + output["predict_action"] = predict_action + + # Process ground truth actions if available + # removed unnormalization step for now + if action_chunk is not None: + output["gt_action"] = action_chunk[:, :, dof_mask[0, 0, :].bool()] + + return output + + def forward(self, mode: str | None = None, predict_mode: str | None = "text", **kwargs): + """ + Main forward pass dispatcher for different execution modes. + + This method routes execution to appropriate forward functions based on the specified mode: + - No mode (None): Training step with gradient disabled + - 'predict': Prediction/inference mode + - 'train': Training mode with gradients enabled + - 'validate': Validation mode with gradients disabled + + Args: + mode (str, optional): Execution mode. If None, defaults to training step without gradients + predict_mode (str, optional): Prediction mode for 'predict' mode ("text", "fast", or "diffusion") + **kwargs: Additional arguments passed to the selected forward function + + Returns: + Model outputs appropriate for the selected mode + + Todo: + - Add support for distinguishing multi-modal data types in prediction mode + """ + if not mode: + with torch.no_grad(): + return self.train_step_forward(**kwargs) + elif mode == "predict": + return self.predict(predict_mode=predict_mode, **kwargs) + elif mode == "train": + return self.train_step_forward(use_cache=False, **kwargs) + elif mode == "validate": + with torch.no_grad(): + return self.train_step_forward(use_cache=False, **kwargs) + else: + raise NotImplementedError("invalid key") + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + moe_token_types=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + proprioception=None, + dof_mask=None, + agent_pos_mask=None, + **kwargs, + ): + """ + Prepare inputs for autoregressive generation with multi-modal support. + + This method handles input preparation for generation, including proper slicing of inputs + based on cache position, MoE token type management, and multi-modal data handling. + Vision inputs are selectively forwarded only when needed during generation. + + Args: + input_ids: Input token IDs + past_key_values: Cached key-value pairs from previous generation steps + attention_mask: Attention mask for input tokens + inputs_embeds: Pre-computed input embeddings + moe_token_types: Token type assignments for MoE routing + cache_position: Current cache position for generation + position_ids: Position IDs for tokens + use_cache: Whether to use key-value caching + pixel_values: Image pixel values + pixel_values_videos: Video pixel values + image_grid_thw: Image grid dimensions + video_grid_thw: Video grid dimensions + second_per_grid_ts: Time interval per temporal grid + proprioception: Proprioceptive sensor data + dof_mask: Degrees of freedom mask + agent_pos_mask: Agent position mask + **kwargs: Additional arguments + + Returns: + dict: Prepared model inputs for generation step + + Todo: + - Test this function thoroughly with various input configurations + + Note: + This is an overridden method that handles specific cases for multi-modal generation: + - Slices input_ids through cache_position to keep only unprocessed tokens + - Handles special cases for input_embeds, generation methods, and GPU synchronization + - Manages vision inputs to avoid unnecessary forward passes + """ + # Initialize MoE token types if not provided + if moe_token_types is None: + moe_token_types = torch.zeros_like( + input_ids + ) # FIXME: Handle case when input_embeds is used instead + else: + # Ensure moe_token_types length matches input_ids + if moe_token_types.shape[1] < input_ids.shape[1]: + # Calculate required padding length + pad_length = input_ids.shape[1] - moe_token_types.shape[1] + # Create padding tensor with default token type (0) + pad_tensor = torch.zeros( + (moe_token_types.shape[0], pad_length), + dtype=moe_token_types.dtype, + device=moe_token_types.device, + ) + # Concatenate padding to existing moe_token_types + moe_token_types = torch.cat([moe_token_types, pad_tensor], dim=1) + + # Handle input slicing based on cache state and special cases + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4: input_embeds case + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + moe_token_types = moe_token_types[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1: input_embeds provided + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3: GPU sync edge case + input_ids = input_ids[:, -cache_position.shape[0] :] + moe_token_types = moe_token_types[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (Exception 2 is no-op) + cache_pos = cache_position.clone() + input_ids = input_ids[:, cache_pos] + moe_token_types = moe_token_types[:, cache_pos] + + # Skip vision inputs for continuation steps (not initial generation) + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # Determine whether to use inputs_embeds or input_ids for this generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + # Prepare 4D causal attention mask for static cache + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + # Assemble all model inputs for generation + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "moe_token_types": moe_token_types, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + "proprioception": proprioception, + "dof_mask": dof_mask, + "agent_pos_mask": agent_pos_mask, + } + ) + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate tensor separation lengths. + + These parameters are computed directly from input_ids rather than being passed through + the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (torch.LongTensor): Input token IDs of shape (batch_size, sequence_length) + + Returns: + tuple: + - image_nums (torch.LongTensor): Number of images per sample + - video_nums (torch.LongTensor): Number of videos per sample + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Find vision start tokens and their following tokens + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + # Count images and videos following vision start tokens + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + """ + Expand inputs for generation with support for multi-modal tensors. + + This is an overridden method that supports expanding tensors without a standard batch + size dimension, specifically for vision-related tensors: + - pixel_values.shape[0] = sum(sequence_lengths for all image samples) + - image_grid_thw.shape[0] = sum(num_images for all samples) + - Similar patterns for video tensors + + Args: + expand_size (int): Factor by which to expand inputs (for beam search, etc.) + is_encoder_decoder (bool): Whether using encoder-decoder architecture + input_ids (torch.LongTensor, optional): Input token IDs + **model_kwargs: Additional model arguments to expand + + Returns: + tuple: (expanded_input_ids, expanded_model_kwargs) + """ + if expand_size == 1: + return input_ids, model_kwargs + + # Define keys for vision-related tensors that need special handling + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + """Expand vision-related tensors based on image/video counts per sample.""" + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + """Split tensor by lengths and repeat each sample.""" + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # Split images into samples and compute sequence lengths + samples = torch.split(image_grid_thw, list(image_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # Expand based on number of images per sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + # Split videos into samples and compute sequence lengths + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + # Expand based on number of videos per sample + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + # Handle list-type temporal grid data + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + """Expand standard tensors using repeat_interleave.""" + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # Expand visual inputs only if input_ids is available for counting images/videos + # If input_ids is unavailable, visual inputs won't be used, so no expansion needed + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + # Expand input_ids using standard repeat_interleave + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + # Expand all other model arguments + model_kwargs = _expand_dict_for_generation(model_kwargs) + + # Handle encoder-decoder specific expansion + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +class WallXPolicy(PreTrainedPolicy): + """ + Wall-X policy for cross-embodiment robotic control. + + Integrates Qwen2.5-VL vision-language model with action prediction + using flow matching for continuous action spaces. + """ + + config_class = WallXConfig + name = "wall_x" + + def __init__(self, config: WallXConfig): + super().__init__(config) + config.validate_features() + self.config = config + + # Initialize the wall-x model + self.model = Qwen2_5_VLMoEForAction.from_pretrained( + pretrained_name_or_path=config.pretrained_name_or_path, + action_tokenizer_path=config.action_tokenizer_path, + attn_implementation=config.attn_implementation, + ) + self.model.to(config.device) + self.model.to_bfloat16_for_selected_params() + + self.reset() + + def reset(self): + """Reset action queue.""" + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def get_optim_params(self): + """Get parameters for optimization.""" + return self.parameters() + + def preprocess_inputs( + self, + batch: dict[str, Any], + ) -> BatchFeature: + """ + Convert a batch of LeRobot dataset items to Wall-X model input format. + + This processes a batched dictionary where tensors have batch dimension first. + + Args: + batch: Dictionary with batched tensors: + - "observation.state": (batch_size, state_dim) or (batch_size, n_obs_steps, state_dim) + - "action": (batch_size, chunk_size, action_dim) + - "observation.images.": (batch_size, C, H, W) + - "task": List[str] of length batch_size + + Returns: + BatchFeature containing batched model inputs + """ + use_fast_tokenizer = self.config.use_fast_tokenizer + + # Get batch size from state tensor + batch_size = batch[OBS_STATE].shape[0] + + # ==================== PROCESS ALL SAMPLES ==================== + all_image_inputs = [] + all_texts = [] + + # Find image keys in batch + img_keys = [key for key in self.config.image_features if key in batch] + + for i in range(batch_size): + # Vision preprocessing per sample + processed_frames = [] + orig_height, orig_width = None, None + resized_height, resized_width = None, None + + for key in img_keys: + current_obs = batch[key][i].clone() # (C, H, W) + if current_obs.dim() == 3: + current_obs = current_obs.permute(1, 2, 0) # (H, W, C) + + img_pil = Image.fromarray((current_obs * 255).to(torch.uint8).cpu().numpy()) + orig_width, orig_height = img_pil.size + + target_size = RESOLUTION + if target_size != -1: + if orig_width > orig_height: + new_width = target_size + new_height = int(target_size * orig_height / orig_width) + else: + new_height = target_size + new_width = int(target_size * orig_width / orig_height) + img_pil = img_pil.resize((new_width, new_height)) + + current_width, current_height = img_pil.size + resized_height, resized_width = smart_resize( + current_height, + current_width, + factor=IMAGE_FACTOR, + min_pixels=MIN_PIXELS, + max_pixels=MAX_PIXELS, + ) + resized_img = img_pil.resize((resized_width, resized_height)) + processed_frames.append(resized_img) + + all_image_inputs.append(processed_frames) + + # Text preprocessing + task_text = batch["task"][i] if isinstance(batch["task"], list) else batch["task"] + instruction_info = {"instruction": task_text} + + frame_index = batch["frame_index"][i] if "frame_index" in batch else 0 + complete_text, _ = get_wallx_normal_text( + instruction_info, + self.config.chunk_size, + frame_index, + PRIORITY_ORDER, + img_keys, + generate_subtask_ratio=GENERATE_SUBTASK_RATIO, + ) + + text = process_grounding_points( + complete_text, orig_height, orig_width, resized_height, resized_width, MODEL_TYPE + ) + all_texts.append(text) + + # ==================== PROCESS AGENT POS ==================== + agent_pos = batch[OBS_STATE] # (batch_size, state_dim) + if agent_pos.dim() == 2: + agent_pos = agent_pos.unsqueeze(1) # (batch_size, 1, state_dim) + agent_pos_mask = (~torch.isnan(agent_pos)).float() + agent_pos = agent_pos.nan_to_num(nan=0.0) + + if agent_pos.shape[-1] != 20: + pad_size = 20 - agent_pos.shape[-1] + agent_pos = torch.cat( + [ + agent_pos, + torch.zeros(agent_pos.shape[0], agent_pos.shape[1], pad_size, device=agent_pos.device), + ], + dim=-1, + ) + agent_pos_mask = torch.cat( + [ + agent_pos_mask, + torch.zeros( + agent_pos_mask.shape[0], + agent_pos_mask.shape[1], + pad_size, + device=agent_pos_mask.device, + ), + ], + dim=-1, + ) + + # ==================== PROCESS ACTIONS ==================== + action = batch.get(ACTION) # (batch_size, chunk_size, action_dim) + if action is not None: + if action.dim() == 2: + action = action.unsqueeze(1) + dof_mask = (~torch.isnan(action)).float() + action = action.nan_to_num(nan=0.0) + + if action.shape[-1] != 20: + pad_size = 20 - action.shape[-1] + action = torch.cat( + [action, torch.zeros(action.shape[0], action.shape[1], pad_size, device=action.device)], + dim=-1, + ) + dof_mask = torch.cat( + [ + dof_mask, + torch.zeros(dof_mask.shape[0], dof_mask.shape[1], pad_size, device=dof_mask.device), + ], + dim=-1, + ) + else: + action_dim = self.config.output_features["action"].shape[0] + dof_mask = torch.cat( + [ + torch.ones( + batch_size, self.config.chunk_size, action_dim, device=batch[OBS_STATE].device + ), + torch.zeros( + batch_size, self.config.chunk_size, 20 - action_dim, device=batch[OBS_STATE].device + ), + ], + dim=-1, + ) + + # ==================== ACTION TOKEN REPLACEMENT ==================== + all_texts = replace_action_token( + all_texts, + action, + self.model.action_tokenizer if use_fast_tokenizer else None, + dof_mask, + ) + + # ==================== TOKENIZATION ==================== + inputs = preprocesser_call( + processor=self.model.processor, + text=all_texts, + images=all_image_inputs, + videos=None, + padding=True, + truncation=True, + return_tensors="pt", + max_length=TOKENIZER_MAX_LENGTH, + ) + + # ==================== ADDITIONAL INPUTS ==================== + action_token_id = self.model.processor.tokenizer.convert_tokens_to_ids("<|action|>") + moe_token_types = inputs.input_ids == action_token_id + + inputs["proprioception"] = agent_pos + inputs["agent_pos_mask"] = agent_pos_mask + inputs["action_chunk"] = action + inputs["dof_mask"] = dof_mask + inputs["moe_token_types"] = moe_token_types + inputs["frame_index"] = ( + batch["frame_index"] + if "frame_index" in batch + else torch.zeros(batch_size, device=batch[OBS_STATE].device) + ) + + # Move all tensors to the correct device + device = self.config.device + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.to(device) + + return inputs + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """ + Training forward pass using Qwen2_5_VLMoEForAction. + + Args: + batch: Dictionary containing preprocessed inputs from preprocess_inputs() + Expected keys: input_ids, attention_mask, pixel_values, image_grid_thw, + proprioception, agent_pos_mask, action_chunk, dof_mask, moe_token_types, + etc. + + Returns: + tuple: (loss, loss_dict) + """ + batch = self.preprocess_inputs( + batch, + ) + + # Call the underlying model's forward with mode="train" + outputs = self.model(**batch, mode="train") + + # Extract losses from output + loss = outputs.loss + loss_dict = { + "loss": loss.item() if loss is not None else 0.0, + } + + if outputs.flow_loss is not None: + loss_dict["flow_loss"] = outputs.flow_loss.item() + if outputs.cross_entropy_loss is not None: + loss_dict["cross_entropy_loss"] = outputs.cross_entropy_loss.item() + + # Add channel losses if available + if outputs.channel_loss_dict is not None: + for key, value in outputs.channel_loss_dict.items(): + if isinstance(value, torch.Tensor): + loss_dict[f"channel_{key}"] = value.item() + + return loss, loss_dict + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict action chunk for evaluation.""" + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + batch = self.preprocess_inputs( + batch, + ) + + if self.config.prediction_mode == "diffusion": + output = self.model( + **batch, + action_dim=self.config.max_action_dim, + pred_horizon=self.config.chunk_size, + mode="predict", + predict_mode="diffusion", + ) + elif self.config.prediction_mode == "fast": + output = self.model( + **batch, + action_dim=self.config.output_features["action"].shape[0], + pred_horizon=self.config.chunk_size, + mode="predict", + predict_mode="fast", + ) + else: + raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented") + + # Extract action tensor from output dictionary + actions = output["predict_action"] + + # Unpad actions to actual action dimension + action_dim = self.config.output_features["action"].shape[0] + actions = actions[:, :, :action_dim] + + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select single action for environment execution.""" + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + # Use action queue + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + + return self._queues[ACTION].popleft() diff --git a/src/lerobot/policies/wall_x/processor_wall_x.py b/src/lerobot/policies/wall_x/processor_wall_x.py new file mode 100644 index 000000000..e4e281541 --- /dev/null +++ b/src/lerobot/policies/wall_x/processor_wall_x.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.wall_x.configuration_wall_x import WallXConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +def make_wall_x_pre_post_processors( + config: WallXConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the Wall-X policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations + 2. Adding a batch dimension + 4. Normalizing input and output features based on dataset statistics + 5. Moving all data to the specified device + + The post-processing pipeline handles the model's output by: + 1. Unnormalizing the output actions to their original scale + 2. Moving data to the CPU + + Args: + config: The configuration object for the Wall-X policy + dataset_stats: A dictionary of statistics for normalization + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + WallXTaskProcessor(), # Process task description + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + +@ProcessorStepRegistry.register(name="wall_x_task_processor") +class WallXTaskProcessor(ComplementaryDataProcessorStep): + """ + A processor step that ensures the task description is properly formatted for Wall-X. + + This step handles task preprocessing similar to Qwen-VL requirements. + """ + + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + # Provide default task if none specified + complementary_data["task"] = "Execute the robot action." + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: ensure proper formatting + if not task.endswith("."): + new_complementary_data["task"] = f"{task}." + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: format each + new_complementary_data["task"] = [t if t.endswith(".") else f"{t}." for t in task] + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py new file mode 100644 index 000000000..731ef3b3e --- /dev/null +++ b/src/lerobot/policies/wall_x/qwen_model/configuration_qwen2_5_vl.py @@ -0,0 +1,248 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation + + +class Qwen2_5_VLVisionConfig(PretrainedConfig): + model_type = "qwen2_5_vl" + base_config_key = "vision_config" + + def __init__( + self, + depth=32, + hidden_size=3584, + hidden_act="silu", + intermediate_size=3420, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + tokens_per_second=4, + window_size=112, + out_hidden_size=3584, + fullatt_block_indexes=[7, 15, 23, 31], + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.tokens_per_second = tokens_per_second + self.window_size = window_size + self.fullatt_block_indexes = fullatt_block_indexes + self.out_hidden_size = out_hidden_size + + +class Qwen2_5_VLConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a + Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 152064): + Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2_5_VLModel`] + hidden_size (`int`, *optional*, defaults to 8192): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 29568): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 80): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + vision_config (`Dict`, *optional*): + The config for the visual encoder initialization. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + ```python + >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig + + >>> # Initializing a Qwen2_5_VL style configuration + >>> configuration = Qwen2_5_VLConfig() + + >>> # Initializing a model from the Qwen2-VL-7B style configuration + >>> model = Qwen2_5_VLForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_5_vl" + sub_configs = {"vision_config": Qwen2_5_VLVisionConfig} + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2_5_VL` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + num_experts=4, + experts=None, + dof_config=None, + noise_scheduler=None, + dim_inputs=(1536, 1536), + attention_moe=False, + mlp_moe=False, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + self.vision_config = self.sub_configs["vision_config"]() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.layer_types = ["dense"] * num_hidden_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + self.num_experts = num_experts + self.experts = experts + self.dof_config = dof_config + self.noise_scheduler = noise_scheduler + self.dim_inputs = tuple(dim_inputs) + self.attention_moe = attention_moe + self.mlp_moe = mlp_moe + + if self.rope_scaling is not None and "type" in self.rope_scaling: + if self.rope_scaling["type"] == "mrope": + self.rope_scaling["type"] = "default" + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self, ignore_keys={"mrope_section"}) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + @property + def text_config(self): + return self + + +__all__ = ["Qwen2_5_VLConfig"] diff --git a/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py new file mode 100644 index 000000000..490e25095 --- /dev/null +++ b/src/lerobot/policies/wall_x/qwen_model/qwen2_5_vl_moe.py @@ -0,0 +1,2788 @@ +import math +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from transformers import AutoConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import ( + Cache, + DynamicCache, + SlidingWindowCache, + StaticCache, +) +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) + +from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + flash_attn_func = None + + +if is_flash_attn_2_available(): + pass +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) + q = q.squeeze(0) + k = k.squeeze(0) + + if max_seqlen is None: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], + torch.finfo(q.dtype).min, + device=q.device, + dtype=q.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int | None = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +Qwen2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = ( + False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + window_index = window_index.to(hidden_states.device) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen_full = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen_window = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen_full + else: + cu_seqlens_now = cu_window_seqlens + max_seqlen_now = max_seqlen_window + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, + hidden_states, + cu_seqlens_now, + None, + position_embeddings, + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + max_seqlen=max_seqlen_now, + position_embeddings=position_embeddings, + ) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer( + "inv_freq", inv_freq, persistent=False + ) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if ( + seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len + ): # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for thw grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int | None = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where( + torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): + """ + Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_rate, + softmax_scale=None, + causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] + | None = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + def __init__(self, config: Qwen2_5_VLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | BaseModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + +QWEN2_5_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = ( + expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + ) + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = ( + torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + ) + w_index = ( + torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + ) + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + ) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1 + is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + } + ) + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: torch.LongTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_values_videos", + "video_grid_thw", + "second_per_grid_ts", + ] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +@dataclass +class Qwen2_5_VLACausalLMOutputWithPast(ModelOutput): + loss: torch.FloatTensor | None = None + flow_loss: torch.FloatTensor | None = None + cross_entropy_loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + channel_loss_dict: dict[torch.FloatTensor] | None = None + channel_loss_count_dict: dict[torch.FloatTensor] | None = None + + +class BlockSparseMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.hidden_size = config["hidden_size"] + self.intermediate_size = config["intermediate_size"] + self.hidden_act = config["hidden_act"] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[self.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class SparseMoeBlock(nn.Module): + def __init__(self, config, num_experts: int): + super().__init__() + self.num_experts = num_experts + self.experts = nn.ModuleList([BlockSparseMLP(config.experts[i]) for i in range(num_experts)]) + + if not hasattr(config, "dim_inputs") or not config.dim_inputs: + raise ValueError("Config must contain valid dim_inputs") + + self.dim_inputs = config.dim_inputs + + def forward(self, hidden_states: torch.Tensor, experts_indices: torch.Tensor) -> torch.Tensor: + """ + Route different hidden_states to corresponding experts for processing. + + Args: + hidden_states (torch.Tensor): Tensor of shape (batch_size, seq_length, hidden_dim). + experts_indices (torch.Tensor): Tensor of shape (batch_size, seq_length), + indicating the expert index assigned to each token. + + Returns: + output (torch.Tensor): Tensor of shape (batch_size, seq_length, hidden_dim). + """ + batch_size, seq_length, hidden_dim = hidden_states.size() + output = torch.zeros_like(hidden_states) + + for expert_idx, expert in enumerate(self.experts): + mask = experts_indices == expert_idx + if mask.sum() == 0: + continue + dim_input = self.dim_inputs[expert_idx] + + selected_hidden = hidden_states[mask] + processed_hidden = expert(selected_hidden[:, :dim_input]) + + batch_indices, seq_indices = torch.where(mask) + output[batch_indices, seq_indices, :dim_input] = processed_hidden + + return output + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer_with_MoE(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int, num_experts: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if config.mlp_moe: + self.moe = SparseMoeBlock(config, num_experts=num_experts) + self.mlp = None + else: + self.mlp = Qwen2_5_VLMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + token_types=None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + hidden_states = hidden_states.to(self.input_layernorm.weight.dtype) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states.to(self.self_attn.q_proj.weight.dtype) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = hidden_states.to(self.post_attention_layernorm.weight.dtype) + hidden_states = self.post_attention_layernorm(hidden_states) + if self.mlp is None: # using moe mlp + hidden_states = hidden_states.to(self.moe.experts[0].down_proj.weight.dtype) + hidden_states = self.moe(hidden_states, token_types) + else: + hidden_states = hidden_states.to(self.mlp.down_proj.weight.dtype) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs + + +class Qwen2_5_VLMoEModel(Qwen2_5_VLPreTrainedModel): + """Qwen2.5-VL model with Mixture of Experts (MoE) architecture. + + This model extends the base Qwen2.5-VL model by incorporating MoE layers + for improved scalability and specialization across different token types. + """ + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + num_experts: int | None = None, + *args, + **kwargs, + ): + """Load a pretrained model with optional MoE configuration. + + Args: + pretrained_model_name_or_path: Path or name of the pretrained model + num_experts: Number of experts for MoE layers (if not in config) + *args: Additional arguments passed to parent class + **kwargs: Additional keyword arguments passed to parent class + + Returns: + Initialized model instance with MoE configuration + """ + config = kwargs.get("config") + if config is None: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + + # Override number of experts if specified + if num_experts is not None: + config.num_experts = num_experts + + kwargs["config"] = config + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + def __init__(self, config: Qwen2_5_VLConfig): + """Initialize the Qwen2.5-VL MoE model. + + Args: + config: Model configuration containing architecture parameters + """ + super().__init__(config) + + # Basic model parameters + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Model components + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + # Decoder layers with MoE support + self.layers = nn.ModuleList( + [ + Qwen2_5_VLDecoderLayer_with_MoE(config, layer_idx, config.num_experts) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # Model configuration + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + """Get the input embedding layer. + + Returns: + The token embedding layer + """ + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + """Set the input embedding layer. + + Args: + value: New embedding layer to use + """ + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + moe_token_types: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple | BaseModelOutputWithPast: + # Set default output options + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Validate inputs + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if moe_token_types is None: + raise ValueError("moe_token_types must be provided for MoE routing") + + # Handle gradient checkpointing compatibility + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Initialize cache if needed + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + # Get input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Set up cache position + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # Set up position IDs (hardcoded 3 dimensions for temporal, height, width) + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # Create causal attention mask + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + moe_token_types, + ) + + hidden_states = inputs_embeds + + # Create position embeddings to be shared across decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # Initialize output collections + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Process through decoder layers + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + # Use gradient checkpointing during training + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + moe_token_types, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + # Regular forward pass + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + token_types=moe_token_types, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + # Update cache if using it + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + # Collect attention weights if requested + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Apply final layer normalization + hidden_states = self.norm(hidden_states) + + # Add final hidden states if collecting all states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + # Return outputs in requested format + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + moe_token_types: torch.LongTensor | None = None, + ): + """Update causal attention mask with support for bidirectional attention for specific token types. + + This method creates and modifies attention masks to support different attention patterns: + - Standard causal (unidirectional) attention for most tokens + - Bidirectional attention for specific token types (e.g., MoE routing tokens) + + Args: + attention_mask: Input attention mask to avoid attending to padding tokens + input_tensor: Input embeddings tensor for shape and device information + cache_position: Position indices for caching mechanisms + past_key_values: Cached key-value pairs from previous forward passes + output_attentions: Whether attention weights will be returned + moe_token_types: Optional tensor indicating token types for MoE routing + (type 1 tokens will use bidirectional attention) + + Returns: + Updated causal attention mask, or None if using Flash Attention 2 + """ + # Flash Attention 2 handles masking internally + if self.config._attn_implementation == "flash_attention_2": + return None + + # Calculate sequence lengths for cache management + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # For SDPA (Scaled Dot Product Attention), use `is_causal` argument when possible + # instead of explicit attention mask to enable Flash Attention 2 dispatch + # Note: This optimization is not compatible with static cache + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + # Check if we can ignore the causal mask and rely on SDPA's internal handling + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + # Extract tensor properties for mask creation + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + + # Determine target length based on cache type + if using_sliding_window_cache or using_static_cache: + # Use maximum cache shape for sliding window or static caches + target_length = past_key_values.get_max_cache_shape() + else: + # For dynamic cache or no cache, calculate based on attention mask or sequence length + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # Generate 4D causal attention mask from 2D input mask if provided + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + # Modify mask to support bidirectional attention for specific token types + if moe_token_types is not None: + # Identify positions of type 1 tokens (MoE routing tokens) + type1_tokens = (moe_token_types == 1).unsqueeze(1).unsqueeze(2) # Shape: [B, 1, 1, S] + + # Create bidirectional attention region for type 1 tokens + # This allows type 1 tokens to attend to each other bidirectionally + type1_mask = torch.zeros_like(causal_mask) # Shape: [B, num_heads, S, S] + type1_region = type1_tokens & type1_tokens.transpose(-1, -2) # Shape: [B, 1, S, S] + type1_mask = type1_mask.masked_fill(type1_region, 1.0).to(torch.bool) + + # Apply bidirectional attention: zero out causal constraints in type 1 regions + causal_mask = torch.where( + type1_mask, # Where type 1 tokens interact with each other + torch.zeros_like(causal_mask), # Remove causal masking (allow bidirectional) + causal_mask, # Keep original causal masking for other regions + ) + + # Handle special case for SDPA with CUDA/XPU devices + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Ensure attention to all tokens in fully masked rows for memory-efficient attention + # This is required for F.scaled_dot_product_attention's memory-efficient path + # when using left padding. See: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +__all__ = [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + "Qwen2_5_VLDecoderLayer_with_MoE", + "Qwen2_5_VLMoEModel", +] diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py new file mode 100644 index 000000000..2ea40b377 --- /dev/null +++ b/src/lerobot/policies/wall_x/utils.py @@ -0,0 +1,631 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wall-X Utility Functions. + +Contains data processing utilities, text formatting functions, and helper classes +for the Wall-X cross-embodiment robotic control model. +""" + +import random +import re +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any + +import torch +from transformers import BatchFeature + +from lerobot.policies.wall_x.constant import ( + CAMERA_NAME_MAPPING, +) +from lerobot.utils.constants import OBS_IMAGES + + +@dataclass +class X2RDataProcessingConfig: + """Configuration class for X2R data processing pipeline. + + This class contains all the necessary parameters for processing robotic data + including camera mappings, tactile sensor configurations, action predictions, + and various processing options. + """ + + # Action prediction configuration + predict_action_keys: list[str] = field(default_factory=list) + obs_action_keys: list[str] = field(default_factory=list) + + # Image resolution settings for different views + resolution: dict[str, int] = field( + default_factory=lambda: { + "face_view": -1, + "left_wrist_view": 128, + "right_wrist_view": 128, + } + ) + + # Dataset splitting + train_test_split: float = 0.9 + split_seed: int = 42 + + # Instruction handling + priority_order: dict[str, float] | None = None + + # Vision model parameters + model_type: str = "qwen2_5" + max_pixels: int = 16384 * 28 * 28 + min_pixels: int = 4 * 28 * 28 + image_factor: int = 28 + + generate_subtask_ratio: float = 0.0 + + def __post_init__(self): + """Post-initialization validation and setup.""" + # Validate train/test split + if not 0 < self.train_test_split < 1: + raise ValueError(f"train_test_split must be between 0 and 1, got {self.train_test_split}") + + def as_dict(self) -> dict: + """Convert configuration to dictionary format. + + Returns: + Dict: Configuration as dictionary + """ + return self.__dict__ + + def update(self, **kwargs) -> "X2RDataProcessingConfig": + """Update configuration parameters. + + Args: + **kwargs: Key-value pairs to update + + Returns: + X2RDataProcessingConfig: Updated configuration instance + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise ValueError(f"Unknown configuration parameter: {key}") + return self + + +def preprocesser_call( + processor, + images: list | Any | None = None, + text: str | list[str] | None = None, + videos: list | Any | None = None, + padding: bool | str = False, + truncation: bool | None = None, + max_length: int | None = None, + return_tensors: str = "pt", +) -> BatchFeature: + """Unified preprocessing function for Wall-X model handling text, image and video inputs. + + Processes inputs into format suitable for multimodal transformer models, including: + - Text tokenization and special token handling + - Image/video processing through image processor + - Attention mask and label generation + - Padding and truncation handling + + Args: + processor: Multimodal processor containing tokenizer and image processor + images: Input images (PIL, numpy arrays, or torch tensors) + text: Text or list of texts to tokenize + videos: Input videos (numpy arrays or torch tensors) + padding: Whether to pad sequences to same length + truncation: Whether to truncate sequences longer than max_length + max_length: Maximum length for truncation/padding + return_tensors: Format for returned tensors ('pt', 'np', etc.) + + Returns: + BatchFeature containing processed inputs with keys: + - input_ids: Tokenized text + - attention_mask: Attention mask for text + - pixel_values: Processed image pixels + - pixel_values_videos: Processed video frames + - image_grid_thw: Image grid dimensions for LLM + - video_grid_thw: Video grid dimensions for LLM + - labels: Training labels with masking + """ + # Process image inputs + if images is not None and len(images) > 0: + image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + # Process video inputs + if videos is not None: + videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors) + video_grid_thw = videos_inputs["video_grid_thw"] + else: + videos_inputs = {} + video_grid_thw = None + + # Ensure text input is in list format + if not isinstance(text, list): + text = [text] + + # Process image placeholder tokens in text + if image_grid_thw is not None: + merge_length = processor.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while "<|image_pad|>" in text[i]: + # Add bounds checking to avoid index overflow + if index >= len(image_grid_thw): + print( + f"Warning: Number of image placeholders ({index + 1}) " + f"exceeds actual images ({len(image_grid_thw)}), " + f"skipping remaining placeholder processing" + ) + break + # Replace image placeholder with actual token count + token_count = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace("<|image_pad|>", "<|placeholder|>" * token_count, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") + + # Process video placeholder tokens in text + if video_grid_thw is not None: + merge_length = processor.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while "<|video_pad|>" in text[i]: + # Replace video placeholder with actual token count + token_count = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace("<|video_pad|>", "<|placeholder|>" * token_count, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>") + + # Tokenize complete input text + text_inputs = processor.tokenizer( + text, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + + # Get pad token ID for label generation + pad_token_id = processor.tokenizer.pad_token_id + if pad_token_id is None: + pad_token_id = processor.tokenizer.eos_token_id + + # Generate labels for multi-turn dialogue, keeping only assistant response loss + labels = torch.full_like(text_inputs.input_ids, -100) + assistant_marker = "<|im_start|>assistant\n" + im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") + assistant_tokens = processor.tokenizer("<|im_start|>assistant\n", add_special_tokens=False).input_ids + + for i in range(len(text)): + assistant_regions = [] + parts = text[i].split(assistant_marker) + + # Process each part to determine which tokens belong to assistant responses + # Count left padding tokens + num_left_pads = 0 + for token_id in text_inputs.input_ids[i]: + if token_id == pad_token_id: + num_left_pads += 1 + else: + break + current_pos = num_left_pads + + for j, part in enumerate(parts): + part_tokens = processor.tokenizer(part, add_special_tokens=False).input_ids + if j == 0: + # First part is system prompt or user question, all labels are -100 + current_pos += len(part_tokens) + continue + + # From second part onwards, each part starts with assistant response + for k in range(current_pos + 1, len(text_inputs.input_ids[i])): + if text_inputs.input_ids[i][k] == im_end_token_id: + assistant_regions.append((current_pos + len(assistant_tokens), k + 2)) + break + current_pos += len(part_tokens) + 3 + + # Set labels for assistant response regions + for start, end in assistant_regions: + labels[i][start:end] = text_inputs.input_ids[i][start:end] + + # Mask special action tokens in labels + action_token_id = processor.tokenizer.encode("<|action|>")[0] + propri_token_id = processor.tokenizer.encode("<|propri|>")[0] + labels[labels == action_token_id] = -100 + labels[labels == propri_token_id] = -100 + labels[labels == processor.tokenizer.pad_token_id] = -100 + + # Set labels to None if all are invalid to skip cross entropy loss + if (labels != -100).any().item(): + text_inputs["labels"] = labels + else: + text_inputs["labels"] = None + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + +def process_grounding_points( + text: str, + orig_height: int, + orig_width: int, + resized_height: int, + resized_width: int, + model_type: str, +) -> str: + """Process grounding point coordinates in text based on image resizing. + + Adjusts coordinate values in tags to match resized image dimensions + for different model types (qwen2, qwen2_5). + + Args: + text: Input text containing tags with coordinates + orig_height: Original image height + orig_width: Original image width + resized_height: Resized image height + resized_width: Resized image width + model_type: Model type for coordinate processing ('qwen2' or 'qwen2_5') + + Returns: + Text with adjusted coordinate values + """ + # Regex pattern to match tags and their contents + point_pattern = re.compile(r"(.*?)") + + def process_match(match): + """Process a single point match and adjust coordinates.""" + coords_str = match.group(1) + try: + # Extract coordinates from string + coords = list(map(int, re.findall(r"\d+", coords_str))) + + # Calculate resize scale factors + scale_w = resized_width / orig_width + scale_h = resized_height / orig_height + + if len(coords) == 2: + x, y = coords + if model_type == "qwen2_5": + # Qwen2.5 uses pixel coordinates + new_x = max(0, min(round(x * scale_w), resized_width - 1)) + new_y = max(0, min(round(y * scale_h), resized_height - 1)) + elif model_type == "qwen2": + # Qwen2 normalizes to [0, 1000) range + new_x = max(0, min(999.999, (x / orig_width) * 1000)) + new_y = max(0, min(999.999, (y / orig_height) * 1000)) + else: + raise ValueError(f"Unsupported model type: {model_type}") + coords = [new_x, new_y] + + elif len(coords) == 4: + x1, y1, x2, y2 = coords + if model_type == "qwen2_5": + new_x1 = max(0, min(round(x1 * scale_w), resized_width - 1)) + new_y1 = max(0, min(round(y1 * scale_h), resized_height - 1)) + new_x2 = max(0, min(round(x2 * scale_w), resized_width - 1)) + new_y2 = max(0, min(round(y2 * scale_h), resized_height - 1)) + elif model_type == "qwen2": + new_x1 = max(0, min(999.999, (x1 / orig_width) * 1000)) + new_y1 = max(0, min(999.999, (y1 / orig_height) * 1000)) + new_x2 = max(0, min(999.999, (x2 / orig_width) * 1000)) + new_y2 = max(0, min(999.999, (y2 / orig_height) * 1000)) + else: + raise ValueError(f"Unsupported model type: {model_type}") + coords = [new_x1, new_y1, new_x2, new_y2] + + # Return processed point tag + return f"[{', '.join(map(str, coords))}]" + + except (ValueError, TypeError): + # Return original content if processing fails + return match.group(0) + + # Replace all matching point tags + processed_text = point_pattern.sub(process_match, text) + return processed_text + + +def get_frame_instruction( + instruction_info: dict[str, Any], + frame_idx: int | None = None, + truncate_keys: list[str] | None = None, +) -> tuple[dict[str, Any], int | None]: + """Extract frame-specific instruction from instruction dictionary. + + Args: + instruction_info: Dictionary containing instruction components + frame_idx: Current frame index + truncate_keys: Keys that trigger truncation when found + + Returns: + Tuple of (frame_instruction_dict, split_end_frame) + """ + if truncate_keys is None: + truncate_keys = [ + "subtask_generation", + "distribute", + "subtask_generation_zh", + "distribute_zh", + ] + + instruction_for_frame = {} + split_end = None + + for key, value in instruction_info.items(): + if isinstance(value, dict): + # Handle frame-range specific instructions + for frame_range, frame_instruction in value.items(): + start_frame, end_frame = map(int, frame_range.split(" ")) + if start_frame <= frame_idx < end_frame or (start_frame == frame_idx): + instruction_for_frame[key] = frame_instruction + if truncate_keys is not None and split_end is None and key in truncate_keys: + split_end = end_frame + 1 + break + else: + instruction_for_frame[key] = value + + return instruction_for_frame, split_end + + +def get_task_instruction( + frame_instruction_info: dict[str, Any], priority_order: OrderedDict | None = None +) -> str: + """Construct task instruction from available instruction fields using priority sampling. + + Args: + frame_instruction_info: Dictionary containing instruction fields + priority_order: OrderedDict specifying sampling probability for each field + + Returns: + Combined instruction string with priority components + """ + # Default priority settings + default_priority_order = OrderedDict( + { + "subtask_generation": 0.25, + "subtask_generation_zh": 0.25, + "distribute": 0.25, + "distribute_zh": 0.25, + } + ) + + if priority_order is not None: + priority_order = OrderedDict(priority_order) + else: + priority_order = default_priority_order + + got_instruction = False + task_instruction = "" + + # Sample instruction components based on priority probabilities + for key, prob in priority_order.items(): + if key in frame_instruction_info and frame_instruction_info[key] != "": + if got_instruction: + if random.random() >= prob: + continue + + task_instruction += f"\n{frame_instruction_info[key]}" + got_instruction = True + break + + # Fall back to base instruction if no priority components found + if not got_instruction: + task_instruction = frame_instruction_info.get("instruction", "") + + return task_instruction + + +def get_wallx_normal_text( + instruction_info: dict[str, Any], + action_chunk_size: int, + frame_idx: int, + priority_order: OrderedDict | None = None, + img_keys: list[str] | None = None, + generate_subtask_ratio: float = 0.0, +) -> tuple[str, bool]: + """Construct complete multimodal prompt text for Wall-X model. + + Formats input using special tokens including: + - System message + - User observations (with image placeholders) + - Task instructions + - Proprioception prompts + - Assistant responses (with action tokens) + + Args: + instruction_info: Dictionary containing instruction components + action_chunk_size: Number of action tokens to generate + frame_idx: Current frame index + priority_order: Priority order for instruction sampling + img_keys: List of image keys + generate_subtask_ratio: Probability of generating subtask instead of actions + + Returns: + Tuple of (formatted_prompt_text, is_subtask_generation) + """ + # Special tokens for formatting + role_start_symbol = "<|im_start|>" + role_end_symbol = "<|im_end|>" + vision_start_symbol = "<|vision_start|>" + vision_end_symbol = "<|vision_end|>" + image_pad_symbol = "<|image_pad|>" + propri_symbol = "<|propri|>" + action_symbol = "<|action|>" + action_fast_symbol = "<|action_fast|>" + + # System prologue + prologue = f"{role_start_symbol}system\nYou are a helpful assistant.{role_end_symbol}\n" + + # User request with observation + user_request = f"{role_start_symbol}user\nObservation:" + if img_keys: + img_keys = img_key_mapping(img_keys) + for key in img_keys: + user_request += f" {key}: {vision_start_symbol}{image_pad_symbol}{vision_end_symbol}" + user_request += "\nInstruction:" + + # Get frame-specific instruction + frame_instruction_info, _ = get_frame_instruction(instruction_info, frame_idx=frame_idx) + + generate_subtask = False + priority_keys = ["subtask_generation", "distribute"] + + # Decide whether to generate subtask or actions + if ( + bool(set(frame_instruction_info.keys()) & set(priority_keys)) + and random.random() < generate_subtask_ratio + ): + # Generate subtask (equivalent to VQA task) + instruction = frame_instruction_info.get("instruction", "") + text_prompt = "\nPredict the next action in language.\n" + user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" + + # Find output instruction from priority keys + for key in priority_keys: + if key in frame_instruction_info: + output_instruction = frame_instruction_info[key] + break + + assistant_output = f"{role_start_symbol}assistant\n{output_instruction}\n{role_end_symbol}" + generate_subtask = True + else: + # Generate actions + instruction = get_task_instruction(frame_instruction_info, priority_order=priority_order) + text_prompt = f"\nPredict the next action in robot action.\nProprioception: {propri_symbol}\n" + user_message = f"{user_request} {instruction}{text_prompt}{role_end_symbol}\n" + assistant_output = f"{role_start_symbol}assistant\n{action_fast_symbol}{role_end_symbol}\n{action_symbol * action_chunk_size}" + + complete_text = prologue + user_message + assistant_output + return complete_text, generate_subtask + + +def img_key_mapping(img_keys: list[str]) -> list[str]: + """Map image keys to camera names. + + Args: + img_keys: List of image keys + + Returns: + List of camera names + """ + processed_img_keys = [] + for key in img_keys: + key = key.replace(OBS_IMAGES + ".", "") + if key in CAMERA_NAME_MAPPING: + key = CAMERA_NAME_MAPPING[key] + else: + if "view" in key: + key = key.replace("_", " ") + else: + key = key + " view" + processed_img_keys.append(key) + return processed_img_keys + + +def get_action_tokens(normalized_actions: torch.Tensor | list, action_tokenizer) -> list[list[str]]: + """Convert normalized actions to action token strings. + + Args: + normalized_actions: Normalized action arrays/tensors + action_tokenizer: Tokenizer for converting actions to tokens + + Returns: + List of action token string lists for each sample + """ + if isinstance(normalized_actions, torch.Tensor): + normalized_actions = normalized_actions.cpu().numpy() + + all_action_tokens = [] + for i in range(len(normalized_actions)): + if isinstance(normalized_actions[i], torch.Tensor): + normalized_actions[i] = normalized_actions[i].cpu().numpy() + + token_id = action_tokenizer(normalized_actions[i]) + action_tokens = [f"<|action_token_{j}|>" for j in token_id[0]] + all_action_tokens.append(action_tokens) + + return all_action_tokens + + +def pad_action_token_strs( + actions_token_lists: list[list[str]], + pad_token: str = "<|endoftext|>", # nosec B107 +) -> list[str]: + """Pad action token lists to same length and join as strings. + + Args: + actions_token_lists: List of action token lists for each sample + pad_token: Token used for padding + + Returns: + List of padded action token strings + """ + max_len = max(len(tokens) for tokens in actions_token_lists) + padded_action_strs = [] + + for tokens in actions_token_lists: + padded_tokens = tokens + ["<|im_end|>\n"] + [pad_token] * (max_len - len(tokens)) + padded_action_strs.append("".join(padded_tokens)) + + return padded_action_strs + + +def replace_action_token( + text: list[str], + norm_action: torch.Tensor | None, + action_tokenizer, + dof_masks: torch.Tensor | None = None, +) -> list[str]: + """Replace action placeholders in text with actual action tokens. + + Args: + text: List of text strings with action placeholders + norm_action: Normalized action tensors + action_tokenizer: Tokenizer for converting actions to tokens + dof_masks: Masks for degrees of freedom + + Returns: + List of text strings with action tokens replaced + """ + if action_tokenizer is not None and norm_action is not None: + # Extract actions based on chunk sizes and DOF masks + norm_action = [action[:32, dof_masks[i, 0].bool()] for i, action in enumerate(norm_action)] + + # Convert to action tokens and pad + actions_fast_tokens = get_action_tokens(norm_action, action_tokenizer) + actions_fast_token_strs = pad_action_token_strs(actions_fast_tokens) + + # Replace action placeholders with actual tokens + actions_fast_token_idx = 0 + for i in range(len(text)): + if "<|action_fast|>" in text[i]: + text[i] = text[i].replace( + "<|action_fast|><|im_end|>\n", + actions_fast_token_strs[actions_fast_token_idx], + ) + actions_fast_token_idx += 1 + + # Remove remaining action placeholders + text = [t.replace("<|action|>", "") for t in text] + else: + # Remove action placeholders when no tokenizer available + text = [t.replace("<|action_fast|><|im_end|>\n", "") for t in text] + + return text diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py new file mode 100644 index 000000000..837907041 --- /dev/null +++ b/tests/policies/wall_x/test_wallx.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!""" + +import os + +import pytest +import torch + +# Skip if openpi or transformers is not available +pytest.importorskip("peft") +pytest.importorskip("transformers==4.49.0") + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local Wall-X installation and is not meant for CI", +) + +from lerobot.policies.factory import make_policy_config # noqa: E402 +from lerobot.policies.wall_x import WallXConfig # noqa: E402 +from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402 +from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402 +from lerobot.utils.random_utils import set_seed # noqa: E402 + + +def test_policy_instantiation(): + # Create config + set_seed(42) + config = WallXConfig(device="cuda") + + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(7,), + ), + "observation.images.face_view": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + "observation.images.face_view": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + } + + # Instantiate policy + policy = WallXPolicy(config) + preprocessor, postprocessor = make_wall_x_pre_post_processors(config=config, dataset_stats=dataset_stats) + # Test forward pass with dummy data + batch_size = 1 + device = config.device + batch = { + "observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), + "observation.images.face_view": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) + try: + loss, loss_dict = policy.forward(batch) + print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") + except Exception as e: + print(f"Forward pass failed: {e}") + raise + + # Test inference + batch = { + "observation.state": torch.randn(batch_size, 7, dtype=torch.float32, device=device), + "observation.images.face_view": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) + try: + with torch.no_grad(): + action = policy.select_action(batch) + action = postprocessor(action) + print(f"Action: {action}") + print(f"Action prediction successful. Action shape: {action.shape}") + except Exception as e: + print(f"Action prediction failed: {e}") + raise + + +def test_config_creation(): + """Test policy config creation through factory.""" + try: + config = make_policy_config( + policy_type="wall_x", + ) + print("Config created successfully through factory") + print(f" Config type: {type(config).__name__}") + except Exception as e: + print(f"Config creation failed: {e}") + raise + + +if __name__ == "__main__": + test_policy_instantiation() + test_config_creation()