mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-13 14:39:44 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c063c3fbc | |||
| 9cf12c941d | |||
| 4039da81c6 | |||
| b3a28a49f6 | |||
| 49755a3d9e | |||
| 09808183ca |
@@ -67,6 +67,8 @@
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: fastwam
|
||||
title: FastWAM
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
# FastWAM
|
||||
|
||||
FastWAM is a World Action Model policy for robot control. The LeRobot integration exposes FastWAM through the standard policy API so it can be configured with `policy.type=fastwam`, trained with `lerobot-train`, and loaded through the LeRobot pretrained policy interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
FastWAM keeps video modeling during training, but uses direct action prediction at inference time instead of iteratively generating future observations. This LeRobot policy wraps the FastWAM action model, adapts LeRobot batches to FastWAM training samples, and provides the standard processor pipeline for normalization and action postprocessing.
|
||||
|
||||
The implementation initializes the visual world-model components from `Wan-AI/Wan2.2-TI2V-5B` by default and predicts action chunks with shape `[batch, action_horizon, action_dim]`.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=fastwam` configuration through LeRobot
|
||||
- Image, state, action, and language-task batch adaptation
|
||||
- Action chunk inference through `select_action` and `predict_action_chunk`
|
||||
- Checkpoint save/load through the LeRobot policy APIs
|
||||
- Configurable LIBERO gripper action postprocessing
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Install LeRobot from source, then install FastWAM dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[fastwam]"
|
||||
```
|
||||
|
||||
This installs the FastWAM policy extra from `pyproject.toml`: `transformers`,
|
||||
`diffusers`, `ftfy`, and `regex`, plus LeRobot's base dependencies.
|
||||
|
||||
For LIBERO evaluation, install the benchmark dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[fastwam,libero]"
|
||||
```
|
||||
|
||||
This installs both extras. In addition to the FastWAM dependencies above, the
|
||||
`libero` extra installs LeRobot dataset dependencies, `hf-libero` on Linux, and
|
||||
`scipy`.
|
||||
|
||||
FastWAM uses the Wan2.2 TI2V backbone. The default model id is:
|
||||
|
||||
```python
|
||||
policy.model_id=Wan-AI/Wan2.2-TI2V-5B
|
||||
```
|
||||
|
||||
## Data Requirements
|
||||
|
||||
FastWAM expects a LeRobot dataset with:
|
||||
|
||||
- one or more visual observations whose widths concatenate to `policy.image_size[1]`
|
||||
- `observation.state` when `policy.proprio_dim` is not `None`
|
||||
- `action`
|
||||
- a language task instruction through the dataset task field, or precomputed `context` and `context_mask` tensors
|
||||
|
||||
The default visual setup is one image feature named `observation.images.image` with shape `(3, 224, 448)`. If the dataset uses two cameras, configure `policy.input_features` so their heights match `224` and their widths sum to `448`.
|
||||
|
||||
## Usage
|
||||
|
||||
Create a new FastWAM policy with:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-org/your-dataset \
|
||||
--policy.type=fastwam \
|
||||
--policy.action_dim=7 \
|
||||
--policy.proprio_dim=8 \
|
||||
--policy.action_horizon=32 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.image_size='[224,448]' \
|
||||
--output_dir=./outputs/fastwam_training \
|
||||
--job_name=fastwam_training \
|
||||
--steps=300000 \
|
||||
--batch_size=8 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
Evaluate an existing LeRobot-format checkpoint on LIBERO-10 with:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
|
||||
--policy.device=cuda \
|
||||
--policy.torch_dtype=float32 \
|
||||
--policy.n_action_steps=10 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.observation_height=224 \
|
||||
--env.observation_width=224 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=0 \
|
||||
--env.episode_length=600
|
||||
```
|
||||
|
||||
For `libero_goal`, `libero_spatial`, and `libero_object`, use
|
||||
`--env.episode_length=300`.
|
||||
|
||||
For real-robot rollout, use the same checkpoint path:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--policy.path=your-org/fastwam-real-robot
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Features
|
||||
|
||||
`policy.image_size` is the size of the concatenated FastWAM image tensor as `(height, width)`. Each configured image feature must have shape `(3, height, camera_width)`, and all camera widths must sum to the configured width.
|
||||
|
||||
### Action Chunking
|
||||
|
||||
`policy.action_horizon` controls the number of future actions supervised during training and predicted during inference. `policy.n_action_steps` controls how many actions are consumed before the policy predicts a fresh chunk. `policy.n_action_steps` must be less than or equal to `policy.action_horizon`.
|
||||
|
||||
### Wan Components
|
||||
|
||||
FastWAM loads the Wan VAE, video DiT, text encoder, and tokenizer from the configured Wan model directory or Hugging Face Hub model id. LeRobot-format FastWAM checkpoints saved by `save_pretrained` also copy the local Wan component files needed by `from_pretrained`.
|
||||
|
||||
### LIBERO Action Toggle
|
||||
|
||||
FastWAM LIBERO checkpoints use `policy.toggle_action_dimensions=[-1]` by
|
||||
default to match the gripper action convention used by the original FastWAM
|
||||
evaluation pipeline:
|
||||
|
||||
```bash
|
||||
--policy.toggle_action_dimensions='[-1]'
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| --- | ---: | ---: |
|
||||
| libero_spatial | 97.6% | 500 |
|
||||
| libero_object | 99.0% | 500 |
|
||||
| libero_goal | 95.0% | 500 |
|
||||
| libero_10 | 94.0% | 500 |
|
||||
| **average** | **96.4%** | 2000 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300` (1x H20 140 GB).
|
||||
|
||||
## References
|
||||
|
||||
- [Fast-WAM paper](https://arxiv.org/abs/2603.16666)
|
||||
- [Fast-WAM project page](https://yuantianyuan01.github.io/FastWAM/)
|
||||
- [Fast-WAM code](https://github.com/yuantianyuan01/FastWAM)
|
||||
- [Released upstream checkpoints](https://huggingface.co/yuanty/fastwam)
|
||||
- [Wan2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{yuan2026fastwam,
|
||||
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
|
||||
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
|
||||
journal = {arXiv preprint arXiv:2603.16666},
|
||||
year = {2026},
|
||||
url = {https://arxiv.org/abs/2603.16666}
|
||||
}
|
||||
```
|
||||
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://arxiv.org/abs/2603.16666
|
||||
|
||||
## Repository
|
||||
|
||||
Code: https://github.com/yuantianyuan01/FastWAM
|
||||
|
||||
Project page: https://yuantianyuan01.github.io/FastWAM/
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{yuan2026fastwam,
|
||||
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
|
||||
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
|
||||
journal = {arXiv preprint arXiv:2603.16666},
|
||||
year = {2026},
|
||||
url = {https://arxiv.org/abs/2603.16666}
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Base video model: https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B
|
||||
|
||||
Released upstream checkpoints: https://huggingface.co/yuanty/fastwam
|
||||
|
||||
## Results
|
||||
|
||||
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
|
||||
|
||||
| Suite | Success rate | n_episodes |
|
||||
| --- | ---: | ---: |
|
||||
| libero_spatial | 97.6% | 500 |
|
||||
| libero_object | 99.0% | 500 |
|
||||
| libero_goal | 95.0% | 500 |
|
||||
| libero_10 | 94.0% | 500 |
|
||||
| **average** | **96.4%** | 2000 |
|
||||
|
||||
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300`.
|
||||
|
||||
For LIBERO-10, use `--env.task=libero_10 --env.episode_length=600`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
|
||||
--policy.device=cuda \
|
||||
--policy.torch_dtype=float32 \
|
||||
--policy.n_action_steps=10 \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 --env.observation_height=256 --env.observation_width=256 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50 \
|
||||
--seed=0 --env.episode_length=600
|
||||
```
|
||||
@@ -216,6 +216,10 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
fastwam = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[diffusers-dep]",
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
@@ -280,6 +284,7 @@ all = [
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
"lerobot[fastwam]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -18,6 +18,7 @@ from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .fastwam.configuration_fastwam import FastWAMConfig as FastWAMConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
|
||||
@@ -42,6 +43,7 @@ __all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"EO1Config",
|
||||
"FastWAMConfig",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MolmoAct2Config",
|
||||
|
||||
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .fastwam.configuration_fastwam import FastWAMConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
|
||||
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
elif name == "fastwam":
|
||||
from .fastwam.modeling_fastwam import FastWAMPolicy
|
||||
|
||||
return FastWAMPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
elif policy_type == "fastwam":
|
||||
return FastWAMConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -448,6 +455,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, FastWAMConfig):
|
||||
from .fastwam.processor_fastwam import make_fastwam_pre_post_processors
|
||||
|
||||
processors = make_fastwam_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_fastwam_README.md
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
from .modeling_fastwam import FastWAMPolicy
|
||||
from .processor_fastwam import make_fastwam_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"FastWAMConfig",
|
||||
"FastWAMPolicy",
|
||||
"make_fastwam_pre_post_processors",
|
||||
]
|
||||
@@ -0,0 +1,345 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
PolicyFeature,
|
||||
PreTrainedConfig,
|
||||
)
|
||||
from lerobot.optim import AdamWConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base"
|
||||
|
||||
|
||||
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
|
||||
"patch_size",
|
||||
"in_dim",
|
||||
"hidden_dim",
|
||||
"ffn_dim",
|
||||
"freq_dim",
|
||||
"text_dim",
|
||||
"out_dim",
|
||||
"num_heads",
|
||||
"attn_head_dim",
|
||||
"num_layers",
|
||||
)
|
||||
|
||||
_FASTWAM_ACTION_BASE_COMPAT_KEYS = (
|
||||
"hidden_dim",
|
||||
"ffn_dim",
|
||||
"num_heads",
|
||||
"attn_head_dim",
|
||||
"num_layers",
|
||||
"text_dim",
|
||||
"freq_dim",
|
||||
)
|
||||
|
||||
|
||||
def default_video_dit_config(action_dim: int) -> dict[str, Any]:
|
||||
return {
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"hidden_dim": 3072,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 48,
|
||||
"num_heads": 24,
|
||||
"attn_head_dim": 128,
|
||||
"num_layers": 30,
|
||||
"eps": 1.0e-6,
|
||||
"seperated_timestep": True,
|
||||
"use_gradient_checkpointing": False,
|
||||
"video_attention_mask_mode": "first_frame_causal",
|
||||
"action_conditioned": False,
|
||||
"action_dim": action_dim,
|
||||
"action_group_causal_mask_mode": "group_diagonal",
|
||||
"fp32_attention": True,
|
||||
}
|
||||
|
||||
|
||||
def default_action_dit_config(action_dim: int) -> dict[str, Any]:
|
||||
return {
|
||||
"action_dim": action_dim,
|
||||
"hidden_dim": 1024,
|
||||
"ffn_dim": 4096,
|
||||
"num_heads": 24,
|
||||
"attn_head_dim": 128,
|
||||
"num_layers": 30,
|
||||
"text_dim": 4096,
|
||||
"freq_dim": 256,
|
||||
"eps": 1.0e-6,
|
||||
"use_gradient_checkpointing": False,
|
||||
"fp32_attention": True,
|
||||
}
|
||||
|
||||
|
||||
def _coerce_enum(enum_cls: type, value: Any) -> Any:
|
||||
if isinstance(value, enum_cls):
|
||||
return value
|
||||
try:
|
||||
return enum_cls(value)
|
||||
except (TypeError, ValueError):
|
||||
return getattr(enum_cls, str(value), value)
|
||||
|
||||
|
||||
def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, PolicyFeature] | None:
|
||||
if features is None:
|
||||
return None
|
||||
coerced = {}
|
||||
for name, feature in features.items():
|
||||
if isinstance(feature, PolicyFeature):
|
||||
coerced[name] = feature
|
||||
continue
|
||||
coerced[name] = PolicyFeature(
|
||||
type=_coerce_enum(FeatureType, feature["type"]),
|
||||
shape=tuple(feature["shape"]),
|
||||
)
|
||||
return coerced
|
||||
|
||||
|
||||
def _is_local_model_id(value: str) -> bool:
|
||||
path = Path(value).expanduser()
|
||||
return path.is_absolute() or value.startswith(("./", "../", "~")) or path.exists()
|
||||
|
||||
|
||||
def _validate_wan_model_id(value: str, field_name: str) -> str:
|
||||
if value == WAN22_MODEL_ID or _is_local_model_id(value):
|
||||
return value
|
||||
raise ValueError(f"`{field_name}` must be `{WAN22_MODEL_ID}` or an explicit local path, got `{value}`.")
|
||||
|
||||
|
||||
def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool:
|
||||
"""Return whether `fastwam-base` partial weights can initialize this config."""
|
||||
|
||||
default_video_config = default_video_dit_config(config.action_dim)
|
||||
default_action_config = default_action_dit_config(config.action_dim)
|
||||
return all(
|
||||
config.video_dit_config.get(key) == default_video_config.get(key)
|
||||
for key in _FASTWAM_VIDEO_BASE_COMPAT_KEYS
|
||||
) and all(
|
||||
config.action_dit_config.get(key) == default_action_config.get(key)
|
||||
for key in _FASTWAM_ACTION_BASE_COMPAT_KEYS
|
||||
)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("fastwam")
|
||||
@dataclass
|
||||
class FastWAMConfig(PreTrainedConfig):
|
||||
"""Configuration for the FastWAM LeRobot policy.
|
||||
|
||||
Args:
|
||||
action_dim (int): Number of scalar action channels per timestep.
|
||||
proprio_dim (int | None): Number of proprioception channels used as an
|
||||
extra text-context token. `None` disables proprio conditioning.
|
||||
action_horizon (int): Number of actions predicted by one policy call.
|
||||
num_video_frames (int): Number of video frames used by FastWAM rollout.
|
||||
image_size (tuple[int, int]): Concatenated image size as `(height, width)`.
|
||||
context_len (int): Maximum text embedding token length.
|
||||
video_dit_config (dict[str, Any] | None): Wan video expert config.
|
||||
action_dit_config (dict[str, Any] | None): Action expert config.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
action_dim: int = 7
|
||||
proprio_dim: int | None = 8
|
||||
action_horizon: int = 32
|
||||
n_action_steps: int = 32
|
||||
num_video_frames: int = 33
|
||||
image_size: tuple[int, int] = (224, 448)
|
||||
context_len: int = 128
|
||||
model_id: str = WAN22_MODEL_ID
|
||||
tokenizer_model_id: str = WAN22_MODEL_ID
|
||||
base_model_id: str | None = FASTWAM_BASE_MODEL_ID
|
||||
tokenizer_max_len: int = 128
|
||||
load_text_encoder: bool = True
|
||||
mot_checkpoint_mixed_attn: bool = False
|
||||
torch_dtype: str = "bfloat16"
|
||||
prompt_template: str = (
|
||||
"A video recorded from a robot's point of view executing the following instruction: {task}"
|
||||
)
|
||||
num_inference_steps: int = 10
|
||||
inference_seed: int | None = 42
|
||||
rand_device: str = "cpu"
|
||||
text_cfg_scale: float = 1.0
|
||||
negative_prompt: str = ""
|
||||
sigma_shift: float | None = None
|
||||
tiled: bool = False
|
||||
fp32_attention: bool = True
|
||||
toggle_action_dimensions: list[int] = field(default_factory=list)
|
||||
video_scheduler: dict[str, float | int] = field(
|
||||
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
|
||||
)
|
||||
action_scheduler: dict[str, float | int] = field(
|
||||
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
|
||||
)
|
||||
loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0})
|
||||
video_dit_config: dict[str, Any] | None = None
|
||||
action_dit_config: dict[str, Any] | None = None
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
input_features: dict[str, PolicyFeature] | None = None
|
||||
output_features: dict[str, PolicyFeature] | None = None
|
||||
optimizer_lr: float = 1.0e-4
|
||||
optimizer_weight_decay: float = 1.0e-2
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
self.image_size = tuple(self.image_size)
|
||||
self.model_id = _validate_wan_model_id(self.model_id, "model_id")
|
||||
self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id")
|
||||
self.input_features = _coerce_policy_features(self.input_features)
|
||||
self.output_features = _coerce_policy_features(self.output_features)
|
||||
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
|
||||
self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim)
|
||||
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
|
||||
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
|
||||
self.action_dit_config["fp32_attention"] = bool(self.fp32_attention)
|
||||
if self.input_features is None:
|
||||
height, width = self.image_size
|
||||
self.input_features = {
|
||||
"observation.images.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, height, width),
|
||||
)
|
||||
}
|
||||
if self.proprio_dim is not None:
|
||||
self.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.proprio_dim,),
|
||||
)
|
||||
if self.output_features is None:
|
||||
self.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))}
|
||||
self.validate_features()
|
||||
if self.pretrained_path or self.use_peft or not self.base_model_id:
|
||||
return
|
||||
if not is_fastwam_base_compatible_config(self):
|
||||
return
|
||||
self.pretrained_path = Path(self.base_model_id)
|
||||
self._auto_pretrained_path = True
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
if not getattr(self, "_auto_pretrained_path", False):
|
||||
super()._save_pretrained(save_directory)
|
||||
return
|
||||
|
||||
pretrained_path = self.pretrained_path
|
||||
self.pretrained_path = None
|
||||
try:
|
||||
super()._save_pretrained(save_directory)
|
||||
finally:
|
||||
self.pretrained_path = pretrained_path
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def set_dataset_feature_metadata(self, dataset_features: dict[str, Any]) -> None:
|
||||
"""Rebuild visual input features from the dataset's real camera keys.
|
||||
|
||||
FastWAM's `__post_init__` installs a synthetic single-image default
|
||||
(`observation.images.image` at full `image_size` width). For datasets
|
||||
with one or more separately-named cameras (e.g. `observation.images.top`,
|
||||
`observation.images.wrist`), this hook — invoked by `make_policy` once the
|
||||
dataset metadata is known — replaces that default with the actual camera
|
||||
keys, each declared at the policy's native per-camera resolution
|
||||
(`image_size[0]` x `image_size[1] // num_cameras`). The accompanying
|
||||
resize step in `make_fastwam_pre_post_processors` resizes raw frames to
|
||||
match, so heterogeneous source resolutions (e.g. 480x640) are supported.
|
||||
"""
|
||||
image_keys = sorted(
|
||||
key
|
||||
for key, feature in dataset_features.items()
|
||||
if key.startswith("observation.images.")
|
||||
and feature.get("dtype") in ("video", "image")
|
||||
)
|
||||
if not image_keys:
|
||||
return
|
||||
height, total_width = self.image_size
|
||||
per_cam_width = total_width // len(image_keys)
|
||||
new_inputs: dict[str, PolicyFeature] = {
|
||||
key: PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, per_cam_width))
|
||||
for key in image_keys
|
||||
}
|
||||
if self.proprio_dim is not None and OBS_STATE in dataset_features:
|
||||
new_inputs[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,))
|
||||
self.input_features = new_inputs
|
||||
self.validate_features()
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.action_dim <= 0:
|
||||
raise ValueError(f"`action_dim` must be positive, got {self.action_dim}.")
|
||||
if self.action_horizon <= 0:
|
||||
raise ValueError(f"`action_horizon` must be positive, got {self.action_horizon}.")
|
||||
if self.n_action_steps > self.action_horizon:
|
||||
raise ValueError("`n_action_steps` cannot exceed `action_horizon`.")
|
||||
if self.num_video_frames % 4 != 1:
|
||||
raise ValueError(f"`num_video_frames` must satisfy T % 4 == 1, got {self.num_video_frames}.")
|
||||
if not self.image_features:
|
||||
raise ValueError("FastWAM requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("FastWAM requires `action` in output_features.")
|
||||
action_shape = tuple(self.action_feature.shape)
|
||||
if action_shape != (self.action_dim,):
|
||||
raise ValueError(
|
||||
f"FastWAM action feature shape must be ({self.action_dim},), got {action_shape}."
|
||||
)
|
||||
if self.proprio_dim is not None:
|
||||
state_feature = self.robot_state_feature
|
||||
if state_feature is None:
|
||||
raise ValueError("FastWAM requires `observation.state` when `proprio_dim` is set.")
|
||||
state_shape = tuple(state_feature.shape)
|
||||
if state_shape != (self.proprio_dim,):
|
||||
raise ValueError(
|
||||
f"FastWAM state feature shape must be ({self.proprio_dim},), got {state_shape}."
|
||||
)
|
||||
height, width = self.image_size
|
||||
image_width_sum = 0
|
||||
for name, feature in self.image_features.items():
|
||||
shape = tuple(feature.shape)
|
||||
if len(shape) != 3 or shape[0] != 3:
|
||||
raise ValueError(f"FastWAM image feature `{name}` must have shape (3, H, W), got {shape}.")
|
||||
if shape[1] != height:
|
||||
raise ValueError(f"FastWAM image feature `{name}` height must be {height}, got {shape[1]}.")
|
||||
image_width_sum += shape[2]
|
||||
if image_width_sum != width:
|
||||
raise ValueError(f"FastWAM image feature widths must sum to {width}, got {image_width_sum}.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.action_horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,383 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
|
||||
|
||||
class FastWAMPolicy(PreTrainedPolicy):
|
||||
"""LeRobot policy wrapper for FastWAM.
|
||||
|
||||
Args:
|
||||
config (FastWAMConfig): FastWAM policy configuration.
|
||||
dataset_stats (dict[str, dict[str, Tensor]] | None): Optional LeRobot
|
||||
dataset statistics passed by the training/evaluation stack.
|
||||
"""
|
||||
|
||||
config_class = FastWAMConfig
|
||||
name = "fastwam"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FastWAMConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.model = self._build_core_model(config)
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
|
||||
"""Shape-aware load that supports cross-embodiment fine-tuning.
|
||||
|
||||
`safetensors.load_model(strict=False)` ignores missing/unexpected keys but
|
||||
still raises on a shape mismatch for a shared key. When fine-tuning from a
|
||||
checkpoint trained on a different embodiment (e.g. the LIBERO 7-DoF / 8-dim
|
||||
checkpoint adapted to a 6-DoF / 6-dim arm), the action encoder/head and
|
||||
proprio encoder legitimately differ in shape. With `strict=False` we drop
|
||||
only those shape-mismatched tensors — leaving them at their freshly
|
||||
initialized values — and load every compatible tensor. With `strict=True`
|
||||
the standard exact-match loader is used.
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
mismatched = []
|
||||
with safe_open(model_file, framework="pt") as f:
|
||||
checkpoint_keys = list(f.keys())
|
||||
for key in checkpoint_keys:
|
||||
if key in model_state_dict and tuple(model_state_dict[key].shape) != tuple(
|
||||
f.get_slice(key).get_shape()
|
||||
):
|
||||
mismatched.append(key)
|
||||
|
||||
if not mismatched:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
if strict:
|
||||
raise RuntimeError(
|
||||
f"FastWAM: {len(mismatched)} checkpoint tensors have a shape mismatch under "
|
||||
f"strict=True: {mismatched}"
|
||||
)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
logging.warning(
|
||||
"FastWAM cross-embodiment load: reinitializing %d shape-mismatched tensor(s), keeping "
|
||||
"every compatible weight: %s",
|
||||
len(mismatched),
|
||||
mismatched,
|
||||
)
|
||||
state_dict = load_file(model_file, device="cpu")
|
||||
for key in mismatched:
|
||||
state_dict.pop(key, None)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if map_location and map_location != "cpu":
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict[str, Any]:
|
||||
params = (
|
||||
list(self.model.dit.parameters()) if hasattr(self.model, "dit") else list(self.model.parameters())
|
||||
)
|
||||
proprio_encoder = getattr(self.model, "proprio_encoder", None)
|
||||
if proprio_encoder is not None:
|
||||
params.extend(list(proprio_encoder.parameters()))
|
||||
return {"params": [p for p in params if p.requires_grad]}
|
||||
|
||||
def reset(self) -> None:
|
||||
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
|
||||
`FastWAM.build_inputs` consumes (`video`, `action`, `context`/`context_mask`,
|
||||
per-frame `proprio`).
|
||||
|
||||
The LeRobot training loop passes raw `observation.images.*`, a single-step
|
||||
`observation.state` `[B, D]`, `action`, and a language `task` string. We do
|
||||
only the translation `build_inputs` can't: stack the camera frames into a
|
||||
video, encode the prompt with the (frozen) text encoder (mirroring inference,
|
||||
so language-conditioned datasets need no precomputed context), and give proprio
|
||||
the per-frame axis `build_inputs` indexes. All shape/presence validation is
|
||||
left to `build_inputs`, the single authority on the contract.
|
||||
"""
|
||||
sample = dict(batch)
|
||||
if "video" not in sample:
|
||||
sample["video"] = _stack_video_from_images(batch, self.config)
|
||||
if "context" not in sample or "context_mask" not in sample:
|
||||
prompt = _prompt_from_batch(batch=batch, config=self.config)
|
||||
if prompt is None:
|
||||
raise KeyError(
|
||||
"FastWAM training requires a `task`/`prompt` to encode text context, "
|
||||
"or precomputed `context`/`context_mask` in the batch."
|
||||
)
|
||||
sample["context"], sample["context_mask"] = self.model.encode_prompt(prompt)
|
||||
if self.config.proprio_dim is not None and "proprio" not in sample:
|
||||
state = sample.get(OBS_STATE)
|
||||
if state is not None:
|
||||
# LeRobot gives a single-step state [B, D]; build_inputs expects
|
||||
# per-frame [B, T, D] and uses frame 0, so add a T=1 axis.
|
||||
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
|
||||
return sample
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Compute FastWAM training loss for a LeRobot batch.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Tensor]): Batch containing FastWAM-ready keys
|
||||
(`video`, `action`, `context`, `context_mask`) or LeRobot keys
|
||||
that can be adapted (`observation.images.*`, `observation.state`,
|
||||
`action`, `action_is_pad`).
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: Output dictionary containing the scalar `loss`
|
||||
key required by LeRobot and optional tensor metrics.
|
||||
"""
|
||||
|
||||
sample = self._batch_to_training_sample(batch)
|
||||
loss, metrics = self.model.training_loss(sample)
|
||||
output = {"loss": loss}
|
||||
for key, value in (metrics or {}).items():
|
||||
if isinstance(value, Tensor):
|
||||
output[key] = value.to(device=loss.device)
|
||||
else:
|
||||
output[key] = torch.as_tensor(value, device=loss.device)
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **_: Any) -> Tensor:
|
||||
"""Predict a chunk of actions from the current FastWAM observation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Tensor]): Inference batch with `input_image` or
|
||||
image observation keys, plus `context/context_mask` or `prompt`.
|
||||
|
||||
Returns:
|
||||
Tensor: Action chunk with shape `[B, action_horizon, action_dim]`.
|
||||
"""
|
||||
|
||||
self.eval()
|
||||
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
|
||||
batch_size = _infer_kwargs_batch_size(infer_kwargs)
|
||||
if batch_size == 1:
|
||||
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
|
||||
else:
|
||||
action = torch.cat(
|
||||
[
|
||||
_action_from_model_output(
|
||||
self.model.infer_action(
|
||||
**_slice_infer_kwargs(infer_kwargs, index=i, batch_size=batch_size)
|
||||
)
|
||||
)
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return action.to(device=batch_device(batch), dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
|
||||
self.eval()
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def _build_core_model(self, config: FastWAMConfig) -> torch.nn.Module:
|
||||
"""Build the FastWAM core for training / inference.
|
||||
|
||||
Only the trainable parts (the MoT DiT and the proprio encoder) are
|
||||
materialized empty here and then filled from the policy's
|
||||
`model.safetensors` by the base `from_pretrained`. The *frozen* Wan2.2 VAE
|
||||
and UMT5 text encoder are loaded with their real weights from the
|
||||
`Wan-AI/Wan2.2-TI2V-5B-Diffusers` repo (cached in the HF cache, shared
|
||||
across checkpoints) and are intentionally excluded from `model.safetensors`
|
||||
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
|
||||
"""
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
from .wan_components import (
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
)
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
dtype = _dtype_from_name(config.torch_dtype)
|
||||
device = config.device
|
||||
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
|
||||
action_expert = ActionDiT(**config.action_dit_config).to(device=device, dtype=dtype)
|
||||
mot = MoT(
|
||||
mixtures={"video": video_expert, "action": action_expert},
|
||||
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
|
||||
)
|
||||
text_encoder = (
|
||||
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
|
||||
if config.load_text_encoder
|
||||
else None
|
||||
)
|
||||
return FastWAM(
|
||||
video_expert=video_expert,
|
||||
action_expert=action_expert,
|
||||
mot=mot,
|
||||
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
|
||||
text_dim=int(config.video_dit_config["text_dim"]),
|
||||
proprio_dim=config.proprio_dim,
|
||||
device=device,
|
||||
torch_dtype=dtype,
|
||||
video_train_shift=float(config.video_scheduler["train_shift"]),
|
||||
video_infer_shift=float(config.video_scheduler["infer_shift"]),
|
||||
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
|
||||
action_train_shift=float(config.action_scheduler["train_shift"]),
|
||||
action_infer_shift=float(config.action_scheduler["infer_shift"]),
|
||||
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
|
||||
loss_lambda_video=float(config.loss["lambda_video"]),
|
||||
loss_lambda_action=float(config.loss["lambda_action"]),
|
||||
)
|
||||
|
||||
|
||||
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
|
||||
return {
|
||||
"prompt": _prompt_from_batch(batch=batch, config=config),
|
||||
"input_image": _input_image_from_batch(batch, config),
|
||||
"action_horizon": config.action_horizon,
|
||||
"proprio": batch.get("proprio", batch.get(OBS_STATE)),
|
||||
"context": batch.get("context"),
|
||||
"context_mask": batch.get("context_mask"),
|
||||
"negative_prompt": batch.get("negative_prompt", config.negative_prompt),
|
||||
"text_cfg_scale": float(batch.get("text_cfg_scale", config.text_cfg_scale)),
|
||||
"num_inference_steps": int(batch.get("num_inference_steps", config.num_inference_steps)),
|
||||
"sigma_shift": batch.get("sigma_shift", config.sigma_shift),
|
||||
"seed": batch.get("seed", config.inference_seed),
|
||||
"rand_device": batch.get("rand_device", config.rand_device),
|
||||
"tiled": bool(batch.get("tiled", config.tiled)),
|
||||
}
|
||||
|
||||
|
||||
def _prompt_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Any:
|
||||
prompt = batch.get("prompt")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
|
||||
task = batch.get("task")
|
||||
if task is None:
|
||||
return None
|
||||
if isinstance(task, str):
|
||||
return config.prompt_template.format(task=task)
|
||||
if isinstance(task, (list, tuple)):
|
||||
return [config.prompt_template.format(task=str(item)) for item in task]
|
||||
return config.prompt_template.format(task=str(task))
|
||||
|
||||
|
||||
def _action_from_model_output(output: Any) -> Tensor:
|
||||
action = output["action"] if isinstance(output, dict) else output
|
||||
if action.ndim == 2:
|
||||
action = action.unsqueeze(0)
|
||||
return action
|
||||
|
||||
|
||||
def _infer_kwargs_batch_size(infer_kwargs: dict[str, Any]) -> int:
|
||||
image = infer_kwargs["input_image"]
|
||||
if not isinstance(image, Tensor):
|
||||
raise TypeError(f"`input_image` must be a tensor, got {type(image).__name__}.")
|
||||
if image.ndim == 3:
|
||||
return 1
|
||||
if image.ndim == 4:
|
||||
return int(image.shape[0])
|
||||
raise ValueError(f"`input_image` must be [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
|
||||
|
||||
|
||||
def _slice_infer_kwargs(infer_kwargs: dict[str, Any], *, index: int, batch_size: int) -> dict[str, Any]:
|
||||
return {
|
||||
key: _slice_infer_value(value, index=index, batch_size=batch_size)
|
||||
for key, value in infer_kwargs.items()
|
||||
}
|
||||
|
||||
|
||||
def _slice_infer_value(value: Any, *, index: int, batch_size: int) -> Any:
|
||||
if isinstance(value, Tensor) and value.ndim > 0 and value.shape[0] == batch_size:
|
||||
return value[index : index + 1]
|
||||
if isinstance(value, (list, tuple)) and len(value) == batch_size:
|
||||
return value[index]
|
||||
return value
|
||||
|
||||
|
||||
def _dtype_from_name(name: str) -> torch.dtype:
|
||||
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
||||
if name not in dtype_map:
|
||||
raise ValueError(f"Unsupported torch dtype `{name}`.")
|
||||
return dtype_map[name]
|
||||
|
||||
|
||||
def batch_device(batch: dict[str, Any]) -> torch.device:
|
||||
for value in batch.values():
|
||||
if isinstance(value, Tensor):
|
||||
return value.device
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
|
||||
image_keys = sorted(k for k in batch if k.startswith("observation.images."))
|
||||
if not image_keys:
|
||||
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
|
||||
images = [batch[key] for key in image_keys]
|
||||
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(2).repeat(1, 1, config.num_video_frames, 1, 1)
|
||||
if image.ndim != 5:
|
||||
raise ValueError(f"Expected image batch [B,C,H,W] or video [B,C,T,H,W], got {tuple(image.shape)}.")
|
||||
return image
|
||||
|
||||
|
||||
def _input_image_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
|
||||
if "input_image" in batch:
|
||||
return _prepare_infer_image(batch["input_image"], config)
|
||||
video = batch.get("video")
|
||||
if video is None:
|
||||
video = _stack_video_from_images(batch, config)
|
||||
if video.ndim == 5:
|
||||
return _prepare_infer_image(video[:, :, 0], config)
|
||||
if video.ndim == 4:
|
||||
return _prepare_infer_image(video, config)
|
||||
raise ValueError(f"Cannot build input image from tensor with shape {tuple(video.shape)}.")
|
||||
|
||||
|
||||
def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor:
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
|
||||
|
||||
target_h, target_w = config.image_size
|
||||
if tuple(image.shape[-2:]) != (target_h, target_w):
|
||||
raise ValueError(
|
||||
"FastWAM policy expects preprocessed image tensors with shape "
|
||||
f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. "
|
||||
"Run the FastWAM preprocessor before calling the policy."
|
||||
)
|
||||
return image
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,151 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
|
||||
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
|
||||
"""Apply FastWAM LIBERO toggle semantics to configured action dimensions."""
|
||||
|
||||
toggle_dimensions: list[int]
|
||||
|
||||
def action(self, action: PolicyAction) -> PolicyAction:
|
||||
if not self.toggle_dimensions:
|
||||
return action
|
||||
processed_action = action.clone()
|
||||
action_dim = int(processed_action.shape[-1])
|
||||
for dim in self.toggle_dimensions:
|
||||
resolved_dim = dim if dim >= 0 else action_dim + dim
|
||||
if resolved_dim < 0 or resolved_dim >= action_dim:
|
||||
raise ValueError(
|
||||
f"FastWAM action toggle dimension {dim} is out of bounds for action dim {action_dim}."
|
||||
)
|
||||
value = processed_action[..., resolved_dim]
|
||||
value = value * 2.0 - 1.0
|
||||
processed_action[..., resolved_dim] = torch.sign(-value)
|
||||
return processed_action
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"toggle_dimensions": self.toggle_dimensions}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def make_fastwam_pre_post_processors(
|
||||
config: FastWAMConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Create LeRobot pre- and post-processing pipelines for FastWAM.
|
||||
|
||||
Args:
|
||||
config (FastWAMConfig): Policy configuration controlling device and
|
||||
normalization feature metadata.
|
||||
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): Optional
|
||||
LeRobot dataset statistics used by normalization processors.
|
||||
|
||||
Returns:
|
||||
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: Input and
|
||||
output processor pipelines discoverable by LeRobot.
|
||||
"""
|
||||
|
||||
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
|
||||
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
|
||||
for key, feature in config.input_features.items():
|
||||
if feature.type != FeatureType.VISUAL:
|
||||
continue
|
||||
channels = int(feature.shape[0])
|
||||
normalization_stats[key] = {
|
||||
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# resize visual inputs to match model expected input size, if necessary
|
||||
visual_shapes = [
|
||||
feature.shape
|
||||
for feature in config.input_features.values()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
resize_steps = []
|
||||
if visual_shapes:
|
||||
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
|
||||
resize_steps.append(ImageCropResizeProcessorStep(resize_size=target_hw))
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
*resize_steps,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=normalization_stats,
|
||||
device=config.device,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=normalization_stats,
|
||||
),
|
||||
]
|
||||
if config.toggle_action_dimensions:
|
||||
output_steps.append(
|
||||
FastWAMActionToggleProcessorStep(toggle_dimensions=config.toggle_action_dimensions)
|
||||
)
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
# Wan2.2 Upstream Subset
|
||||
|
||||
This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM.
|
||||
|
||||
- Upstream repository: https://github.com/Wan-Video/Wan2.2
|
||||
- Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc`
|
||||
- License: Apache-2.0, matching the license in `LICENSE.txt` from the upstream repository
|
||||
|
||||
Copied files:
|
||||
|
||||
- `wan/modules/attention.py`
|
||||
- `wan/modules/model.py`
|
||||
- `wan/modules/__init__.py`
|
||||
- `wan/utils/fm_solvers.py`
|
||||
- `wan/utils/__init__.py`
|
||||
|
||||
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
|
||||
UMT5 text encoder, and tokenizer are no longer vendored — they come from
|
||||
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
|
||||
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
|
||||
|
||||
Current FastWAM adapters that directly reuse this vendored subset:
|
||||
|
||||
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
|
||||
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .attention import flash_attention
|
||||
from .model import WanModel
|
||||
|
||||
__all__ = [
|
||||
"WanModel",
|
||||
"flash_attention",
|
||||
]
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_3_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
FLASH_ATTN_2_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
FLASH_ATTN_2_AVAILABLE = False
|
||||
|
||||
import warnings
|
||||
|
||||
__all__ = [
|
||||
"flash_attention",
|
||||
"attention",
|
||||
]
|
||||
|
||||
|
||||
def flash_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
version=None,
|
||||
):
|
||||
"""
|
||||
q: [B, Lq, Nq, C1].
|
||||
k: [B, Lk, Nk, C1].
|
||||
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
||||
q_lens: [B].
|
||||
k_lens: [B].
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
||||
deterministic: bool. If True, slightly slower and uses more memory.
|
||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||
"""
|
||||
half_dtypes = (torch.float16, torch.bfloat16)
|
||||
assert dtype in half_dtypes
|
||||
assert q.device.type == "cuda" and q.size(-1) <= 256
|
||||
|
||||
# params
|
||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||
|
||||
def half(x):
|
||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||
|
||||
# preprocess query
|
||||
if q_lens is None:
|
||||
q = half(q.flatten(0, 1))
|
||||
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
||||
else:
|
||||
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens, strict=False)]))
|
||||
|
||||
# preprocess key, value
|
||||
if k_lens is None:
|
||||
k = half(k.flatten(0, 1))
|
||||
v = half(v.flatten(0, 1))
|
||||
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
|
||||
else:
|
||||
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens, strict=False)]))
|
||||
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens, strict=False)]))
|
||||
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
if q_scale is not None:
|
||||
q = q * q_scale
|
||||
|
||||
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
||||
warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.", stacklevel=2)
|
||||
|
||||
# apply attention
|
||||
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
||||
# Note: dropout_p, window_size are not supported in FA3 now.
|
||||
x = flash_attn_interface.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
deterministic=deterministic,
|
||||
)[0].unflatten(0, (b, lq))
|
||||
else:
|
||||
assert FLASH_ATTN_2_AVAILABLE
|
||||
x = flash_attn.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
||||
.cumsum(0, dtype=torch.int32)
|
||||
.to(q.device, non_blocking=True),
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic,
|
||||
).unflatten(0, (b, lq))
|
||||
|
||||
# output
|
||||
return x.type(out_dtype)
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_lens=None,
|
||||
k_lens=None,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
q_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
fa_version=None,
|
||||
):
|
||||
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
||||
return flash_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
q_lens=q_lens,
|
||||
k_lens=k_lens,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
q_scale=q_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
deterministic=deterministic,
|
||||
dtype=dtype,
|
||||
version=fa_version,
|
||||
)
|
||||
else:
|
||||
if q_lens is not None or k_lens is not None:
|
||||
warnings.warn(
|
||||
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.",
|
||||
stacklevel=2,
|
||||
)
|
||||
attn_mask = None
|
||||
|
||||
q = q.transpose(1, 2).to(dtype)
|
||||
k = k.transpose(1, 2).to(dtype)
|
||||
v = v.transpose(1, 2).to(dtype)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
@@ -0,0 +1,519 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from .attention import flash_attention
|
||||
|
||||
__all__ = ["WanModel"]
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.float64)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
||||
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def rope_params(max_seq_len, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
|
||||
)
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def rope_apply(x, grid_sizes, freqs):
|
||||
n, c = x.size(2), x.size(3) // 2
|
||||
|
||||
# split freqs
|
||||
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
||||
|
||||
# loop over samples
|
||||
output = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
||||
seq_len = f * h * w
|
||||
|
||||
# precompute multipliers
|
||||
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
|
||||
freqs_i = torch.cat(
|
||||
[
|
||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
||||
],
|
||||
dim=-1,
|
||||
).reshape(seq_len, 1, -1)
|
||||
|
||||
# apply rotary embedding
|
||||
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
||||
x_i = torch.cat([x_i, x[i, seq_len:]])
|
||||
|
||||
# append to collection
|
||||
output.append(x_i)
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
return self._norm(x.float()).type_as(x) * self.weight
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.LayerNorm):
|
||||
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
||||
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, seq_lens, grid_sizes, freqs):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
seq_lens(Tensor): Shape [B]
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
def qkv_fn(x):
|
||||
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
||||
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanCrossAttention(WanSelfAttention):
|
||||
def forward(self, x, context, context_lens):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
context(Tensor): Shape [B, L2, C]
|
||||
context_lens(Tensor): Shape [B]
|
||||
"""
|
||||
b, n, d = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
||||
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
||||
v = self.v(context).view(b, -1, n, d)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)
|
||||
)
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
context_lens,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, C]
|
||||
e(Tensor): Shape [B, L1, 6, C]
|
||||
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
|
||||
assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
y = self.self_attn(
|
||||
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), seq_lens, grid_sizes, freqs
|
||||
)
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
x = x + y * e[2].squeeze(2)
|
||||
|
||||
# cross-attention & ffn function
|
||||
def cross_attn_ffn(x, context, context_lens, e):
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
||||
y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
x = x + y * e[5].squeeze(2)
|
||||
return x
|
||||
|
||||
x = cross_attn_ffn(x, context, context_lens, e)
|
||||
return x
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.eps = eps
|
||||
|
||||
# layers
|
||||
out_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, out_dim)
|
||||
|
||||
# modulation
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, e):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L1, C]
|
||||
e(Tensor): Shape [B, L1, C]
|
||||
"""
|
||||
assert e.dtype == torch.float32
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
|
||||
x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
|
||||
return x
|
||||
|
||||
|
||||
class WanModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||
"""
|
||||
|
||||
ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
|
||||
_no_split_modules = ["WanAttentionBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
model_type="t2v",
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
):
|
||||
r"""
|
||||
Initialize the diffusion model backbone.
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
Fixed length for text embeddings
|
||||
in_dim (`int`, *optional*, defaults to 16):
|
||||
Input video channels (C_in)
|
||||
dim (`int`, *optional*, defaults to 2048):
|
||||
Hidden dimension of the transformer
|
||||
ffn_dim (`int`, *optional*, defaults to 8192):
|
||||
Intermediate dimension in feed-forward network
|
||||
freq_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension for sinusoidal time embeddings
|
||||
text_dim (`int`, *optional*, defaults to 4096):
|
||||
Input dimension for text embeddings
|
||||
out_dim (`int`, *optional*, defaults to 16):
|
||||
Output video channels (C_out)
|
||||
num_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads
|
||||
num_layers (`int`, *optional*, defaults to 32):
|
||||
Number of transformer blocks
|
||||
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
||||
Window size for local attention (-1 indicates global attention)
|
||||
qk_norm (`bool`, *optional*, defaults to True):
|
||||
Enable query/key normalization
|
||||
cross_attn_norm (`bool`, *optional*, defaults to False):
|
||||
Enable cross-attention normalization
|
||||
eps (`float`, *optional*, defaults to 1e-6):
|
||||
Epsilon value for normalization layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.text_len = text_len
|
||||
self.in_dim = in_dim
|
||||
self.dim = dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.freq_dim = freq_dim
|
||||
self.text_dim = text_dim
|
||||
self.out_dim = out_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.window_size = window_size
|
||||
self.qk_norm = qk_norm
|
||||
self.cross_attn_norm = cross_attn_norm
|
||||
self.eps = eps
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.text_embedding = nn.Sequential(
|
||||
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
|
||||
)
|
||||
|
||||
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# head
|
||||
self.head = Head(dim, out_dim, patch_size, eps)
|
||||
|
||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
||||
d = dim // num_heads
|
||||
self.freqs = torch.cat(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# initialize weights
|
||||
self.init_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
y=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
if self.model_type == "i2v":
|
||||
assert y is not None
|
||||
# params
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y, strict=False)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
|
||||
|
||||
# time embeddings
|
||||
if t.dim() == 1:
|
||||
t = t.expand(t.size(0), seq_len)
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
bt = t.size(0)
|
||||
t = t.flatten()
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_len)).float()
|
||||
)
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
|
||||
)
|
||||
|
||||
# arguments
|
||||
kwargs = {
|
||||
"e": e0,
|
||||
"seq_lens": seq_lens,
|
||||
"grid_sizes": grid_sizes,
|
||||
"freqs": self.freqs,
|
||||
"context": context,
|
||||
"context_lens": context_lens,
|
||||
}
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return [u.float() for u in x]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
r"""
|
||||
Reconstruct video tensors from patch embeddings.
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
||||
grid_sizes (Tensor):
|
||||
Original spatial-temporal grid dimensions before patching,
|
||||
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
|
||||
c = self.out_dim
|
||||
out = []
|
||||
for u, v in zip(x, grid_sizes.tolist(), strict=False):
|
||||
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
|
||||
u = torch.einsum("fhwpqrc->cfphqwr", u)
|
||||
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size, strict=False)])
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def init_weights(self):
|
||||
r"""
|
||||
Initialize model parameters using Xavier initialization.
|
||||
"""
|
||||
|
||||
# basic init
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
# init embeddings
|
||||
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
||||
for m in self.text_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=0.02)
|
||||
for m in self.time_embedding.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=0.02)
|
||||
|
||||
# init output layer
|
||||
nn.init.zeros_(self.head.head.weight)
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .fm_solvers import get_sampling_sigmas
|
||||
|
||||
__all__ = [
|
||||
"get_sampling_sigmas",
|
||||
]
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_sampling_sigmas(sampling_steps, shift):
|
||||
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
||||
sigma = shift * sigma / (1 + (shift - 1) * sigma)
|
||||
return sigma
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
|
||||
class WanVideoVAE38(torch.nn.Module):
|
||||
"""FastWAM VAE contract over `diffusers.AutoencoderKLWan` (Wan2.2-TI2V-5B).
|
||||
|
||||
16x spatial / 4x temporal compression, 48 latent channels. diffusers'
|
||||
`AutoencoderKLWan` returns *raw* latents (it does not apply `latents_mean`/
|
||||
`latents_std`), so `encode`/`decode` here apply the same standardization the
|
||||
Wan reference uses — `(latents - mean) / std` — done in fp32 for stability.
|
||||
`encode` uses the deterministic posterior mode, matching the original VAE
|
||||
which returned the latent mean `mu`.
|
||||
"""
|
||||
|
||||
upsampling_factor = 16
|
||||
temporal_downsample_factor = 4
|
||||
z_dim = 48
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: str | torch.device = "cuda",
|
||||
*,
|
||||
pretrained: AutoencoderKLWan,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# The Wan2.2 VAE is a fixed pretrained model — it is never trained from scratch,
|
||||
# so a real `AutoencoderKLWan` (with weights) must always be supplied (loaded from
|
||||
# the diffusers repo by `load_pretrained_wan_vae`). No random/offline build path.
|
||||
self.vae = pretrained.to(device=device, dtype=dtype)
|
||||
|
||||
# Read the standardization stats from the VAE's own config (diffusers populates
|
||||
# these from vae/config.json) — single source of truth, no local copy. diffusers'
|
||||
# encode/decode return *raw* latents, so we apply (latent - mean) / std ourselves.
|
||||
# Non-persistent: kept out of state_dict.
|
||||
self.register_buffer(
|
||||
"latents_mean",
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.z_dim, 1, 1, 1),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"latents_std",
|
||||
torch.tensor(self.vae.config.latents_std).view(1, self.z_dim, 1, 1, 1),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _device_dtype(self) -> tuple[torch.device, torch.dtype]:
|
||||
param = next(self.vae.parameters())
|
||||
return param.device, param.dtype
|
||||
|
||||
def encode(
|
||||
self,
|
||||
videos: list[torch.Tensor] | torch.Tensor,
|
||||
device: str | torch.device | None = None,
|
||||
tiled: bool = False,
|
||||
tile_size: tuple[int, int] = (34, 34),
|
||||
tile_stride: tuple[int, int] = (18, 16),
|
||||
) -> torch.Tensor:
|
||||
del device, tile_size, tile_stride
|
||||
if tiled:
|
||||
raise NotImplementedError("Tiled Wan2.2 VAE encoding is not supported by the FastWAM adapter.")
|
||||
if isinstance(videos, (list, tuple)):
|
||||
videos = torch.stack(list(videos))
|
||||
dev, dtype = self._device_dtype()
|
||||
mu = self.vae.encode(videos.to(device=dev, dtype=dtype)).latent_dist.mode().float()
|
||||
mean = self.latents_mean.float().to(mu.device)
|
||||
std = self.latents_std.float().to(mu.device)
|
||||
return (mu - mean) / std
|
||||
|
||||
def decode(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
device: str | torch.device | None = None,
|
||||
tiled: bool = False,
|
||||
tile_size: tuple[int, int] = (34, 34),
|
||||
tile_stride: tuple[int, int] = (18, 16),
|
||||
) -> torch.Tensor:
|
||||
del device, tile_size, tile_stride
|
||||
if tiled:
|
||||
raise NotImplementedError("Tiled Wan2.2 VAE decoding is not supported by the FastWAM adapter.")
|
||||
if isinstance(hidden_states, (list, tuple)):
|
||||
hidden_states = torch.stack(list(hidden_states))
|
||||
dev, dtype = self._device_dtype()
|
||||
z = hidden_states.float()
|
||||
z = z * self.latents_std.float().to(z.device) + self.latents_mean.float().to(z.device)
|
||||
out = self.vae.decode(z.to(device=dev, dtype=dtype)).sample
|
||||
return out.float().clamp_(-1.0, 1.0)
|
||||
|
||||
|
||||
__all__ = ["WanVideoVAE38"]
|
||||
@@ -0,0 +1,174 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
|
||||
# repo as sharded `diffusion_pytorch_model*.safetensors`; the VAE and UMT5 text
|
||||
# encoder come from the diffusers conversion. Tokenizer is the stock UMT5 one.
|
||||
WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
|
||||
WAN_T5_TOKENIZER = "google/umt5-xxl"
|
||||
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
|
||||
|
||||
class WanTextEncoder(torch.nn.Module):
|
||||
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
|
||||
|
||||
Exposes `.dim` (hidden size) and `forward(ids, mask) -> [B, L, dim]`, matching
|
||||
the call in `FastWAM.encode_prompt`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str | torch.device = "cuda",
|
||||
*,
|
||||
pretrained: torch.nn.Module,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# UMT5-XXL is a fixed pretrained encoder — never trained from scratch, so a real
|
||||
# `UMT5EncoderModel` (with weights) must always be supplied (loaded from the
|
||||
# diffusers repo by `load_pretrained_wan_text_encoder`). No random/offline build.
|
||||
self.model = pretrained.to(device=device, dtype=dtype)
|
||||
self.dim = int(self.model.config.d_model)
|
||||
|
||||
def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
return self.model(input_ids=ids, attention_mask=mask.long()).last_hidden_state
|
||||
|
||||
|
||||
class WanTokenizer:
|
||||
"""UMT5 tokenizer wrapper returning `(input_ids, attention_mask)` like the
|
||||
FastWAM call site expects."""
|
||||
|
||||
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
self.seq_len = int(seq_len)
|
||||
|
||||
def __call__(
|
||||
self, sequence: str | Sequence[str], return_mask: bool = False, add_special_tokens: bool = True, **_: Any
|
||||
):
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
out = self.tokenizer(
|
||||
list(sequence),
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.seq_len,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if return_mask:
|
||||
return out.input_ids, out.attention_mask
|
||||
return out.input_ids
|
||||
|
||||
|
||||
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
|
||||
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
|
||||
|
||||
|
||||
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype
|
||||
)
|
||||
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
|
||||
|
||||
|
||||
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
|
||||
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
|
||||
from transformers import UMT5EncoderModel
|
||||
|
||||
encoder = UMT5EncoderModel.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
|
||||
)
|
||||
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
|
||||
|
||||
|
||||
def resolve_wan_dit_paths(
|
||||
model_id_or_path: str | Path,
|
||||
*,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
) -> list[Path]:
|
||||
"""Resolve the custom MoT DiT shards from the original Wan2.2 repo or a local dir."""
|
||||
path = Path(model_id_or_path).expanduser()
|
||||
if path.is_dir():
|
||||
return sorted(path.glob(WAN_DIT_PATTERN))
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_path = snapshot_download(
|
||||
repo_id=str(model_id_or_path),
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
allow_patterns=[WAN_DIT_PATTERN],
|
||||
)
|
||||
return sorted(Path(snapshot_path).glob(WAN_DIT_PATTERN))
|
||||
|
||||
|
||||
def load_wan_video_dit(
|
||||
paths: list[str | Path],
|
||||
*,
|
||||
dit_config: dict[str, Any],
|
||||
torch_dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> WanVideoDiT:
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
model = WanVideoDiT(**dit_config)
|
||||
state_dict = _read_wan_dit_safetensors(paths)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model.to(device=device, dtype=torch_dtype)
|
||||
|
||||
|
||||
def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
for path in paths:
|
||||
state_dict.update(load_file(str(path), device="cpu"))
|
||||
return state_dict
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WAN22_DIFFUSERS_MODEL_ID",
|
||||
"WAN_DIT_PATTERN",
|
||||
"WAN_T5_TOKENIZER",
|
||||
"WanTextEncoder",
|
||||
"WanTokenizer",
|
||||
"build_wan_tokenizer",
|
||||
"load_pretrained_wan_text_encoder",
|
||||
"load_pretrained_wan_vae",
|
||||
"load_wan_video_dit",
|
||||
"resolve_wan_dit_paths",
|
||||
]
|
||||
@@ -0,0 +1,814 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as functional
|
||||
from einops import rearrange
|
||||
|
||||
from .wan.modules.model import (
|
||||
WanAttentionBlock,
|
||||
WanLayerNorm,
|
||||
WanModel,
|
||||
WanRMSNorm,
|
||||
rope_apply,
|
||||
rope_params,
|
||||
sinusoidal_embedding_1d,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
|
||||
def fastwam_masked_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
num_heads: int,
|
||||
ctx_mask: torch.Tensor | None = None,
|
||||
fp32_attention: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""FastWAM masked attention wrapper for MoT masks and CPU test coverage.
|
||||
|
||||
The official Wan attention implementation is still used as the source of
|
||||
the projection/norm modules. This wrapper only replaces the final attention
|
||||
kernel because FastWAM needs explicit boolean masks for video/action MoT
|
||||
routing, while the upstream FlashAttention path accepts sequence lengths
|
||||
but not arbitrary [query, key] masks.
|
||||
"""
|
||||
|
||||
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
||||
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
||||
if fp32_attention:
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
v = v.float()
|
||||
else:
|
||||
q = q.to(dtype=v.dtype)
|
||||
k = k.to(dtype=v.dtype)
|
||||
x = functional.scaled_dot_product_attention(q, k, v, attn_mask=ctx_mask)
|
||||
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
|
||||
from .wan.utils.fm_solvers import get_sampling_sigmas
|
||||
|
||||
return get_sampling_sigmas(num_inference_steps, shift)
|
||||
|
||||
|
||||
class WanContinuousFlowMatchScheduler:
|
||||
"""Continuous-time Flow-Matching scheduler with shift-based Wan sampling."""
|
||||
|
||||
def __init__(self, num_train_timesteps: int = 1000, shift: float = 5.0, eps: float = 1e-10):
|
||||
if num_train_timesteps <= 0:
|
||||
raise ValueError(f"`num_train_timesteps` must be positive, got {num_train_timesteps}")
|
||||
if shift <= 0:
|
||||
raise ValueError(f"`shift` must be positive, got {shift}")
|
||||
self.num_train_timesteps = int(num_train_timesteps)
|
||||
self.shift = float(shift)
|
||||
self.eps = float(eps)
|
||||
self._y_min, self._weight_norm_const = self._precompute_training_weight_stats()
|
||||
|
||||
@staticmethod
|
||||
def _phi(u: torch.Tensor, shift: float) -> torch.Tensor:
|
||||
return shift * u / (1.0 + (shift - 1.0) * u)
|
||||
|
||||
def _precompute_training_weight_stats(self) -> tuple[float, float]:
|
||||
steps = self.num_train_timesteps
|
||||
u_grid = torch.linspace(1.0, 0.0, steps + 1, dtype=torch.float64)[:-1]
|
||||
t_grid = self._phi(u_grid, self.shift) * float(steps)
|
||||
y_grid = torch.exp(-2.0 * ((t_grid - (steps / 2.0)) / steps) ** 2)
|
||||
y_min = float(y_grid.min().item())
|
||||
y_shifted_grid = y_grid - y_min
|
||||
norm_const = float(y_shifted_grid.mean().item())
|
||||
return y_min, norm_const
|
||||
|
||||
def sample_training_t(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"`batch_size` must be positive, got {batch_size}")
|
||||
u = torch.rand((batch_size,), device=device, dtype=torch.float32)
|
||||
sigma = self._phi(u, self.shift)
|
||||
timestep = sigma * float(self.num_train_timesteps)
|
||||
return timestep.to(dtype=dtype)
|
||||
|
||||
def training_weight(self, timestep: torch.Tensor) -> torch.Tensor:
|
||||
t = timestep.to(dtype=torch.float32)
|
||||
steps = float(self.num_train_timesteps)
|
||||
y = torch.exp(-2.0 * ((t - (steps / 2.0)) / steps) ** 2)
|
||||
y_shifted = y - self._y_min
|
||||
weight = y_shifted / (self._weight_norm_const + self.eps)
|
||||
if weight.numel() == 1:
|
||||
return weight.reshape(())
|
||||
return weight
|
||||
|
||||
def add_noise(
|
||||
self, original_samples: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
sigma = (timestep / float(self.num_train_timesteps)).to(
|
||||
original_samples.device, dtype=original_samples.dtype
|
||||
)
|
||||
if sigma.ndim == 0:
|
||||
return (1 - sigma) * original_samples + sigma * noise
|
||||
sigma = sigma.view(-1, *([1] * (original_samples.ndim - 1)))
|
||||
return (1 - sigma) * original_samples + sigma * noise
|
||||
|
||||
@staticmethod
|
||||
def training_target(sample: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
||||
del timestep
|
||||
return noise - sample
|
||||
|
||||
def build_inference_schedule(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
shift_override: float | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be positive, got {num_inference_steps}")
|
||||
shift = self.shift if shift_override is None else float(shift_override)
|
||||
if shift <= 0:
|
||||
raise ValueError(f"`shift` must be positive, got {shift}")
|
||||
|
||||
sigma_steps = torch.as_tensor(
|
||||
_get_wan_sampling_sigmas(num_inference_steps, shift),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
timesteps = sigma_steps * float(self.num_train_timesteps)
|
||||
sigma_next = torch.cat([sigma_steps[1:], sigma_steps.new_zeros(1)])
|
||||
deltas = sigma_next - sigma_steps
|
||||
return timesteps.to(dtype=dtype), deltas.to(dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def step(model_output: torch.Tensor, delta: torch.Tensor, sample: torch.Tensor) -> torch.Tensor:
|
||||
delta = delta.to(sample.device, dtype=sample.dtype)
|
||||
if delta.ndim == 0:
|
||||
return sample + model_output * delta
|
||||
delta = delta.view(-1, *([1] * (sample.ndim - 1)))
|
||||
return sample + model_output * delta
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
||||
return rope_params(end, dim, theta)
|
||||
|
||||
|
||||
def apply_dense_rope(x: torch.Tensor, freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
x_out = torch.view_as_complex(x.to(torch.float32).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
def _linear_input(linear: nn.Linear, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(dtype=linear.weight.dtype)
|
||||
|
||||
|
||||
def _wan_layer_norm(norm: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(norm, WanLayerNorm) and norm.weight is not None:
|
||||
weight = norm.weight.float()
|
||||
bias = norm.bias.float() if norm.bias is not None else None
|
||||
return functional.layer_norm(x.float(), norm.normalized_shape, weight, bias, norm.eps).to(
|
||||
dtype=x.dtype
|
||||
)
|
||||
return norm(x)
|
||||
|
||||
|
||||
def create_group_causal_attn_mask(
|
||||
num_temporal_groups: int, num_query_per_group: int, num_key_per_group: int, mode: str = "causal"
|
||||
) -> torch.Tensor:
|
||||
if mode not in ["causal", "group_diagonal"]:
|
||||
raise ValueError(f"`mode` must be 'causal' or 'group_diagonal', got {mode}.")
|
||||
if num_temporal_groups <= 0:
|
||||
raise ValueError(f"`num_temporal_groups` must be positive, got {num_temporal_groups}.")
|
||||
if num_query_per_group <= 0:
|
||||
raise ValueError(f"`num_query_per_group` must be positive, got {num_query_per_group}.")
|
||||
if num_key_per_group <= 0:
|
||||
raise ValueError(f"`num_key_per_group` must be positive, got {num_key_per_group}.")
|
||||
|
||||
total_num_query_tokens = num_temporal_groups * num_query_per_group
|
||||
total_num_key_tokens = num_temporal_groups * num_key_per_group
|
||||
query_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_query_per_group).unsqueeze(1)
|
||||
key_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_key_per_group).unsqueeze(0)
|
||||
|
||||
if mode == "causal":
|
||||
attn_mask = query_time_indices >= key_time_indices
|
||||
else:
|
||||
attn_mask = query_time_indices == key_time_indices
|
||||
|
||||
if attn_mask.shape != (total_num_query_tokens, total_num_key_tokens):
|
||||
raise RuntimeError("Attention mask shape mismatch.")
|
||||
return attn_mask
|
||||
|
||||
|
||||
class FastWAMAttentionBlock(WanAttentionBlock):
|
||||
"""Wan attention block with FastWAM's arbitrary boolean mask support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
attn_head_dim: int,
|
||||
num_heads: int,
|
||||
ffn_dim: int,
|
||||
eps: float = 1e-6,
|
||||
fp32_attention: bool = True,
|
||||
):
|
||||
attention_dim = attn_head_dim * num_heads
|
||||
if hidden_dim == attention_dim:
|
||||
super().__init__(
|
||||
dim=hidden_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
num_heads=num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=eps,
|
||||
)
|
||||
else:
|
||||
nn.Module.__init__(self)
|
||||
self.dim = hidden_dim
|
||||
self.ffn_dim = ffn_dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = (-1, -1)
|
||||
self.qk_norm = True
|
||||
self.cross_attn_norm = True
|
||||
self.eps = eps
|
||||
self.norm1 = WanLayerNorm(hidden_dim, eps)
|
||||
self.self_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
|
||||
self.norm3 = WanLayerNorm(hidden_dim, eps, elementwise_affine=True)
|
||||
self.cross_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
|
||||
self.norm2 = WanLayerNorm(hidden_dim, eps)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(hidden_dim, ffn_dim),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(ffn_dim, hidden_dim),
|
||||
)
|
||||
self.modulation = nn.Parameter(torch.randn(1, 6, hidden_dim) / hidden_dim**0.5)
|
||||
self.attn_head_dim = attn_head_dim
|
||||
self.fp32_attention = bool(fp32_attention)
|
||||
|
||||
@staticmethod
|
||||
def split_modulation(block, t_mod: torch.Tensor):
|
||||
has_seq = len(t_mod.shape) == 4
|
||||
chunk_dim = 2 if has_seq else 1
|
||||
|
||||
base_mod = block.modulation.to(dtype=t_mod.dtype, device=t_mod.device)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (base_mod + t_mod).chunk(
|
||||
6, dim=chunk_dim
|
||||
)
|
||||
if has_seq:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
shift_msa.squeeze(2),
|
||||
scale_msa.squeeze(2),
|
||||
gate_msa.squeeze(2),
|
||||
shift_mlp.squeeze(2),
|
||||
scale_mlp.squeeze(2),
|
||||
gate_mlp.squeeze(2),
|
||||
)
|
||||
return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
def project_self_attention(
|
||||
self, x: torch.Tensor, freqs: torch.Tensor | dict[str, torch.Tensor]
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q = self.self_attn.norm_q(self.self_attn.q(x))
|
||||
k = self.self_attn.norm_k(self.self_attn.k(x))
|
||||
v = self.self_attn.v(x)
|
||||
if isinstance(freqs, dict):
|
||||
b, s = x.shape[:2]
|
||||
q = rope_apply(
|
||||
q.view(b, s, self.num_heads, self.attn_head_dim),
|
||||
freqs["grid_sizes"],
|
||||
freqs["freqs"],
|
||||
).flatten(2)
|
||||
k = rope_apply(
|
||||
k.view(b, s, self.num_heads, self.attn_head_dim),
|
||||
freqs["grid_sizes"],
|
||||
freqs["freqs"],
|
||||
).flatten(2)
|
||||
else:
|
||||
q = apply_dense_rope(q, freqs, self.num_heads)
|
||||
k = apply_dense_rope(k, freqs, self.num_heads)
|
||||
return q, k, v
|
||||
|
||||
def apply_cross_attention(
|
||||
self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
if context_mask is not None and context_mask.dim() == 3:
|
||||
context_mask = context_mask.unsqueeze(1)
|
||||
attn = self.cross_attn
|
||||
b, n, d = x.size(0), attn.num_heads, attn.head_dim
|
||||
q = attn.norm_q(attn.q(x)).view(b, -1, n * d)
|
||||
k = attn.norm_k(attn.k(context)).view(b, -1, n * d)
|
||||
v = attn.v(context).view(b, -1, n * d)
|
||||
x = fastwam_masked_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
num_heads=n,
|
||||
ctx_mask=context_mask,
|
||||
fp32_attention=self.fp32_attention,
|
||||
)
|
||||
return attn.o(_linear_input(attn.o, x))
|
||||
|
||||
def project_self_attention_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_attn.o(_linear_input(self.self_attn.o, x))
|
||||
|
||||
def apply_norm1(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return _wan_layer_norm(self.norm1, x)
|
||||
|
||||
def apply_norm2(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return _wan_layer_norm(self.norm2, x)
|
||||
|
||||
def apply_norm3(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return _wan_layer_norm(self.norm3, x)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
t_mod: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
self_attn_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.split_modulation(self, t_mod)
|
||||
residual_x = x
|
||||
attn_input = modulate(self.apply_norm1(x), shift_msa, scale_msa)
|
||||
q, k, v = self.project_self_attention(attn_input, freqs)
|
||||
y = fastwam_masked_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
num_heads=self.num_heads,
|
||||
ctx_mask=self_attn_mask,
|
||||
fp32_attention=self.fp32_attention,
|
||||
)
|
||||
x = residual_x + gate_msa * self.project_self_attention_output(y)
|
||||
x = x + self.apply_cross_attention(self.apply_norm3(x), context, context_mask=context_mask)
|
||||
mlp_input = modulate(self.apply_norm2(x), shift_mlp, scale_mlp)
|
||||
return x + gate_mlp * self.ffn(mlp_input)
|
||||
|
||||
|
||||
class _FastWAMProjectedAttention(nn.Module):
|
||||
def __init__(self, hidden_dim: int, attention_dim: int, num_heads: int, eps: float):
|
||||
super().__init__()
|
||||
self.dim = hidden_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = attention_dim // num_heads
|
||||
self.q = nn.Linear(hidden_dim, attention_dim)
|
||||
self.k = nn.Linear(hidden_dim, attention_dim)
|
||||
self.v = nn.Linear(hidden_dim, attention_dim)
|
||||
self.o = nn.Linear(attention_dim, hidden_dim)
|
||||
self.norm_q = WanRMSNorm(attention_dim, eps=eps)
|
||||
self.norm_k = WanRMSNorm(attention_dim, eps=eps)
|
||||
|
||||
|
||||
class WanVideoDiT(WanModel):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
in_dim: int,
|
||||
ffn_dim: int,
|
||||
out_dim: int,
|
||||
text_dim: int,
|
||||
freq_dim: int,
|
||||
eps: float,
|
||||
patch_size: tuple[int, int, int],
|
||||
num_heads: int,
|
||||
attn_head_dim: int,
|
||||
num_layers: int,
|
||||
has_image_input: bool = False,
|
||||
has_image_pos_emb: bool = False,
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
seperated_timestep: bool = False,
|
||||
require_vae_embedding: bool = False,
|
||||
require_clip_embedding: bool = False,
|
||||
fuse_vae_embedding_in_latents: bool = True,
|
||||
action_conditioned: bool = False,
|
||||
action_dim: int = 7,
|
||||
action_group_causal_mask_mode="causal",
|
||||
video_attention_mask_mode: str = "bidirectional",
|
||||
use_gradient_checkpointing: bool = False,
|
||||
fp32_attention: bool = True,
|
||||
):
|
||||
del in_dim_control_adapter
|
||||
if has_image_input:
|
||||
raise ValueError("FastWAM currently expects Wan2.2 TI2V latents with fused image conditioning.")
|
||||
if has_image_pos_emb:
|
||||
raise ValueError("FastWAM does not support extra image positional embeddings in WanVideoDiT.")
|
||||
if has_ref_conv:
|
||||
raise ValueError("FastWAM does not support reference convolutions in WanVideoDiT.")
|
||||
if add_control_adapter:
|
||||
raise ValueError("FastWAM does not support control adapters in WanVideoDiT.")
|
||||
if require_clip_embedding:
|
||||
raise ValueError("FastWAM does not support CLIP embedding conditioning in WanVideoDiT.")
|
||||
if require_vae_embedding or not fuse_vae_embedding_in_latents:
|
||||
raise ValueError("FastWAM expects VAE conditioning to be fused in latents.")
|
||||
if attn_head_dim != hidden_dim // num_heads:
|
||||
raise ValueError(
|
||||
"`attn_head_dim` must match the upstream Wan head dimension `hidden_dim // num_heads`; "
|
||||
f"got {attn_head_dim} vs {hidden_dim // num_heads}."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_type="ti2v",
|
||||
patch_size=patch_size,
|
||||
text_len=512,
|
||||
in_dim=in_dim,
|
||||
dim=hidden_dim,
|
||||
ffn_dim=ffn_dim,
|
||||
freq_dim=freq_dim,
|
||||
text_dim=text_dim,
|
||||
out_dim=out_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=eps,
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
FastWAMAttentionBlock(
|
||||
hidden_dim=hidden_dim,
|
||||
attn_head_dim=attn_head_dim,
|
||||
num_heads=num_heads,
|
||||
ffn_dim=ffn_dim,
|
||||
eps=eps,
|
||||
fp32_attention=fp32_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.init_weights()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.attn_head_dim = attn_head_dim
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
self.video_attention_mask_mode = str(video_attention_mask_mode)
|
||||
self.action_conditioned = action_conditioned
|
||||
self.action_dim = action_dim
|
||||
self.fp32_attention = bool(fp32_attention)
|
||||
|
||||
if self.action_conditioned:
|
||||
self.action_embedding = torch.nn.Linear(action_dim, hidden_dim)
|
||||
self.action_group_causal_mask_mode = action_group_causal_mask_mode
|
||||
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
if self.use_gradient_checkpointing:
|
||||
logger.info(
|
||||
"Using gradient checkpointing for DiT blocks. This will save memory but use more computation."
|
||||
)
|
||||
|
||||
def patchify(self, x: torch.Tensor):
|
||||
return self.patch_embedding(x)
|
||||
|
||||
def _validate_forward_inputs(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
context_mask: torch.Tensor | None,
|
||||
action: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if x.ndim != 5:
|
||||
raise ValueError(f"`latents` must be 5D [B, C, T, H, W], got shape {tuple(x.shape)}")
|
||||
num_latent_frames = x.shape[2]
|
||||
if context.ndim != 3:
|
||||
raise ValueError(f"`context` must be 3D [B, L, D], got shape {tuple(context.shape)}")
|
||||
if timestep.ndim != 1:
|
||||
raise ValueError(f"`timestep` must be 1D [B] or [1], got shape {tuple(timestep.shape)}")
|
||||
if self.action_conditioned:
|
||||
allow_text_only_single_frame = num_latent_frames == 1 and action is None
|
||||
if not allow_text_only_single_frame:
|
||||
if action is None:
|
||||
raise ValueError("Action input is required for action-conditioned model.")
|
||||
if action.ndim != 3:
|
||||
raise ValueError(
|
||||
f"`action` must be 3D [B, action_horizon, action_dim], got shape {tuple(action.shape)}"
|
||||
)
|
||||
if action.shape[2] != self.action_dim:
|
||||
raise ValueError(
|
||||
f"`action` last dimension must be {self.action_dim}, got {action.shape[2]}"
|
||||
)
|
||||
if num_latent_frames <= 1:
|
||||
raise ValueError(
|
||||
f"video length must be > 1 for action-conditioned model, got {num_latent_frames}"
|
||||
)
|
||||
if action.shape[1] % (num_latent_frames - 1) != 0:
|
||||
raise ValueError(
|
||||
"action horizon must be divisible by (num_latent_frames - 1), "
|
||||
f"got action_horizon={action.shape[1]}"
|
||||
)
|
||||
if context_mask is None:
|
||||
context_mask = torch.ones(
|
||||
(context.shape[0], context.shape[1]), dtype=torch.bool, device=context.device
|
||||
)
|
||||
else:
|
||||
if context_mask.ndim != 2:
|
||||
raise ValueError(f"`context_mask` must be 2D [B, L], got shape {tuple(context_mask.shape)}")
|
||||
if context_mask.shape[0] != context.shape[0] or context_mask.shape[1] != context.shape[1]:
|
||||
raise ValueError(
|
||||
"`context_mask` shape must match `context` shape [B, L], "
|
||||
f"got {tuple(context_mask.shape)} vs {tuple(context.shape)}"
|
||||
)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
if batch_size != context.shape[0]:
|
||||
if not self.training and batch_size == 1:
|
||||
x = x.expand(context.shape[0], -1, -1, -1, -1)
|
||||
batch_size = context.shape[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Batch mismatch between latents and context: {batch_size} vs {context.shape[0]}."
|
||||
)
|
||||
|
||||
if timestep.shape[0] not in (1, batch_size):
|
||||
raise ValueError(
|
||||
f"`timestep` length must be 1 or batch_size({batch_size}), got {timestep.shape[0]}"
|
||||
)
|
||||
if timestep.shape[0] == 1 and batch_size > 1:
|
||||
if self.training:
|
||||
raise ValueError("During training, timestep length must match batch_size.")
|
||||
timestep = timestep.expand(batch_size)
|
||||
return x, timestep, context_mask
|
||||
|
||||
def build_video_to_video_mask(
|
||||
self,
|
||||
video_seq_len: int,
|
||||
video_tokens_per_frame: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
if video_seq_len <= 0:
|
||||
raise ValueError(f"`video_seq_len` must be positive, got {video_seq_len}")
|
||||
if video_tokens_per_frame <= 0:
|
||||
raise ValueError(f"`video_tokens_per_frame` must be positive, got {video_tokens_per_frame}")
|
||||
|
||||
if self.video_attention_mask_mode == "bidirectional":
|
||||
return torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
|
||||
|
||||
if self.video_attention_mask_mode == "per_frame_causal":
|
||||
if video_seq_len % video_tokens_per_frame != 0:
|
||||
raise ValueError(
|
||||
"`video_seq_len` must be divisible by `video_tokens_per_frame` in `per_frame_causal` mode, "
|
||||
f"got {video_seq_len} and {video_tokens_per_frame}"
|
||||
)
|
||||
num_video_frames = video_seq_len // video_tokens_per_frame
|
||||
frame_causal = torch.tril(
|
||||
torch.ones((num_video_frames, num_video_frames), dtype=torch.bool, device=device)
|
||||
)
|
||||
return frame_causal.repeat_interleave(video_tokens_per_frame, dim=0).repeat_interleave(
|
||||
video_tokens_per_frame, dim=1
|
||||
)
|
||||
|
||||
if self.video_attention_mask_mode == "first_frame_causal":
|
||||
video_mask = torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
|
||||
first_frame_tokens = min(video_tokens_per_frame, video_seq_len)
|
||||
video_mask[:first_frame_tokens, first_frame_tokens:] = False
|
||||
return video_mask
|
||||
|
||||
raise ValueError(f"Unsupported video attention mask mode: {self.video_attention_mask_mode}")
|
||||
|
||||
def pre_dit(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
action: torch.Tensor | None = None,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
x, timestep, context_mask = self._validate_forward_inputs(
|
||||
x=x,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
action=action,
|
||||
)
|
||||
model_dtype = self.patch_embedding.weight.dtype
|
||||
x = x.to(dtype=model_dtype)
|
||||
context = context.to(dtype=model_dtype)
|
||||
if action is not None:
|
||||
action = action.to(dtype=model_dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
patch_h = int(self.patch_size[1])
|
||||
patch_w = int(self.patch_size[2])
|
||||
if x.shape[3] % patch_h != 0 or x.shape[4] % patch_w != 0:
|
||||
raise ValueError(
|
||||
"Latent spatial shape must be divisible by DiT patch size, "
|
||||
f"got HxW=({x.shape[3]}, {x.shape[4]}), patch=({patch_h}, {patch_w})"
|
||||
)
|
||||
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
|
||||
|
||||
if not (self.seperated_timestep and fuse_vae_embedding_in_latents):
|
||||
raise NotImplementedError(
|
||||
"FastWAM currently requires separated timesteps with fused VAE latents."
|
||||
)
|
||||
|
||||
token_timesteps = torch.ones(
|
||||
(batch_size, x.shape[2], tokens_per_frame),
|
||||
dtype=model_dtype,
|
||||
device=timestep.device,
|
||||
) * timestep.to(dtype=model_dtype).view(batch_size, 1, 1)
|
||||
token_timesteps[:, 0, :] = 0
|
||||
token_timesteps = token_timesteps.reshape(batch_size, -1)
|
||||
# Wan keeps the time embedding in fp32: the AdaLN modulation in the vendored
|
||||
# Head/Block asserts e.dtype == float32 (numerical stability of the scale/shift).
|
||||
# Upstream guarantees this via an fp32 autocast region, so it holds even when the
|
||||
# model runs in bf16. Mirror that here, then cast the per-block modulation back to
|
||||
# model_dtype so the bf16 attention blocks are not upcast to fp32.
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).float()
|
||||
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
|
||||
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
|
||||
t_mod = t_mod.to(dtype=model_dtype)
|
||||
|
||||
x = self.patchify(x)
|
||||
f, h, w = x.shape[2:]
|
||||
|
||||
context = self.text_embedding(context)
|
||||
context_len = context.shape[1]
|
||||
if self.action_conditioned and action is not None:
|
||||
action_len = action.shape[1]
|
||||
action_emb = self.action_embedding(action)
|
||||
action_pos_embed = sinusoidal_embedding_1d(
|
||||
self.hidden_dim, torch.arange(action_len, device=action_emb.device)
|
||||
).to(dtype=action_emb.dtype)
|
||||
action_emb = action_emb + action_pos_embed.unsqueeze(0)
|
||||
context = torch.cat([context, action_emb], dim=1)
|
||||
|
||||
num_temporal_groups = f - 1
|
||||
if num_temporal_groups <= 0:
|
||||
raise ValueError(
|
||||
"Action-conditioned context mask requires at least 2 latent frames when `action` is provided."
|
||||
)
|
||||
if action_emb.shape[1] % num_temporal_groups != 0:
|
||||
raise ValueError(
|
||||
f"Action embedding length {action_emb.shape[1]} must be divisible by "
|
||||
f"number of temporal groups {num_temporal_groups}"
|
||||
)
|
||||
action_group_mask = create_group_causal_attn_mask(
|
||||
num_temporal_groups=num_temporal_groups,
|
||||
num_query_per_group=tokens_per_frame,
|
||||
num_key_per_group=action_len // num_temporal_groups,
|
||||
mode=self.action_group_causal_mask_mode,
|
||||
).to(context.device)
|
||||
|
||||
seq_len = f * h * w
|
||||
final_context_mask = torch.zeros(
|
||||
(batch_size, seq_len, context.shape[1]), dtype=torch.bool, device=context.device
|
||||
)
|
||||
final_context_mask[:, :, :context_len] = context_mask.unsqueeze(1).expand(-1, seq_len, -1)
|
||||
final_context_mask[:, tokens_per_frame:, context_len:] = action_group_mask.unsqueeze(0).expand(
|
||||
batch_size, -1, -1
|
||||
)
|
||||
context_mask = final_context_mask
|
||||
elif self.action_conditioned and action is None:
|
||||
if f != 1:
|
||||
raise ValueError(
|
||||
"Action-conditioned model requires `action` unless running single-frame text-only mode "
|
||||
"with num_latent_frames=1."
|
||||
)
|
||||
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
|
||||
else:
|
||||
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
|
||||
|
||||
x_tokens = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
|
||||
grid_sizes = torch.tensor([[f, h, w]] * batch_size, dtype=torch.long, device=x_tokens.device)
|
||||
freqs = {"grid_sizes": grid_sizes, "freqs": self.freqs.to(x_tokens.device)}
|
||||
|
||||
return {
|
||||
"tokens": x_tokens,
|
||||
"freqs": freqs,
|
||||
"t": t,
|
||||
"t_mod": t_mod,
|
||||
"context": context,
|
||||
"context_mask": context_mask,
|
||||
"meta": {
|
||||
"grid_sizes": grid_sizes,
|
||||
"tokens_per_frame": tokens_per_frame,
|
||||
"batch_size": batch_size,
|
||||
},
|
||||
}
|
||||
|
||||
def post_dit(self, x_tokens: torch.Tensor, pre_state: dict[str, Any]) -> torch.Tensor:
|
||||
x = self.head(x_tokens, pre_state["t"])
|
||||
return torch.stack(super().unpatchify(x, pre_state["meta"]["grid_sizes"]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
action: torch.Tensor | None = None,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
):
|
||||
pre_state = self.pre_dit(
|
||||
x=x,
|
||||
timestep=timestep,
|
||||
context=context,
|
||||
context_mask=context_mask,
|
||||
action=action,
|
||||
fuse_vae_embedding_in_latents=fuse_vae_embedding_in_latents,
|
||||
)
|
||||
x_tokens = pre_state["tokens"]
|
||||
context_emb = pre_state["context"]
|
||||
t_mod = pre_state["t_mod"]
|
||||
freqs = pre_state["freqs"]
|
||||
context_attn_mask = pre_state["context_mask"]
|
||||
self_attn_mask = (
|
||||
self.build_video_to_video_mask(
|
||||
video_seq_len=x_tokens.shape[1],
|
||||
video_tokens_per_frame=int(pre_state["meta"]["tokens_per_frame"]),
|
||||
device=x_tokens.device,
|
||||
)
|
||||
if self.video_attention_mask_mode != "bidirectional"
|
||||
else None
|
||||
)
|
||||
|
||||
for block in self.blocks:
|
||||
if self.use_gradient_checkpointing:
|
||||
x_tokens = gradient_checkpoint_forward(
|
||||
block,
|
||||
self.use_gradient_checkpointing,
|
||||
x_tokens,
|
||||
context_emb,
|
||||
t_mod,
|
||||
freqs,
|
||||
context_mask=context_attn_mask,
|
||||
self_attn_mask=self_attn_mask,
|
||||
)
|
||||
else:
|
||||
x_tokens = block(
|
||||
x_tokens,
|
||||
context_emb,
|
||||
t_mod,
|
||||
freqs,
|
||||
context_mask=context_attn_mask,
|
||||
self_attn_mask=self_attn_mask,
|
||||
)
|
||||
|
||||
return self.post_dit(x_tokens, pre_state)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FastWAMAttentionBlock",
|
||||
"WanContinuousFlowMatchScheduler",
|
||||
"WanVideoDiT",
|
||||
"apply_dense_rope",
|
||||
"create_group_causal_attn_mask",
|
||||
"fastwam_masked_attention",
|
||||
"gradient_checkpoint_forward",
|
||||
"modulate",
|
||||
"precompute_freqs_cis",
|
||||
"sinusoidal_embedding_1d",
|
||||
]
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -23,6 +23,7 @@ from .configs import (
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
@@ -49,6 +50,7 @@ from .inference import (
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
EpisodicStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
@@ -66,6 +68,8 @@ __all__ = [
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
|
||||
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@dataclass
|
||||
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||
|
||||
Keyboard controls:
|
||||
Right arrow — end current episode or reset phase early
|
||||
Left arrow — discard current episode and re-record
|
||||
Escape — stop recording session
|
||||
|
||||
In between episodes:
|
||||
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||
"""
|
||||
|
||||
# This only applies if there are no teleop leaders specified.
|
||||
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||
# Otherwise, leave the robot in its current position.
|
||||
reset_to_initial_position: bool = True
|
||||
|
||||
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||
# When False, fallback to follower -> leader handover.
|
||||
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||
smooth_leader_to_follower_handover: bool = True
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -229,7 +258,13 @@ class RolloutConfig:
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
self.strategy,
|
||||
(
|
||||
SentryStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
),
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"EpisodicStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
|
||||
@@ -56,10 +56,14 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
@@ -171,64 +174,6 @@ class DAggerEvents:
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if _teleop_supports_feedback(teleop):
|
||||
if teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
- Policy drives the robot during each recording episode.
|
||||
- An optional teleoperator can drive the robot during reset phases so the
|
||||
operator can bring the environment back to its starting configuration.
|
||||
If no teleop is connected the robot stays in its current position.
|
||||
- Keyboard controls:
|
||||
|
||||
Right arrow — end the current episode or reset phase early
|
||||
Left arrow — discard the current episode and re-record it
|
||||
Escape — stop the recording session
|
||||
|
||||
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..configs import EpisodicStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicStrategy(RolloutStrategy):
|
||||
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||
follows every episode (except the last) so the operator can manually
|
||||
reset the environment. During the reset phase, an optional teleoperator
|
||||
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||
|
||||
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||
the start of each recording episode.
|
||||
|
||||
Keyboard events:
|
||||
right arrow → end current episode or reset phase early
|
||||
left arrow → discard & re-record current episode
|
||||
ESC → stop the session
|
||||
"""
|
||||
|
||||
config: EpisodicStrategyConfig
|
||||
|
||||
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Start the inference engine and attach the keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
self._listener, self._events = init_keyboard_listener()
|
||||
logger.info("Episodic strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main multi-episode recording loop."""
|
||||
cfg = ctx.runtime.cfg
|
||||
dataset_cfg = cfg.dataset
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
fps = cfg.fps
|
||||
episode_time_s = dataset_cfg.episode_time_s
|
||||
reset_time_s = dataset_cfg.reset_time_s
|
||||
num_episodes = dataset_cfg.num_episodes
|
||||
single_task = dataset_cfg.single_task or cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
display_compressed = (
|
||||
True
|
||||
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||
else cfg.display_compressed_images
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||
self._engine.reset()
|
||||
self._interpolator.reset()
|
||||
self._engine.resume()
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
self._policy_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
events=events,
|
||||
features=features,
|
||||
fps=fps,
|
||||
control_time_s=episode_time_s,
|
||||
dataset=dataset,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
# Reset phase, skip after the last episode (but run when re-recording)
|
||||
if not events["stop_recording"] and (
|
||||
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
if teleop:
|
||||
# Smooth handover so the transition to teleop control is jerk-free.
|
||||
# For actuated teleops: drive the leader arm to the follower's current
|
||||
# position so the operator takes over without fighting the arm.
|
||||
# For non-actuated teleops: slide the follower to the teleop's current
|
||||
# pose instead, since the leader cannot be driven.
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
if (
|
||||
teleop_supports_feedback(teleop)
|
||||
and self.config.smooth_leader_to_follower_handover
|
||||
):
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||
teleop.disable_torque()
|
||||
else:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||
|
||||
elif self.config.reset_to_initial_position:
|
||||
# No teleop: return the robot to its startup position.
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
self._reset_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=fps,
|
||||
control_time_s=reset_time_s,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed=display_compressed,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# returns to its initial joint positions captured at startup
|
||||
if not teleop and self.config.reset_to_initial_position:
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Save any frames buffered in the current episode so an unexpected
|
||||
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def _policy_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
events: dict,
|
||||
features: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
dataset,
|
||||
single_task: str,
|
||||
) -> None:
|
||||
"""Policy-driven recording loop for a single episode."""
|
||||
interpolator = self._interpolator
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
|
||||
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
if sleep_t < 0:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||
"3) CPU starvation"
|
||||
)
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def _reset_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
display_data: bool,
|
||||
display_compressed: bool,
|
||||
) -> None:
|
||||
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||
processors = ctx.processors
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
if teleop is not None:
|
||||
act = teleop.get_action()
|
||||
act_teleop = processors.teleop_action_processor((act, obs))
|
||||
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||
robot.send_action(robot_action)
|
||||
|
||||
if display_data:
|
||||
obs_processed = processors.robot_observation_processor(obs)
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=act_teleop,
|
||||
compress_images=display_compressed,
|
||||
)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||
cfg = ctx.runtime.cfg
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
|
||||
if (
|
||||
cfg.dataset is not None
|
||||
and cfg.dataset.push_to_hub
|
||||
and ctx.data.dataset is not None
|
||||
and safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=cfg.dataset.tags,
|
||||
private=cfg.dataset.private,
|
||||
)
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=cfg.return_to_initial_position,
|
||||
)
|
||||
log_say("Exiting", play_sounds)
|
||||
logger.info("Episodic strategy teardown complete")
|
||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
if config.type == "episodic":
|
||||
return EpisodicStrategy(config)
|
||||
raise ValueError(
|
||||
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ Strategies
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
@@ -111,6 +112,18 @@ Usage examples
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Episodic mode — episode-oriented recording with reset phases
|
||||
lerobot-rollout \\
|
||||
--strategy.type=episodic \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/rollout_episodic_data \\
|
||||
--dataset.num_episodes=20 \\
|
||||
--dataset.single_task="Grab the cube"
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
|
||||
@@ -0,0 +1,356 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
|
||||
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy
|
||||
from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
class FakeFastWAMCore(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dit = nn.Linear(2, 2)
|
||||
|
||||
def training_loss(self, sample):
|
||||
assert sample["video"].ndim == 5
|
||||
assert sample["context"].ndim == 3
|
||||
return sample[ACTION].sum() * 0.0 + torch.tensor(1.0), {"loss_action": 1.0}
|
||||
|
||||
def infer_action(self, **kwargs):
|
||||
return {"action": torch.ones(1, kwargs["action_horizon"], 3)}
|
||||
|
||||
|
||||
def test_fastwam_is_registered_and_publicly_exported():
|
||||
cfg = make_policy_config(
|
||||
"fastwam",
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
base_model_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(cfg, FastWAMConfig)
|
||||
assert cfg.type == "fastwam"
|
||||
assert get_policy_class("fastwam") is FastWAMPolicy
|
||||
|
||||
|
||||
def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path):
|
||||
cfg = FastWAMConfig()
|
||||
cfg.save_pretrained(tmp_path)
|
||||
saved = json.loads((tmp_path / "config.json").read_text())
|
||||
|
||||
assert saved["pretrained_path"] is None
|
||||
assert cfg.image_features["observation.images.image"].type == FeatureType.VISUAL
|
||||
assert cfg.action_feature.shape == (7,)
|
||||
assert cfg.robot_state_feature.shape == (8,)
|
||||
with pytest.raises(ValueError, match="image feature"):
|
||||
FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))})
|
||||
with pytest.raises(ValueError, match="tokenizer_model_id"):
|
||||
FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer")
|
||||
|
||||
|
||||
def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_path):
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
image_size=(2, 2),
|
||||
device="cpu",
|
||||
toggle_action_dimensions=[-1],
|
||||
input_features={
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 2, 2)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
base_model_id=None,
|
||||
)
|
||||
dataset_stats = {
|
||||
"observation.images.image": {
|
||||
"mean": torch.full((3, 1, 1), 0.2),
|
||||
"std": torch.full((3, 1, 1), 0.1),
|
||||
},
|
||||
OBS_STATE: {
|
||||
"mean": torch.tensor([1.0, 3.0]),
|
||||
"std": torch.tensor([2.0, 4.0]),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(3),
|
||||
"std": torch.ones(3),
|
||||
},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats)
|
||||
processed = preprocessor(
|
||||
{
|
||||
"observation.images.image": torch.tensor(
|
||||
[
|
||||
[[0.0, 0.5], [1.0, 0.5]],
|
||||
[[0.0, 0.5], [1.0, 0.5]],
|
||||
[[0.0, 0.5], [1.0, 0.5]],
|
||||
]
|
||||
),
|
||||
OBS_STATE: torch.tensor([3.0, 7.0]),
|
||||
}
|
||||
)
|
||||
preprocessor.save_pretrained(tmp_path, config_filename="policy_preprocessor.json")
|
||||
postprocessor.save_pretrained(tmp_path, config_filename="policy_postprocessor.json")
|
||||
_, loaded_postprocessor = make_pre_post_processors(cfg, pretrained_path=str(tmp_path))
|
||||
|
||||
expected_image = torch.tensor(
|
||||
[[[[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]]]]
|
||||
)
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
assert torch.allclose(processed["observation.images.image"], expected_image)
|
||||
assert torch.allclose(processed[OBS_STATE], torch.tensor([[1.0, 1.0]]))
|
||||
assert torch.equal(dataset_stats["observation.images.image"]["mean"], torch.full((3, 1, 1), 0.2))
|
||||
assert any(isinstance(step, FastWAMActionToggleProcessorStep) for step in loaded_postprocessor.steps)
|
||||
assert torch.equal(
|
||||
loaded_postprocessor(torch.tensor([[0.25, 0.5, 1.0]])), torch.tensor([[0.25, 0.5, -1.0]])
|
||||
)
|
||||
|
||||
|
||||
def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
captured = []
|
||||
|
||||
class CapturingCore(FakeFastWAMCore):
|
||||
def infer_action(self, **kwargs):
|
||||
captured.append(
|
||||
{
|
||||
"image_shape": tuple(kwargs["input_image"].shape),
|
||||
"proprio_shape": tuple(kwargs["proprio"].shape),
|
||||
"prompt": kwargs["prompt"],
|
||||
}
|
||||
)
|
||||
return {"action": torch.full((1, kwargs["action_horizon"], 3), float(len(captured)))}
|
||||
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CapturingCore())
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
image_size=(16, 16),
|
||||
input_features={
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
base_model_id=None,
|
||||
)
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
output = policy.forward(
|
||||
{
|
||||
"observation.images.image": torch.zeros(1, 3, 16, 16),
|
||||
OBS_STATE: torch.zeros(1, 2),
|
||||
ACTION: torch.zeros(1, 4, 3),
|
||||
"context": torch.zeros(1, 5, 4096),
|
||||
"context_mask": torch.ones(1, 5, dtype=torch.bool),
|
||||
}
|
||||
)
|
||||
action = policy.predict_action_chunk(
|
||||
{
|
||||
"observation.images.image": torch.stack(
|
||||
[
|
||||
torch.zeros(3, 16, 16),
|
||||
torch.ones(3, 16, 16),
|
||||
]
|
||||
),
|
||||
OBS_STATE: torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
|
||||
"task": ["task 0", "task 1"],
|
||||
}
|
||||
)
|
||||
|
||||
assert output["loss"].item() == 1.0
|
||||
assert output["loss_action"].item() == 1.0
|
||||
assert action.shape == (2, 4, 3)
|
||||
assert action[:, 0, 0].tolist() == [1.0, 2.0]
|
||||
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
|
||||
assert [item["proprio_shape"] for item in captured] == [(1, 2), (1, 2)]
|
||||
assert [item["prompt"] for item in captured] == [
|
||||
cfg.prompt_template.format(task="task 0"),
|
||||
cfg.prompt_template.format(task="task 1"),
|
||||
]
|
||||
|
||||
|
||||
class CoreWithFrozenComponents(FakeFastWAMCore):
|
||||
"""Fake core mirroring the real one: frozen VAE / text encoder held as
|
||||
*unregistered* attributes (via `object.__setattr__`) so they are excluded from
|
||||
`state_dict()` and the saved checkpoint, but still moved by the `_apply` override."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "vae", nn.Linear(2, 2))
|
||||
object.__setattr__(self, "text_encoder", nn.Linear(2, 2))
|
||||
self.vae.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
super()._apply(fn, *args, **kwargs)
|
||||
self.vae._apply(fn)
|
||||
self.text_encoder._apply(fn)
|
||||
return self
|
||||
|
||||
|
||||
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
|
||||
def build_core(self, config):
|
||||
core = CoreWithFrozenComponents()
|
||||
with torch.no_grad():
|
||||
core.dit.weight.fill_(0.5)
|
||||
return core
|
||||
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", build_core)
|
||||
|
||||
reference = FastWAMPolicy(cfg)
|
||||
with torch.no_grad():
|
||||
reference.model.dit.weight.fill_(1.25) # a distinctive, trained-looking weight
|
||||
reference.save_pretrained(tmp_path)
|
||||
|
||||
# Building from Wan2.2 must never happen on a checkpoint load.
|
||||
def fail_if_wan_pretrained_is_loaded(*args, **kwargs):
|
||||
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
|
||||
fail_if_wan_pretrained_is_loaded,
|
||||
)
|
||||
|
||||
policy = FastWAMPolicy.from_pretrained(tmp_path)
|
||||
|
||||
assert isinstance(policy.model, CoreWithFrozenComponents)
|
||||
# The bundled checkpoint weights overwrote the freshly built (0.5) DiT weights.
|
||||
assert torch.allclose(policy.model.dit.weight, torch.full_like(policy.model.dit.weight, 1.25))
|
||||
|
||||
|
||||
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
save_dir = tmp_path / "saved"
|
||||
policy.save_pretrained(save_dir)
|
||||
|
||||
assert (save_dir / "model.safetensors").is_file()
|
||||
# No Wan sidecar files either: the frozen backbone comes from the diffusers repo.
|
||||
assert not (save_dir / "Wan2.2_VAE.safetensors").exists()
|
||||
assert not (save_dir / "google").exists()
|
||||
|
||||
with safe_open(save_dir / "model.safetensors", framework="pt") as f:
|
||||
keys = set(f.keys())
|
||||
# Lean checkpoint: only the trainable DiT is saved; the frozen VAE / UMT5 text
|
||||
# encoder are excluded (loaded from the diffusers/transformers repos at init).
|
||||
assert any(key.startswith("model.dit.") for key in keys)
|
||||
assert not any(key.startswith("model.vae.") for key in keys)
|
||||
assert not any(key.startswith("model.text_encoder.") for key in keys)
|
||||
|
||||
|
||||
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
# Unregistered: excluded from state_dict and from the optimizer's parameter set.
|
||||
sd = policy.state_dict()
|
||||
assert not any(k.startswith("model.vae.") or k.startswith("model.text_encoder.") for k in sd)
|
||||
param_names = [n for n, _ in policy.named_parameters()]
|
||||
assert not any("vae" in n or "text_encoder" in n for n in param_names)
|
||||
|
||||
# ...but the `_apply` override still carries them through `.to()` (dtype stands in
|
||||
# for device on a CPU box), so they never strand off the rest of the model.
|
||||
policy.to(torch.float64)
|
||||
assert policy.model.dit.weight.dtype == torch.float64 # registered
|
||||
assert policy.model.vae.weight.dtype == torch.float64 # unregistered, moved via _apply
|
||||
assert policy.model.text_encoder.weight.dtype == torch.float64
|
||||
|
||||
|
||||
def test_pretrained_config_round_trips_fastwam_features(tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=7, proprio_dim=8, image_size=(224, 448), base_model_id=None)
|
||||
cfg.save_pretrained(tmp_path)
|
||||
|
||||
loaded = PreTrainedConfig.from_pretrained(tmp_path)
|
||||
|
||||
assert loaded.type == "fastwam"
|
||||
assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL
|
||||
assert loaded.action_feature.shape == (7,)
|
||||
assert loaded.robot_state_feature.shape == (8,)
|
||||
|
||||
|
||||
def test_vae_adapter_empty_build_encode_decode_shapes():
|
||||
"""Offline glue check of the diffusers-backed VAE adapter (random weights).
|
||||
|
||||
Validates the encode/decode contract — 48 latent channels, 16x spatial / 4x
|
||||
temporal compression, list-or-batch input, scaling round-trip — without any
|
||||
weight download. (Numerical fidelity vs the original Wan VAE is a separate,
|
||||
GPU + real-weights verification step.)
|
||||
"""
|
||||
pytest.importorskip("diffusers")
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
|
||||
|
||||
# Production always loads a real pretrained VAE from the diffusers repo; here we
|
||||
# build the same architecture with random weights and dummy standardization stats
|
||||
# to exercise the adapter's shape/scaling contract offline (fidelity is checked
|
||||
# separately, with real weights, on GPU).
|
||||
arch = {
|
||||
"base_dim": 160,
|
||||
"decoder_base_dim": 256,
|
||||
"z_dim": 48,
|
||||
"dim_mult": [1, 2, 4, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_scales": [],
|
||||
"temperal_downsample": [False, True, True],
|
||||
"dropout": 0.0,
|
||||
"is_residual": True,
|
||||
"in_channels": 12,
|
||||
"out_channels": 12,
|
||||
"patch_size": 2,
|
||||
"scale_factor_spatial": 16,
|
||||
"scale_factor_temporal": 4,
|
||||
"clip_output": False,
|
||||
"latents_mean": [0.0] * 48,
|
||||
"latents_std": [1.0] * 48,
|
||||
}
|
||||
raw = AutoencoderKLWan.from_config(arch)
|
||||
vae = WanVideoVAE38(dtype=torch.float32, device="cpu", pretrained=raw)
|
||||
assert vae.z_dim == 48
|
||||
assert vae.upsampling_factor == 16
|
||||
assert vae.temporal_downsample_factor == 4
|
||||
|
||||
video = torch.rand(1, 3, 5, 32, 32) * 2 - 1 # [B,C,T,H,W] in [-1,1]
|
||||
latents = vae.encode(video)
|
||||
assert latents.shape == (1, 48, 2, 2, 2) # T'=(5-1)//4+1, H'=W'=32//16
|
||||
|
||||
decoded = vae.decode(latents)
|
||||
assert decoded.shape[0] == 1 and decoded.shape[1] == 3 and decoded.shape[-2:] == (32, 32)
|
||||
assert decoded.min() >= -1.0 and decoded.max() <= 1.0
|
||||
|
||||
# list input is accepted and equals the batched path
|
||||
assert torch.equal(vae.encode([video[0]]), latents)
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
@@ -67,6 +68,7 @@ def test_strategy_config_types():
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
assert EpisodicStrategyConfig().type == "episodic"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategy,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategy,
|
||||
EpisodicStrategyConfig,
|
||||
SentryStrategy,
|
||||
SentryStrategyConfig,
|
||||
create_strategy,
|
||||
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
|
||||
@@ -2830,6 +2830,10 @@ eo1 = [
|
||||
evaluation = [
|
||||
{ name = "av" },
|
||||
]
|
||||
fastwam = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
feetech = [
|
||||
{ name = "deepdiff" },
|
||||
{ name = "feetech-servo-sdk" },
|
||||
@@ -3123,11 +3127,13 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["deepdiff-dep"], marker = "extra == 'hardware'" },
|
||||
{ name = "lerobot", extras = ["dev"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'fastwam'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["fastwam"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
|
||||
@@ -3196,6 +3202,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'fastwam'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
|
||||
@@ -3276,7 +3283,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "fastwam", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user