Compare commits

..

8 Commits

Author SHA1 Message Date
Maxime Ellerbach ecf342d481 small fix for the preprocessor and padded images 2026-06-16 11:27:51 +00:00
Maxime Ellerbach 1e762d5240 linting 2026-06-15 12:11:39 +00:00
Maxime Ellerbach 35c3302f4d re-parenting of some layers to enable proper zero-3 FSDP 2026-06-15 12:11:27 +00:00
Maxime Ellerbach a323ea67b6 preparing for training adding some temporary debug code aswell to visualize model output 2026-06-12 15:25:28 +00:00
Maxime Ellerbach 7c063c3fbc changing reproducable results 2026-06-12 08:57:11 +00:00
Maxime Ellerbach 9cf12c941d big refactor to use models from diffusers and transformers 2026-06-12 08:56:58 +00:00
ZibinDong 4039da81c6 Add FastWAM policy review updates 2026-06-09 13:37:59 +00:00
ZibinDong b3a28a49f6 Add FastWAM policy 2026-06-09 13:37:59 +00:00
29 changed files with 5866 additions and 1802 deletions
+2
View File
@@ -67,6 +67,8 @@
title: VLA-JEPA
- local: eo1
title: EO-1
- local: fastwam
title: FastWAM
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
+163
View File
@@ -0,0 +1,163 @@
# FastWAM
FastWAM is a World Action Model policy for robot control. The LeRobot integration exposes FastWAM through the standard policy API so it can be configured with `policy.type=fastwam`, trained with `lerobot-train`, and loaded through the LeRobot pretrained policy interface.
## Model Overview
FastWAM keeps video modeling during training, but uses direct action prediction at inference time instead of iteratively generating future observations. This LeRobot policy wraps the FastWAM action model, adapts LeRobot batches to FastWAM training samples, and provides the standard processor pipeline for normalization and action postprocessing.
The implementation initializes the visual world-model components from `Wan-AI/Wan2.2-TI2V-5B` by default and predicts action chunks with shape `[batch, action_horizon, action_dim]`.
### What the LeRobot Integration Covers
- Standard `policy.type=fastwam` configuration through LeRobot
- Image, state, action, and language-task batch adaptation
- Action chunk inference through `select_action` and `predict_action_chunk`
- Checkpoint save/load through the LeRobot policy APIs
- Configurable LIBERO gripper action postprocessing
## Installation Requirements
Install LeRobot from source, then install FastWAM dependencies:
```bash
pip install -e ".[fastwam]"
```
This installs the FastWAM policy extra from `pyproject.toml`: `transformers`,
`diffusers`, `ftfy`, and `regex`, plus LeRobot's base dependencies.
For LIBERO evaluation, install the benchmark dependencies too:
```bash
pip install -e ".[fastwam,libero]"
```
This installs both extras. In addition to the FastWAM dependencies above, the
`libero` extra installs LeRobot dataset dependencies, `hf-libero` on Linux, and
`scipy`.
FastWAM uses the Wan2.2 TI2V backbone. The default model id is:
```python
policy.model_id=Wan-AI/Wan2.2-TI2V-5B
```
## Data Requirements
FastWAM expects a LeRobot dataset with:
- one or more visual observations whose widths concatenate to `policy.image_size[1]`
- `observation.state` when `policy.proprio_dim` is not `None`
- `action`
- a language task instruction through the dataset task field, or precomputed `context` and `context_mask` tensors
The default visual setup is one image feature named `observation.images.image` with shape `(3, 224, 448)`. If the dataset uses two cameras, configure `policy.input_features` so their heights match `224` and their widths sum to `448`.
## Usage
Create a new FastWAM policy with:
```bash
lerobot-train \
--dataset.repo_id=your-org/your-dataset \
--policy.type=fastwam \
--policy.action_dim=7 \
--policy.proprio_dim=8 \
--policy.action_horizon=32 \
--policy.n_action_steps=10 \
--policy.image_size='[224,448]' \
--output_dir=./outputs/fastwam_training \
--job_name=fastwam_training \
--steps=300000 \
--batch_size=8 \
--policy.device=cuda
```
Evaluate an existing LeRobot-format checkpoint on LIBERO-10 with:
```bash
lerobot-eval \
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
--policy.device=cuda \
--policy.torch_dtype=float32 \
--policy.n_action_steps=10 \
--env.type=libero \
--env.task=libero_10 \
--env.observation_height=224 \
--env.observation_width=224 \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=0 \
--env.episode_length=600
```
For `libero_goal`, `libero_spatial`, and `libero_object`, use
`--env.episode_length=300`.
For real-robot rollout, use the same checkpoint path:
```bash
lerobot-rollout \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--policy.path=your-org/fastwam-real-robot
```
## Configuration Notes
### Image Features
`policy.image_size` is the size of the concatenated FastWAM image tensor as `(height, width)`. Each configured image feature must have shape `(3, height, camera_width)`, and all camera widths must sum to the configured width.
### Action Chunking
`policy.action_horizon` controls the number of future actions supervised during training and predicted during inference. `policy.n_action_steps` controls how many actions are consumed before the policy predicts a fresh chunk. `policy.n_action_steps` must be less than or equal to `policy.action_horizon`.
### Wan Components
FastWAM loads the Wan VAE, video DiT, text encoder, and tokenizer from the configured Wan model directory or Hugging Face Hub model id. LeRobot-format FastWAM checkpoints saved by `save_pretrained` also copy the local Wan component files needed by `from_pretrained`.
### LIBERO Action Toggle
FastWAM LIBERO checkpoints use `policy.toggle_action_dimensions=[-1]` by
default to match the gripper action convention used by the original FastWAM
evaluation pipeline:
```bash
--policy.toggle_action_dimensions='[-1]'
```
## Results
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
| Suite | Success rate | n_episodes |
| -------------- | -----------: | ---------: |
| libero_spatial | 97.6% | 500 |
| libero_object | 99.0% | 500 |
| libero_goal | 95.0% | 500 |
| libero_10 | 94.0% | 500 |
| **average** | **96.4%** | 2000 |
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300` (1x H20 140 GB).
## References
- [Fast-WAM paper](https://arxiv.org/abs/2603.16666)
- [Fast-WAM project page](https://yuantianyuan01.github.io/FastWAM/)
- [Fast-WAM code](https://github.com/yuantianyuan01/FastWAM)
- [Released upstream checkpoints](https://huggingface.co/yuanty/fastwam)
- [Wan2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)
## Citation
```bibtex
@article{yuan2026fastwam,
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
journal = {arXiv preprint arXiv:2603.16666},
year = {2026},
url = {https://arxiv.org/abs/2603.16666}
}
```
+56
View File
@@ -0,0 +1,56 @@
## Research Paper
Paper: https://arxiv.org/abs/2603.16666
## Repository
Code: https://github.com/yuantianyuan01/FastWAM
Project page: https://yuantianyuan01.github.io/FastWAM/
## Citation
```bibtex
@article{yuan2026fastwam,
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
journal = {arXiv preprint arXiv:2603.16666},
year = {2026},
url = {https://arxiv.org/abs/2603.16666}
}
```
## Additional Resources
Base video model: https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B
Released upstream checkpoints: https://huggingface.co/yuanty/fastwam
## Results
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
| Suite | Success rate | n_episodes |
| -------------- | -----------: | ---------: |
| libero_spatial | 97.6% | 500 |
| libero_object | 99.0% | 500 |
| libero_goal | 95.0% | 500 |
| libero_10 | 94.0% | 500 |
| **average** | **96.4%** | 2000 |
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300`.
For LIBERO-10, use `--env.task=libero_10 --env.episode_length=600`:
```bash
lerobot-eval \
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
--policy.device=cuda \
--policy.torch_dtype=float32 \
--policy.n_action_steps=10 \
--env.type=libero \
--env.task=libero_10 --env.observation_height=256 --env.observation_width=256 \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=0 --env.episode_length=600
```
+5 -2
View File
@@ -214,9 +214,12 @@ groot = [
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
recap = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
fastwam = [
"lerobot[transformers-dep]",
"lerobot[diffusers-dep]",
]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
@@ -281,6 +284,7 @@ all = [
"lerobot[pi]",
"lerobot[molmoact2]",
"lerobot[smolvla]",
"lerobot[fastwam]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
@@ -297,7 +301,6 @@ all = [
"lerobot[sarm]",
"lerobot[robometer]",
"lerobot[topreward]",
"lerobot[recap]",
"lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
]
+2
View File
@@ -18,6 +18,7 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .fastwam.configuration_fastwam import FastWAMConfig as FastWAMConfig
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
@@ -42,6 +43,7 @@ __all__ = [
"ACTConfig",
"DiffusionConfig",
"EO1Config",
"FastWAMConfig",
"GaussianActorConfig",
"GrootConfig",
"MolmoAct2Config",
+15
View File
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .fastwam.configuration_fastwam import FastWAMConfig
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
elif name == "fastwam":
from .fastwam.modeling_fastwam import FastWAMPolicy
return FastWAMPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
elif policy_type == "fastwam":
return FastWAMConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -448,6 +455,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, FastWAMConfig):
from .fastwam.processor_fastwam import make_fastwam_pre_post_processors
processors = make_fastwam_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_fastwam_README.md
@@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_distributional_value_function import DistributionalVFConfig
from .modeling_distributional_value_function import DistributionalVFRewardModel
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
from .configuration_fastwam import FastWAMConfig
from .modeling_fastwam import FastWAMPolicy
from .processor_fastwam import make_fastwam_pre_post_processors
__all__ = [
"DistributionalVFConfig",
"DistributionalVFRewardModel",
"make_distributional_vf_pre_post_processors",
"FastWAMConfig",
"FastWAMPolicy",
"make_fastwam_pre_post_processors",
]
@@ -0,0 +1,394 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.configs import (
FeatureType,
NormalizationMode,
PolicyFeature,
PreTrainedConfig,
)
from lerobot.optim import AdamWConfig
from lerobot.utils.constants import ACTION, OBS_STATE
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base"
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
"patch_size",
"in_dim",
"hidden_dim",
"ffn_dim",
"freq_dim",
"text_dim",
"out_dim",
"num_heads",
"attn_head_dim",
"num_layers",
)
_FASTWAM_ACTION_BASE_COMPAT_KEYS = (
"hidden_dim",
"ffn_dim",
"num_heads",
"attn_head_dim",
"num_layers",
"text_dim",
"freq_dim",
)
def default_video_dit_config(action_dim: int) -> dict[str, Any]:
return {
"patch_size": [1, 2, 2],
"in_dim": 48,
"hidden_dim": 3072,
"ffn_dim": 14336,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 48,
"num_heads": 24,
"attn_head_dim": 128,
"num_layers": 30,
"eps": 1.0e-6,
"separated_timestep": True,
"use_gradient_checkpointing": False,
"video_attention_mask_mode": "first_frame_causal",
"action_conditioned": False,
"action_dim": action_dim,
"action_group_causal_mask_mode": "group_diagonal",
"fp32_attention": True,
}
def default_action_dit_config(action_dim: int) -> dict[str, Any]:
return {
"action_dim": action_dim,
"hidden_dim": 1024,
"ffn_dim": 4096,
"num_heads": 24,
"attn_head_dim": 128,
"num_layers": 30,
"text_dim": 4096,
"freq_dim": 256,
"eps": 1.0e-6,
"use_gradient_checkpointing": False,
"fp32_attention": True,
}
def _coerce_enum(enum_cls: type, value: Any) -> Any:
if isinstance(value, enum_cls):
return value
try:
return enum_cls(value)
except (TypeError, ValueError):
return getattr(enum_cls, str(value), value)
def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, PolicyFeature] | None:
if features is None:
return None
coerced = {}
for name, feature in features.items():
if isinstance(feature, PolicyFeature):
coerced[name] = feature
continue
coerced[name] = PolicyFeature(
type=_coerce_enum(FeatureType, feature["type"]),
shape=tuple(feature["shape"]),
)
return coerced
def _is_local_model_id(value: str) -> bool:
path = Path(value).expanduser()
return path.is_absolute() or value.startswith(("./", "../", "~")) or path.exists()
def _validate_wan_model_id(value: str, field_name: str) -> str:
if value == WAN22_MODEL_ID or _is_local_model_id(value):
return value
raise ValueError(f"`{field_name}` must be `{WAN22_MODEL_ID}` or an explicit local path, got `{value}`.")
def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool:
"""Return whether `fastwam-base` partial weights can initialize this config."""
default_video_config = default_video_dit_config(config.action_dim)
default_action_config = default_action_dit_config(config.action_dim)
return all(
config.video_dit_config.get(key) == default_video_config.get(key)
for key in _FASTWAM_VIDEO_BASE_COMPAT_KEYS
) and all(
config.action_dit_config.get(key) == default_action_config.get(key)
for key in _FASTWAM_ACTION_BASE_COMPAT_KEYS
)
@PreTrainedConfig.register_subclass("fastwam")
@dataclass
class FastWAMConfig(PreTrainedConfig):
"""Configuration for the FastWAM LeRobot policy.
Args:
action_dim (int): Number of scalar action channels per timestep.
proprio_dim (int | None): Number of proprioception channels used as an
extra text-context token. `None` disables proprio conditioning.
action_horizon (int): Number of actions predicted by one policy call.
num_video_frames (int): Raw video sampling window (in dataset frames). The
model actually operates on `model_video_frames` frames after subsampling
by `action_video_freq_ratio`.
action_video_freq_ratio (int): Actions are sampled at this multiple of the
video frame rate. Video frames are taken every `action_video_freq_ratio`-th
raw frame, so the model sees `(num_video_frames - 1) // ratio + 1` frames
spanning the same time window as `action_horizon` actions (ratio actions
per video frame).
image_size (tuple[int, int]): Concatenated image size as `(height, width)`.
context_len (int): Maximum text embedding token length.
video_dit_config (dict[str, Any] | None): Wan video expert config.
action_dit_config (dict[str, Any] | None): Action expert config.
use_gradient_checkpointing (bool): Enable activation checkpointing in both DiT
experts (trades compute for memory; propagated into the DiT configs).
freeze_video_expert (bool): Freeze the ~5B Wan video expert
(`model.video_expert`) so only the action expert + proprio encoder train.
Cuts the AdamW optimizer footprint substantially; the video expert keeps its
pretrained weights. (If enabled, also set `loss.lambda_video=0` to skip the
now-gradient-free video loss compute.)
"""
n_obs_steps: int = 1
action_dim: int = 7
proprio_dim: int | None = 8
action_horizon: int = 32
n_action_steps: int = 32
num_video_frames: int = 33
action_video_freq_ratio: int = 4
image_size: tuple[int, int] = (224, 448)
context_len: int = 128
model_id: str = WAN22_MODEL_ID
tokenizer_model_id: str = WAN22_MODEL_ID
base_model_id: str | None = FASTWAM_BASE_MODEL_ID
tokenizer_max_len: int = 128
load_text_encoder: bool = True
mot_checkpoint_mixed_attn: bool = False
torch_dtype: str = "bfloat16"
prompt_template: str = (
"A video recorded from a robot's point of view executing the following instruction: {task}"
)
num_inference_steps: int = 10
inference_seed: int | None = 42
rand_device: str = "cpu"
text_cfg_scale: float = 1.0
negative_prompt: str = ""
sigma_shift: float | None = None
tiled: bool = False
fp32_attention: bool = True
use_gradient_checkpointing: bool = False
freeze_video_expert: bool = False
toggle_action_dimensions: list[int] = field(default_factory=list)
video_scheduler: dict[str, float | int] = field(
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
)
action_scheduler: dict[str, float | int] = field(
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
)
loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0})
video_dit_config: dict[str, Any] | None = None
action_dit_config: dict[str, Any] | None = None
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
input_features: dict[str, PolicyFeature] | None = None
output_features: dict[str, PolicyFeature] | None = None
optimizer_lr: float = 1.0e-4
optimizer_weight_decay: float = 1.0e-2
def __post_init__(self) -> None:
super().__post_init__()
self.image_size = tuple(self.image_size)
self.model_id = _validate_wan_model_id(self.model_id, "model_id")
self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id")
self.input_features = _coerce_policy_features(self.input_features)
self.output_features = _coerce_policy_features(self.output_features)
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim)
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
self.action_dit_config["fp32_attention"] = bool(self.fp32_attention)
self.video_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
self.action_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
if self.input_features is None:
height, width = self.image_size
self.input_features = {
"observation.images.image": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, height, width),
)
}
if self.proprio_dim is not None:
self.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.proprio_dim,),
)
if self.output_features is None:
self.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))}
self.validate_features()
if self.pretrained_path or self.use_peft or not self.base_model_id:
return
if not is_fastwam_base_compatible_config(self):
return
self.pretrained_path = Path(self.base_model_id)
self._auto_pretrained_path = True
def _save_pretrained(self, save_directory: Path) -> None:
if not getattr(self, "_auto_pretrained_path", False):
super()._save_pretrained(save_directory)
return
pretrained_path = self.pretrained_path
self.pretrained_path = None
try:
super()._save_pretrained(save_directory)
finally:
self.pretrained_path = pretrained_path
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
def get_scheduler_preset(self) -> None:
return None
def set_dataset_feature_metadata(self, dataset_features: dict[str, Any]) -> None:
"""Rebuild visual input features from the dataset's real camera keys.
FastWAM's `__post_init__` installs a synthetic single-image default
(`observation.images.image` at full `image_size` width). For datasets
with one or more separately-named cameras (e.g. `observation.images.top`,
`observation.images.wrist`), this hook — invoked by `make_policy` once the
dataset metadata is known — replaces that default with the actual camera
keys, each declared at the policy's native per-camera resolution
(`image_size[0]` x `image_size[1] // num_cameras`). The accompanying
resize step in `make_fastwam_pre_post_processors` resizes raw frames to
match, so heterogeneous source resolutions (e.g. 480x640) are supported.
"""
image_keys = sorted(
key
for key, feature in dataset_features.items()
if key.startswith("observation.images.") and feature.get("dtype") in ("video", "image")
)
if not image_keys:
return
height, total_width = self.image_size
per_cam_width = total_width // len(image_keys)
new_inputs: dict[str, PolicyFeature] = {
key: PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, per_cam_width))
for key in image_keys
}
if self.proprio_dim is not None and OBS_STATE in dataset_features:
new_inputs[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,))
self.input_features = new_inputs
self.validate_features()
def validate_features(self) -> None:
if self.action_dim <= 0:
raise ValueError(f"`action_dim` must be positive, got {self.action_dim}.")
if self.action_horizon <= 0:
raise ValueError(f"`action_horizon` must be positive, got {self.action_horizon}.")
if self.n_action_steps > self.action_horizon:
raise ValueError("`n_action_steps` cannot exceed `action_horizon`.")
if self.action_video_freq_ratio <= 0:
raise ValueError(
f"`action_video_freq_ratio` must be positive, got {self.action_video_freq_ratio}."
)
# Video frames are subsampled by action_video_freq_ratio; the resulting model frame
# count must satisfy T % 4 == 1 for the VAE temporal tokenization (mirrors the
# original FastWAM dataset asserts).
if (self.num_video_frames - 1) % self.action_video_freq_ratio != 0:
raise ValueError(
f"`num_video_frames - 1` ({self.num_video_frames - 1}) must be divisible by "
f"`action_video_freq_ratio` ({self.action_video_freq_ratio})."
)
if ((self.num_video_frames - 1) // self.action_video_freq_ratio) % 4 != 0:
raise ValueError(
f"Subsampled video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio}) "
"must be divisible by 4 for VAE tokenization (i.e. model_video_frames % 4 == 1)."
)
if self.action_horizon % ((self.num_video_frames - 1) // self.action_video_freq_ratio) != 0:
raise ValueError(
f"`action_horizon` ({self.action_horizon}) must be divisible by the number of "
f"video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio})."
)
if not self.image_features:
raise ValueError("FastWAM requires at least one image feature.")
if self.action_feature is None:
raise ValueError("FastWAM requires `action` in output_features.")
action_shape = tuple(self.action_feature.shape)
if action_shape != (self.action_dim,):
raise ValueError(
f"FastWAM action feature shape must be ({self.action_dim},), got {action_shape}."
)
if self.proprio_dim is not None:
state_feature = self.robot_state_feature
if state_feature is None:
raise ValueError("FastWAM requires `observation.state` when `proprio_dim` is set.")
state_shape = tuple(state_feature.shape)
if state_shape != (self.proprio_dim,):
raise ValueError(
f"FastWAM state feature shape must be ({self.proprio_dim},), got {state_shape}."
)
height, width = self.image_size
image_width_sum = 0
for name, feature in self.image_features.items():
shape = tuple(feature.shape)
if len(shape) != 3 or shape[0] != 3:
raise ValueError(f"FastWAM image feature `{name}` must have shape (3, H, W), got {shape}.")
if shape[1] != height:
raise ValueError(f"FastWAM image feature `{name}` height must be {height}, got {shape[1]}.")
image_width_sum += shape[2]
if image_width_sum != width:
raise ValueError(f"FastWAM image feature widths must sum to {width}, got {image_width_sum}.")
@property
def model_video_frames(self) -> int:
"""Number of video frames the model actually operates on, after subsampling the
raw `num_video_frames` window by `action_video_freq_ratio` (e.g. 33 -> 9)."""
return (self.num_video_frames - 1) // self.action_video_freq_ratio + 1
@property
def observation_delta_indices(self) -> list[int]:
# Load the video frames the model is supervised on: the future window subsampled by
# action_video_freq_ratio (e.g. [0, 4, 8, ..., 32] -> 9 frames). Each video frame is
# thus `action_video_freq_ratio` actions apart, while actions load at the full rate
# (`action_delta_indices` = range(action_horizon)). Returning None would load only the
# current frame, making the video target a static repeat (degenerate supervision).
return list(range(0, self.num_video_frames, self.action_video_freq_ratio))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.action_horizon))
@property
def reward_delta_indices(self) -> None:
return None
@@ -0,0 +1,540 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
from collections import deque
from pathlib import Path
from typing import Any
import torch
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import OBS_STATE
from .configuration_fastwam import FastWAMConfig
from .modular_fastwam import ActionDiT, FastWAM, MoT
from .wan_components import (
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
)
from .wan_video_dit import WanVideoDiT
# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first
# eval episode's action chunks through `infer_joint` so the predicted video latents are
# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path).
_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1"
# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1
# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/
# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the
# (now-fixed) bug where the model ran on the un-subsampled num_video_frames.
_DEBUG_PRED_RATE_DIV = 1
class FastWAMPolicy(PreTrainedPolicy):
"""LeRobot policy wrapper for FastWAM.
Args:
config (FastWAMConfig): FastWAM policy configuration.
dataset_stats (dict[str, dict[str, Tensor]] | None): Optional LeRobot
dataset statistics passed by the training/evaluation stack.
"""
config_class = FastWAMConfig
name = "fastwam"
def __init__(
self,
config: FastWAMConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
**kwargs: Any,
):
# `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the
# dataset feature metadata is already applied to `config` by make_policy upstream,
# so we accept and ignore them, matching the other LeRobot policies.
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.model = self._build_core_model(config)
if config.freeze_video_expert and getattr(self.model, "video_expert", None) is not None:
# Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad,
# so its params drop out of the optimizer (and DDP skips them).
self.model.video_expert.requires_grad_(False)
# The transformer blocks are re-parented onto the MoTLayers (single FSDP owner), so
# `video_expert.requires_grad_` no longer reaches them — freeze them via the layers.
mot = getattr(self.model, "mot", None)
if mot is not None and getattr(mot, "layers", None) is not None:
for layer in mot.layers:
if "video" in layer.blocks:
layer.blocks["video"].requires_grad_(False)
self.reset()
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
# counts only eval-rollout resets (one per episode), not this __init__ one.
self._debug_constructed = True
self._debug_episode_index = -1
self._debug_seen_tasks: set[str] = set()
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video: list | None = None
self._debug_pairs: list = []
@classmethod
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
"""Shape-aware load that supports cross-embodiment fine-tuning.
`safetensors.load_model(strict=False)` ignores missing/unexpected keys but
still raises on a shape mismatch for a shared key. When fine-tuning from a
checkpoint trained on a different embodiment (e.g. the LIBERO 7-DoF / 8-dim
checkpoint adapted to a 6-DoF / 6-dim arm), the action encoder/head and
proprio encoder legitimately differ in shape. With `strict=False` we drop
only those shape-mismatched tensors — leaving them at their freshly
initialized values — and load every compatible tensor. With `strict=True`
the standard exact-match loader is used.
"""
from safetensors import safe_open
model_state_dict = model.state_dict()
mismatched = []
with safe_open(model_file, framework="pt") as f:
checkpoint_keys = list(f.keys())
for key in checkpoint_keys:
if key in model_state_dict and tuple(model_state_dict[key].shape) != tuple(
f.get_slice(key).get_shape()
):
mismatched.append(key)
if not mismatched:
return super()._load_as_safetensor(model, model_file, map_location, strict)
if strict:
raise RuntimeError(
f"FastWAM: {len(mismatched)} checkpoint tensors have a shape mismatch under "
f"strict=True: {mismatched}"
)
from safetensors.torch import load_file
logging.warning(
"FastWAM cross-embodiment load: reinitializing %d shape-mismatched tensor(s), keeping "
"every compatible weight: %s",
len(mismatched),
mismatched,
)
state_dict = load_file(model_file, device="cpu")
for key in mismatched:
state_dict.pop(key, None)
model.load_state_dict(state_dict, strict=False)
if map_location and map_location != "cpu":
model.to(map_location)
return model
def get_optim_params(self) -> list[Tensor]:
# Return the trainable tensors directly (a single param group). The optimizer
# builder wraps these in a param group; returning a bare {"params": [...]} dict
# instead would make `list(...)` yield the key string "params".
params = (
list(self.model.dit.parameters()) if hasattr(self.model, "dit") else list(self.model.parameters())
)
proprio_encoder = getattr(self.model, "proprio_encoder", None)
if proprio_encoder is not None:
params.extend(list(proprio_encoder.parameters()))
return [p for p in params if p.requires_grad]
def reset(self) -> None:
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
# TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's
# true-vs-pred video if it was a captured one (pairs accumulate only while
# capturing), then reset per-episode capture state.
if getattr(self, "_debug_constructed", False):
if _FASTWAM_DECODE_DEBUG and self._debug_pairs:
self._save_debug_video()
self._debug_episode_index += 1
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video = None
self._debug_pairs = []
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
`FastWAM.build_inputs` consumes (`video`, `action`, `context`/`context_mask`,
per-frame `proprio`).
The LeRobot training loop passes raw `observation.images.*`, a single-step
`observation.state` `[B, D]`, `action`, and a language `task` string. We do
only the translation `build_inputs` can't: stack the camera frames into a
video, encode the prompt with the (frozen) text encoder (mirroring inference,
so language-conditioned datasets need no precomputed context), and give proprio
the per-frame axis `build_inputs` indexes. All shape/presence validation is
left to `build_inputs`, the single authority on the contract.
"""
sample = dict(batch)
if "video" not in sample:
sample["video"] = _stack_video_from_images(batch, self.config)
if "context" not in sample or "context_mask" not in sample:
prompt = _prompt_from_batch(batch=batch, config=self.config)
if prompt is None:
raise KeyError(
"FastWAM training requires a `task`/`prompt` to encode text context, "
"or precomputed `context`/`context_mask` in the batch."
)
sample["context"], sample["context_mask"] = self.model.encode_prompt(prompt)
if self.config.proprio_dim is not None and "proprio" not in sample:
state = sample.get(OBS_STATE)
if state is not None:
# LeRobot gives a single-step state [B, D]; build_inputs expects
# per-frame [B, T, D] and uses frame 0, so add a T=1 axis.
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
return sample
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Compute FastWAM training loss for a LeRobot batch.
Args:
batch (dict[str, Tensor]): Batch containing FastWAM-ready keys
(`video`, `action`, `context`, `context_mask`) or LeRobot keys
that can be adapted (`observation.images.*`, `observation.state`,
`action`, `action_is_pad`).
Returns:
tuple[Tensor, dict[str, Any]]: The scalar loss to backprop, and a dict of
logging metrics (e.g. `loss_video`, `loss_action`) — the `(loss, output_dict)`
contract the LeRobot training loop expects.
"""
sample = self._batch_to_training_sample(batch)
loss, metrics = self.model.training_loss(sample)
return loss, dict(metrics or {})
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **_: Any) -> Tensor:
"""Predict a chunk of actions from the current FastWAM observation.
Args:
batch (dict[str, Tensor]): Inference batch with `input_image` or
image observation keys, plus `context/context_mask` or `prompt`.
Returns:
Tensor: Action chunk with shape `[B, action_horizon, action_dim]`.
"""
self.eval()
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
batch_size = _infer_kwargs_batch_size(infer_kwargs)
# TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task),
# run the joint video+action path so the predicted video is VAE-decoded; stash it
# so select_action can pair each predicted frame with the real obs that follows.
if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1:
out = self.model.infer_joint(
**infer_kwargs,
num_video_frames=self.config.model_video_frames,
test_action_with_infer_action=False,
)
# The decoded rollout has model_video_frames frames spanning the full
# action_horizon (action_video_freq_ratio actions per frame); the per-step
# pairing indexes into it, so keep all frames.
self._debug_last_video = out["video"]
action = _action_from_model_output(out)
elif batch_size == 1:
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
else:
action = torch.cat(
[
_action_from_model_output(
self.model.infer_action(
**_slice_infer_kwargs(infer_kwargs, index=i, batch_size=batch_size)
)
)
for i in range(batch_size)
],
dim=0,
)
return action.to(device=batch_device(batch), dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
self.eval()
# TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide
# whether to capture: yes iff this episode's task hasn't been captured yet (so we
# get the first episode of every task).
if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started:
self._debug_episode_started = True
task = self._debug_task_name(batch)
if task not in self._debug_seen_tasks:
self._debug_seen_tasks.add(task)
self._debug_capturing = True
self._debug_episode_task = task
capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
if capturing:
self._debug_step_in_chunk = 0 # a fresh chunk was just predicted
if capturing:
self._debug_capture_pair(batch)
self._debug_step_in_chunk += 1
return self._action_queue.popleft()
# ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ----
@staticmethod
def _debug_task_name(batch: dict[str, Any]) -> str:
task = batch.get("task")
if isinstance(task, (list, tuple)):
task = task[0] if task else None
return str(task) if task else "no_task"
def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None:
video = getattr(self, "_debug_last_video", None)
if not video:
return
real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1]
# Map env-step offset within the chunk to a predicted-frame index. The rollout has
# (model_video_frames - 1) transitions over action_horizon actions, so each env step
# advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio,
# e.g. 8/32 = 0.25 — one predicted frame per ~4 actions).
frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon)
idx = min(
int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)),
len(video) - 1,
)
pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx])
self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx)
self._debug_pairs.append(pair)
@staticmethod
def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None:
from PIL import ImageDraw
draw = ImageDraw.Draw(pair)
draw.text((3, 3), "true", fill=(255, 255, 0))
draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0))
@staticmethod
def _debug_tensor_to_pil(image: Tensor):
from PIL import Image
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
@staticmethod
def _debug_hstack(left, right):
from PIL import Image
if right.height != left.height:
right = right.resize((round(right.width * left.height / right.height), left.height))
canvas = Image.new("RGB", (left.width + right.width, left.height))
canvas.paste(left, (0, 0))
canvas.paste(right, (left.width, 0))
return canvas
def _save_debug_video(self) -> None:
import re
import numpy as np
from lerobot.utils.io_utils import write_video
pairs = getattr(self, "_debug_pairs", None)
if not pairs:
return
out_dir = Path("outputs/fastwam_debug")
out_dir.mkdir(parents=True, exist_ok=True)
slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task"
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
write_video(path, frames, fps=30)
logging.info(
"FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path
)
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
"""Build the FastWAM core for training / inference.
Only the trainable parts (the MoT DiT and the proprio encoder) are
materialized empty here and then filled from the policy's
`model.safetensors` by the base `from_pretrained`. The *frozen* Wan2.2 VAE
and UMT5 text encoder are loaded with their real weights from the
`Wan-AI/Wan2.2-TI2V-5B-Diffusers` repo (cached in the HF cache, shared
across checkpoints) and are intentionally excluded from `model.safetensors`
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
"""
dtype = _dtype_from_name(config.torch_dtype)
device = config.device
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
action_expert = ActionDiT(**config.action_dit_config).to(device=device, dtype=dtype)
mot = MoT(
mixtures={"video": video_expert, "action": action_expert},
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
)
text_encoder = (
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
if config.load_text_encoder
else None
)
return FastWAM(
video_expert=video_expert,
action_expert=action_expert,
mot=mot,
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
text_encoder=text_encoder,
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
text_dim=int(config.video_dit_config["text_dim"]),
proprio_dim=config.proprio_dim,
device=device,
torch_dtype=dtype,
video_train_shift=float(config.video_scheduler["train_shift"]),
video_infer_shift=float(config.video_scheduler["infer_shift"]),
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
action_train_shift=float(config.action_scheduler["train_shift"]),
action_infer_shift=float(config.action_scheduler["infer_shift"]),
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
loss_lambda_video=float(config.loss["lambda_video"]),
loss_lambda_action=float(config.loss["lambda_action"]),
)
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
return {
"prompt": _prompt_from_batch(batch=batch, config=config),
"input_image": _input_image_from_batch(batch, config),
"action_horizon": config.action_horizon,
"proprio": batch.get("proprio", batch.get(OBS_STATE)),
"context": batch.get("context"),
"context_mask": batch.get("context_mask"),
"negative_prompt": batch.get("negative_prompt", config.negative_prompt),
"text_cfg_scale": float(batch.get("text_cfg_scale", config.text_cfg_scale)),
"num_inference_steps": int(batch.get("num_inference_steps", config.num_inference_steps)),
"sigma_shift": batch.get("sigma_shift", config.sigma_shift),
"seed": batch.get("seed", config.inference_seed),
"rand_device": batch.get("rand_device", config.rand_device),
"tiled": bool(batch.get("tiled", config.tiled)),
}
def _prompt_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Any:
prompt = batch.get("prompt")
if prompt is not None:
return prompt
task = batch.get("task")
if task is None:
return None
if isinstance(task, str):
return config.prompt_template.format(task=task)
if isinstance(task, (list, tuple)):
return [config.prompt_template.format(task=str(item)) for item in task]
return config.prompt_template.format(task=str(task))
def _action_from_model_output(output: Any) -> Tensor:
action = output["action"] if isinstance(output, dict) else output
if action.ndim == 2:
action = action.unsqueeze(0)
return action
def _infer_kwargs_batch_size(infer_kwargs: dict[str, Any]) -> int:
image = infer_kwargs["input_image"]
if not isinstance(image, Tensor):
raise TypeError(f"`input_image` must be a tensor, got {type(image).__name__}.")
if image.ndim == 3:
return 1
if image.ndim == 4:
return int(image.shape[0])
raise ValueError(f"`input_image` must be [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
def _slice_infer_kwargs(infer_kwargs: dict[str, Any], *, index: int, batch_size: int) -> dict[str, Any]:
return {
key: _slice_infer_value(value, index=index, batch_size=batch_size)
for key, value in infer_kwargs.items()
}
def _slice_infer_value(value: Any, *, index: int, batch_size: int) -> Any:
if isinstance(value, Tensor) and value.ndim > 0 and value.shape[0] == batch_size:
return value[index : index + 1]
if isinstance(value, (list, tuple)) and len(value) == batch_size:
return value[index]
return value
def _dtype_from_name(name: str) -> torch.dtype:
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
if name not in dtype_map:
raise ValueError(f"Unsupported torch dtype `{name}`.")
return dtype_map[name]
def batch_device(batch: dict[str, Any]) -> torch.device:
for value in batch.values():
if isinstance(value, Tensor):
return value.device
return torch.device("cpu")
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
if not image_keys:
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
images = [batch[key] for key in image_keys]
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
if image.ndim == 4:
# [B, C, H, W]: a single frame (e.g. the live eval observation) -> repeat across time.
image = image.unsqueeze(2).repeat(1, 1, config.model_video_frames, 1, 1)
elif image.ndim == 5:
# [B, T, C, H, W]: temporal stack from delta-timestamp loading -> [B, C, T, H, W].
image = image.permute(0, 2, 1, 3, 4)
else:
raise ValueError(f"Expected image batch [B,C,H,W] or temporal [B,T,C,H,W], got {tuple(image.shape)}.")
return image
def _input_image_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
if "input_image" in batch:
return _prepare_infer_image(batch["input_image"], config)
video = batch.get("video")
if video is None:
video = _stack_video_from_images(batch, config)
if video.ndim == 5:
return _prepare_infer_image(video[:, :, 0], config)
if video.ndim == 4:
return _prepare_infer_image(video, config)
raise ValueError(f"Cannot build input image from tensor with shape {tuple(video.shape)}.")
def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor:
if image.ndim == 3:
image = image.unsqueeze(0)
if image.ndim != 4:
raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
target_h, target_w = config.image_size
if tuple(image.shape[-2:]) != (target_h, target_w):
raise ValueError(
"FastWAM policy expects preprocessed image tensors with shape "
f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. "
"Run the FastWAM preprocessor before calling the policy."
)
return image
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,183 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import (
ActionProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
ImageCropResizeProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_fastwam import FastWAMConfig
@dataclass
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
FastWAM loads a per-camera video stack, so image observations arrive as
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
single leading batch dim (resize raises on 5-D input), so we flatten any leading
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
"""
def observation(self, observation: dict) -> dict:
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
# matches on the `"image"` substring, so set these aside and restore them untouched rather
# than letting it try to resize a mask.
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
leads: dict[str, tuple] = {}
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
for key, img in list(flat_input.items()):
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
leads[key] = tuple(img.shape[:-3])
flat_input[key] = img.reshape(-1, *img.shape[-3:])
processed = super().observation(flat_input)
out = dict(processed)
for key, lead in leads.items():
im = processed[key]
out[key] = im.reshape(*lead, *im.shape[-3:])
out.update(pad_keys)
return out
@dataclass
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
"""Apply FastWAM LIBERO toggle semantics to configured action dimensions."""
toggle_dimensions: list[int]
def action(self, action: PolicyAction) -> PolicyAction:
if not self.toggle_dimensions:
return action
processed_action = action.clone()
action_dim = int(processed_action.shape[-1])
for dim in self.toggle_dimensions:
resolved_dim = dim if dim >= 0 else action_dim + dim
if resolved_dim < 0 or resolved_dim >= action_dim:
raise ValueError(
f"FastWAM action toggle dimension {dim} is out of bounds for action dim {action_dim}."
)
value = processed_action[..., resolved_dim]
value = value * 2.0 - 1.0
processed_action[..., resolved_dim] = torch.sign(-value)
return processed_action
def get_config(self) -> dict[str, Any]:
return {"toggle_dimensions": self.toggle_dimensions}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def make_fastwam_pre_post_processors(
config: FastWAMConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Create LeRobot pre- and post-processing pipelines for FastWAM.
Args:
config (FastWAMConfig): Policy configuration controlling device and
normalization feature metadata.
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): Optional
LeRobot dataset statistics used by normalization processors.
Returns:
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: Input and
output processor pipelines discoverable by LeRobot.
"""
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
for key, feature in config.input_features.items():
if feature.type != FeatureType.VISUAL:
continue
channels = int(feature.shape[0])
normalization_stats[key] = {
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
}
# resize visual inputs to match model expected input size, if necessary
visual_shapes = [
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
]
resize_steps = []
if visual_shapes:
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
*resize_steps,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=normalization_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=normalization_stats,
),
]
if config.toggle_action_dimensions:
output_steps.append(
FastWAMActionToggleProcessorStep(toggle_dimensions=config.toggle_action_dimensions)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -0,0 +1,25 @@
# Wan2.2 Upstream Subset
This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM.
- Upstream repository: https://github.com/Wan-Video/Wan2.2
- Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc`
- License: Apache-2.0, matching the license in `LICENSE.txt` from the upstream repository
Copied files:
- `wan/modules/attention.py`
- `wan/modules/model.py`
- `wan/modules/__init__.py`
- `wan/utils/fm_solvers.py`
- `wan/utils/__init__.py`
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
UMT5 text encoder, and tokenizer are no longer vendored — they come from
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
Current FastWAM adapters that directly reuse this vendored subset:
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
@@ -0,0 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .attention import flash_attention
from .model import WanModel
__all__ = [
"WanModel",
"flash_attention",
]
@@ -0,0 +1,183 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
"flash_attention",
"attention",
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == "cuda" and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens, strict=False)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens, strict=False)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens, strict=False)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.", stacklevel=2)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic,
)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.",
stacklevel=2,
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
)
out = out.transpose(1, 2).contiguous()
return out
@@ -0,0 +1,519 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention
__all__ = ["WanModel"]
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@torch.amp.autocast("cuda", enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@torch.amp.autocast("cuda", enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module):
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size,
)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(
self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6
):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)
)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with torch.amp.autocast("cuda", dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), seq_lens, grid_sizes, freqs
)
with torch.amp.autocast("cuda", dtype=torch.float32):
x = x + y * e[2].squeeze(2)
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
with torch.amp.autocast("cuda", dtype=torch.float32):
x = x + y * e[5].squeeze(2)
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with torch.amp.autocast("cuda", dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
return x
class WanModel(ModelMixin, ConfigMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
_no_split_modules = ["WanAttentionBlock"]
@register_to_config
def __init__(
self,
model_type="t2v",
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList(
[
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
]
)
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
# initialize weights
self.init_weights()
def forward(
self,
x,
t,
context,
seq_len,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == "i2v":
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y, strict=False)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
with torch.amp.autocast("cuda", dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_len)).float()
)
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
)
# arguments
kwargs = {
"e": e0,
"seq_lens": seq_lens,
"grid_sizes": grid_sizes,
"freqs": self.freqs,
"context": context,
"context_lens": context_lens,
}
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist(), strict=False):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size, strict=False)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
@@ -0,0 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .fm_solvers import get_sampling_sigmas
__all__ = [
"get_sampling_sigmas",
]
@@ -0,0 +1,9 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = shift * sigma / (1 + (shift - 1) * sigma)
return sigma
@@ -0,0 +1,111 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from diffusers import AutoencoderKLWan
class WanVideoVAE38(torch.nn.Module):
"""FastWAM VAE contract over `diffusers.AutoencoderKLWan` (Wan2.2-TI2V-5B).
16x spatial / 4x temporal compression, 48 latent channels. diffusers'
`AutoencoderKLWan` returns *raw* latents (it does not apply `latents_mean`/
`latents_std`), so `encode`/`decode` here apply the same standardization the
Wan reference uses `(latents - mean) / std` done in fp32 for stability.
`encode` uses the deterministic posterior mode, matching the original VAE
which returned the latent mean `mu`.
"""
upsampling_factor = 16
temporal_downsample_factor = 4
z_dim = 48
def __init__(
self,
dtype: torch.dtype = torch.float32,
device: str | torch.device = "cuda",
*,
pretrained: AutoencoderKLWan,
) -> None:
super().__init__()
# The Wan2.2 VAE is a fixed pretrained model — it is never trained from scratch,
# so a real `AutoencoderKLWan` (with weights) must always be supplied (loaded from
# the diffusers repo by `load_pretrained_wan_vae`). No random/offline build path.
self.vae = pretrained.to(device=device, dtype=dtype)
# Read the standardization stats from the VAE's own config (diffusers populates
# these from vae/config.json) — single source of truth, no local copy. diffusers'
# encode/decode return *raw* latents, so we apply (latent - mean) / std ourselves.
# Non-persistent: kept out of state_dict.
self.register_buffer(
"latents_mean",
torch.tensor(self.vae.config.latents_mean).view(1, self.z_dim, 1, 1, 1),
persistent=False,
)
self.register_buffer(
"latents_std",
torch.tensor(self.vae.config.latents_std).view(1, self.z_dim, 1, 1, 1),
persistent=False,
)
def _device_dtype(self) -> tuple[torch.device, torch.dtype]:
param = next(self.vae.parameters())
return param.device, param.dtype
def encode(
self,
videos: list[torch.Tensor] | torch.Tensor,
device: str | torch.device | None = None,
tiled: bool = False,
tile_size: tuple[int, int] = (34, 34),
tile_stride: tuple[int, int] = (18, 16),
) -> torch.Tensor:
del device, tile_size, tile_stride
if tiled:
raise NotImplementedError("Tiled Wan2.2 VAE encoding is not supported by the FastWAM adapter.")
if isinstance(videos, (list, tuple)):
videos = torch.stack(list(videos))
dev, dtype = self._device_dtype()
mu = self.vae.encode(videos.to(device=dev, dtype=dtype)).latent_dist.mode().float()
mean = self.latents_mean.float().to(mu.device)
std = self.latents_std.float().to(mu.device)
return (mu - mean) / std
def decode(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
device: str | torch.device | None = None,
tiled: bool = False,
tile_size: tuple[int, int] = (34, 34),
tile_stride: tuple[int, int] = (18, 16),
) -> torch.Tensor:
del device, tile_size, tile_stride
if tiled:
raise NotImplementedError("Tiled Wan2.2 VAE decoding is not supported by the FastWAM adapter.")
if isinstance(hidden_states, (list, tuple)):
hidden_states = torch.stack(list(hidden_states))
dev, dtype = self._device_dtype()
z = hidden_states.float()
z = z * self.latents_std.float().to(z.device) + self.latents_mean.float().to(z.device)
out = self.vae.decode(z.to(device=dev, dtype=dtype)).sample
return out.float().clamp_(-1.0, 1.0)
__all__ = ["WanVideoVAE38"]
@@ -0,0 +1,172 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, UMT5EncoderModel
if TYPE_CHECKING:
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
from diffusers import AutoencoderKLWan
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
logger = logging.getLogger(__name__)
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
# repo as sharded `diffusion_pytorch_model*.safetensors`; the VAE and UMT5 text
# encoder come from the diffusers conversion. Tokenizer is the stock UMT5 one.
WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
WAN_T5_TOKENIZER = "google/umt5-xxl"
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
class WanTextEncoder(torch.nn.Module):
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
Exposes `.dim` (hidden size) and `forward(ids, mask) -> [B, L, dim]`, matching
the call in `FastWAM.encode_prompt`.
"""
def __init__(
self,
dtype: torch.dtype = torch.bfloat16,
device: str | torch.device = "cuda",
*,
pretrained: torch.nn.Module,
) -> None:
super().__init__()
# UMT5-XXL is a fixed pretrained encoder — never trained from scratch, so a real
# `UMT5EncoderModel` (with weights) must always be supplied (loaded from the
# diffusers repo by `load_pretrained_wan_text_encoder`). No random/offline build.
self.model = pretrained.to(device=device, dtype=dtype)
self.dim = int(self.model.config.d_model)
def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
return self.model(input_ids=ids, attention_mask=mask.long()).last_hidden_state
class WanTokenizer:
"""UMT5 tokenizer wrapper returning `(input_ids, attention_mask)` like the
FastWAM call site expects."""
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(name)
self.seq_len = int(seq_len)
def __call__(
self,
sequence: str | Sequence[str],
return_mask: bool = False,
add_special_tokens: bool = True,
**_: Any,
):
if isinstance(sequence, str):
sequence = [sequence]
out = self.tokenizer(
list(sequence),
padding="max_length",
truncation=True,
max_length=self.seq_len,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)
if return_mask:
return out.input_ids, out.attention_mask
return out.input_ids
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype)
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
encoder = UMT5EncoderModel.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
)
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
def resolve_wan_dit_paths(
model_id_or_path: str | Path,
*,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
) -> list[Path]:
"""Resolve the custom MoT DiT shards from the original Wan2.2 repo or a local dir."""
path = Path(model_id_or_path).expanduser()
if path.is_dir():
return sorted(path.glob(WAN_DIT_PATTERN))
snapshot_path = snapshot_download(
repo_id=str(model_id_or_path),
revision=revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
allow_patterns=[WAN_DIT_PATTERN],
)
return sorted(Path(snapshot_path).glob(WAN_DIT_PATTERN))
def load_wan_video_dit(
paths: list[str | Path],
*,
dit_config: dict[str, Any],
torch_dtype: torch.dtype,
device: str,
) -> WanVideoDiT:
model = WanVideoDiT(**dit_config)
state_dict = _read_wan_dit_safetensors(paths)
model.load_state_dict(state_dict, strict=False)
return model.to(device=device, dtype=torch_dtype)
def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor]:
state_dict = {}
for path in paths:
state_dict.update(load_file(str(path), device="cpu"))
return state_dict
__all__ = [
"WAN22_DIFFUSERS_MODEL_ID",
"WAN_DIT_PATTERN",
"WAN_T5_TOKENIZER",
"WanTextEncoder",
"WanTokenizer",
"build_wan_tokenizer",
"load_pretrained_wan_text_encoder",
"load_pretrained_wan_vae",
"load_wan_video_dit",
"resolve_wan_dit_paths",
]
@@ -0,0 +1,813 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as functional
from einops import rearrange
from .wan.modules.model import (
WanAttentionBlock,
WanLayerNorm,
WanModel,
WanRMSNorm,
rope_apply,
rope_params,
sinusoidal_embedding_1d,
)
from .wan.utils.fm_solvers import get_sampling_sigmas
logger = logging.getLogger(__name__)
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
*args,
**kwargs,
):
if use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output
def fastwam_masked_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
num_heads: int,
ctx_mask: torch.Tensor | None = None,
fp32_attention: bool = True,
) -> torch.Tensor:
"""FastWAM masked attention wrapper for MoT masks and CPU test coverage.
The official Wan attention implementation is still used as the source of
the projection/norm modules. This wrapper only replaces the final attention
kernel because FastWAM needs explicit boolean masks for video/action MoT
routing, while the upstream FlashAttention path accepts sequence lengths
but not arbitrary [query, key] masks.
"""
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
if fp32_attention:
q = q.float()
k = k.float()
v = v.float()
else:
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
x = functional.scaled_dot_product_attention(q, k, v, attn_mask=ctx_mask)
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
return get_sampling_sigmas(num_inference_steps, shift)
class WanContinuousFlowMatchScheduler:
"""Continuous-time Flow-Matching scheduler with shift-based Wan sampling."""
def __init__(self, num_train_timesteps: int = 1000, shift: float = 5.0, eps: float = 1e-10):
if num_train_timesteps <= 0:
raise ValueError(f"`num_train_timesteps` must be positive, got {num_train_timesteps}")
if shift <= 0:
raise ValueError(f"`shift` must be positive, got {shift}")
self.num_train_timesteps = int(num_train_timesteps)
self.shift = float(shift)
self.eps = float(eps)
self._y_min, self._weight_norm_const = self._precompute_training_weight_stats()
@staticmethod
def _phi(u: torch.Tensor, shift: float) -> torch.Tensor:
return shift * u / (1.0 + (shift - 1.0) * u)
def _precompute_training_weight_stats(self) -> tuple[float, float]:
steps = self.num_train_timesteps
u_grid = torch.linspace(1.0, 0.0, steps + 1, dtype=torch.float64)[:-1]
t_grid = self._phi(u_grid, self.shift) * float(steps)
y_grid = torch.exp(-2.0 * ((t_grid - (steps / 2.0)) / steps) ** 2)
y_min = float(y_grid.min().item())
y_shifted_grid = y_grid - y_min
norm_const = float(y_shifted_grid.mean().item())
return y_min, norm_const
def sample_training_t(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
if batch_size <= 0:
raise ValueError(f"`batch_size` must be positive, got {batch_size}")
u = torch.rand((batch_size,), device=device, dtype=torch.float32)
sigma = self._phi(u, self.shift)
timestep = sigma * float(self.num_train_timesteps)
return timestep.to(dtype=dtype)
def training_weight(self, timestep: torch.Tensor) -> torch.Tensor:
t = timestep.to(dtype=torch.float32)
steps = float(self.num_train_timesteps)
y = torch.exp(-2.0 * ((t - (steps / 2.0)) / steps) ** 2)
y_shifted = y - self._y_min
weight = y_shifted / (self._weight_norm_const + self.eps)
if weight.numel() == 1:
return weight.reshape(())
return weight
def add_noise(
self, original_samples: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor
) -> torch.Tensor:
sigma = (timestep / float(self.num_train_timesteps)).to(
original_samples.device, dtype=original_samples.dtype
)
if sigma.ndim == 0:
return (1 - sigma) * original_samples + sigma * noise
sigma = sigma.view(-1, *([1] * (original_samples.ndim - 1)))
return (1 - sigma) * original_samples + sigma * noise
@staticmethod
def training_target(sample: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
del timestep
return noise - sample
def build_inference_schedule(
self,
num_inference_steps: int,
device: torch.device,
dtype: torch.dtype,
shift_override: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be positive, got {num_inference_steps}")
shift = self.shift if shift_override is None else float(shift_override)
if shift <= 0:
raise ValueError(f"`shift` must be positive, got {shift}")
sigma_steps = torch.as_tensor(
_get_wan_sampling_sigmas(num_inference_steps, shift),
device=device,
dtype=torch.float32,
)
timesteps = sigma_steps * float(self.num_train_timesteps)
sigma_next = torch.cat([sigma_steps[1:], sigma_steps.new_zeros(1)])
deltas = sigma_next - sigma_steps
return timesteps.to(dtype=dtype), deltas.to(dtype=dtype)
@staticmethod
def step(model_output: torch.Tensor, delta: torch.Tensor, sample: torch.Tensor) -> torch.Tensor:
delta = delta.to(sample.device, dtype=sample.dtype)
if delta.ndim == 0:
return sample + model_output * delta
delta = delta.view(-1, *([1] * (sample.ndim - 1)))
return sample + model_output * delta
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
return rope_params(end, dim, theta)
def apply_dense_rope(x: torch.Tensor, freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float32).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
def _linear_input(linear: nn.Linear, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=linear.weight.dtype)
def _wan_layer_norm(norm: nn.Module, x: torch.Tensor) -> torch.Tensor:
if isinstance(norm, WanLayerNorm) and norm.weight is not None:
weight = norm.weight.float()
bias = norm.bias.float() if norm.bias is not None else None
return functional.layer_norm(x.float(), norm.normalized_shape, weight, bias, norm.eps).to(
dtype=x.dtype
)
return norm(x)
def create_group_causal_attn_mask(
num_temporal_groups: int, num_query_per_group: int, num_key_per_group: int, mode: str = "causal"
) -> torch.Tensor:
if mode not in ["causal", "group_diagonal"]:
raise ValueError(f"`mode` must be 'causal' or 'group_diagonal', got {mode}.")
if num_temporal_groups <= 0:
raise ValueError(f"`num_temporal_groups` must be positive, got {num_temporal_groups}.")
if num_query_per_group <= 0:
raise ValueError(f"`num_query_per_group` must be positive, got {num_query_per_group}.")
if num_key_per_group <= 0:
raise ValueError(f"`num_key_per_group` must be positive, got {num_key_per_group}.")
total_num_query_tokens = num_temporal_groups * num_query_per_group
total_num_key_tokens = num_temporal_groups * num_key_per_group
query_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_query_per_group).unsqueeze(1)
key_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_key_per_group).unsqueeze(0)
if mode == "causal":
attn_mask = query_time_indices >= key_time_indices
else:
attn_mask = query_time_indices == key_time_indices
if attn_mask.shape != (total_num_query_tokens, total_num_key_tokens):
raise RuntimeError("Attention mask shape mismatch.")
return attn_mask
class FastWAMAttentionBlock(WanAttentionBlock):
"""Wan attention block with FastWAM's arbitrary boolean mask support."""
def __init__(
self,
hidden_dim: int,
attn_head_dim: int,
num_heads: int,
ffn_dim: int,
eps: float = 1e-6,
fp32_attention: bool = True,
):
attention_dim = attn_head_dim * num_heads
if hidden_dim == attention_dim:
super().__init__(
dim=hidden_dim,
ffn_dim=ffn_dim,
num_heads=num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=eps,
)
else:
nn.Module.__init__(self)
self.dim = hidden_dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = (-1, -1)
self.qk_norm = True
self.cross_attn_norm = True
self.eps = eps
self.norm1 = WanLayerNorm(hidden_dim, eps)
self.self_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
self.norm3 = WanLayerNorm(hidden_dim, eps, elementwise_affine=True)
self.cross_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
self.norm2 = WanLayerNorm(hidden_dim, eps)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, ffn_dim),
nn.GELU(approximate="tanh"),
nn.Linear(ffn_dim, hidden_dim),
)
self.modulation = nn.Parameter(torch.randn(1, 6, hidden_dim) / hidden_dim**0.5)
self.attn_head_dim = attn_head_dim
self.fp32_attention = bool(fp32_attention)
@staticmethod
def split_modulation(block, t_mod: torch.Tensor):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
base_mod = block.modulation.to(dtype=t_mod.dtype, device=t_mod.device)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (base_mod + t_mod).chunk(
6, dim=chunk_dim
)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2),
scale_msa.squeeze(2),
gate_msa.squeeze(2),
shift_mlp.squeeze(2),
scale_mlp.squeeze(2),
gate_mlp.squeeze(2),
)
return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
def project_self_attention(
self, x: torch.Tensor, freqs: torch.Tensor | dict[str, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.self_attn.norm_q(self.self_attn.q(x))
k = self.self_attn.norm_k(self.self_attn.k(x))
v = self.self_attn.v(x)
if isinstance(freqs, dict):
b, s = x.shape[:2]
q = rope_apply(
q.view(b, s, self.num_heads, self.attn_head_dim),
freqs["grid_sizes"],
freqs["freqs"],
).flatten(2)
k = rope_apply(
k.view(b, s, self.num_heads, self.attn_head_dim),
freqs["grid_sizes"],
freqs["freqs"],
).flatten(2)
else:
q = apply_dense_rope(q, freqs, self.num_heads)
k = apply_dense_rope(k, freqs, self.num_heads)
return q, k, v
def apply_cross_attention(
self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor | None = None
) -> torch.Tensor:
if context_mask is not None and context_mask.dim() == 3:
context_mask = context_mask.unsqueeze(1)
attn = self.cross_attn
b, n, d = x.size(0), attn.num_heads, attn.head_dim
q = attn.norm_q(attn.q(x)).view(b, -1, n * d)
k = attn.norm_k(attn.k(context)).view(b, -1, n * d)
v = attn.v(context).view(b, -1, n * d)
x = fastwam_masked_attention(
q=q,
k=k,
v=v,
num_heads=n,
ctx_mask=context_mask,
fp32_attention=self.fp32_attention,
)
return attn.o(_linear_input(attn.o, x))
def project_self_attention_output(self, x: torch.Tensor) -> torch.Tensor:
return self.self_attn.o(_linear_input(self.self_attn.o, x))
def apply_norm1(self, x: torch.Tensor) -> torch.Tensor:
return _wan_layer_norm(self.norm1, x)
def apply_norm2(self, x: torch.Tensor) -> torch.Tensor:
return _wan_layer_norm(self.norm2, x)
def apply_norm3(self, x: torch.Tensor) -> torch.Tensor:
return _wan_layer_norm(self.norm3, x)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
t_mod: torch.Tensor,
freqs: torch.Tensor,
context_mask: torch.Tensor | None = None,
self_attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.split_modulation(self, t_mod)
residual_x = x
attn_input = modulate(self.apply_norm1(x), shift_msa, scale_msa)
q, k, v = self.project_self_attention(attn_input, freqs)
y = fastwam_masked_attention(
q=q,
k=k,
v=v,
num_heads=self.num_heads,
ctx_mask=self_attn_mask,
fp32_attention=self.fp32_attention,
)
x = residual_x + gate_msa * self.project_self_attention_output(y)
x = x + self.apply_cross_attention(self.apply_norm3(x), context, context_mask=context_mask)
mlp_input = modulate(self.apply_norm2(x), shift_mlp, scale_mlp)
return x + gate_mlp * self.ffn(mlp_input)
class _FastWAMProjectedAttention(nn.Module):
def __init__(self, hidden_dim: int, attention_dim: int, num_heads: int, eps: float):
super().__init__()
self.dim = hidden_dim
self.num_heads = num_heads
self.head_dim = attention_dim // num_heads
self.q = nn.Linear(hidden_dim, attention_dim)
self.k = nn.Linear(hidden_dim, attention_dim)
self.v = nn.Linear(hidden_dim, attention_dim)
self.o = nn.Linear(attention_dim, hidden_dim)
self.norm_q = WanRMSNorm(attention_dim, eps=eps)
self.norm_k = WanRMSNorm(attention_dim, eps=eps)
class WanVideoDiT(WanModel):
def __init__(
self,
hidden_dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: tuple[int, int, int],
num_heads: int,
attn_head_dim: int,
num_layers: int,
has_image_input: bool = False,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
separated_timestep: bool = False,
require_vae_embedding: bool = False,
require_clip_embedding: bool = False,
fuse_vae_embedding_in_latents: bool = True,
action_conditioned: bool = False,
action_dim: int = 7,
action_group_causal_mask_mode="causal",
video_attention_mask_mode: str = "bidirectional",
use_gradient_checkpointing: bool = False,
fp32_attention: bool = True,
):
del in_dim_control_adapter
if has_image_input:
raise ValueError("FastWAM currently expects Wan2.2 TI2V latents with fused image conditioning.")
if has_image_pos_emb:
raise ValueError("FastWAM does not support extra image positional embeddings in WanVideoDiT.")
if has_ref_conv:
raise ValueError("FastWAM does not support reference convolutions in WanVideoDiT.")
if add_control_adapter:
raise ValueError("FastWAM does not support control adapters in WanVideoDiT.")
if require_clip_embedding:
raise ValueError("FastWAM does not support CLIP embedding conditioning in WanVideoDiT.")
if require_vae_embedding or not fuse_vae_embedding_in_latents:
raise ValueError("FastWAM expects VAE conditioning to be fused in latents.")
if attn_head_dim != hidden_dim // num_heads:
raise ValueError(
"`attn_head_dim` must match the upstream Wan head dimension `hidden_dim // num_heads`; "
f"got {attn_head_dim} vs {hidden_dim // num_heads}."
)
super().__init__(
model_type="ti2v",
patch_size=patch_size,
text_len=512,
in_dim=in_dim,
dim=hidden_dim,
ffn_dim=ffn_dim,
freq_dim=freq_dim,
text_dim=text_dim,
out_dim=out_dim,
num_heads=num_heads,
num_layers=num_layers,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=eps,
)
self.blocks = torch.nn.ModuleList(
[
FastWAMAttentionBlock(
hidden_dim=hidden_dim,
attn_head_dim=attn_head_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
eps=eps,
fp32_attention=fp32_attention,
)
for _ in range(num_layers)
]
)
self.init_weights()
self.hidden_dim = hidden_dim
self.attn_head_dim = attn_head_dim
self.separated_timestep = separated_timestep
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.video_attention_mask_mode = str(video_attention_mask_mode)
self.action_conditioned = action_conditioned
self.action_dim = action_dim
self.fp32_attention = bool(fp32_attention)
if self.action_conditioned:
self.action_embedding = torch.nn.Linear(action_dim, hidden_dim)
self.action_group_causal_mask_mode = action_group_causal_mask_mode
self.use_gradient_checkpointing = use_gradient_checkpointing
if self.use_gradient_checkpointing:
logger.info(
"Using gradient checkpointing for DiT blocks. This will save memory but use more computation."
)
def patchify(self, x: torch.Tensor):
return self.patch_embedding(x)
def _validate_forward_inputs(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor | None,
action: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if x.ndim != 5:
raise ValueError(f"`latents` must be 5D [B, C, T, H, W], got shape {tuple(x.shape)}")
num_latent_frames = x.shape[2]
if context.ndim != 3:
raise ValueError(f"`context` must be 3D [B, L, D], got shape {tuple(context.shape)}")
if timestep.ndim != 1:
raise ValueError(f"`timestep` must be 1D [B] or [1], got shape {tuple(timestep.shape)}")
if self.action_conditioned:
allow_text_only_single_frame = num_latent_frames == 1 and action is None
if not allow_text_only_single_frame:
if action is None:
raise ValueError("Action input is required for action-conditioned model.")
if action.ndim != 3:
raise ValueError(
f"`action` must be 3D [B, action_horizon, action_dim], got shape {tuple(action.shape)}"
)
if action.shape[2] != self.action_dim:
raise ValueError(
f"`action` last dimension must be {self.action_dim}, got {action.shape[2]}"
)
if num_latent_frames <= 1:
raise ValueError(
f"video length must be > 1 for action-conditioned model, got {num_latent_frames}"
)
if action.shape[1] % (num_latent_frames - 1) != 0:
raise ValueError(
"action horizon must be divisible by (num_latent_frames - 1), "
f"got action_horizon={action.shape[1]}"
)
if context_mask is None:
context_mask = torch.ones(
(context.shape[0], context.shape[1]), dtype=torch.bool, device=context.device
)
else:
if context_mask.ndim != 2:
raise ValueError(f"`context_mask` must be 2D [B, L], got shape {tuple(context_mask.shape)}")
if context_mask.shape[0] != context.shape[0] or context_mask.shape[1] != context.shape[1]:
raise ValueError(
"`context_mask` shape must match `context` shape [B, L], "
f"got {tuple(context_mask.shape)} vs {tuple(context.shape)}"
)
batch_size = x.shape[0]
if batch_size != context.shape[0]:
if not self.training and batch_size == 1:
x = x.expand(context.shape[0], -1, -1, -1, -1)
batch_size = context.shape[0]
else:
raise ValueError(
f"Batch mismatch between latents and context: {batch_size} vs {context.shape[0]}."
)
if timestep.shape[0] not in (1, batch_size):
raise ValueError(
f"`timestep` length must be 1 or batch_size({batch_size}), got {timestep.shape[0]}"
)
if timestep.shape[0] == 1 and batch_size > 1:
if self.training:
raise ValueError("During training, timestep length must match batch_size.")
timestep = timestep.expand(batch_size)
return x, timestep, context_mask
def build_video_to_video_mask(
self,
video_seq_len: int,
video_tokens_per_frame: int,
device: torch.device,
) -> torch.Tensor:
if video_seq_len <= 0:
raise ValueError(f"`video_seq_len` must be positive, got {video_seq_len}")
if video_tokens_per_frame <= 0:
raise ValueError(f"`video_tokens_per_frame` must be positive, got {video_tokens_per_frame}")
if self.video_attention_mask_mode == "bidirectional":
return torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
if self.video_attention_mask_mode == "per_frame_causal":
if video_seq_len % video_tokens_per_frame != 0:
raise ValueError(
"`video_seq_len` must be divisible by `video_tokens_per_frame` in `per_frame_causal` mode, "
f"got {video_seq_len} and {video_tokens_per_frame}"
)
num_video_frames = video_seq_len // video_tokens_per_frame
frame_causal = torch.tril(
torch.ones((num_video_frames, num_video_frames), dtype=torch.bool, device=device)
)
return frame_causal.repeat_interleave(video_tokens_per_frame, dim=0).repeat_interleave(
video_tokens_per_frame, dim=1
)
if self.video_attention_mask_mode == "first_frame_causal":
video_mask = torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
first_frame_tokens = min(video_tokens_per_frame, video_seq_len)
video_mask[:first_frame_tokens, first_frame_tokens:] = False
return video_mask
raise ValueError(f"Unsupported video attention mask mode: {self.video_attention_mask_mode}")
def pre_dit(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor | None = None,
action: torch.Tensor | None = None,
fuse_vae_embedding_in_latents: bool = False,
) -> dict[str, Any]:
x, timestep, context_mask = self._validate_forward_inputs(
x=x,
timestep=timestep,
context=context,
context_mask=context_mask,
action=action,
)
model_dtype = self.patch_embedding.weight.dtype
x = x.to(dtype=model_dtype)
context = context.to(dtype=model_dtype)
if action is not None:
action = action.to(dtype=model_dtype)
batch_size = x.shape[0]
patch_h = int(self.patch_size[1])
patch_w = int(self.patch_size[2])
if x.shape[3] % patch_h != 0 or x.shape[4] % patch_w != 0:
raise ValueError(
"Latent spatial shape must be divisible by DiT patch size, "
f"got HxW=({x.shape[3]}, {x.shape[4]}), patch=({patch_h}, {patch_w})"
)
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
raise NotImplementedError(
"FastWAM currently requires separated timesteps with fused VAE latents."
)
token_timesteps = torch.ones(
(batch_size, x.shape[2], tokens_per_frame),
dtype=model_dtype,
device=timestep.device,
) * timestep.to(dtype=model_dtype).view(batch_size, 1, 1)
token_timesteps[:, 0, :] = 0
token_timesteps = token_timesteps.reshape(batch_size, -1)
# Wan keeps the time embedding in fp32: the AdaLN modulation in the vendored
# Head/Block asserts e.dtype == float32 (numerical stability of the scale/shift).
# Upstream guarantees this via an fp32 autocast region, so it holds even when the
# model runs in bf16. Mirror that here, then cast the per-block modulation back to
# model_dtype so the bf16 attention blocks are not upcast to fp32.
with torch.amp.autocast("cuda", dtype=torch.float32):
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).float()
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
t_mod = t_mod.to(dtype=model_dtype)
x = self.patchify(x)
f, h, w = x.shape[2:]
context = self.text_embedding(context)
context_len = context.shape[1]
if self.action_conditioned and action is not None:
action_len = action.shape[1]
action_emb = self.action_embedding(action)
action_pos_embed = sinusoidal_embedding_1d(
self.hidden_dim, torch.arange(action_len, device=action_emb.device)
).to(dtype=action_emb.dtype)
action_emb = action_emb + action_pos_embed.unsqueeze(0)
context = torch.cat([context, action_emb], dim=1)
num_temporal_groups = f - 1
if num_temporal_groups <= 0:
raise ValueError(
"Action-conditioned context mask requires at least 2 latent frames when `action` is provided."
)
if action_emb.shape[1] % num_temporal_groups != 0:
raise ValueError(
f"Action embedding length {action_emb.shape[1]} must be divisible by "
f"number of temporal groups {num_temporal_groups}"
)
action_group_mask = create_group_causal_attn_mask(
num_temporal_groups=num_temporal_groups,
num_query_per_group=tokens_per_frame,
num_key_per_group=action_len // num_temporal_groups,
mode=self.action_group_causal_mask_mode,
).to(context.device)
seq_len = f * h * w
final_context_mask = torch.zeros(
(batch_size, seq_len, context.shape[1]), dtype=torch.bool, device=context.device
)
final_context_mask[:, :, :context_len] = context_mask.unsqueeze(1).expand(-1, seq_len, -1)
final_context_mask[:, tokens_per_frame:, context_len:] = action_group_mask.unsqueeze(0).expand(
batch_size, -1, -1
)
context_mask = final_context_mask
elif self.action_conditioned and action is None:
if f != 1:
raise ValueError(
"Action-conditioned model requires `action` unless running single-frame text-only mode "
"with num_latent_frames=1."
)
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
else:
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
x_tokens = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
grid_sizes = torch.tensor([[f, h, w]] * batch_size, dtype=torch.long, device=x_tokens.device)
freqs = {"grid_sizes": grid_sizes, "freqs": self.freqs.to(x_tokens.device)}
return {
"tokens": x_tokens,
"freqs": freqs,
"t": t,
"t_mod": t_mod,
"context": context,
"context_mask": context_mask,
"meta": {
"grid_sizes": grid_sizes,
"tokens_per_frame": tokens_per_frame,
"batch_size": batch_size,
},
}
def post_dit(self, x_tokens: torch.Tensor, pre_state: dict[str, Any]) -> torch.Tensor:
x = self.head(x_tokens, pre_state["t"])
return torch.stack(super().unpatchify(x, pre_state["meta"]["grid_sizes"]))
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor | None = None,
action: torch.Tensor | None = None,
fuse_vae_embedding_in_latents: bool = False,
):
pre_state = self.pre_dit(
x=x,
timestep=timestep,
context=context,
context_mask=context_mask,
action=action,
fuse_vae_embedding_in_latents=fuse_vae_embedding_in_latents,
)
x_tokens = pre_state["tokens"]
context_emb = pre_state["context"]
t_mod = pre_state["t_mod"]
freqs = pre_state["freqs"]
context_attn_mask = pre_state["context_mask"]
self_attn_mask = (
self.build_video_to_video_mask(
video_seq_len=x_tokens.shape[1],
video_tokens_per_frame=int(pre_state["meta"]["tokens_per_frame"]),
device=x_tokens.device,
)
if self.video_attention_mask_mode != "bidirectional"
else None
)
for block in self.blocks:
if self.use_gradient_checkpointing:
x_tokens = gradient_checkpoint_forward(
block,
self.use_gradient_checkpointing,
x_tokens,
context_emb,
t_mod,
freqs,
context_mask=context_attn_mask,
self_attn_mask=self_attn_mask,
)
else:
x_tokens = block(
x_tokens,
context_emb,
t_mod,
freqs,
context_mask=context_attn_mask,
self_attn_mask=self_attn_mask,
)
return self.post_dit(x_tokens, pre_state)
__all__ = [
"FastWAMAttentionBlock",
"WanContinuousFlowMatchScheduler",
"WanVideoDiT",
"apply_dense_rope",
"create_group_causal_attn_mask",
"fastwam_masked_attention",
"gradient_checkpoint_forward",
"modulate",
"precompute_freqs_cis",
"sinusoidal_embedding_1d",
]
-4
View File
@@ -13,9 +13,6 @@
# limitations under the License.
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig as DistributionalVFConfig,
)
from .factory import (
get_reward_model_class as get_reward_model_class,
make_reward_model as make_reward_model,
@@ -29,7 +26,6 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
__all__ = [
# Configuration classes
"DistributionalVFConfig",
"RewardClassifierConfig",
"RobometerConfig",
"SARMConfig",
@@ -1,108 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weights initialized from a pre-trained PI05 actor checkpoint.
"""
from dataclasses import dataclass, field
from lerobot.configs import FeatureType, NormalizationMode
from lerobot.configs.rewards import RewardModelConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
@RewardModelConfig.register_subclass("distributional_value_function")
@dataclass
class DistributionalVFConfig(RewardModelConfig):
"""Configuration for RECAP's distributional value function.
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
(Eq. 1 in the paper).
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
about 670M params and initialize from the PI05 actor checkpoint.
"""
# Backbone
paligemma_variant: str = "gemma_2b"
num_hidden_layers: int = 6
num_vision_layers: int = 13
# Distributional head
num_value_bins: int = 201
value_support_min: float = -1.0
value_support_max: float = 0.0
hl_gauss_sigma_ratio: float = 5.0
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
target_method: str = "hl_gauss"
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
# When False, terminal states use the same target method as non-terminal states.
use_one_hot_terminal: bool = True
# Image
image_resolution: tuple[int, int] = (224, 224)
# Tokenizer
tokenizer_max_length: int = 64
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
# Pass a PI05 checkpoint path or Hub repo_id here.
# After training, load the value function with RewardModel.from_pretrained() instead.
init_from_actor_path: str = ""
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
}
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=3e-4,
weight_decay=1e-4,
grad_clip_norm=1.0,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=500,
num_decay_steps=50000,
)
def validate_features(self) -> None:
if not self.input_features:
return
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
if not has_image:
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
@@ -1,567 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modeling for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
Inputs: single image observation + task text prompt ("Task: {task}.")
Outputs: softmax distribution over value bins; expected value E[V] for inference.
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
with optional one-hot targets for terminal states; MC returns normalized per task.
Weight initialization: vision tower, multi-modal projector, token embeddings, and
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.rewards.pretrained import PreTrainedRewardModel
from lerobot.utils.import_utils import _transformers_available, require_package
from .configuration_distributional_value_function import DistributionalVFConfig
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaRMSNorm,
_gated_residual,
_get_pi_gemma_decoder_layer_base,
)
else:
CONFIG_MAPPING = None
modeling_gemma = None
PaliGemmaForConditionalGenerationWithPiGemma = None
PiGemmaRMSNorm = None
_gated_residual = None
_get_pi_gemma_decoder_layer_base = None
PALIGEMMA_VOCAB_SIZE = 257152
class DistributionalVFRewardModel(PreTrainedRewardModel):
"""Distributional value function model for RECAP.
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
per-task normalized Monte Carlo returns.
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
causal attention, [CLS] token, and Linear(D, num_bins) value head.
The expected value is E[V] = sum(softmax(logits) * bin_centers).
"""
name = "distributional_value_function"
config_class = DistributionalVFConfig
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
require_package("transformers", extra="recap")
super().__init__(config)
self.config = config
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
# Get base dimensions from the paligemma variant (OpenPI config format)
base_config = get_gemma_config(config.paligemma_variant)
hidden_dim = base_config.width
mlp_dim = base_config.mlp_dim
num_layers = config.num_hidden_layers
# HuggingFace GemmaConfig for transformer layers
gemma_config = CONFIG_MAPPING["gemma"](
head_dim=base_config.head_dim,
hidden_size=hidden_dim,
intermediate_size=mlp_dim,
num_attention_heads=base_config.num_heads,
num_hidden_layers=num_layers,
num_key_value_heads=base_config.num_kv_heads,
vocab_size=PALIGEMMA_VOCAB_SIZE,
hidden_activation="gelu_pytorch_tanh",
)
self.gemma_config = gemma_config
self.hidden_dim = hidden_dim
self.num_value_bins = config.num_value_bins
# Single learned [CLS] token for value prediction
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
# Value projection head: Linear(hidden_dim, num_bins)
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
# Transformer layers (overwritten by _initialize_from_actor on first run)
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
)
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
# PaliGemmaConfig wraps both vision and text configs into a single model
paligemma_config = CONFIG_MAPPING["paligemma"]()
paligemma_config.text_config = gemma_config
paligemma_config.vision_config.image_size = config.image_resolution[0]
paligemma_config.vision_config.intermediate_size = 4304
paligemma_config.vision_config.projection_dim = 2048
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
self.vision_tower = paligemma_full.model.vision_tower
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
self.token_embedding = paligemma_full.model.language_model.embed_tokens
del paligemma_full
# Truncate vision tower to num_vision_layers
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
vision_encoder = self.vision_tower.vision_model.encoder
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
# Bin support: evenly spaced centers from value_support_min to value_support_max
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
self.register_buffer("bin_centers", bin_centers, persistent=False)
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
# Overwrite with pre-trained PI05 actor weights (first training run only)
if config.init_from_actor_path:
self._initialize_from_actor()
def _initialize_from_actor(self) -> None:
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
Called on first training run only (when init_from_actor_path is set).
"""
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
actor_model = actor_policy.model
paligemma_model = actor_model.paligemma_with_expert.paligemma
source_language_model = paligemma_model.model.language_model
# Transformer components
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
num_layers = self.gemma_config.num_hidden_layers
for i in range(num_layers):
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
self.norm.load_state_dict(source_language_model.norm.state_dict())
# Vision tower (truncate source first, then copy)
source_vision_tower = paligemma_model.model.vision_tower
if hasattr(source_vision_tower, "vision_model") and hasattr(
source_vision_tower.vision_model, "encoder"
):
source_encoder = source_vision_tower.vision_model.encoder
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
# Multi-modal projector
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
# Token embedding table
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
del actor_policy
def embed_image(self, image: Tensor) -> Tensor:
"""Embed images using the value function's SigLIP vision tower.
Args:
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
Returns:
[batch_size, num_patches, hidden_dim] projected image features.
"""
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.vision_tower(image, return_dict=True)
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
image_features = image_features / (self.hidden_dim**0.5)
if image_features.dtype != out_dtype:
image_features = image_features.to(out_dtype)
return image_features
def embed_text(self, token_ids: Tensor) -> Tensor:
"""Embed text token IDs using the value function's token embedding table.
Args:
token_ids: [batch_size, seq_len] integer token IDs
Returns:
[batch_size, seq_len, hidden_dim] text embeddings
"""
return self.token_embedding(token_ids)
def _get_cls_embedding(self, batch_size: int) -> Tensor:
"""Get [CLS] token embedding expanded to batch size.
Args:
batch_size: number of samples in the batch.
Returns:
[batch_size, 1, hidden_dim] learned [CLS] embedding.
"""
return self.cls_embedding.expand(batch_size, -1, -1)
def forward_value(
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
) -> dict[str, Tensor]:
"""Core forward pass through the distributional value function.
Args:
vision_features: [batch_size, num_patches, hidden_dim]
text_embeddings: [batch_size, seq_len, hidden_dim]
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
Returns:
logits: [batch_size, num_value_bins]
probs: [batch_size, num_value_bins]
value: [batch_size, 1]
"""
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
batch_size = text_embeddings.shape[0]
device = text_embeddings.device
# Build sequence: [vision, text, CLS]
cls_embedding = self._get_cls_embedding(batch_size)
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
# Build causal attention mask
vision_len = vision_features.shape[1]
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
full_seq_len = full_padding_mask.shape[1]
# Causal mask
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
# Combine causal mask with padding mask
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
batch_size, 1, full_seq_len, full_seq_len
)
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
cos, sin = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
norm_output = layer.input_layernorm(hidden_states, cond=None)
if isinstance(norm_output, tuple):
hidden_states_normed, gate = norm_output
else:
hidden_states_normed, gate = norm_output, None
input_shape = hidden_states_normed.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
attention_output, _ = modeling_gemma.eager_attention_forward(
layer.self_attn,
query_states,
key_states,
value_states,
attention_mask,
layer.self_attn.scaling,
)
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
projected_attention = layer.self_attn.o_proj(attention_output)
if gate is not None:
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
else:
projected_attention = hidden_states + projected_attention
after_attention_residual = projected_attention.clone()
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
if isinstance(norm_output, tuple):
mlp_input, gate = norm_output
else:
mlp_input, gate = norm_output, None
mlp_output = layer.mlp(mlp_input)
if gate is not None:
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
else:
hidden_states = after_attention_residual + mlp_output
hidden_states = self.norm(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Extract [CLS] token (last position in the sequence)
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
# Value head: Linear(hidden_dim, num_bins) -> logits
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
value_probs = F.softmax(value_logits, dim=-1)
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
dim=-1, keepdim=True
)
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
"""HL-Gauss soft target distribution.
Places a Gaussian N(target, sigma^2) over the bin support and computes
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
Functions via Classification for Scalable Deep RL", Section 3.1.
arXiv:2403.03950
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
# Bin edges: half a bin-width outside the first/last center
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
self.num_value_bins - 1
)
support_edges = torch.linspace(
self.config.value_support_min - bin_width / 2,
self.config.value_support_max + bin_width / 2,
self.num_value_bins + 1,
device=target_value.device,
dtype=target_value.dtype,
)
# CDF of N(target, sigma^2) evaluated at each edge
cdf_at_edges = 0.5 * (
1.0
+ torch.erf(
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
/ (self.hl_gauss_sigma * math.sqrt(2))
)
) # [batch_size, num_bins + 1]
# Normalize: z = cdf(max_edge) - cdf(min_edge)
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
# Bin probabilities = differences of consecutive CDF values, normalized
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
return bin_probabilities
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
"""Dirac delta (C51) projection: split probability between two nearest bins.
Standard distributional RL projection from Bellemare et al. 2017.
"A Distributional Perspective on Reinforcement Learning"
arXiv:1707.06887
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] target probability distribution.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
target_value = target_value.to(dtype=self.bin_centers.dtype)
bin_width = self.bin_centers[1] - self.bin_centers[0]
normalized_position = (target_value - self.config.value_support_min) / bin_width
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
weight_upper = normalized_position - lower_bin_idx.float()
weight_lower = upper_bin_idx.float() - normalized_position
same_bin = lower_bin_idx == upper_bin_idx
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
batch_size = target_value.shape[0]
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
batch_indices = torch.arange(batch_size, device=target_value.device)
target_distribution[batch_indices, lower_bin_idx] += weight_lower
target_distribution[batch_indices, upper_bin_idx] += weight_upper
return target_distribution
def one_hot_target(self, target_value: Tensor) -> Tensor:
"""One-hot target for terminal states (exact return, no smoothing).
Args:
target_value: [batch_size] or [batch_size, 1] target values.
Returns:
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
"""
if target_value.ndim == 2:
target_value = target_value.squeeze(-1)
target_value = target_value.to(dtype=self.bin_centers.dtype)
nearest_bin_idx = torch.argmin(
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
)
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
def compute_target_distribution(
self,
target_value: Tensor,
is_terminal: Tensor,
method: str = "hl_gauss",
use_one_hot_terminal: bool = True,
) -> Tensor:
"""Compute target distribution using configured method.
Args:
target_value: [batch_size] scalar return targets
is_terminal: [batch_size] boolean terminal flags
method: "hl_gauss" or "dirac_delta"
use_one_hot_terminal: if True, terminal states get one-hot targets
(exact return, no smoothing). If False, all states use the same method.
Returns:
[batch_size, num_value_bins] target probability distribution
"""
if method == "hl_gauss":
base_distribution = self.hl_gauss_target(target_value)
elif method == "dirac_delta":
base_distribution = self.dirac_delta_target(target_value)
else:
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
if not use_one_hot_terminal:
return base_distribution
terminal_distribution = self.one_hot_target(target_value)
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Training forward pass — computes cross-entropy loss against MC return targets.
The batch is expected to be preprocessed by the processor pipeline.
Keys expected in batch:
- observation.images.*: [B, C, H, W] preprocessed images
- observation.language_tokens: [B, seq_len] tokenized task prompt
- observation.language_attention_mask: [B, seq_len] padding mask
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
- is_terminal: [B] boolean terminal flags
Returns:
(loss, output_dict) where loss is scalar cross-entropy
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
# Get first image key from batch
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
mc_return = batch["mc_return"]
is_terminal = batch["is_terminal"]
# Embed observations
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
# Forward through value function transformer
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
value_logits = vf_output["logits"]
predicted_value = vf_output["value"]
# Compute target distribution
target_distribution = self.compute_target_distribution(
mc_return,
is_terminal,
method=self.config.target_method,
use_one_hot_terminal=self.config.use_one_hot_terminal,
)
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
log_probs = F.log_softmax(value_logits, dim=-1)
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
output_dict = {
"loss": loss.item(),
"predicted_value_mean": predicted_value.mean().item(),
"mc_return_mean": mc_return.mean().item(),
}
return loss, output_dict
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
"""Compute V(s) for a batch of observations. Used for advantage scoring.
Args:
batch: preprocessed batch with images and tokenized text
Returns:
[batch_size] tensor of predicted values V(s)
"""
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
if not image_keys:
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
images = batch[image_keys[0]]
token_ids = batch[OBS_LANGUAGE_TOKENS]
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
vision_features = self.embed_image(images)
text_embeddings = self.embed_text(token_ids)
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
return vf_output["value"].squeeze(-1) # [batch_size]
@@ -1,235 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor for RECAP's distributional value function.
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
https://pi.website/blog/pistar06
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
Training targets (mc_return, is_terminal) are NOT routed through the processor.
They are dataset columns read directly from the batch in the model's forward().
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
)
from lerobot.processor.converters import to_tensor
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_distributional_value_function import DistributionalVFConfig
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
@dataclass
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
"""Format the task string for the distributional value function.
The value function receives only visual observations and task text.
Builds prompt: "Task: {task}."
"""
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
tasks = complementary_data.get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
if isinstance(tasks, str):
tasks = [tasks]
full_prompts = []
for task in tasks:
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
full_prompts.append(f"Task: {cleaned_text}.")
new_complementary_data = dict(complementary_data)
new_complementary_data[self.task_key] = full_prompts
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {"task_key": self.task_key}
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
@dataclass
class DistributionalVFImagePreprocessorStep(ProcessorStep):
"""Resize and normalize images for the value function's SigLIP vision tower.
Expects float images in [0, 1].
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
- Scale to [-1, 1] for SigLIP
"""
image_resolution: tuple[int, int] = (224, 224)
image_keys: tuple[str, ...] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
observation = transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict):
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
image_keys = self.image_keys or tuple(
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
)
if not image_keys:
raise KeyError(
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
)
new_observation = dict(observation)
for image_key in image_keys:
image = new_observation[image_key]
if not isinstance(image, Tensor):
image = to_tensor(image)
if image.dtype != torch.float32:
image = image.to(torch.float32)
is_channels_first = image.ndim == 4 and image.shape[1] == 3
if is_channels_first:
image = image.permute(0, 2, 3, 1)
if image.shape[1:3] != self.image_resolution:
image = resize_with_pad_torch(image, *self.image_resolution)
image = image * 2.0 - 1.0
if is_channels_first:
image = image.permute(0, 3, 1, 2)
new_observation[image_key] = image
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = new_observation
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def get_config(self) -> dict[str, Any]:
return {
"image_resolution": self.image_resolution,
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
}
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
return tuple(
feature_name
for feature_name, feature in config.input_features.items()
if feature.type == FeatureType.VISUAL
)
def make_distributional_vf_pre_post_processors(
config: DistributionalVFConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create pre/post processors for the distributional value function.
Preprocessor steps:
1. Rename observations (no-op by default)
2. Add a batch dimension
3. Normalize features (images use identity, so they stay in [0, 1])
4. Format task prompt: "Task: {task}."
5. Tokenize with the PaliGemma tokenizer
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
7. Move tensors to the configured device
Training targets (mc_return, is_terminal) are not processed here.
The model reads them directly from the batch in forward().
The postprocessor is a no-op because the value function does not need
action postprocessing.
"""
image_keys = _visual_image_keys(config)
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=[
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DistributionalVFPrepareTaskPromptStep(),
TokenizerProcessorStep(
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DistributionalVFImagePreprocessorStep(
image_resolution=config.image_resolution,
image_keys=image_keys or None,
),
DeviceProcessorStep(device=config.device or "cpu"),
],
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
to_transition=batch_to_transition,
to_output=transition_to_batch,
)
postprocessor = PolicyProcessorPipeline(
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
)
return preprocessor, postprocessor
-19
View File
@@ -24,7 +24,6 @@ from lerobot.configs.rewards import RewardModelConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from .classifier.configuration_classifier import RewardClassifierConfig
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
from .pretrained import PreTrainedRewardModel
from .robometer.configuration_robometer import RobometerConfig
from .sarm.configuration_sarm import SARMConfig
@@ -64,12 +63,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
return TOPRewardModel
elif name == "distributional_value_function":
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel
else:
try:
return _get_reward_model_cls_from_name(name=name)
@@ -103,8 +96,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
return RobometerConfig(**kwargs)
elif reward_type == "topreward":
return TOPRewardConfig(**kwargs)
elif reward_type == "distributional_value_function":
return DistributionalVFConfig(**kwargs)
else:
try:
config_cls = RewardModelConfig.get_choice_class(reward_type)
@@ -200,16 +191,6 @@ def make_reward_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(reward_cfg, DistributionalVFConfig):
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
return make_distributional_vf_pre_post_processors(
config=reward_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_reward_model_config(
@@ -0,0 +1,386 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pytest
import torch
from safetensors import safe_open
from torch import nn
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy
from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep
from lerobot.utils.constants import ACTION, OBS_STATE
class FakeFastWAMCore(nn.Module):
def __init__(self):
super().__init__()
self.dit = nn.Linear(2, 2)
def training_loss(self, sample):
assert sample["video"].ndim == 5
assert sample["context"].ndim == 3
return sample[ACTION].sum() * 0.0 + torch.tensor(1.0), {"loss_action": 1.0}
def infer_action(self, **kwargs):
return {"action": torch.ones(1, kwargs["action_horizon"], 3)}
def test_fastwam_is_registered_and_publicly_exported():
cfg = make_policy_config(
"fastwam",
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
assert isinstance(cfg, FastWAMConfig)
assert cfg.type == "fastwam"
assert get_policy_class("fastwam") is FastWAMPolicy
def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path):
cfg = FastWAMConfig()
cfg.save_pretrained(tmp_path)
saved = json.loads((tmp_path / "config.json").read_text())
assert saved["pretrained_path"] is None
assert cfg.image_features["observation.images.image"].type == FeatureType.VISUAL
assert cfg.action_feature.shape == (7,)
assert cfg.robot_state_feature.shape == (8,)
with pytest.raises(ValueError, match="image feature"):
FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))})
with pytest.raises(ValueError, match="tokenizer_model_id"):
FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer")
def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_path):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(2, 2),
device="cpu",
toggle_action_dimensions=[-1],
input_features={
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 2, 2)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
base_model_id=None,
)
dataset_stats = {
"observation.images.image": {
"mean": torch.full((3, 1, 1), 0.2),
"std": torch.full((3, 1, 1), 0.1),
},
OBS_STATE: {
"mean": torch.tensor([1.0, 3.0]),
"std": torch.tensor([2.0, 4.0]),
},
ACTION: {
"mean": torch.zeros(3),
"std": torch.ones(3),
},
}
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats)
processed = preprocessor(
{
"observation.images.image": torch.tensor(
[
[[0.0, 0.5], [1.0, 0.5]],
[[0.0, 0.5], [1.0, 0.5]],
[[0.0, 0.5], [1.0, 0.5]],
]
),
OBS_STATE: torch.tensor([3.0, 7.0]),
}
)
preprocessor.save_pretrained(tmp_path, config_filename="policy_preprocessor.json")
postprocessor.save_pretrained(tmp_path, config_filename="policy_postprocessor.json")
_, loaded_postprocessor = make_pre_post_processors(cfg, pretrained_path=str(tmp_path))
expected_image = torch.tensor(
[[[[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]]]]
)
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
assert torch.allclose(processed["observation.images.image"], expected_image)
assert torch.allclose(processed[OBS_STATE], torch.tensor([[1.0, 1.0]]))
assert torch.equal(dataset_stats["observation.images.image"]["mean"], torch.full((3, 1, 1), 0.2))
assert any(isinstance(step, FastWAMActionToggleProcessorStep) for step in loaded_postprocessor.steps)
assert torch.equal(
loaded_postprocessor(torch.tensor([[0.25, 0.5, 1.0]])), torch.tensor([[0.25, 0.5, -1.0]])
)
def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
captured = []
class CapturingCore(FakeFastWAMCore):
def infer_action(self, **kwargs):
captured.append(
{
"image_shape": tuple(kwargs["input_image"].shape),
"proprio_shape": tuple(kwargs["proprio"].shape),
"prompt": kwargs["prompt"],
}
)
return {"action": torch.full((1, kwargs["action_horizon"], 3), float(len(captured)))}
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CapturingCore())
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(16, 16),
input_features={
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
base_model_id=None,
)
policy = FastWAMPolicy(cfg)
loss, metrics = policy.forward(
{
"observation.images.image": torch.zeros(1, 3, 16, 16),
OBS_STATE: torch.zeros(1, 2),
ACTION: torch.zeros(1, 4, 3),
"context": torch.zeros(1, 5, 4096),
"context_mask": torch.ones(1, 5, dtype=torch.bool),
}
)
action = policy.predict_action_chunk(
{
"observation.images.image": torch.stack(
[
torch.zeros(3, 16, 16),
torch.ones(3, 16, 16),
]
),
OBS_STATE: torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"task": ["task 0", "task 1"],
}
)
assert loss.item() == 1.0
assert metrics["loss_action"] == 1.0
assert action.shape == (2, 4, 3)
assert action[:, 0, 0].tolist() == [1.0, 2.0]
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
assert [item["proprio_shape"] for item in captured] == [(1, 2), (1, 2)]
assert [item["prompt"] for item in captured] == [
cfg.prompt_template.format(task="task 0"),
cfg.prompt_template.format(task="task 1"),
]
class CoreWithFrozenComponents(FakeFastWAMCore):
"""Fake core mirroring the real one: frozen VAE / text encoder held as
*unregistered* attributes (via `object.__setattr__`) so they are excluded from
`state_dict()` and the saved checkpoint, but still moved by the `_apply` override."""
def __init__(self):
super().__init__()
object.__setattr__(self, "vae", nn.Linear(2, 2))
object.__setattr__(self, "text_encoder", nn.Linear(2, 2))
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)
def _apply(self, fn, *args, **kwargs):
super()._apply(fn, *args, **kwargs)
self.vae._apply(fn)
self.text_encoder._apply(fn)
return self
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
def build_core(self, config):
core = CoreWithFrozenComponents()
with torch.no_grad():
core.dit.weight.fill_(0.5)
return core
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", build_core)
reference = FastWAMPolicy(cfg)
with torch.no_grad():
reference.model.dit.weight.fill_(1.25) # a distinctive, trained-looking weight
reference.save_pretrained(tmp_path)
# Building from Wan2.2 must never happen on a checkpoint load.
def fail_if_wan_pretrained_is_loaded(*args, **kwargs):
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
monkeypatch.setattr(
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
fail_if_wan_pretrained_is_loaded,
)
policy = FastWAMPolicy.from_pretrained(tmp_path)
assert isinstance(policy.model, CoreWithFrozenComponents)
# The bundled checkpoint weights overwrote the freshly built (0.5) DiT weights.
assert torch.allclose(policy.model.dit.weight, torch.full_like(policy.model.dit.weight, 1.25))
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
save_dir = tmp_path / "saved"
policy.save_pretrained(save_dir)
assert (save_dir / "model.safetensors").is_file()
# No Wan sidecar files either: the frozen backbone comes from the diffusers repo.
assert not (save_dir / "Wan2.2_VAE.safetensors").exists()
assert not (save_dir / "google").exists()
with safe_open(save_dir / "model.safetensors", framework="pt") as f:
keys = set(f.keys())
# Lean checkpoint: only the trainable DiT is saved; the frozen VAE / UMT5 text
# encoder are excluded (loaded from the diffusers/transformers repos at init).
assert any(key.startswith("model.dit.") for key in keys)
assert not any(key.startswith("model.vae.") for key in keys)
assert not any(key.startswith("model.text_encoder.") for key in keys)
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
# Unregistered: excluded from state_dict and from the optimizer's parameter set.
sd = policy.state_dict()
assert not any(k.startswith("model.vae.") or k.startswith("model.text_encoder.") for k in sd)
param_names = [n for n, _ in policy.named_parameters()]
assert not any("vae" in n or "text_encoder" in n for n in param_names)
# ...but the `_apply` override still carries them through `.to()` (dtype stands in
# for device on a CPU box), so they never strand off the rest of the model.
policy.to(torch.float64)
assert policy.model.dit.weight.dtype == torch.float64 # registered
assert policy.model.vae.weight.dtype == torch.float64 # unregistered, moved via _apply
assert policy.model.text_encoder.weight.dtype == torch.float64
def test_pretrained_config_round_trips_fastwam_features(tmp_path):
cfg = FastWAMConfig(action_dim=7, proprio_dim=8, image_size=(224, 448), base_model_id=None)
cfg.save_pretrained(tmp_path)
loaded = PreTrainedConfig.from_pretrained(tmp_path)
assert loaded.type == "fastwam"
assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL
assert loaded.action_feature.shape == (7,)
assert loaded.robot_state_feature.shape == (8,)
def test_vae_adapter_empty_build_encode_decode_shapes():
"""Offline glue check of the diffusers-backed VAE adapter (random weights).
Validates the encode/decode contract 48 latent channels, 16x spatial / 4x
temporal compression, list-or-batch input, scaling round-trip without any
weight download. (Numerical fidelity vs the original Wan VAE is a separate,
GPU + real-weights verification step.)
"""
pytest.importorskip("diffusers")
from diffusers import AutoencoderKLWan
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
# Production always loads a real pretrained VAE from the diffusers repo; here we
# build the same architecture with random weights and dummy standardization stats
# to exercise the adapter's shape/scaling contract offline (fidelity is checked
# separately, with real weights, on GPU).
arch = {
"base_dim": 160,
"decoder_base_dim": 256,
"z_dim": 48,
"dim_mult": [1, 2, 4, 4],
"num_res_blocks": 2,
"attn_scales": [],
"temporal_downsample": [False, True, True],
"dropout": 0.0,
"is_residual": True,
"in_channels": 12,
"out_channels": 12,
"patch_size": 2,
"scale_factor_spatial": 16,
"scale_factor_temporal": 4,
"clip_output": False,
"latents_mean": [0.0] * 48,
"latents_std": [1.0] * 48,
}
raw = AutoencoderKLWan.from_config(arch)
vae = WanVideoVAE38(dtype=torch.float32, device="cpu", pretrained=raw)
assert vae.z_dim == 48
assert vae.upsampling_factor == 16
assert vae.temporal_downsample_factor == 4
video = torch.rand(1, 3, 5, 32, 32) * 2 - 1 # [B,C,T,H,W] in [-1,1]
latents = vae.encode(video)
assert latents.shape == (1, 48, 2, 2, 2) # T'=(5-1)//4+1, H'=W'=32//16
decoded = vae.decode(latents)
assert decoded.shape[0] == 1 and decoded.shape[1] == 3 and decoded.shape[-2:] == (32, 32)
assert decoded.min() >= -1.0 and decoded.max() <= 1.0
# list input is accepted and equals the batched path
assert torch.equal(vae.encode([video[0]]), latents)
@@ -1,518 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RECAP's distributional value function."""
from __future__ import annotations
import pytest
import torch
from lerobot.configs.rewards import RewardModelConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
DistributionalVFConfig,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_IMAGES
from tests.utils import skip_if_package_missing
BATCH_SIZE = 4
NUM_BINS = 201
IMAGE_KEY = f"{OBS_IMAGES}.top"
def _make_config(**overrides) -> DistributionalVFConfig:
defaults = {
"init_from_actor_path": "",
"device": "cpu",
"image_resolution": (224, 224),
}
defaults.update(overrides)
config = DistributionalVFConfig(**defaults)
config.input_features = {
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
}
return config
def _make_model():
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
return DistributionalVFRewardModel(_make_config())
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
return {
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
"mc_return": torch.rand(batch_size, device=device) * -1.0,
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
}
def test_config_registered_in_reward_model_registry():
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
known = RewardModelConfig.get_known_choices()
assert "distributional_value_function" in known
def test_factory_returns_correct_class():
"""get_reward_model_class returns DistributionalVFRewardModel."""
from lerobot.rewards.factory import get_reward_model_class
cls = get_reward_model_class("distributional_value_function")
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
assert cls is DistributionalVFRewardModel
def test_make_reward_model_config_factory():
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
from lerobot.rewards.factory import make_reward_model_config
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
assert isinstance(config, DistributionalVFConfig)
assert config.num_value_bins == 101
@skip_if_package_missing("transformers")
def test_hl_gauss_sums_to_one():
"""HL-Gauss target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
dist = model.hl_gauss_target(targets)
assert dist.shape == (4, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_non_negative():
"""HL-Gauss target probabilities are all non-negative."""
model = _make_model()
targets = torch.linspace(-1.0, 0.0, 10)
dist = model.hl_gauss_target(targets)
assert (dist >= 0).all()
@skip_if_package_missing("transformers")
def test_hl_gauss_expected_value_matches():
"""E[V] under HL-Gauss distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.hl_gauss_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
@skip_if_package_missing("transformers")
def test_hl_gauss_handles_2d_input():
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
dist = model.hl_gauss_target(targets)
assert dist.shape == (2, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_sums_to_one():
"""Dirac delta target distribution sums to 1 for each sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
dist = model.dirac_delta_target(targets)
assert dist.shape == (5, NUM_BINS)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_at_most_two_nonzero():
"""Dirac delta places probability on at most two adjacent bins."""
model = _make_model()
targets = torch.tensor([-0.7523, -0.0013])
dist = model.dirac_delta_target(targets)
for i in range(2):
assert (dist[i] > 0).sum() <= 2
@skip_if_package_missing("transformers")
def test_dirac_delta_expected_value_matches():
"""E[V] under Dirac delta distribution matches the target value."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -0.9])
dist = model.dirac_delta_target(targets)
expected = (dist * model.bin_centers).sum(dim=-1)
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
@skip_if_package_missing("transformers")
def test_dirac_delta_boundary_values_clamped():
"""Values outside support are clamped to boundary bins."""
model = _make_model()
targets = torch.tensor([-1.5, 0.5])
dist = model.dirac_delta_target(targets)
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
assert dist[0, 0] == 1.0
assert dist[1, -1] == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_single_nonzero():
"""One-hot target has exactly one non-zero bin per sample."""
model = _make_model()
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
dist = model.one_hot_target(targets)
assert dist.shape == (4, NUM_BINS)
for i in range(4):
assert (dist[i] > 0).sum() == 1
assert dist[i].sum() == 1.0
@skip_if_package_missing("transformers")
def test_one_hot_nearest_bin():
"""One-hot target activates the bin closest to the target value."""
model = _make_model()
targets = torch.tensor([-0.5])
dist = model.one_hot_target(targets)
hot_idx = dist[0].argmax()
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
@skip_if_package_missing("transformers")
def test_terminal_gets_one_hot():
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
is_terminal = torch.tensor([False, True, False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
)
for i in range(4):
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
assert (dist[1] > 0).sum() == 1
assert (dist[3] > 0).sum() == 1
assert (dist[0] > 0).sum() > 2
assert (dist[2] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_no_terminal_override_when_disabled():
"""When use_one_hot_terminal=False, terminal states use the base method."""
model = _make_model()
targets = torch.tensor([-0.5, -0.3])
is_terminal = torch.tensor([False, True])
dist = model.compute_target_distribution(
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
)
assert (dist[1] > 0).sum() > 2
@skip_if_package_missing("transformers")
def test_model_has_expected_components():
"""Model scaffold contains all architectural components."""
model = _make_model()
assert hasattr(model, "vision_tower")
assert hasattr(model, "multi_modal_projector")
assert hasattr(model, "token_embedding")
assert hasattr(model, "layers")
assert hasattr(model, "value_head")
assert hasattr(model, "cls_embedding")
assert hasattr(model, "norm")
assert hasattr(model, "rotary_emb")
assert hasattr(model, "bin_centers")
@skip_if_package_missing("transformers")
def test_model_bin_centers_shape():
"""Bin centers buffer has shape (num_value_bins,)."""
model = _make_model()
assert model.bin_centers.shape == (NUM_BINS,)
@skip_if_package_missing("transformers")
def test_model_layer_count():
"""Transformer has num_hidden_layers (6) layers."""
model = _make_model()
assert len(model.layers) == 6
@skip_if_package_missing("transformers")
def test_model_value_head_output_dim():
"""Value head outputs num_value_bins logits."""
model = _make_model()
assert model.value_head.out_features == NUM_BINS
@skip_if_package_missing("transformers")
def test_forward_returns_loss_and_dict():
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
model = _make_model()
batch = _make_batch()
loss, output_dict = model.forward(batch)
assert loss.shape == ()
assert torch.isfinite(loss)
assert "loss" in output_dict
assert "predicted_value_mean" in output_dict
assert "mc_return_mean" in output_dict
@skip_if_package_missing("transformers")
def test_forward_loss_is_positive():
"""Cross-entropy loss is strictly positive for random weights."""
model = _make_model()
batch = _make_batch()
loss, _ = model.forward(batch)
assert loss.item() > 0
@skip_if_package_missing("transformers")
def test_compute_reward_returns_correct_shape():
"""compute_reward returns [batch_size] tensor of finite float32 values."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=3)
with torch.no_grad():
values = model.compute_reward(batch)
assert values.shape == (3,)
assert values.dtype == torch.float32
assert torch.isfinite(values).all()
@skip_if_package_missing("transformers")
def test_compute_reward_values_in_support_range():
"""Predicted values lie within [value_support_min, value_support_max]."""
model = _make_model()
model.eval()
batch = _make_batch(batch_size=8)
with torch.no_grad():
values = model.compute_reward(batch)
assert (values >= -1.0 - 0.01).all()
assert (values <= 0.0 + 0.01).all()
@skip_if_package_missing("transformers")
def test_processor_pipeline_produces_expected_keys():
"""Full preprocessor pipeline produces tokenized text and processed images."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
make_distributional_vf_pre_post_processors,
)
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
config = _make_config()
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
raw_batch = {
IMAGE_KEY: torch.rand(3, 224, 224),
"task": "pick up the cup",
}
processed = preprocessor(raw_batch)
assert OBS_LANGUAGE_TOKENS in processed
assert OBS_LANGUAGE_ATTENTION_MASK in processed
assert IMAGE_KEY in processed
@skip_if_package_missing("transformers")
def test_gradient_flows_through_value_head():
"""Backprop produces non-zero gradients on the value head."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.value_head.weight.grad is not None
assert not torch.all(model.value_head.weight.grad == 0)
@skip_if_package_missing("transformers")
def test_gradient_flows_through_cls_embedding():
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
model = _make_model()
model.train()
batch = _make_batch()
loss, _ = model.forward(batch)
loss.backward()
assert model.cls_embedding.grad is not None
assert not torch.all(model.cls_embedding.grad == 0)
def test_config_requires_visual_feature():
"""validate_features raises if no VISUAL feature is present."""
config = DistributionalVFConfig(init_from_actor_path="")
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
}
with pytest.raises(ValueError, match="VISUAL"):
config.validate_features()
def test_config_passes_with_visual_feature():
"""validate_features succeeds when a VISUAL feature is present."""
config = _make_config()
config.validate_features()
@skip_if_package_missing("transformers")
def test_save_load_pretrained_roundtrip(tmp_path):
"""Saved model can be loaded back with identical weights."""
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
DistributionalVFRewardModel,
)
model = _make_model()
model._save_pretrained(tmp_path)
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
orig_sd = model.state_dict()
loaded_sd = loaded.state_dict()
assert set(orig_sd.keys()) == set(loaded_sd.keys())
for key in orig_sd:
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
@skip_if_package_missing("transformers")
def test_image_preprocessor_normalizes_to_minus_one_one():
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 224, 224, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.min() >= -1.0 - 1e-5
assert image.max() <= 1.0 + 1e-5
@skip_if_package_missing("transformers")
def test_image_preprocessor_resizes_with_pad():
"""Image preprocessor resizes non-square images to target resolution."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFImagePreprocessorStep,
)
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
transition = {
TransitionKey.OBSERVATION: {
IMAGE_KEY: torch.rand(1, 480, 640, 3),
},
}
result = step(transition)
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
assert image.shape[1:3] == (224, 224)
def test_task_prompt_formats_correctly():
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: pick up the cup."
def test_task_prompt_handles_string_input():
"""Task prompt step accepts a plain string (not just a list)."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
}
result = step(transition)
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
assert prompt == "Task: open drawer."
def test_task_prompt_raises_on_missing_task():
"""Task prompt step raises ValueError when task key is absent."""
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
DistributionalVFPrepareTaskPromptStep,
)
step = DistributionalVFPrepareTaskPromptStep()
transition = {
TransitionKey.COMPLEMENTARY_DATA: {},
}
with pytest.raises(ValueError, match="No task found"):
step(transition)
Generated
+344 -342
View File
File diff suppressed because it is too large Load Diff