feat(lingbot_va): RoboTwin eef-pose eval, single-file model, Hub checkpoints

Make the LingBot-VA port runnable on both LIBERO and RoboTwin and clean up the
package to LeRobot conventions.

- Consolidate all vendored Wan2.2 model code (transformer, attention, VAE helpers,
  flow-matching scheduler, grid utils, flex-attention) into a single
  modeling_lingbot_va.py; remove the separate wan_*/schedulers modules.
- Move the fixed action (un)normalization quantiles out of the config and into the
  post-processor (LIBERO 7-DoF + RoboTwin 16-d eef); remove the conversion script in
  favour of ready-to-use LeRobot-format checkpoints on the Hub.
- Fixes found via on-sim validation: undo LIBERO's 180-degree image flip
  (image_hflip), encode obs as a multi-frame streaming-VAE clip, reset the streaming
  VAE cache between episodes, run the transformer in config.dtype, lazy-load frozen
  VAE/UMT5 by subfolder with the text encoder on CPU.
- RoboTwin: add an end-effector-pose action mode to RoboTwinEnv (16-d per-arm
  xyz+quat+gripper deltas composed onto the initial eef pose, executed via CuRobo IK)
  and the robotwin_tshape latent layout (full-res head + half-res wrists via a second
  streaming VAE) with the upstream RoboTwin action quantiles + camera mapping.
- Predicted-video saving works for both benchmarks; docs + tests updated.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-06 15:20:51 +02:00
committed by Maxime Ellerbach
parent d600a52943
commit b81909fc28
20 changed files with 2372 additions and 2834 deletions
+4
View File
@@ -22,6 +22,10 @@ outputs
rl
media
# Local virtualenvs (the image provides its own)
.venv
venv
# Logging
logs
+34 -17
View File
@@ -28,7 +28,7 @@ fed back into the KV cache as the chunk is executed (closed-loop world modeling)
### What the LeRobot Integration Covers
- Standard `policy.type=lingbot_va` configuration through LeRobot.
- Checkpoint conversion from the released HuggingFace checkpoints.
- Ready-to-use LeRobot-format checkpoints on the Hub (converted from the released upstream ones).
- Autoregressive dual-stream inference behind the standard `select_action` interface
(single-environment eval, `--eval.batch_size=1`).
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
@@ -48,40 +48,57 @@ pip install -e ".[lingbot_va]"
pip install -e ".[lingbot_va,libero]"
```
## Checkpoint Conversion
## Checkpoints
The released checkpoints are diffusers-style directories
(`robbyant/lingbot-va-base`, `robbyant/lingbot-va-posttrain-robotwin`,
`robbyant/lingbot-va-posttrain-libero-long`). Convert one to LeRobot format with:
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
```bash
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
--variant libero \
--output_dir outputs/lingbot_va_libero_long
```
| Variant | LeRobot checkpoint |
|---|---|
| LIBERO-Long post-train | `pepijn223/lingbot_va_libero_long` |
| RoboTwin post-train | `pepijn223/lingbot_va_robotwin` |
| Pretrained base | `pepijn223/lingbot_va_base` |
**Packaging:** only the trainable ~5B transformer is stored in the LeRobot
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are **lazily pulled** from
`config.wan_pretrained_path` at load time (defaults to the source repo). Pass
`--bundle-frozen` to copy those sub-folders next to the converted checkpoint instead.
Run conversion on a Linux machine with a CUDA GPU and enough RAM/VRAM to materialize the
transformer.
`config.wan_pretrained_path` at load time (defaults to the source `robbyant/*` repo). The
UMT5-XXL text encoder runs on CPU by default (`config.text_encoder_device`) so the 5B
transformer + VAE fit on a single 2432 GB GPU.
## Evaluation (LIBERO)
```bash
lerobot-eval \
--policy.path=outputs/lingbot_va_libero_long \
--policy.path=pepijn223/lingbot_va_libero_long \
--policy.device=cuda \
--env.type=libero --env.task=libero_10 \
--env.observation_height=128 --env.observation_width=128 \
--eval.n_episodes=50 --eval.batch_size=1 \
--output_dir=outputs/eval/lingbot_va_libero
```
Native LeRobot eval reproduces **96% success on `libero_10` (LIBERO-Long)** (48/50 episodes).
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
single-environment eval; use `--eval.batch_size=1`.
## Evaluation (RoboTwin)
RoboTwin 2.0 needs the SAPIEN + CuRobo simulator stack — use the benchmark Docker image
(`docker/Dockerfile.benchmark.robotwin`, which also needs `warp-lang==1.3.1` and CuRobo built
with the GPU's compute capability in `TORCH_CUDA_ARCH_LIST`). RoboTwin uses **end-effector-pose
control**, so run with `--env.action_mode=ee`: the policy predicts per-arm `xyz+quaternion+gripper`
deltas (`robotwin_tshape` latent layout) that are composed onto the episode's initial eef pose and
executed via CuRobo IK.
```bash
lerobot-eval \
--policy.path=pepijn223/lingbot_va_robotwin \
--policy.device=cuda \
--env.type=robotwin --env.task=beat_block_hammer --env.action_mode=ee \
--eval.n_episodes=10 --eval.batch_size=1 \
--output_dir=outputs/eval/lingbot_va_robotwin
```
### Saving predicted (imagined) videos
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
+1 -1
View File
@@ -228,7 +228,7 @@ vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen
# LingBot-VA needs the Wan2.2 stack (AutoencoderKLWan z_dim=48 + WanTransformer3DModel config schema),
# which only exists in diffusers>=0.36. Pin the floor explicitly so a standalone `lerobot[lingbot_va]`
# install can't resolve to a pre-Wan2.2 diffusers via the looser diffusers-dep floor.
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]"]
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]", "accelerate>=1.10.0,<2.0.0", "ftfy>=6.0.0,<7.0.0"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
+6
View File
@@ -768,6 +768,9 @@ class RoboTwinEnvConfig(EnvConfig):
# must equal what SAPIEN actually renders.
observation_height: int = 240
observation_width: int = 320
# "joint": 14-d joint-space control. "ee": 16-d end-effector-pose deltas executed via CuRobo IK
# (for world-model policies like LingBot-VA that predict per-arm xyz+quaternion+gripper poses).
action_mode: str = "joint"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
@@ -784,6 +787,8 @@ class RoboTwinEnvConfig(EnvConfig):
)
def __post_init__(self):
if self.action_mode == "ee":
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(16,))
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
for cam in cam_list:
self.features[f"pixels/{cam}"] = PolicyFeature(
@@ -826,6 +831,7 @@ class RoboTwinEnvConfig(EnvConfig):
observation_height=self.observation_height,
observation_width=self.observation_width,
episode_length=self.episode_length,
action_mode=self.action_mode,
)
+67 -5
View File
@@ -41,10 +41,42 @@ ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
"right_camera",
)
ACTION_DIM = 14 # 7 DOF × 2 arms
ACTION_DIM = 14 # 7 DOF × 2 arms (joint-space control mode)
# End-effector-pose control mode: per arm [x, y, z, qx, qy, qz, qw, gripper] = 8, dual-arm = 16.
# Used by world-model policies (e.g. LingBot-VA) that predict eef-pose deltas executed via CuRobo IK.
EEF_ACTION_DIM = 16
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
DEFAULT_EPISODE_LENGTH = 300
def _compose_eef_pose(new_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
"""Compose a single-arm predicted delta pose onto the initial pose.
``new_pose`` / ``init_pose`` are 8-vectors ``[x, y, z, qx, qy, qz, qw, gripper]``. Translation
is added, rotation is composed (``init_R * new_R``), and the gripper is taken from the
prediction. Mirrors ``add_eef_pose`` in the upstream LingBot-VA RoboTwin client.
"""
from scipy.spatial.transform import Rotation
new_r = Rotation.from_quat(new_pose[3:7])
init_r = Rotation.from_quat(init_pose[3:7])
out_rot = (init_r * new_r).as_quat()
out_trans = new_pose[:3] + init_pose[:3]
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
def _add_init_eef_pose(delta_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
"""Compose a dual-arm (16-d) predicted delta pose onto the initial eef pose, normalizing quats."""
left = _compose_eef_pose(delta_pose[:8], init_pose[:8])
right = _compose_eef_pose(delta_pose[8:], init_pose[8:])
out = np.concatenate([left, right])
# Normalize the two quaternions (indices 3:7 and 11:15) as the upstream client does.
out[3:7] = out[3:7] / (np.linalg.norm(out[3:7]) + 1e-8)
out[11:15] = out[11:15] / (np.linalg.norm(out[11:15]) + 1e-8)
return out
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
DEFAULT_CAMERA_H = 240
DEFAULT_CAMERA_W = 320
@@ -234,6 +266,7 @@ class RoboTwinEnv(gym.Env):
observation_width: int | None = None,
episode_length: int = DEFAULT_EPISODE_LENGTH,
render_mode: str = "rgb_array",
action_mode: str = "joint",
):
super().__init__()
self.task_name = task_name
@@ -241,6 +274,13 @@ class RoboTwinEnv(gym.Env):
self.task_description = task_name.replace("_", " ")
self.episode_index = episode_index
self._reset_stride = n_envs
# "joint": 14-d joint-space actions via take_action(action). "ee": 16-d end-effector-pose
# deltas (added onto the episode's initial eef pose) executed via take_action(.., "ee") + IK.
if action_mode not in ("joint", "ee"):
raise ValueError(f"action_mode must be 'joint' or 'ee'; got {action_mode!r}")
self.action_mode = action_mode
self._action_dim = EEF_ACTION_DIM if action_mode == "ee" else ACTION_DIM
self._init_eef_pose: np.ndarray | None = None
self.camera_names = list(camera_names)
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
# The YAML-driven lookup is deferred to reset() so construction doesn't
@@ -271,7 +311,7 @@ class RoboTwinEnv(gym.Env):
}
)
self.action_space = spaces.Box(
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
low=ACTION_LOW, high=ACTION_HIGH, shape=(self._action_dim,), dtype=np.float32
)
def _ensure_env(self) -> None:
@@ -317,6 +357,17 @@ class RoboTwinEnv(gym.Env):
return {"pixels": images, "agent_pos": joint_state}
def _read_eef_pose(self) -> np.ndarray:
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
ep = self._env.get_obs()["endpose"]
pose = (
list(ep["left_endpose"])
+ [ep["left_gripper"]]
+ list(ep["right_endpose"])
+ [ep["right_gripper"]]
)
return np.asarray(pose, dtype=np.float64)
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
self._ensure_env()
super().reset(seed=seed)
@@ -330,16 +381,23 @@ class RoboTwinEnv(gym.Env):
self.episode_index += self._reset_stride
self._step_count = 0
# In eef mode the policy predicts pose deltas relative to the initial eef pose.
if self.action_mode == "ee":
self._init_eef_pose = self._read_eef_pose()
obs = self._get_obs()
return obs, {"is_success": False, "task": self.task_name}
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
assert self._env is not None, "step() called before reset()"
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
if action.ndim != 1 or action.shape[0] != self._action_dim:
raise ValueError(f"Expected 1-D action of shape ({self._action_dim},), got {action.shape}")
with torch.enable_grad():
if hasattr(self._env, "take_action"):
if self.action_mode == "ee":
ee_action = _add_init_eef_pose(np.asarray(action, dtype=np.float64), self._init_eef_pose)
self._env.take_action(ee_action, action_type="ee")
elif hasattr(self._env, "take_action"):
self._env.take_action(action)
else:
self._env.step(action)
@@ -398,6 +456,7 @@ def _make_env_fns(
observation_height: int,
observation_width: int,
episode_length: int,
action_mode: str = "joint",
) -> list[Callable[[], RoboTwinEnv]]:
"""Return n_envs factory callables for a single task."""
@@ -410,6 +469,7 @@ def _make_env_fns(
observation_height=observation_height,
observation_width=observation_width,
episode_length=episode_length,
action_mode=action_mode,
)
return [partial(_make_one, i) for i in range(n_envs)]
@@ -423,6 +483,7 @@ def create_robotwin_envs(
observation_height: int = DEFAULT_CAMERA_H,
observation_width: int = DEFAULT_CAMERA_W,
episode_length: int = DEFAULT_EPISODE_LENGTH,
action_mode: str = "joint",
) -> dict[str, dict[int, Any]]:
"""Create vectorized RoboTwin 2.0 environments.
@@ -473,6 +534,7 @@ def create_robotwin_envs(
observation_height=observation_height,
observation_width=observation_width,
episode_length=episode_length,
action_mode=action_mode,
)
if is_async:
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
@@ -31,27 +31,6 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION
# Upstream LIBERO action-normalization quantiles (single 7-DoF arm + gripper).
# Verbatim from wan_va/configs/va_libero_cfg.py (channels 0-6 of a 30-dim action space).
LIBERO_ACTION_Q01 = [
-0.6589285731315613,
-0.84375,
-0.9375,
-0.12107142806053162,
-0.15964286029338837,
-0.26571428775787354,
-1.0,
]
LIBERO_ACTION_Q99 = [
0.8999999761581421,
0.8544642925262451,
0.9375,
0.17142857611179352,
0.1842857152223587,
0.34392857551574707,
1.0,
]
@PreTrainedConfig.register_subclass("lingbot_va")
@dataclass
@@ -84,12 +63,27 @@ class LingBotVAConfig(PreTrainedConfig):
wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long"
# dtype used for the transformer / VAE / text-encoder weights at inference.
dtype: str = "bfloat16" # one of "bfloat16", "float16", "float32"
# Device for the frozen UMT5-XXL text encoder. It encodes the (fixed) instruction once per
# episode, so keeping it on CPU frees ~11 GB of VRAM and lets the 5B transformer + VAE fit on
# a single 24-32 GB GPU. Set to "cuda" if you have the headroom and want faster prompt encoding.
text_encoder_device: str = "cpu"
# ── Observation cameras (order matters: latents are concatenated on width) ──
# Defaults match the LIBERO env feature keys (agentview -> image, eye-in-hand -> image2).
obs_cam_keys: list[str] = field(
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
)
# Horizontally flip the camera images before encoding. LeRobot's LIBERO env processor rotates
# frames 180° (flip H *and* W; the HuggingFaceVLA convention), but upstream LingBot-VA trains /
# evaluates on vertically-flipped-only frames (``obs[::-1]`` in evaluation/libero/client.py).
# Undoing the extra horizontal flip here realigns the input with the model's training orientation.
image_hflip: bool = False
# Latent assembly layout for the observation cameras:
# "width_concat" : encode every camera at (height, width) and concat latents on width (LIBERO).
# "robotwin_tshape" : head camera at full (height, width), the two wrist cameras at half
# resolution, assembled in a "T" (wrists side-by-side on top of the head
# on the height axis) using a second streaming VAE (RoboTwin).
camera_layout: str = "width_concat"
# ── Inference hyperparameters (LIBERO defaults) ──
n_obs_steps: int = 1
@@ -108,10 +102,9 @@ class LingBotVAConfig(PreTrainedConfig):
max_sequence_length: int = 512 # UMT5 prompt length
# Subset of the 30-d action space actually used by the benchmark (LIBERO = 7-DoF).
# The fixed action (un)normalization quantiles live in the post-processor
# (``LingBotVAActionUnnormalizeStep`` in ``processor_lingbot_va.py``), not here.
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
# Fixed quantiles for action (un)normalization on the *used* channels.
action_q01: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q01))
action_q99: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q99))
# Opt-in: VAE-decode the predicted video latents and stash them on
# ``self.last_predicted_frames`` so eval/train can save predicted-video MP4s.
@@ -140,13 +133,6 @@ class LingBotVAConfig(PreTrainedConfig):
super().__post_init__()
if self.attn_mode not in ("torch", "flashattn", "flex"):
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
if len(self.action_q01) != len(self.used_action_channel_ids) or len(self.action_q99) != len(
self.used_action_channel_ids
):
raise ValueError(
"action_q01 / action_q99 must each have one entry per used_action_channel_ids "
f"({len(self.used_action_channel_ids)}); got {len(self.action_q01)} / {len(self.action_q99)}."
)
@property
def chunk_size(self) -> int:
@@ -1,256 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a released LingBot-VA HuggingFace checkpoint to LeRobot format.
The released checkpoints are diffusers-style directories with ``transformer/``, ``vae/``,
``text_encoder/`` and ``tokenizer/`` sub-folders. This script:
1. loads the (sharded) ``transformer/`` weights with the vendored ``WanTransformer3DModel``;
2. builds a :class:`LingBotVAConfig` for the target benchmark variant;
3. instantiates a :class:`LingBotVAPolicy` and copies the transformer weights into it
(near-identity: the only key change is the ``transformer.`` prefix);
4. saves the LeRobot policy (``model.safetensors`` + ``config.json``) and its processors.
Packaging decision: only the trainable ~5B transformer is bundled into the LeRobot
``model.safetensors``. The frozen VAE + UMT5 text encoder + tokenizer (~20 GB) are NOT
copied; instead ``config.wan_pretrained_path`` records where to lazily pull them from at
load time (defaults to the source repo/dir). Pass ``--bundle-frozen`` to additionally copy
those sub-folders next to the converted checkpoint and point ``wan_pretrained_path`` at it.
Example (LIBERO-Long, the LIBERO eval gate):
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
--variant libero \
--output_dir outputs/lingbot_va_libero_long
Requires a CUDA GPU with enough RAM/VRAM to materialize the transformer; run on Linux.
"""
import argparse
import shutil
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
from lerobot.policies.lingbot_va.wan_transformer import WanTransformer3DModel
from lerobot.utils.constants import ACTION, OBS_IMAGES
# Per-benchmark variant presets (camera keys + action layout). Values mirror the upstream
# configs (wan_va/configs/va_*_cfg.py).
VARIANTS = {
"libero": {
"obs_cam_keys": [f"{OBS_IMAGES}.image", f"{OBS_IMAGES}.image2"],
"height": 128,
"width": 128,
"action_per_frame": 4,
"frame_chunk_size": 4,
"attn_window": 30,
"num_inference_steps": 20,
"action_num_inference_steps": 50,
"guidance_scale": 5.0,
"action_guidance_scale": 1.0,
"snr_shift": 5.0,
"action_snr_shift": 0.05,
"used_action_channel_ids": list(range(7)),
# 7-DoF: agentview + eye-in-hand, single arm. Quantiles are the config defaults.
"image_shape": (3, 256, 256),
},
"robotwin": {
"obs_cam_keys": [
f"{OBS_IMAGES}.cam_high",
f"{OBS_IMAGES}.cam_left_wrist",
f"{OBS_IMAGES}.cam_right_wrist",
],
"height": 256,
"width": 320,
"action_per_frame": 16,
"frame_chunk_size": 2,
"attn_window": 72,
"num_inference_steps": 25,
"action_num_inference_steps": 50,
"guidance_scale": 5.0,
"action_guidance_scale": 1.0,
"snr_shift": 5.0,
"action_snr_shift": 1.0,
# RoboTwin is dual-arm; set the used channels / quantiles to match the deployed config.
"used_action_channel_ids": list(range(14)),
"image_shape": (3, 256, 256),
},
}
def _transformer_dir(checkpoint: str) -> str:
"""Return the path/repo that ``WanTransformer3DModel.from_pretrained`` should read."""
p = Path(checkpoint)
if p.is_dir():
return str(p / "transformer")
return checkpoint # HF repo id; use subfolder kwarg below
def load_source_transformer(checkpoint: str, dtype: torch.dtype) -> WanTransformer3DModel:
p = Path(checkpoint)
if p.is_dir():
return WanTransformer3DModel.from_pretrained(
str(p / "transformer"), torch_dtype=dtype, attn_mode="torch"
)
return WanTransformer3DModel.from_pretrained(
checkpoint, subfolder="transformer", torch_dtype=dtype, attn_mode="torch"
)
def build_config(variant: str, wan_pretrained_path: str, dtype: str) -> LingBotVAConfig:
preset = VARIANTS[variant]
n_used = len(preset["used_action_channel_ids"])
kwargs = {
"wan_pretrained_path": wan_pretrained_path,
"dtype": dtype,
"obs_cam_keys": preset["obs_cam_keys"],
"height": preset["height"],
"width": preset["width"],
"action_per_frame": preset["action_per_frame"],
"frame_chunk_size": preset["frame_chunk_size"],
"attn_window": preset["attn_window"],
"num_inference_steps": preset["num_inference_steps"],
"action_num_inference_steps": preset["action_num_inference_steps"],
"guidance_scale": preset["guidance_scale"],
"action_guidance_scale": preset["action_guidance_scale"],
"snr_shift": preset["snr_shift"],
"action_snr_shift": preset["action_snr_shift"],
"used_action_channel_ids": preset["used_action_channel_ids"],
"device": "cpu",
}
if variant != "libero":
# LIBERO keeps the config default quantiles; other variants need their own. Until the
# exact per-channel quantiles are wired in, use a neutral [-1, 1] mapping (no rescale).
kwargs["action_q01"] = [-1.0] * n_used
kwargs["action_q99"] = [1.0] * n_used
cfg = LingBotVAConfig(**kwargs)
# Populate input/output features (cameras + action) so validate_features passes.
img_shape = preset["image_shape"]
cfg.input_features = {
k: PolicyFeature(type=FeatureType.VISUAL, shape=img_shape) for k in preset["obs_cam_keys"]
}
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(n_used,))}
cfg.validate_features()
return cfg
def convert(
checkpoint: str, variant: str, output_dir: str, dtype: str, bundle_frozen: bool, push_to: str | None
):
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[dtype]
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
# Decide where frozen modules will be pulled from at load time.
if bundle_frozen:
wan_pretrained_path = str(out)
_copy_frozen_subfolders(checkpoint, out)
else:
wan_pretrained_path = checkpoint
print(f"Building LingBot-VA config for variant '{variant}' (frozen modules from: {wan_pretrained_path})")
cfg = build_config(variant, wan_pretrained_path, dtype)
print("Loading source transformer weights ...")
src = load_source_transformer(checkpoint, torch_dtype)
src_sd = src.state_dict()
print("Instantiating LingBotVAPolicy and copying transformer weights ...")
# Build the policy without triggering frozen-module download by constructing directly.
policy = LingBotVAPolicy(cfg)
# Near-identity remap: source transformer keys -> policy "transformer.*".
remapped = {f"transformer.{k}": v for k, v in src_sd.items()}
missing, unexpected = policy.load_state_dict(remapped, strict=False)
_log_load_keys(missing, unexpected)
policy = policy.to(torch_dtype)
print(f"Saving converted policy to {out}")
policy.save_pretrained(out)
preprocessor, postprocessor = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
preprocessor.save_pretrained(out)
postprocessor.save_pretrained(out)
if push_to:
print(f"Pushing to the Hub: {push_to}")
policy.push_to_hub(push_to)
preprocessor.push_to_hub(push_to)
postprocessor.push_to_hub(push_to)
print("Done.")
def _copy_frozen_subfolders(checkpoint: str, out: Path):
p = Path(checkpoint)
if not p.is_dir():
from huggingface_hub import snapshot_download
p = Path(snapshot_download(checkpoint, allow_patterns=["vae/*", "text_encoder/*", "tokenizer/*"]))
for sub in ("vae", "text_encoder", "tokenizer"):
src_sub = p / sub
if src_sub.is_dir():
shutil.copytree(src_sub, out / sub, dirs_exist_ok=True)
print(f" bundled {sub}/")
def _log_load_keys(missing, unexpected):
# The source transformer should account for every "transformer.*" key in the policy.
if missing:
print(
f" [load_state_dict] {len(missing)} missing keys (expected: none for transformer). Sample: {missing[:5]}"
)
if unexpected:
print(f" [load_state_dict] {len(unexpected)} unexpected keys. Sample: {unexpected[:5]}")
if not missing and not unexpected:
print(" [load_state_dict] perfect match (near-identity remap).")
def main():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--checkpoint", required=True, help="HF repo id or local diffusers-style directory.")
parser.add_argument("--variant", required=True, choices=sorted(VARIANTS.keys()))
parser.add_argument("--output_dir", required=True)
parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
parser.add_argument(
"--bundle-frozen",
action="store_true",
help="Copy the frozen vae/text_encoder/tokenizer next to the checkpoint instead of lazy-pulling.",
)
parser.add_argument(
"--push_to_hub", default=None, help="Optional HF repo id to push the converted policy to."
)
args = parser.parse_args()
convert(
checkpoint=args.checkpoint,
variant=args.variant,
output_dir=args.output_dir,
dtype=args.dtype,
bundle_frozen=args.bundle_frozen,
push_to=args.push_to_hub,
)
if __name__ == "__main__":
main()
File diff suppressed because it is too large Load Diff
@@ -47,6 +47,60 @@ from lerobot.utils.constants import (
from .configuration_lingbot_va import LingBotVAConfig
# Upstream LIBERO action-normalization quantiles (single 7-DoF arm + gripper).
# Verbatim from wan_va/configs/va_libero_cfg.py (channels 0-6 of a 30-dim action space).
# These are the fixed (un)normalization stats baked into the released LIBERO checkpoint; they
# live here (in the processor) and are serialized into the saved post-processor config.
LIBERO_ACTION_Q01 = [
-0.6589285731315613,
-0.84375,
-0.9375,
-0.12107142806053162,
-0.15964286029338837,
-0.26571428775787354,
-1.0,
]
LIBERO_ACTION_Q99 = [
0.8999999761581421,
0.8544642925262451,
0.9375,
0.17142857611179352,
0.1842857152223587,
0.34392857551574707,
1.0,
]
# Upstream RoboTwin action quantiles, reordered to the model's used-channel layout
# [left xyz+quat (0-6), left gripper (28), right xyz+quat (7-13), right gripper (29)] = 16 channels.
# Verbatim from wan_va/configs/va_robotwin_cfg.py ``norm_stat`` (quaternion + gripper channels use the
# neutral [-1, 1] / [0, 1] mapping). Positions are quantile-scaled; rotations pass through.
ROBOTWIN_ACTION_Q01 = [
-0.06172713458538055, -3.6716461181640625e-05, -0.08783501386642456, -1.0, -1.0, -1.0, -1.0,
0.0,
-0.3547105032205582, -1.3113021850585938e-06, -0.11975435614585876, -1.0, -1.0, -1.0, -1.0,
0.0,
] # fmt: skip
ROBOTWIN_ACTION_Q99 = [
0.3462600058317184, 0.39966784834861746, 0.14745532035827624, 1.0, 1.0, 1.0, 1.0,
1.0,
0.034201726913452024, 0.39142737388610793, 0.1792279863357542, 1.0, 1.0, 1.0, 1.0,
1.0,
] # fmt: skip
def _default_action_quantiles(n_used: int) -> tuple[list[float], list[float]]:
"""Return the fixed (q01, q99) for the used action channels, by benchmark channel count.
LIBERO = 7 (single 7-DoF arm), RoboTwin = 16 (dual-arm eef pose + grippers). Falls back to a
neutral ``[-1, 1]`` mapping (no rescale) for any other channel count.
"""
if n_used == len(LIBERO_ACTION_Q01):
return list(LIBERO_ACTION_Q01), list(LIBERO_ACTION_Q99)
if n_used == len(ROBOTWIN_ACTION_Q01):
return list(ROBOTWIN_ACTION_Q01), list(ROBOTWIN_ACTION_Q99)
return [-1.0] * n_used, [1.0] * n_used
@dataclass
@ProcessorStepRegistry.register(name="lingbot_va_action_unnormalize")
@@ -94,8 +148,9 @@ def make_lingbot_va_pre_post_processors(
DeviceProcessorStep(device=config.device),
]
action_q01, action_q99 = _default_action_quantiles(len(config.used_action_channel_ids))
output_steps: list[ProcessorStep] = [
LingBotVAActionUnnormalizeStep(action_q01=config.action_q01, action_q99=config.action_q99),
LingBotVAActionUnnormalizeStep(action_q01=action_q01, action_q99=action_q99),
DeviceProcessorStep(device="cpu"),
]
@@ -1,155 +0,0 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flow-matching scheduler for LingBot-VA.
Vendored verbatim from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/scheduler.py``). LingBot-VA uses
two independent instances of this scheduler at inference time one for the video-latent
stream and one for the action stream each with its own ``shift`` (signal-to-noise ratio
shift) and number of denoising steps.
"""
import math
import torch
__all__ = ["FlowMatchScheduler"]
class FlowMatchScheduler:
def __init__(
self,
num_inference_steps=100,
num_train_timesteps=1000,
shift=3.0,
sigma_max=1.0,
sigma_min=0.003 / 1.002,
inverse_timesteps=False,
extra_one_step=False,
reverse_sigmas=False,
exponential_shift=False,
exponential_shift_mu=None,
shift_terminal=None,
):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.exponential_shift = exponential_shift
self.exponential_shift_mu = exponential_shift_mu
self.shift_terminal = shift_terminal
self.set_timesteps(num_inference_steps)
def set_timesteps(
self,
num_inference_steps=100,
denoising_strength=1.0,
training=False,
shift=None,
dynamic_shift_len=None,
):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
if self.exponential_shift:
mu = (
self.calculate_shift(dynamic_shift_len)
if dynamic_shift_len is not None
else self.exponential_shift_mu
)
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
else:
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.shift_terminal is not None:
one_minus_z = 1 - self.sigmas
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
self.sigmas = 1 - (one_minus_z / scale_factor)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_ = self.sigmas[timestep_id + 1]
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
def return_to_timestep(self, timestep, sample, sample_stablized):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
model_output = (sample - sample_stablized) / sigma
return model_output
def add_noise(self, original_samples, noise, timestep, t_dim=2):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep = timestep[None]
timestep_id = torch.argmin((self.timesteps[:, None] - timestep).abs(), dim=0)
shape = [1] * noise.ndim
shape[t_dim] = timestep_id.shape[0]
sigma = self.sigmas[timestep_id].to(original_samples).view(shape)
sample = (1 - sigma) * original_samples + sigma * noise
return sample
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
timestep_id = torch.argmin(
(self.timesteps[:, None].to(timestep.device) - timestep[None]).abs(), dim=0
)
weights = self.linear_timesteps_weights.to(timestep.device)[timestep_id].to(timestep.device)
return weights
def calculate_shift(
self,
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 8192,
base_shift: float = 0.5,
max_shift: float = 0.9,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
@@ -1,286 +0,0 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention and rotary-position-embedding modules for the LingBot-VA Wan transformer.
Vendored and lightly adapted from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``). The ``torch``
SDPA backend is the default and is always available; the ``flashattn`` and ``flex``
backends are imported lazily and only required when the corresponding ``attn_mode`` is
selected. State-dict parameter names are preserved verbatim so that conversion from the
original diffusers-style checkpoint is near-identity.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# ``flash_attn`` and the flex-attention APIs are optional. We import them lazily inside the
# backends that need them so that the (default) ``torch`` SDPA path works on any platform,
# including CPU-only and macOS where neither package is available.
def custom_sdpa(q, k, v):
"""Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors."""
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))
return out.transpose(1, 2)
def _load_flash_attn_func():
try:
from flash_attn_interface import flash_attn_func
except ImportError:
try:
from flash_attn import flash_attn_func
except ImportError as e:
raise ImportError(
"attn_mode='flashattn' requires the `flash_attn` package, which is not installed. "
"Install it, or use attn_mode='torch' (the default)."
) from e
return flash_attn_func
class WanRotaryPosEmbed(nn.Module):
"""Rotary position embedding with separate frequency bases for frame / height / width."""
def __init__(
self,
attention_head_dim: int,
patch_size,
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.theta = theta
self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim // 3)
self.h_dim = self.attention_head_dim // 3
self.w_dim = self.attention_head_dim // 3
f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base()
self.f_freqs_base = f_freqs_base
self.h_freqs_base = h_freqs_base
self.w_freqs_base = w_freqs_base
def _precompute_freqs_base(self):
# freqs_base = 1.0 / (theta ** (2k / dim))
f_freqs_base = 1.0 / (
self.theta ** (torch.arange(0, self.f_dim, 2)[: (self.f_dim // 2)].double() / self.f_dim)
)
h_freqs_base = 1.0 / (
self.theta ** (torch.arange(0, self.h_dim, 2)[: (self.h_dim // 2)].double() / self.h_dim)
)
w_freqs_base = 1.0 / (
self.theta ** (torch.arange(0, self.w_dim, 2)[: (self.w_dim // 2)].double() / self.w_dim)
)
return f_freqs_base, h_freqs_base, w_freqs_base
def forward(self, grid_ids):
with torch.no_grad():
f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base.to(grid_ids.device)
h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base.to(grid_ids.device)
w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base.to(grid_ids.device)
freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
class WanAttention(nn.Module):
"""Self/cross attention with KV-caching for autoregressive streaming inference.
Backends:
* ``torch`` (default): standard SDPA, available everywhere.
* ``flashattn``: FlashAttention kernels (optional dependency).
* ``flex``: PyTorch flex-attention (optional, used for block-causal training masks).
"""
def __init__(
self,
dim,
heads=8,
dim_head=64,
eps=1e-5,
dropout=0.0,
cross_attention_dim_head=None,
attn_mode="torch",
):
super().__init__()
if attn_mode == "torch":
self.attn_op = custom_sdpa
elif attn_mode == "flashattn":
self.attn_op = _load_flash_attn_func()
elif attn_mode == "flex":
# Imported lazily to avoid a hard dependency on torch flex-attention at import time.
from .wan_flex_attention import FlexAttnFunc
self.attn_op = FlexAttnFunc(cross_attention_dim_head is not None)
else:
raise ValueError(
f"Unsupported attention mode: {attn_mode}, only support 'torch', 'flashattn' and 'flex'"
)
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = (
self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
)
self.to_q = nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = nn.ModuleList(
[
nn.Linear(self.inner_dim, dim, bias=True),
nn.Dropout(dropout),
]
)
self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
# KV cache only lives on self-attention modules (cross_attention_dim_head is None).
self.attn_caches = {} if cross_attention_dim_head is None else None
def clear_pred_cache(self, cache_name):
if self.attn_caches is None:
return
cache = self.attn_caches[cache_name]
is_pred = cache["is_pred"]
cache["mask"][is_pred] = False
def clear_cache(self, cache_name):
if self.attn_caches is None:
return
self.attn_caches[cache_name] = None
def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, device, dtype, batch_size):
if self.attn_caches is None:
return
self.attn_caches[cache_name] = {
"k": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
"v": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
"id": torch.full((total_tolen,), -1, device=device),
"mask": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
"is_pred": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
}
def allocate_slots(self, cache_name, key_size):
cache = self.attn_caches[cache_name]
mask = cache["mask"]
ids = cache["id"]
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
if free.numel() < key_size:
used = mask.nonzero(as_tuple=False).squeeze(-1)
used_ids = ids[used]
order = torch.argsort(used_ids)
need = key_size - free.numel()
to_free = used[order[:need]]
mask[to_free] = False
ids[to_free] = -1
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
assert free.numel() >= key_size
return free[:key_size]
def _next_cache_id(self, cache_name):
ids = self.attn_caches[cache_name]["id"]
mask = self.attn_caches[cache_name]["mask"]
if mask.any():
return ids[mask].max() + 1
else:
return torch.tensor(0, device=ids.device, dtype=ids.dtype)
def update_cache(self, cache_name, key, value, is_pred):
cache = self.attn_caches[cache_name]
key_size = key.shape[1]
slots = self.allocate_slots(cache_name, key_size)
new_id = self._next_cache_id(cache_name)
cache["k"][:, slots] = key
cache["v"][:, slots] = value
cache["mask"][slots] = True
cache["id"][slots] = new_id
cache["is_pred"][slots] = is_pred
return slots
def restore_cache(self, cache_name, slots):
self.attn_caches[cache_name]["mask"][slots] = False
def forward(
self,
q,
k,
v,
rotary_emb,
update_cache=0,
cache_name="pos",
):
kv_cache = (
self.attn_caches[cache_name]
if (self.attn_caches is not None) and (cache_name in self.attn_caches)
else None
)
query, key, value = self.to_q(q), self.to_k(k), self.to_v(v)
query = self.norm_q(query)
query = query.unflatten(2, (self.heads, -1))
key = self.norm_k(key)
key = key.unflatten(2, (self.heads, -1))
value = value.unflatten(2, (self.heads, -1))
if rotary_emb is not None:
def apply_rotary_emb(x, freqs):
x_out = torch.view_as_complex(
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
)
x_out = torch.view_as_real(x_out * freqs).flatten(3)
return x_out.to(x.dtype)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
slots = None
if kv_cache is not None and kv_cache["k"] is not None:
slots = self.update_cache(cache_name, key, value, is_pred=(update_cache == 1))
key_pool = self.attn_caches[cache_name]["k"]
value_pool = self.attn_caches[cache_name]["v"]
mask = self.attn_caches[cache_name]["mask"]
valid = mask.nonzero(as_tuple=False).squeeze(-1)
key = key_pool[:, valid]
value = value_pool[:, valid]
hidden_states = self.attn_op(query, key, value)
if update_cache == 0:
if kv_cache is not None and kv_cache["k"] is not None:
self.restore_cache(cache_name, slots)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
__all__ = ["WanAttention", "WanRotaryPosEmbed", "custom_sdpa"]
@@ -1,207 +0,0 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flex-attention backend for the LingBot-VA Wan transformer (training only).
This module is imported lazily and ONLY when ``attn_mode='flex'`` is requested. It builds
the block-causal / window / noise-vs-clean attention masks used during the dual-stream
flow-matching training described in the LingBot-VA paper. Inference uses the ``torch``
SDPA backend (see :mod:`wan_attention`) which does not need flex-attention.
``torch.nn.attention.flex_attention`` requires a recent PyTorch build with the relevant
inductor support; importing this module on an unsupported build raises ``ImportError``.
"""
from collections.abc import Callable
from functools import partial
from typing import ClassVar
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.attention.flex_attention import (
BlockMask,
and_masks,
create_block_mask,
flex_attention,
or_masks,
)
class FlexAttnFunc(nn.Module):
flex_attn: ClassVar[Callable] = torch.compile(flex_attention, dynamic=True)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
attention_mask: ClassVar[BlockMask] = None
cross_attention_mask: ClassVar[BlockMask] = None
def __init__(self, is_cross=False) -> None:
super().__init__()
self.is_cross = is_cross
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype=torch.bfloat16,
) -> torch.Tensor:
q_varlen = rearrange(query[0], "s n d -> 1 n s d")
k_varlen = rearrange(key[0], "s n d -> 1 n s d")
v_varlen = rearrange(value[0], "s n d -> 1 n s d")
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
q_varlen = half(q_varlen)
k_varlen = half(k_varlen)
v_varlen = half(v_varlen)
q_varlen = q_varlen.to(v_varlen.dtype)
k_varlen = k_varlen.to(v_varlen.dtype)
block_mask = FlexAttnFunc.cross_attention_mask if self.is_cross else FlexAttnFunc.attention_mask
x_out = FlexAttnFunc.flex_attn(
q_varlen,
k_varlen,
v_varlen,
block_mask=block_mask,
kernel_options={
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_M1": 32,
"BLOCK_N1": 64,
"BLOCK_M2": 64,
"BLOCK_N2": 32,
},
)
x_out = rearrange(x_out, "b n s d -> b s n d")
return x_out
@staticmethod
@torch.no_grad()
def init_mask(
latent_shape,
action_shape,
padded_length,
chunk_size,
window_size,
patch_size,
device,
):
torch._inductor.config.realize_opcount_threshold = 100
B, _, L_F, L_H, L_W = latent_shape
_, _, A_F, A_H, A_W = action_shape
latent_seq_id = (
torch.arange(B)[:, None, None, None]
.expand(-1, L_F // patch_size[0], L_H // patch_size[1], L_W // patch_size[2])
.flatten()
)
action_seq_id = torch.arange(B)[:, None, None, None].expand(-1, A_F, A_H, A_W).flatten()
seq_ids = torch.cat([latent_seq_id] * 2 + [action_seq_id] * 2)
latent_frame_id = (
torch.arange(L_F)[None, :, None, None]
.expand(B, -1, L_H // patch_size[1], L_W // patch_size[2])[None]
.flatten()
)
action_frame_id = torch.arange(A_F)[None, :, None, None].expand(B, -1, A_H, A_W)[None].flatten()
frame_ids = torch.cat(
[latent_frame_id // chunk_size * 2] * 2 + [action_frame_id // chunk_size * 2 + 1] * 2
)
noise_ids = torch.cat(
[
torch.zeros_like(latent_frame_id),
torch.ones_like(latent_frame_id),
torch.zeros_like(action_frame_id),
torch.ones_like(action_frame_id),
]
)
seq_ids = F.pad(seq_ids, (0, padded_length), value=-1)
frame_ids = F.pad(frame_ids, (0, padded_length), value=-1)
noise_ids = F.pad(noise_ids, (0, padded_length), value=-1)
mask_mod = FlexAttnFunc._get_mask_mod(
seq_ids.long().to(device), frame_ids.long().to(device), noise_ids.long().to(device), window_size
)
block_mask = FlexAttnFunc.compiled_create_block_mask(
mask_mod, 1, 1, len(seq_ids), len(seq_ids), device=device, _compile=True
)
FlexAttnFunc.attention_mask = block_mask
text_seq_ids = torch.arange(B)[:, None].expand(-1, 512).flatten()
mask_mod_cross = FlexAttnFunc._get_cross_mask_mod(
seq_ids.long().to(device), text_seq_ids.long().to(device)
)
block_mask_cross = FlexAttnFunc.compiled_create_block_mask(
mask_mod_cross, 1, 1, len(seq_ids), len(text_seq_ids), device=device, _compile=True
)
FlexAttnFunc.cross_attention_mask = block_mask_cross
@staticmethod
@torch.no_grad()
def _get_cross_mask_mod(seq_ids, text_seq_ids):
def seq_mask(b, h, q_idx, kv_idx):
return (
(seq_ids[q_idx] == text_seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (text_seq_ids[kv_idx] >= 0)
)
return seq_mask
@staticmethod
@torch.no_grad()
def _get_mask_mod(seq_ids, frame_ids, noise_ids, window_size):
def seq_mask(b, h, q_idx, kv_idx):
return (seq_ids[q_idx] == seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (seq_ids[kv_idx] >= 0)
def block_causal_mask(b, h, q_idx, kv_idx):
return frame_ids[kv_idx] <= frame_ids[q_idx]
def block_causal_mask_exclude_self(b, h, q_idx, kv_idx):
return frame_ids[kv_idx] < frame_ids[q_idx]
def block_self_mask(b, h, q_idx, kv_idx):
return frame_ids[kv_idx] == frame_ids[q_idx]
def clean2clean_mask(b, h, q_idx, kv_idx):
return (noise_ids[q_idx] == 1) & (noise_ids[kv_idx] == 1)
def noise2clean_mask(b, h, q_idx, kv_idx):
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 1)
def noise2noise_mask(b, h, q_idx, kv_idx):
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 0)
def block_window_mask(b, h, q_idx, kv_idx, window_size: int):
return (frame_ids[q_idx] - frame_ids[kv_idx]).abs() <= window_size
mask_list = []
mask_list.append(and_masks(clean2clean_mask, block_causal_mask))
mask_list.append(and_masks(noise2clean_mask, block_causal_mask_exclude_self))
mask_list.append(and_masks(noise2noise_mask, block_self_mask))
mask = or_masks(*mask_list)
mask = and_masks(mask, seq_mask)
mask = and_masks(mask, partial(block_window_mask, window_size=window_size))
return mask
__all__ = ["FlexAttnFunc"]
@@ -1,514 +0,0 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The dual-stream Wan2.2 video-action transformer backbone for LingBot-VA.
Vendored and lightly adapted from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``).
The model keeps the diffusers ``ModelMixin``/``ConfigMixin`` mixins so the original
sharded ``transformer/`` checkpoint can be loaded with ``from_pretrained`` during
conversion, but in LeRobot it is owned as a plain ``nn.Module`` sub-component of
:class:`~lerobot.policies.lingbot_va.modeling_lingbot_va.LingBotVAPolicy`. State-dict
parameter names are preserved verbatim so conversion is near-identity.
"""
import math
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import (
PixArtAlphaTextProjection,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import FP32LayerNorm
from einops import rearrange
from .wan_attention import WanAttention, WanRotaryPosEmbed
__all__ = ["WanTransformer3DModel", "WanTransformerBlock", "WanTimeTextImageEmbedding"]
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim,
time_freq_dim,
time_proj_dim,
text_embed_dim,
pos_embed_seq_len,
):
super().__init__()
self.timesteps_proj = Timesteps(
num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(self, timestep: torch.Tensor, dtype=None):
B, L = timestep.shape
timestep = timestep.reshape(-1)
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).to(dtype=dtype)
timestep_proj = self.time_proj(self.act_fn(temb))
return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1)
class WanTransformerBlock(nn.Module):
def __init__(
self,
dim,
ffn_dim,
num_heads,
cross_attn_norm=False,
eps=1e-6,
attn_mode: str = "torch",
):
super().__init__()
self.attn_mode = attn_mode
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=None,
attn_mode=attn_mode,
)
# 2. Cross-attention
self.attn2 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=dim // num_heads,
attn_mode=attn_mode,
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb,
update_cache=0,
cache_name="pos",
) -> torch.Tensor:
temb_scale_shift_table = self.scale_shift_table[None] + temb.float()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = rearrange(
temb_scale_shift_table, "b l n c -> b n l c"
).chunk(6, dim=1)
shift_msa = shift_msa.squeeze(1)
scale_msa = scale_msa.squeeze(1)
gate_msa = gate_msa.squeeze(1)
c_shift_msa = c_shift_msa.squeeze(1)
c_scale_msa = c_scale_msa.squeeze(1)
c_gate_msa = c_gate_msa.squeeze(1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1.0 + scale_msa) + shift_msa).type_as(
hidden_states
)
attn_output = self.attn1(
norm_hidden_states,
norm_hidden_states,
norm_hidden_states,
rotary_emb,
update_cache=update_cache,
cache_name=cache_name,
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states,
encoder_hidden_states,
None,
update_cache=0,
cache_name=cache_name,
)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1.0 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin):
"""Dual-stream (video + action) Wan2.2 DiT backbone with autoregressive KV caching."""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = [
"patch_embedding_mlp",
"condition_embedder",
"condition_embedder_action",
"norm",
]
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = [
"time_embedder",
"scale_shift_table",
"scale_shift_table_action",
"norm1",
"action_norm1",
"text_norm1",
"norm2",
"action_norm2",
"text_norm2",
"norm3",
"action_norm3",
"text_norm3",
]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size=(1, 2, 2),
num_attention_heads=24,
attention_head_dim=128,
in_channels=48,
out_channels=48,
action_dim=30,
text_dim=4096,
freq_dim=256,
ffn_dim=14336,
num_layers=30,
cross_attn_norm=True,
eps=1e-06,
rope_max_seq_len=1024,
pos_embed_seq_len=None,
attn_mode="torch",
):
super().__init__()
self.patch_size = patch_size
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding_mlp = nn.Linear(
in_channels * patch_size[0] * patch_size[1] * patch_size[2], inner_dim
)
self.action_embedder = nn.Linear(action_dim, inner_dim)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
self.condition_embedder_action = deepcopy(self.condition_embedder)
self.blocks = nn.ModuleList(
[
WanTransformerBlock(
inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, attn_mode=attn_mode
)
for _ in range(num_layers)
]
)
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.action_proj_out = nn.Linear(inner_dim, action_dim)
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
# ------------------------------------------------------------------
# KV-cache management for autoregressive streaming inference
# ------------------------------------------------------------------
def clear_cache(self, cache_name):
for block in self.blocks:
block.attn1.clear_cache(cache_name)
def clear_pred_cache(self, cache_name):
for block in self.blocks:
block.attn1.clear_pred_cache(cache_name)
def create_empty_cache(
self,
cache_name,
attn_window,
latent_token_per_chunk,
action_token_per_chunk,
device,
dtype,
batch_size,
):
total_tolen = (attn_window // 2) * latent_token_per_chunk + (
attn_window // 2
) * action_token_per_chunk
for block in self.blocks:
block.attn1.init_kv_cache(
cache_name,
total_tolen,
self.num_attention_heads,
self.attention_head_dim,
device,
dtype,
batch_size,
)
# ------------------------------------------------------------------
# Embedding helpers (shared by train + inference paths)
# ------------------------------------------------------------------
def _input_embed(self, latents, input_type="latent"):
if input_type == "latent":
hidden_states = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self.patch_size[0],
p2=self.patch_size[1],
p3=self.patch_size[2],
)
hidden_states = self.patch_embedding_mlp(hidden_states)
elif input_type == "action":
hidden_states = rearrange(latents, "b c f h w -> b (f h w) c")
hidden_states = self.action_embedder(hidden_states)
elif input_type == "text":
hidden_states = self.condition_embedder.text_embedder(latents)
else:
raise ValueError(f"Unsupported input type: {input_type}")
return hidden_states
def _time_embed(self, timesteps, H, W, dtype, action_mode=False):
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
latent_time_steps = torch.repeat_interleave(
timesteps, (H // pach_scale_h) * (W // pach_scale_w), dim=1
)
current_condition_embedder = (
self.condition_embedder_action if action_mode else self.condition_embedder
)
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=dtype)
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
return temb, timestep_proj
# ------------------------------------------------------------------
# Dual-stream training forward (flow matching). Requires attn_mode='flex'.
# ------------------------------------------------------------------
def forward_train(self, input_dict):
from .wan_flex_attention import FlexAttnFunc
input_dict["latent_dict"]["noisy_latents"] = input_dict["latent_dict"]["noisy_latents"].to(
torch.bfloat16
)
input_dict["latent_dict"]["latent"] = input_dict["latent_dict"]["latent"].to(torch.bfloat16)
input_dict["action_dict"]["noisy_latents"] = input_dict["action_dict"]["noisy_latents"].to(
torch.bfloat16
)
input_dict["action_dict"]["latent"] = input_dict["action_dict"]["latent"].to(torch.bfloat16)
latent_dict = input_dict["latent_dict"]
action_dict = input_dict["action_dict"]
batch_size = latent_dict["noisy_latents"].shape[0]
latent_hidden_states = self._input_embed(latent_dict["noisy_latents"], input_type="latent").flatten(
0, 1
)[None]
action_hidden_states = self._input_embed(action_dict["noisy_latents"], input_type="action").flatten(
0, 1
)[None]
text_hidden_states = self._input_embed(latent_dict["text_emb"], input_type="text")
text_hidden_states = text_hidden_states.flatten(0, 1)[None]
condition_latent_hidden_states = self._input_embed(
latent_dict["latent"], input_type="latent"
).flatten(0, 1)[None]
condition_action_hidden_states = self._input_embed(
action_dict["latent"], input_type="action"
).flatten(0, 1)[None]
hidden_states = torch.cat(
[
latent_hidden_states,
condition_latent_hidden_states,
action_hidden_states,
condition_action_hidden_states,
],
dim=1,
)
latent_grid_id = latent_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
action_grid_id = action_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
full_grid_id = torch.cat([latent_grid_id] * 2 + [action_grid_id] * 2, dim=2)
rotary_emb = self.rope(full_grid_id)[:, :, None]
latent_time_steps = torch.cat(
[latent_dict["timesteps"].flatten(0, 1), latent_dict["cond_timesteps"].flatten(0, 1)]
)[None]
action_time_steps = torch.cat(
[action_dict["timesteps"].flatten(0, 1), action_dict["cond_timesteps"].flatten(0, 1)]
)[None]
latent_temb, latent_timestep_proj = self._time_embed(
latent_time_steps,
latent_dict["noisy_latents"].shape[-2],
latent_dict["noisy_latents"].shape[-1],
dtype=hidden_states.dtype,
action_mode=False,
)
action_temb, action_timestep_proj = self._time_embed(
action_time_steps,
action_dict["noisy_latents"].shape[-2],
action_dict["noisy_latents"].shape[-1],
dtype=hidden_states.dtype,
action_mode=True,
)
temb = torch.cat([latent_temb, action_temb], dim=1)
timestep_proj = torch.cat([latent_timestep_proj, action_timestep_proj], dim=1)
total_length = hidden_states.shape[1]
padded_length = (128 - total_length % 128) % 128
hidden_states = F.pad(hidden_states, (0, 0, 0, padded_length))
rotary_emb = F.pad(rotary_emb, (0, 0, 0, 0, 0, padded_length))
temb = F.pad(temb, (0, 0, 0, padded_length))
timestep_proj = F.pad(timestep_proj, (0, 0, 0, 0, 0, padded_length))
split_list = [
latent_hidden_states.shape[1],
condition_latent_hidden_states.shape[1],
action_hidden_states.shape[1],
condition_action_hidden_states.shape[1],
padded_length,
]
FlexAttnFunc.init_mask(
latent_dict["noisy_latents"].shape,
action_dict["noisy_latents"].shape,
padded_length,
input_dict["chunk_size"],
window_size=input_dict["window_size"],
patch_size=self.patch_size,
device=hidden_states.device,
)
for block in self.blocks:
hidden_states = block(
hidden_states, text_hidden_states, timestep_proj, rotary_emb, update_cache=False
)
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
shift = shift.to(hidden_states.device).squeeze(1)
scale = scale.to(hidden_states.device).squeeze(1)
hidden_states = (self.norm_out(hidden_states.float()) * (1.0 + scale) + shift).type_as(hidden_states)
latent_hidden_states, _, action_hidden_states, _, _ = torch.split(hidden_states, split_list, dim=1)
latent_hidden_states = self.proj_out(latent_hidden_states)
latent_hidden_states = rearrange(
latent_hidden_states, "1 (b l) (n c) -> b (l n) c", n=math.prod(self.patch_size), b=batch_size
)
action_hidden_states = self.action_proj_out(action_hidden_states)
action_hidden_states = rearrange(action_hidden_states, "1 (b l) c -> b l c", b=batch_size)
return latent_hidden_states, action_hidden_states
# ------------------------------------------------------------------
# Single-stream inference forward (one denoising step for one stream)
# ------------------------------------------------------------------
def forward(
self,
input_dict,
update_cache=0,
cache_name="pos",
action_mode=False,
train_mode=False,
):
if train_mode:
return self.forward_train(input_dict)
if action_mode: # action input emb
latent_hidden_states = rearrange(input_dict["noisy_latents"], "b c f h w -> b (f h w) c")
latent_hidden_states = self.action_embedder(latent_hidden_states) # B L1 C
else: # latent input emb
latent_hidden_states = rearrange(
input_dict["noisy_latents"],
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self.patch_size[0],
p2=self.patch_size[1],
p3=self.patch_size[2],
)
latent_hidden_states = self.patch_embedding_mlp(latent_hidden_states)
text_hidden_states = self.condition_embedder.text_embedder(input_dict["text_emb"]) # B L2 C
latent_grid_id = input_dict["grid_id"]
rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
latent_time_steps = torch.repeat_interleave(
input_dict["timesteps"],
(input_dict["noisy_latents"].shape[-2] // pach_scale_h)
* (input_dict["noisy_latents"].shape[-1] // pach_scale_w),
dim=1,
) # L
current_condition_embedder = (
self.condition_embedder_action if action_mode else self.condition_embedder
)
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=latent_hidden_states.dtype)
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
for block in self.blocks:
latent_hidden_states = block(
latent_hidden_states,
text_hidden_states,
timestep_proj,
rotary_emb,
update_cache=update_cache,
cache_name=cache_name,
)
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
shift = shift.to(latent_hidden_states.device).squeeze(1)
scale = scale.to(latent_hidden_states.device).squeeze(1)
latent_hidden_states = (self.norm_out(latent_hidden_states.float()) * (1.0 + scale) + shift).type_as(
latent_hidden_states
)
if action_mode:
latent_hidden_states = self.action_proj_out(latent_hidden_states)
else:
latent_hidden_states = self.proj_out(latent_hidden_states)
latent_hidden_states = rearrange(
latent_hidden_states, "b l (n c) -> b (l n) c", n=math.prod(self.patch_size)
)
return latent_hidden_states
@@ -1,56 +0,0 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grid-id / patch utilities for the LingBot-VA autoregressive inference loop.
Vendored verbatim from the upstream LingBot-VA repository
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/utils.py``).
"""
import torch
__all__ = ["get_mesh_id", "data_seq_to_patch"]
def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1):
"""Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid."""
p_t, p_h, p_w = patch_size
post_patch_num_frames = latent_num_frames // p_t
post_patch_height = latent_height // p_h
post_patch_width = latent_width // p_w
data_patch = data_seq.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
data_patch = data_patch.permute(0, 7, 1, 4, 2, 5, 3, 6)
data_patch = data_patch.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return data_patch
def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False):
"""Build the (frame, height, width, stream) grid ids used to index the rotary embedding."""
f_idx = torch.arange(f_shift, f + f_shift) * f_w
h_idx = torch.arange(h)
w_idx = torch.arange(w)
ff, hh, ww = torch.meshgrid(f_idx, h_idx, w_idx, indexing="ij")
if action:
ff_offset = (torch.ones([h]).cumsum(0) / (h + 1)).view(1, -1, 1)
ff = ff + ff_offset
hh = torch.ones_like(hh) * -1
ww = torch.ones_like(ww) * -1
grid_id = torch.cat([ff.unsqueeze(0), hh.unsqueeze(0), ww.unsqueeze(0)], dim=0).flatten(1)
grid_id = torch.cat([grid_id, torch.full_like(grid_id[:1], t)], dim=0)
return grid_id
-120
View File
@@ -1,120 +0,0 @@
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thin helpers around the stock diffusers ``AutoencoderKLWan`` (Wan2.2, ``z_dim=48``).
The VAE class itself is NOT vendored it lives in ``diffusers>=0.36``. This module
provides:
* loaders for the VAE / text encoder / tokenizer / transformer sub-checkpoints,
* the streaming-encoder wrapper used for autoregressive frame-by-frame VAE encoding
(it caches the causal-conv state across chunks),
* latent (de)normalization helpers using the VAE's ``latents_mean`` / ``latents_std``.
Vendored and adapted from ``wan_va/modules/utils.py`` upstream.
"""
import torch
__all__ = [
"WanVAEStreamingWrapper",
"load_vae",
"load_text_encoder",
"load_tokenizer",
"normalize_latents",
"denormalize_latents",
"patchify",
]
def load_vae(vae_path, torch_dtype, torch_device):
from diffusers import AutoencoderKLWan
vae = AutoencoderKLWan.from_pretrained(vae_path, torch_dtype=torch_dtype)
return vae.to(torch_device)
def load_text_encoder(text_encoder_path, torch_dtype, torch_device):
from transformers import UMT5EncoderModel
text_encoder = UMT5EncoderModel.from_pretrained(text_encoder_path, torch_dtype=torch_dtype)
return text_encoder.to(torch_device)
def load_tokenizer(tokenizer_path):
from transformers import T5TokenizerFast
return T5TokenizerFast.from_pretrained(tokenizer_path)
def patchify(x, patch_size):
if patch_size is None or patch_size == 1:
return x
batch_size, channels, frames, height, width = x.shape
x = x.view(
batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size
)
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
x = x.view(
batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size
)
return x
def normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
) -> torch.Tensor:
"""Apply ``(x - mean) * std`` channel-wise (note: upstream passes ``1/std`` as ``latents_std``)."""
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
latents = ((latents.float() - latents_mean) * latents_std).to(latents)
return latents
def denormalize_latents(latents: torch.Tensor, latents_mean, latents_std, z_dim) -> torch.Tensor:
"""Inverse of the normalization applied at encode time, for VAE decoding of predicted latents."""
mean = torch.tensor(latents_mean).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
inv_std = 1.0 / torch.tensor(latents_std).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
return latents / inv_std + mean
class WanVAEStreamingWrapper:
"""Wraps an ``AutoencoderKLWan`` encoder to support causal streaming encoding across chunks."""
def __init__(self, vae_model):
self.vae = vae_model
self.encoder = vae_model.encoder
self.quant_conv = vae_model.quant_conv
if hasattr(self.vae, "_cached_conv_counts"):
self.enc_conv_num = self.vae._cached_conv_counts["encoder"]
else:
count = 0
for m in self.encoder.modules():
if m.__class__.__name__ == "WanCausalConv3d":
count += 1
self.enc_conv_num = count
self.clear_cache()
def clear_cache(self):
self.feat_cache = [None] * self.enc_conv_num
def encode_chunk(self, x_chunk):
if hasattr(self.vae.config, "patch_size") and self.vae.config.patch_size is not None:
x_chunk = patchify(x_chunk, self.vae.config.patch_size)
feat_idx = [0]
out = self.encoder(x_chunk, feat_cache=self.feat_cache, feat_idx=feat_idx)
enc = self.quant_conv(out)
return enc
@@ -76,8 +76,3 @@ def test_validate_features_no_visual_raises() -> None:
def test_invalid_attn_mode_raises() -> None:
with pytest.raises(ValueError, match="attn_mode"):
make_config(attn_mode="banana")
def test_quantile_length_mismatch_raises() -> None:
with pytest.raises(ValueError, match="action_q01"):
make_config(used_action_channel_ids=[0, 1, 2], action_q01=[0.0, 0.0], action_q99=[1.0, 1.0, 1.0])
-14
View File
@@ -36,17 +36,3 @@ def test_get_policy_class_resolves_lazily() -> None:
cls = get_policy_class("lingbot_va")
assert cls.name == "lingbot_va"
assert cls.config_class is LingBotVAConfig
def test_convert_build_config_libero() -> None:
pytest.importorskip("diffusers")
from lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints import build_config
cfg = build_config("libero", wan_pretrained_path="dummy/path", dtype="float32")
assert cfg.height == 128 and cfg.width == 128
assert cfg.used_action_channel_ids == list(range(7))
# validate_features (called inside build_config) must have populated the action feature.
from lerobot.utils.constants import ACTION
assert cfg.output_features[ACTION].shape == (7,)
assert len(cfg.obs_cam_keys) == 2
+9 -3
View File
@@ -14,14 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pure-torch unit tests for the vendored LingBot-VA helper modules (no diffusers needed)."""
"""Unit tests for the vendored LingBot-VA helper code (scheduler + grid utilities)."""
from __future__ import annotations
import pytest
import torch
from lerobot.policies.lingbot_va.schedulers import FlowMatchScheduler
from lerobot.policies.lingbot_va.wan_utils import data_seq_to_patch, get_mesh_id
pytest.importorskip("diffusers") # the model code lives in modeling_lingbot_va, which imports diffusers
from lerobot.policies.lingbot_va.modeling_lingbot_va import ( # noqa: E402
FlowMatchScheduler,
data_seq_to_patch,
get_mesh_id,
)
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
+3 -2
View File
@@ -21,6 +21,7 @@ import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
from lerobot.policies.lingbot_va.processor_lingbot_va import (
LIBERO_ACTION_Q01,
LingBotVAActionUnnormalizeStep,
make_lingbot_va_pre_post_processors,
)
@@ -75,7 +76,7 @@ def test_make_pre_post_processors_names_and_steps() -> None:
def test_postprocessor_applies_unnormalization() -> None:
cfg = _make_config()
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
# A normalized action of all -1 should map back to q01.
# A normalized action of all -1 should map back to q01 (the LIBERO 7-DoF default quantiles).
normed = torch.full((1, len(cfg.used_action_channel_ids)), -1.0)
out = post(normed)
assert torch.allclose(out, torch.tensor(cfg.action_q01).unsqueeze(0), atol=1e-4)
assert torch.allclose(out, torch.tensor(LIBERO_ACTION_Q01).unsqueeze(0), atol=1e-4)
Generated
+958 -1107
View File
File diff suppressed because it is too large Load Diff