mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 07:37:10 +00:00
feat(lingbot_va): RoboTwin eef-pose eval, single-file model, Hub checkpoints
Make the LingBot-VA port runnable on both LIBERO and RoboTwin and clean up the package to LeRobot conventions. - Consolidate all vendored Wan2.2 model code (transformer, attention, VAE helpers, flow-matching scheduler, grid utils, flex-attention) into a single modeling_lingbot_va.py; remove the separate wan_*/schedulers modules. - Move the fixed action (un)normalization quantiles out of the config and into the post-processor (LIBERO 7-DoF + RoboTwin 16-d eef); remove the conversion script in favour of ready-to-use LeRobot-format checkpoints on the Hub. - Fixes found via on-sim validation: undo LIBERO's 180-degree image flip (image_hflip), encode obs as a multi-frame streaming-VAE clip, reset the streaming VAE cache between episodes, run the transformer in config.dtype, lazy-load frozen VAE/UMT5 by subfolder with the text encoder on CPU. - RoboTwin: add an end-effector-pose action mode to RoboTwinEnv (16-d per-arm xyz+quat+gripper deltas composed onto the initial eef pose, executed via CuRobo IK) and the robotwin_tshape latent layout (full-res head + half-res wrists via a second streaming VAE) with the upstream RoboTwin action quantiles + camera mapping. - Predicted-video saving works for both benchmarks; docs + tests updated. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
committed by
Maxime Ellerbach
parent
d600a52943
commit
b81909fc28
@@ -22,6 +22,10 @@ outputs
|
||||
rl
|
||||
media
|
||||
|
||||
# Local virtualenvs (the image provides its own)
|
||||
.venv
|
||||
venv
|
||||
|
||||
|
||||
# Logging
|
||||
logs
|
||||
|
||||
+34
-17
@@ -28,7 +28,7 @@ fed back into the KV cache as the chunk is executed (closed-loop world modeling)
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=lingbot_va` configuration through LeRobot.
|
||||
- Checkpoint conversion from the released HuggingFace checkpoints.
|
||||
- Ready-to-use LeRobot-format checkpoints on the Hub (converted from the released upstream ones).
|
||||
- Autoregressive dual-stream inference behind the standard `select_action` interface
|
||||
(single-environment eval, `--eval.batch_size=1`).
|
||||
- Opt-in saving of the policy's **predicted (imagined) videos** during eval / training.
|
||||
@@ -48,40 +48,57 @@ pip install -e ".[lingbot_va]"
|
||||
pip install -e ".[lingbot_va,libero]"
|
||||
```
|
||||
|
||||
## Checkpoint Conversion
|
||||
## Checkpoints
|
||||
|
||||
The released checkpoints are diffusers-style directories
|
||||
(`robbyant/lingbot-va-base`, `robbyant/lingbot-va-posttrain-robotwin`,
|
||||
`robbyant/lingbot-va-posttrain-libero-long`). Convert one to LeRobot format with:
|
||||
The released upstream checkpoints have been converted to LeRobot format and pushed to the Hub:
|
||||
|
||||
```bash
|
||||
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
|
||||
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
|
||||
--variant libero \
|
||||
--output_dir outputs/lingbot_va_libero_long
|
||||
```
|
||||
| Variant | LeRobot checkpoint |
|
||||
|---|---|
|
||||
| LIBERO-Long post-train | `pepijn223/lingbot_va_libero_long` |
|
||||
| RoboTwin post-train | `pepijn223/lingbot_va_robotwin` |
|
||||
| Pretrained base | `pepijn223/lingbot_va_base` |
|
||||
|
||||
**Packaging:** only the trainable ~5B transformer is stored in the LeRobot
|
||||
`model.safetensors`. The frozen VAE + UMT5 + tokenizer (~20 GB) are **lazily pulled** from
|
||||
`config.wan_pretrained_path` at load time (defaults to the source repo). Pass
|
||||
`--bundle-frozen` to copy those sub-folders next to the converted checkpoint instead.
|
||||
|
||||
Run conversion on a Linux machine with a CUDA GPU and enough RAM/VRAM to materialize the
|
||||
transformer.
|
||||
`config.wan_pretrained_path` at load time (defaults to the source `robbyant/*` repo). The
|
||||
UMT5-XXL text encoder runs on CPU by default (`config.text_encoder_device`) so the 5B
|
||||
transformer + VAE fit on a single 24–32 GB GPU.
|
||||
|
||||
## Evaluation (LIBERO)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=outputs/lingbot_va_libero_long \
|
||||
--policy.path=pepijn223/lingbot_va_libero_long \
|
||||
--policy.device=cuda \
|
||||
--env.type=libero --env.task=libero_10 \
|
||||
--env.observation_height=128 --env.observation_width=128 \
|
||||
--eval.n_episodes=50 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_libero
|
||||
```
|
||||
|
||||
Native LeRobot eval reproduces **96% success on `libero_10` (LIBERO-Long)** (48/50 episodes).
|
||||
|
||||
LingBot-VA's streaming inference (KV cache + observed-keyframe feedback) is implemented for
|
||||
single-environment eval; use `--eval.batch_size=1`.
|
||||
|
||||
## Evaluation (RoboTwin)
|
||||
|
||||
RoboTwin 2.0 needs the SAPIEN + CuRobo simulator stack — use the benchmark Docker image
|
||||
(`docker/Dockerfile.benchmark.robotwin`, which also needs `warp-lang==1.3.1` and CuRobo built
|
||||
with the GPU's compute capability in `TORCH_CUDA_ARCH_LIST`). RoboTwin uses **end-effector-pose
|
||||
control**, so run with `--env.action_mode=ee`: the policy predicts per-arm `xyz+quaternion+gripper`
|
||||
deltas (`robotwin_tshape` latent layout) that are composed onto the episode's initial eef pose and
|
||||
executed via CuRobo IK.
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=pepijn223/lingbot_va_robotwin \
|
||||
--policy.device=cuda \
|
||||
--env.type=robotwin --env.task=beat_block_hammer --env.action_mode=ee \
|
||||
--eval.n_episodes=10 --eval.batch_size=1 \
|
||||
--output_dir=outputs/eval/lingbot_va_robotwin
|
||||
```
|
||||
|
||||
### Saving predicted (imagined) videos
|
||||
|
||||
Set `--policy.save_predicted_video=true` to additionally VAE-decode the predicted video
|
||||
|
||||
+1
-1
@@ -228,7 +228,7 @@ vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen
|
||||
# LingBot-VA needs the Wan2.2 stack (AutoencoderKLWan z_dim=48 + WanTransformer3DModel config schema),
|
||||
# which only exists in diffusers>=0.36. Pin the floor explicitly so a standalone `lerobot[lingbot_va]`
|
||||
# install can't resolve to a pre-Wan2.2 diffusers via the looser diffusers-dep floor.
|
||||
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]"]
|
||||
lingbot_va = ["lerobot[transformers-dep]", "diffusers>=0.36.0,<0.37.0", "lerobot[imageio-dep]", "accelerate>=1.10.0,<2.0.0", "ftfy>=6.0.0,<7.0.0"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
|
||||
@@ -768,6 +768,9 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
# must equal what SAPIEN actually renders.
|
||||
observation_height: int = 240
|
||||
observation_width: int = 320
|
||||
# "joint": 14-d joint-space control. "ee": 16-d end-effector-pose deltas executed via CuRobo IK
|
||||
# (for world-model policies like LingBot-VA that predict per-arm xyz+quaternion+gripper poses).
|
||||
action_mode: str = "joint"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
@@ -784,6 +787,8 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.action_mode == "ee":
|
||||
self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(16,))
|
||||
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
|
||||
for cam in cam_list:
|
||||
self.features[f"pixels/{cam}"] = PolicyFeature(
|
||||
@@ -826,6 +831,7 @@ class RoboTwinEnvConfig(EnvConfig):
|
||||
observation_height=self.observation_height,
|
||||
observation_width=self.observation_width,
|
||||
episode_length=self.episode_length,
|
||||
action_mode=self.action_mode,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,10 +41,42 @@ ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
|
||||
"right_camera",
|
||||
)
|
||||
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms (joint-space control mode)
|
||||
# End-effector-pose control mode: per arm [x, y, z, qx, qy, qz, qw, gripper] = 8, dual-arm = 16.
|
||||
# Used by world-model policies (e.g. LingBot-VA) that predict eef-pose deltas executed via CuRobo IK.
|
||||
EEF_ACTION_DIM = 16
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
DEFAULT_EPISODE_LENGTH = 300
|
||||
|
||||
|
||||
def _compose_eef_pose(new_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a single-arm predicted delta pose onto the initial pose.
|
||||
|
||||
``new_pose`` / ``init_pose`` are 8-vectors ``[x, y, z, qx, qy, qz, qw, gripper]``. Translation
|
||||
is added, rotation is composed (``init_R * new_R``), and the gripper is taken from the
|
||||
prediction. Mirrors ``add_eef_pose`` in the upstream LingBot-VA RoboTwin client.
|
||||
"""
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
new_r = Rotation.from_quat(new_pose[3:7])
|
||||
init_r = Rotation.from_quat(init_pose[3:7])
|
||||
out_rot = (init_r * new_r).as_quat()
|
||||
out_trans = new_pose[:3] + init_pose[:3]
|
||||
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
|
||||
|
||||
|
||||
def _add_init_eef_pose(delta_pose: np.ndarray, init_pose: np.ndarray) -> np.ndarray:
|
||||
"""Compose a dual-arm (16-d) predicted delta pose onto the initial eef pose, normalizing quats."""
|
||||
left = _compose_eef_pose(delta_pose[:8], init_pose[:8])
|
||||
right = _compose_eef_pose(delta_pose[8:], init_pose[8:])
|
||||
out = np.concatenate([left, right])
|
||||
# Normalize the two quaternions (indices 3:7 and 11:15) as the upstream client does.
|
||||
out[3:7] = out[3:7] / (np.linalg.norm(out[3:7]) + 1e-8)
|
||||
out[11:15] = out[11:15] / (np.linalg.norm(out[11:15]) + 1e-8)
|
||||
return out
|
||||
|
||||
|
||||
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
|
||||
DEFAULT_CAMERA_H = 240
|
||||
DEFAULT_CAMERA_W = 320
|
||||
@@ -234,6 +266,7 @@ class RoboTwinEnv(gym.Env):
|
||||
observation_width: int | None = None,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
render_mode: str = "rgb_array",
|
||||
action_mode: str = "joint",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = task_name
|
||||
@@ -241,6 +274,13 @@ class RoboTwinEnv(gym.Env):
|
||||
self.task_description = task_name.replace("_", " ")
|
||||
self.episode_index = episode_index
|
||||
self._reset_stride = n_envs
|
||||
# "joint": 14-d joint-space actions via take_action(action). "ee": 16-d end-effector-pose
|
||||
# deltas (added onto the episode's initial eef pose) executed via take_action(.., "ee") + IK.
|
||||
if action_mode not in ("joint", "ee"):
|
||||
raise ValueError(f"action_mode must be 'joint' or 'ee'; got {action_mode!r}")
|
||||
self.action_mode = action_mode
|
||||
self._action_dim = EEF_ACTION_DIM if action_mode == "ee" else ACTION_DIM
|
||||
self._init_eef_pose: np.ndarray | None = None
|
||||
self.camera_names = list(camera_names)
|
||||
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
|
||||
# The YAML-driven lookup is deferred to reset() so construction doesn't
|
||||
@@ -271,7 +311,7 @@ class RoboTwinEnv(gym.Env):
|
||||
}
|
||||
)
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(self._action_dim,), dtype=np.float32
|
||||
)
|
||||
|
||||
def _ensure_env(self) -> None:
|
||||
@@ -317,6 +357,17 @@ class RoboTwinEnv(gym.Env):
|
||||
|
||||
return {"pixels": images, "agent_pos": joint_state}
|
||||
|
||||
def _read_eef_pose(self) -> np.ndarray:
|
||||
"""Read the current 16-d dual-arm eef pose [left(xyz+quat)+grip, right(xyz+quat)+grip]."""
|
||||
ep = self._env.get_obs()["endpose"]
|
||||
pose = (
|
||||
list(ep["left_endpose"])
|
||||
+ [ep["left_gripper"]]
|
||||
+ list(ep["right_endpose"])
|
||||
+ [ep["right_gripper"]]
|
||||
)
|
||||
return np.asarray(pose, dtype=np.float64)
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
|
||||
self._ensure_env()
|
||||
super().reset(seed=seed)
|
||||
@@ -330,16 +381,23 @@ class RoboTwinEnv(gym.Env):
|
||||
self.episode_index += self._reset_stride
|
||||
self._step_count = 0
|
||||
|
||||
# In eef mode the policy predicts pose deltas relative to the initial eef pose.
|
||||
if self.action_mode == "ee":
|
||||
self._init_eef_pose = self._read_eef_pose()
|
||||
|
||||
obs = self._get_obs()
|
||||
return obs, {"is_success": False, "task": self.task_name}
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
assert self._env is not None, "step() called before reset()"
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
|
||||
if action.ndim != 1 or action.shape[0] != self._action_dim:
|
||||
raise ValueError(f"Expected 1-D action of shape ({self._action_dim},), got {action.shape}")
|
||||
|
||||
with torch.enable_grad():
|
||||
if hasattr(self._env, "take_action"):
|
||||
if self.action_mode == "ee":
|
||||
ee_action = _add_init_eef_pose(np.asarray(action, dtype=np.float64), self._init_eef_pose)
|
||||
self._env.take_action(ee_action, action_type="ee")
|
||||
elif hasattr(self._env, "take_action"):
|
||||
self._env.take_action(action)
|
||||
else:
|
||||
self._env.step(action)
|
||||
@@ -398,6 +456,7 @@ def _make_env_fns(
|
||||
observation_height: int,
|
||||
observation_width: int,
|
||||
episode_length: int,
|
||||
action_mode: str = "joint",
|
||||
) -> list[Callable[[], RoboTwinEnv]]:
|
||||
"""Return n_envs factory callables for a single task."""
|
||||
|
||||
@@ -410,6 +469,7 @@ def _make_env_fns(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
|
||||
return [partial(_make_one, i) for i in range(n_envs)]
|
||||
@@ -423,6 +483,7 @@ def create_robotwin_envs(
|
||||
observation_height: int = DEFAULT_CAMERA_H,
|
||||
observation_width: int = DEFAULT_CAMERA_W,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
action_mode: str = "joint",
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboTwin 2.0 environments.
|
||||
|
||||
@@ -473,6 +534,7 @@ def create_robotwin_envs(
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
action_mode=action_mode,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
|
||||
@@ -31,27 +31,6 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
# Upstream LIBERO action-normalization quantiles (single 7-DoF arm + gripper).
|
||||
# Verbatim from wan_va/configs/va_libero_cfg.py (channels 0-6 of a 30-dim action space).
|
||||
LIBERO_ACTION_Q01 = [
|
||||
-0.6589285731315613,
|
||||
-0.84375,
|
||||
-0.9375,
|
||||
-0.12107142806053162,
|
||||
-0.15964286029338837,
|
||||
-0.26571428775787354,
|
||||
-1.0,
|
||||
]
|
||||
LIBERO_ACTION_Q99 = [
|
||||
0.8999999761581421,
|
||||
0.8544642925262451,
|
||||
0.9375,
|
||||
0.17142857611179352,
|
||||
0.1842857152223587,
|
||||
0.34392857551574707,
|
||||
1.0,
|
||||
]
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("lingbot_va")
|
||||
@dataclass
|
||||
@@ -84,12 +63,27 @@ class LingBotVAConfig(PreTrainedConfig):
|
||||
wan_pretrained_path: str = "robbyant/lingbot-va-posttrain-libero-long"
|
||||
# dtype used for the transformer / VAE / text-encoder weights at inference.
|
||||
dtype: str = "bfloat16" # one of "bfloat16", "float16", "float32"
|
||||
# Device for the frozen UMT5-XXL text encoder. It encodes the (fixed) instruction once per
|
||||
# episode, so keeping it on CPU frees ~11 GB of VRAM and lets the 5B transformer + VAE fit on
|
||||
# a single 24-32 GB GPU. Set to "cuda" if you have the headroom and want faster prompt encoding.
|
||||
text_encoder_device: str = "cpu"
|
||||
|
||||
# ── Observation cameras (order matters: latents are concatenated on width) ──
|
||||
# Defaults match the LIBERO env feature keys (agentview -> image, eye-in-hand -> image2).
|
||||
obs_cam_keys: list[str] = field(
|
||||
default_factory=lambda: ["observation.images.image", "observation.images.image2"]
|
||||
)
|
||||
# Horizontally flip the camera images before encoding. LeRobot's LIBERO env processor rotates
|
||||
# frames 180° (flip H *and* W; the HuggingFaceVLA convention), but upstream LingBot-VA trains /
|
||||
# evaluates on vertically-flipped-only frames (``obs[::-1]`` in evaluation/libero/client.py).
|
||||
# Undoing the extra horizontal flip here realigns the input with the model's training orientation.
|
||||
image_hflip: bool = False
|
||||
# Latent assembly layout for the observation cameras:
|
||||
# "width_concat" : encode every camera at (height, width) and concat latents on width (LIBERO).
|
||||
# "robotwin_tshape" : head camera at full (height, width), the two wrist cameras at half
|
||||
# resolution, assembled in a "T" (wrists side-by-side on top of the head
|
||||
# on the height axis) using a second streaming VAE (RoboTwin).
|
||||
camera_layout: str = "width_concat"
|
||||
|
||||
# ── Inference hyperparameters (LIBERO defaults) ──
|
||||
n_obs_steps: int = 1
|
||||
@@ -108,10 +102,9 @@ class LingBotVAConfig(PreTrainedConfig):
|
||||
max_sequence_length: int = 512 # UMT5 prompt length
|
||||
|
||||
# Subset of the 30-d action space actually used by the benchmark (LIBERO = 7-DoF).
|
||||
# The fixed action (un)normalization quantiles live in the post-processor
|
||||
# (``LingBotVAActionUnnormalizeStep`` in ``processor_lingbot_va.py``), not here.
|
||||
used_action_channel_ids: list[int] = field(default_factory=lambda: list(range(7)))
|
||||
# Fixed quantiles for action (un)normalization on the *used* channels.
|
||||
action_q01: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q01))
|
||||
action_q99: list[float] = field(default_factory=lambda: list(LIBERO_ACTION_Q99))
|
||||
|
||||
# Opt-in: VAE-decode the predicted video latents and stash them on
|
||||
# ``self.last_predicted_frames`` so eval/train can save predicted-video MP4s.
|
||||
@@ -140,13 +133,6 @@ class LingBotVAConfig(PreTrainedConfig):
|
||||
super().__post_init__()
|
||||
if self.attn_mode not in ("torch", "flashattn", "flex"):
|
||||
raise ValueError(f"attn_mode must be one of 'torch', 'flashattn', 'flex'; got {self.attn_mode!r}")
|
||||
if len(self.action_q01) != len(self.used_action_channel_ids) or len(self.action_q99) != len(
|
||||
self.used_action_channel_ids
|
||||
):
|
||||
raise ValueError(
|
||||
"action_q01 / action_q99 must each have one entry per used_action_channel_ids "
|
||||
f"({len(self.used_action_channel_ids)}); got {len(self.action_q01)} / {len(self.action_q99)}."
|
||||
)
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert a released LingBot-VA HuggingFace checkpoint to LeRobot format.
|
||||
|
||||
The released checkpoints are diffusers-style directories with ``transformer/``, ``vae/``,
|
||||
``text_encoder/`` and ``tokenizer/`` sub-folders. This script:
|
||||
|
||||
1. loads the (sharded) ``transformer/`` weights with the vendored ``WanTransformer3DModel``;
|
||||
2. builds a :class:`LingBotVAConfig` for the target benchmark variant;
|
||||
3. instantiates a :class:`LingBotVAPolicy` and copies the transformer weights into it
|
||||
(near-identity: the only key change is the ``transformer.`` prefix);
|
||||
4. saves the LeRobot policy (``model.safetensors`` + ``config.json``) and its processors.
|
||||
|
||||
Packaging decision: only the trainable ~5B transformer is bundled into the LeRobot
|
||||
``model.safetensors``. The frozen VAE + UMT5 text encoder + tokenizer (~20 GB) are NOT
|
||||
copied; instead ``config.wan_pretrained_path`` records where to lazily pull them from at
|
||||
load time (defaults to the source repo/dir). Pass ``--bundle-frozen`` to additionally copy
|
||||
those sub-folders next to the converted checkpoint and point ``wan_pretrained_path`` at it.
|
||||
|
||||
Example (LIBERO-Long, the LIBERO eval gate):
|
||||
|
||||
python -m lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints \
|
||||
--checkpoint robbyant/lingbot-va-posttrain-libero-long \
|
||||
--variant libero \
|
||||
--output_dir outputs/lingbot_va_libero_long
|
||||
|
||||
Requires a CUDA GPU with enough RAM/VRAM to materialize the transformer; run on Linux.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import LingBotVAPolicy
|
||||
from lerobot.policies.lingbot_va.processor_lingbot_va import make_lingbot_va_pre_post_processors
|
||||
from lerobot.policies.lingbot_va.wan_transformer import WanTransformer3DModel
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
# Per-benchmark variant presets (camera keys + action layout). Values mirror the upstream
|
||||
# configs (wan_va/configs/va_*_cfg.py).
|
||||
VARIANTS = {
|
||||
"libero": {
|
||||
"obs_cam_keys": [f"{OBS_IMAGES}.image", f"{OBS_IMAGES}.image2"],
|
||||
"height": 128,
|
||||
"width": 128,
|
||||
"action_per_frame": 4,
|
||||
"frame_chunk_size": 4,
|
||||
"attn_window": 30,
|
||||
"num_inference_steps": 20,
|
||||
"action_num_inference_steps": 50,
|
||||
"guidance_scale": 5.0,
|
||||
"action_guidance_scale": 1.0,
|
||||
"snr_shift": 5.0,
|
||||
"action_snr_shift": 0.05,
|
||||
"used_action_channel_ids": list(range(7)),
|
||||
# 7-DoF: agentview + eye-in-hand, single arm. Quantiles are the config defaults.
|
||||
"image_shape": (3, 256, 256),
|
||||
},
|
||||
"robotwin": {
|
||||
"obs_cam_keys": [
|
||||
f"{OBS_IMAGES}.cam_high",
|
||||
f"{OBS_IMAGES}.cam_left_wrist",
|
||||
f"{OBS_IMAGES}.cam_right_wrist",
|
||||
],
|
||||
"height": 256,
|
||||
"width": 320,
|
||||
"action_per_frame": 16,
|
||||
"frame_chunk_size": 2,
|
||||
"attn_window": 72,
|
||||
"num_inference_steps": 25,
|
||||
"action_num_inference_steps": 50,
|
||||
"guidance_scale": 5.0,
|
||||
"action_guidance_scale": 1.0,
|
||||
"snr_shift": 5.0,
|
||||
"action_snr_shift": 1.0,
|
||||
# RoboTwin is dual-arm; set the used channels / quantiles to match the deployed config.
|
||||
"used_action_channel_ids": list(range(14)),
|
||||
"image_shape": (3, 256, 256),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _transformer_dir(checkpoint: str) -> str:
|
||||
"""Return the path/repo that ``WanTransformer3DModel.from_pretrained`` should read."""
|
||||
p = Path(checkpoint)
|
||||
if p.is_dir():
|
||||
return str(p / "transformer")
|
||||
return checkpoint # HF repo id; use subfolder kwarg below
|
||||
|
||||
|
||||
def load_source_transformer(checkpoint: str, dtype: torch.dtype) -> WanTransformer3DModel:
|
||||
p = Path(checkpoint)
|
||||
if p.is_dir():
|
||||
return WanTransformer3DModel.from_pretrained(
|
||||
str(p / "transformer"), torch_dtype=dtype, attn_mode="torch"
|
||||
)
|
||||
return WanTransformer3DModel.from_pretrained(
|
||||
checkpoint, subfolder="transformer", torch_dtype=dtype, attn_mode="torch"
|
||||
)
|
||||
|
||||
|
||||
def build_config(variant: str, wan_pretrained_path: str, dtype: str) -> LingBotVAConfig:
|
||||
preset = VARIANTS[variant]
|
||||
n_used = len(preset["used_action_channel_ids"])
|
||||
kwargs = {
|
||||
"wan_pretrained_path": wan_pretrained_path,
|
||||
"dtype": dtype,
|
||||
"obs_cam_keys": preset["obs_cam_keys"],
|
||||
"height": preset["height"],
|
||||
"width": preset["width"],
|
||||
"action_per_frame": preset["action_per_frame"],
|
||||
"frame_chunk_size": preset["frame_chunk_size"],
|
||||
"attn_window": preset["attn_window"],
|
||||
"num_inference_steps": preset["num_inference_steps"],
|
||||
"action_num_inference_steps": preset["action_num_inference_steps"],
|
||||
"guidance_scale": preset["guidance_scale"],
|
||||
"action_guidance_scale": preset["action_guidance_scale"],
|
||||
"snr_shift": preset["snr_shift"],
|
||||
"action_snr_shift": preset["action_snr_shift"],
|
||||
"used_action_channel_ids": preset["used_action_channel_ids"],
|
||||
"device": "cpu",
|
||||
}
|
||||
if variant != "libero":
|
||||
# LIBERO keeps the config default quantiles; other variants need their own. Until the
|
||||
# exact per-channel quantiles are wired in, use a neutral [-1, 1] mapping (no rescale).
|
||||
kwargs["action_q01"] = [-1.0] * n_used
|
||||
kwargs["action_q99"] = [1.0] * n_used
|
||||
cfg = LingBotVAConfig(**kwargs)
|
||||
# Populate input/output features (cameras + action) so validate_features passes.
|
||||
img_shape = preset["image_shape"]
|
||||
cfg.input_features = {
|
||||
k: PolicyFeature(type=FeatureType.VISUAL, shape=img_shape) for k in preset["obs_cam_keys"]
|
||||
}
|
||||
cfg.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(n_used,))}
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
|
||||
|
||||
def convert(
|
||||
checkpoint: str, variant: str, output_dir: str, dtype: str, bundle_frozen: bool, push_to: str | None
|
||||
):
|
||||
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[dtype]
|
||||
out = Path(output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Decide where frozen modules will be pulled from at load time.
|
||||
if bundle_frozen:
|
||||
wan_pretrained_path = str(out)
|
||||
_copy_frozen_subfolders(checkpoint, out)
|
||||
else:
|
||||
wan_pretrained_path = checkpoint
|
||||
|
||||
print(f"Building LingBot-VA config for variant '{variant}' (frozen modules from: {wan_pretrained_path})")
|
||||
cfg = build_config(variant, wan_pretrained_path, dtype)
|
||||
|
||||
print("Loading source transformer weights ...")
|
||||
src = load_source_transformer(checkpoint, torch_dtype)
|
||||
src_sd = src.state_dict()
|
||||
|
||||
print("Instantiating LingBotVAPolicy and copying transformer weights ...")
|
||||
# Build the policy without triggering frozen-module download by constructing directly.
|
||||
policy = LingBotVAPolicy(cfg)
|
||||
# Near-identity remap: source transformer keys -> policy "transformer.*".
|
||||
remapped = {f"transformer.{k}": v for k, v in src_sd.items()}
|
||||
missing, unexpected = policy.load_state_dict(remapped, strict=False)
|
||||
_log_load_keys(missing, unexpected)
|
||||
policy = policy.to(torch_dtype)
|
||||
|
||||
print(f"Saving converted policy to {out}")
|
||||
policy.save_pretrained(out)
|
||||
|
||||
preprocessor, postprocessor = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
preprocessor.save_pretrained(out)
|
||||
postprocessor.save_pretrained(out)
|
||||
|
||||
if push_to:
|
||||
print(f"Pushing to the Hub: {push_to}")
|
||||
policy.push_to_hub(push_to)
|
||||
preprocessor.push_to_hub(push_to)
|
||||
postprocessor.push_to_hub(push_to)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
def _copy_frozen_subfolders(checkpoint: str, out: Path):
|
||||
p = Path(checkpoint)
|
||||
if not p.is_dir():
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
p = Path(snapshot_download(checkpoint, allow_patterns=["vae/*", "text_encoder/*", "tokenizer/*"]))
|
||||
for sub in ("vae", "text_encoder", "tokenizer"):
|
||||
src_sub = p / sub
|
||||
if src_sub.is_dir():
|
||||
shutil.copytree(src_sub, out / sub, dirs_exist_ok=True)
|
||||
print(f" bundled {sub}/")
|
||||
|
||||
|
||||
def _log_load_keys(missing, unexpected):
|
||||
# The source transformer should account for every "transformer.*" key in the policy.
|
||||
if missing:
|
||||
print(
|
||||
f" [load_state_dict] {len(missing)} missing keys (expected: none for transformer). Sample: {missing[:5]}"
|
||||
)
|
||||
if unexpected:
|
||||
print(f" [load_state_dict] {len(unexpected)} unexpected keys. Sample: {unexpected[:5]}")
|
||||
if not missing and not unexpected:
|
||||
print(" [load_state_dict] perfect match (near-identity remap).")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("--checkpoint", required=True, help="HF repo id or local diffusers-style directory.")
|
||||
parser.add_argument("--variant", required=True, choices=sorted(VARIANTS.keys()))
|
||||
parser.add_argument("--output_dir", required=True)
|
||||
parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
|
||||
parser.add_argument(
|
||||
"--bundle-frozen",
|
||||
action="store_true",
|
||||
help="Copy the frozen vae/text_encoder/tokenizer next to the checkpoint instead of lazy-pulling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", default=None, help="Optional HF repo id to push the converted policy to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(
|
||||
checkpoint=args.checkpoint,
|
||||
variant=args.variant,
|
||||
output_dir=args.output_dir,
|
||||
dtype=args.dtype,
|
||||
bundle_frozen=args.bundle_frozen,
|
||||
push_to=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,6 +47,60 @@ from lerobot.utils.constants import (
|
||||
|
||||
from .configuration_lingbot_va import LingBotVAConfig
|
||||
|
||||
# Upstream LIBERO action-normalization quantiles (single 7-DoF arm + gripper).
|
||||
# Verbatim from wan_va/configs/va_libero_cfg.py (channels 0-6 of a 30-dim action space).
|
||||
# These are the fixed (un)normalization stats baked into the released LIBERO checkpoint; they
|
||||
# live here (in the processor) and are serialized into the saved post-processor config.
|
||||
LIBERO_ACTION_Q01 = [
|
||||
-0.6589285731315613,
|
||||
-0.84375,
|
||||
-0.9375,
|
||||
-0.12107142806053162,
|
||||
-0.15964286029338837,
|
||||
-0.26571428775787354,
|
||||
-1.0,
|
||||
]
|
||||
LIBERO_ACTION_Q99 = [
|
||||
0.8999999761581421,
|
||||
0.8544642925262451,
|
||||
0.9375,
|
||||
0.17142857611179352,
|
||||
0.1842857152223587,
|
||||
0.34392857551574707,
|
||||
1.0,
|
||||
]
|
||||
|
||||
|
||||
# Upstream RoboTwin action quantiles, reordered to the model's used-channel layout
|
||||
# [left xyz+quat (0-6), left gripper (28), right xyz+quat (7-13), right gripper (29)] = 16 channels.
|
||||
# Verbatim from wan_va/configs/va_robotwin_cfg.py ``norm_stat`` (quaternion + gripper channels use the
|
||||
# neutral [-1, 1] / [0, 1] mapping). Positions are quantile-scaled; rotations pass through.
|
||||
ROBOTWIN_ACTION_Q01 = [
|
||||
-0.06172713458538055, -3.6716461181640625e-05, -0.08783501386642456, -1.0, -1.0, -1.0, -1.0,
|
||||
0.0,
|
||||
-0.3547105032205582, -1.3113021850585938e-06, -0.11975435614585876, -1.0, -1.0, -1.0, -1.0,
|
||||
0.0,
|
||||
] # fmt: skip
|
||||
ROBOTWIN_ACTION_Q99 = [
|
||||
0.3462600058317184, 0.39966784834861746, 0.14745532035827624, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0,
|
||||
0.034201726913452024, 0.39142737388610793, 0.1792279863357542, 1.0, 1.0, 1.0, 1.0,
|
||||
1.0,
|
||||
] # fmt: skip
|
||||
|
||||
|
||||
def _default_action_quantiles(n_used: int) -> tuple[list[float], list[float]]:
|
||||
"""Return the fixed (q01, q99) for the used action channels, by benchmark channel count.
|
||||
|
||||
LIBERO = 7 (single 7-DoF arm), RoboTwin = 16 (dual-arm eef pose + grippers). Falls back to a
|
||||
neutral ``[-1, 1]`` mapping (no rescale) for any other channel count.
|
||||
"""
|
||||
if n_used == len(LIBERO_ACTION_Q01):
|
||||
return list(LIBERO_ACTION_Q01), list(LIBERO_ACTION_Q99)
|
||||
if n_used == len(ROBOTWIN_ACTION_Q01):
|
||||
return list(ROBOTWIN_ACTION_Q01), list(ROBOTWIN_ACTION_Q99)
|
||||
return [-1.0] * n_used, [1.0] * n_used
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="lingbot_va_action_unnormalize")
|
||||
@@ -94,8 +148,9 @@ def make_lingbot_va_pre_post_processors(
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
action_q01, action_q99 = _default_action_quantiles(len(config.used_action_channel_ids))
|
||||
output_steps: list[ProcessorStep] = [
|
||||
LingBotVAActionUnnormalizeStep(action_q01=config.action_q01, action_q99=config.action_q99),
|
||||
LingBotVAActionUnnormalizeStep(action_q01=action_q01, action_q99=action_q99),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Flow-matching scheduler for LingBot-VA.
|
||||
|
||||
Vendored verbatim from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/scheduler.py``). LingBot-VA uses
|
||||
two independent instances of this scheduler at inference time — one for the video-latent
|
||||
stream and one for the action stream — each with its own ``shift`` (signal-to-noise ratio
|
||||
shift) and number of denoising steps.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["FlowMatchScheduler"]
|
||||
|
||||
|
||||
class FlowMatchScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps=100,
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
sigma_max=1.0,
|
||||
sigma_min=0.003 / 1.002,
|
||||
inverse_timesteps=False,
|
||||
extra_one_step=False,
|
||||
reverse_sigmas=False,
|
||||
exponential_shift=False,
|
||||
exponential_shift_mu=None,
|
||||
shift_terminal=None,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.reverse_sigmas = reverse_sigmas
|
||||
self.exponential_shift = exponential_shift
|
||||
self.exponential_shift_mu = exponential_shift_mu
|
||||
self.shift_terminal = shift_terminal
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps=100,
|
||||
denoising_strength=1.0,
|
||||
training=False,
|
||||
shift=None,
|
||||
dynamic_shift_len=None,
|
||||
):
|
||||
if shift is not None:
|
||||
self.shift = shift
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
if self.exponential_shift:
|
||||
mu = (
|
||||
self.calculate_shift(dynamic_shift_len)
|
||||
if dynamic_shift_len is not None
|
||||
else self.exponential_shift_mu
|
||||
)
|
||||
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
||||
else:
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
if self.shift_terminal is not None:
|
||||
one_minus_z = 1 - self.sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
|
||||
self.sigmas = 1 - (one_minus_z / scale_factor)
|
||||
if self.reverse_sigmas:
|
||||
self.sigmas = 1 - self.sigmas
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
if training:
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
model_output = (sample - sample_stablized) / sigma
|
||||
return model_output
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep, t_dim=2):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep = timestep[None]
|
||||
timestep_id = torch.argmin((self.timesteps[:, None] - timestep).abs(), dim=0)
|
||||
shape = [1] * noise.ndim
|
||||
shape[t_dim] = timestep_id.shape[0]
|
||||
sigma = self.sigmas[timestep_id].to(original_samples).view(shape)
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
target = noise - sample
|
||||
return target
|
||||
|
||||
def training_weight(self, timestep):
|
||||
timestep_id = torch.argmin(
|
||||
(self.timesteps[:, None].to(timestep.device) - timestep[None]).abs(), dim=0
|
||||
)
|
||||
weights = self.linear_timesteps_weights.to(timestep.device)[timestep_id].to(timestep.device)
|
||||
return weights
|
||||
|
||||
def calculate_shift(
|
||||
self,
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 8192,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 0.9,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
@@ -1,286 +0,0 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention and rotary-position-embedding modules for the LingBot-VA Wan transformer.
|
||||
|
||||
Vendored and lightly adapted from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``). The ``torch``
|
||||
SDPA backend is the default and is always available; the ``flashattn`` and ``flex``
|
||||
backends are imported lazily and only required when the corresponding ``attn_mode`` is
|
||||
selected. State-dict parameter names are preserved verbatim so that conversion from the
|
||||
original diffusers-style checkpoint is near-identity.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# ``flash_attn`` and the flex-attention APIs are optional. We import them lazily inside the
|
||||
# backends that need them so that the (default) ``torch`` SDPA path works on any platform,
|
||||
# including CPU-only and macOS where neither package is available.
|
||||
|
||||
|
||||
def custom_sdpa(q, k, v):
|
||||
"""Scaled-dot-product attention operating on ``(B, S, H, D)`` tensors."""
|
||||
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))
|
||||
return out.transpose(1, 2)
|
||||
|
||||
|
||||
def _load_flash_attn_func():
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func
|
||||
except ImportError:
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"attn_mode='flashattn' requires the `flash_attn` package, which is not installed. "
|
||||
"Install it, or use attn_mode='torch' (the default)."
|
||||
) from e
|
||||
return flash_attn_func
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
"""Rotary position embedding with separate frequency bases for frame / height / width."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size,
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.theta = theta
|
||||
|
||||
self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim // 3)
|
||||
self.h_dim = self.attention_head_dim // 3
|
||||
self.w_dim = self.attention_head_dim // 3
|
||||
|
||||
f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base()
|
||||
self.f_freqs_base = f_freqs_base
|
||||
self.h_freqs_base = h_freqs_base
|
||||
self.w_freqs_base = w_freqs_base
|
||||
|
||||
def _precompute_freqs_base(self):
|
||||
# freqs_base = 1.0 / (theta ** (2k / dim))
|
||||
f_freqs_base = 1.0 / (
|
||||
self.theta ** (torch.arange(0, self.f_dim, 2)[: (self.f_dim // 2)].double() / self.f_dim)
|
||||
)
|
||||
h_freqs_base = 1.0 / (
|
||||
self.theta ** (torch.arange(0, self.h_dim, 2)[: (self.h_dim // 2)].double() / self.h_dim)
|
||||
)
|
||||
w_freqs_base = 1.0 / (
|
||||
self.theta ** (torch.arange(0, self.w_dim, 2)[: (self.w_dim // 2)].double() / self.w_dim)
|
||||
)
|
||||
return f_freqs_base, h_freqs_base, w_freqs_base
|
||||
|
||||
def forward(self, grid_ids):
|
||||
with torch.no_grad():
|
||||
f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base.to(grid_ids.device)
|
||||
h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base.to(grid_ids.device)
|
||||
w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base.to(grid_ids.device)
|
||||
freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float()
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class WanAttention(nn.Module):
|
||||
"""Self/cross attention with KV-caching for autoregressive streaming inference.
|
||||
|
||||
Backends:
|
||||
* ``torch`` (default): standard SDPA, available everywhere.
|
||||
* ``flashattn``: FlashAttention kernels (optional dependency).
|
||||
* ``flex``: PyTorch flex-attention (optional, used for block-causal training masks).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
eps=1e-5,
|
||||
dropout=0.0,
|
||||
cross_attention_dim_head=None,
|
||||
attn_mode="torch",
|
||||
):
|
||||
super().__init__()
|
||||
if attn_mode == "torch":
|
||||
self.attn_op = custom_sdpa
|
||||
elif attn_mode == "flashattn":
|
||||
self.attn_op = _load_flash_attn_func()
|
||||
elif attn_mode == "flex":
|
||||
# Imported lazily to avoid a hard dependency on torch flex-attention at import time.
|
||||
from .wan_flex_attention import FlexAttnFunc
|
||||
|
||||
self.attn_op = FlexAttnFunc(cross_attention_dim_head is not None)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported attention mode: {attn_mode}, only support 'torch', 'flashattn' and 'flex'"
|
||||
)
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = (
|
||||
self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
)
|
||||
|
||||
self.to_q = nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(self.inner_dim, dim, bias=True),
|
||||
nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
# KV cache only lives on self-attention modules (cross_attention_dim_head is None).
|
||||
self.attn_caches = {} if cross_attention_dim_head is None else None
|
||||
|
||||
def clear_pred_cache(self, cache_name):
|
||||
if self.attn_caches is None:
|
||||
return
|
||||
cache = self.attn_caches[cache_name]
|
||||
is_pred = cache["is_pred"]
|
||||
cache["mask"][is_pred] = False
|
||||
|
||||
def clear_cache(self, cache_name):
|
||||
if self.attn_caches is None:
|
||||
return
|
||||
self.attn_caches[cache_name] = None
|
||||
|
||||
def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim, device, dtype, batch_size):
|
||||
if self.attn_caches is None:
|
||||
return
|
||||
self.attn_caches[cache_name] = {
|
||||
"k": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
|
||||
"v": torch.empty([batch_size, total_tolen, num_head, head_dim], device=device, dtype=dtype),
|
||||
"id": torch.full((total_tolen,), -1, device=device),
|
||||
"mask": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
|
||||
"is_pred": torch.zeros((total_tolen,), dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
def allocate_slots(self, cache_name, key_size):
|
||||
cache = self.attn_caches[cache_name]
|
||||
mask = cache["mask"]
|
||||
ids = cache["id"]
|
||||
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
if free.numel() < key_size:
|
||||
used = mask.nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
used_ids = ids[used]
|
||||
order = torch.argsort(used_ids)
|
||||
need = key_size - free.numel()
|
||||
to_free = used[order[:need]]
|
||||
|
||||
mask[to_free] = False
|
||||
ids[to_free] = -1
|
||||
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
assert free.numel() >= key_size
|
||||
return free[:key_size]
|
||||
|
||||
def _next_cache_id(self, cache_name):
|
||||
ids = self.attn_caches[cache_name]["id"]
|
||||
mask = self.attn_caches[cache_name]["mask"]
|
||||
|
||||
if mask.any():
|
||||
return ids[mask].max() + 1
|
||||
else:
|
||||
return torch.tensor(0, device=ids.device, dtype=ids.dtype)
|
||||
|
||||
def update_cache(self, cache_name, key, value, is_pred):
|
||||
cache = self.attn_caches[cache_name]
|
||||
|
||||
key_size = key.shape[1]
|
||||
slots = self.allocate_slots(cache_name, key_size)
|
||||
|
||||
new_id = self._next_cache_id(cache_name)
|
||||
|
||||
cache["k"][:, slots] = key
|
||||
cache["v"][:, slots] = value
|
||||
cache["mask"][slots] = True
|
||||
cache["id"][slots] = new_id
|
||||
cache["is_pred"][slots] = is_pred
|
||||
return slots
|
||||
|
||||
def restore_cache(self, cache_name, slots):
|
||||
self.attn_caches[cache_name]["mask"][slots] = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
rotary_emb,
|
||||
update_cache=0,
|
||||
cache_name="pos",
|
||||
):
|
||||
kv_cache = (
|
||||
self.attn_caches[cache_name]
|
||||
if (self.attn_caches is not None) and (cache_name in self.attn_caches)
|
||||
else None
|
||||
)
|
||||
|
||||
query, key, value = self.to_q(q), self.to_k(k), self.to_v(v)
|
||||
query = self.norm_q(query)
|
||||
query = query.unflatten(2, (self.heads, -1))
|
||||
key = self.norm_k(key)
|
||||
key = key.unflatten(2, (self.heads, -1))
|
||||
value = value.unflatten(2, (self.heads, -1))
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs):
|
||||
x_out = torch.view_as_complex(
|
||||
x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
)
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(3)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
slots = None
|
||||
if kv_cache is not None and kv_cache["k"] is not None:
|
||||
slots = self.update_cache(cache_name, key, value, is_pred=(update_cache == 1))
|
||||
key_pool = self.attn_caches[cache_name]["k"]
|
||||
value_pool = self.attn_caches[cache_name]["v"]
|
||||
mask = self.attn_caches[cache_name]["mask"]
|
||||
valid = mask.nonzero(as_tuple=False).squeeze(-1)
|
||||
key = key_pool[:, valid]
|
||||
value = value_pool[:, valid]
|
||||
|
||||
hidden_states = self.attn_op(query, key, value)
|
||||
|
||||
if update_cache == 0:
|
||||
if kv_cache is not None and kv_cache["k"] is not None:
|
||||
self.restore_cache(cache_name, slots)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
__all__ = ["WanAttention", "WanRotaryPosEmbed", "custom_sdpa"]
|
||||
@@ -1,207 +0,0 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Flex-attention backend for the LingBot-VA Wan transformer (training only).
|
||||
|
||||
This module is imported lazily and ONLY when ``attn_mode='flex'`` is requested. It builds
|
||||
the block-causal / window / noise-vs-clean attention masks used during the dual-stream
|
||||
flow-matching training described in the LingBot-VA paper. Inference uses the ``torch``
|
||||
SDPA backend (see :mod:`wan_attention`) which does not need flex-attention.
|
||||
|
||||
``torch.nn.attention.flex_attention`` requires a recent PyTorch build with the relevant
|
||||
inductor support; importing this module on an unsupported build raises ``ImportError``.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
and_masks,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
or_masks,
|
||||
)
|
||||
|
||||
|
||||
class FlexAttnFunc(nn.Module):
|
||||
flex_attn: ClassVar[Callable] = torch.compile(flex_attention, dynamic=True)
|
||||
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
|
||||
attention_mask: ClassVar[BlockMask] = None
|
||||
cross_attention_mask: ClassVar[BlockMask] = None
|
||||
|
||||
def __init__(self, is_cross=False) -> None:
|
||||
super().__init__()
|
||||
self.is_cross = is_cross
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dtype=torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
q_varlen = rearrange(query[0], "s n d -> 1 n s d")
|
||||
k_varlen = rearrange(key[0], "s n d -> 1 n s d")
|
||||
v_varlen = rearrange(value[0], "s n d -> 1 n s d")
|
||||
|
||||
half_dtypes = (torch.float16, torch.bfloat16)
|
||||
assert dtype in half_dtypes
|
||||
|
||||
def half(x):
|
||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||
|
||||
q_varlen = half(q_varlen)
|
||||
k_varlen = half(k_varlen)
|
||||
v_varlen = half(v_varlen)
|
||||
q_varlen = q_varlen.to(v_varlen.dtype)
|
||||
k_varlen = k_varlen.to(v_varlen.dtype)
|
||||
|
||||
block_mask = FlexAttnFunc.cross_attention_mask if self.is_cross else FlexAttnFunc.attention_mask
|
||||
|
||||
x_out = FlexAttnFunc.flex_attn(
|
||||
q_varlen,
|
||||
k_varlen,
|
||||
v_varlen,
|
||||
block_mask=block_mask,
|
||||
kernel_options={
|
||||
"BLOCK_M": 64,
|
||||
"BLOCK_N": 64,
|
||||
"BLOCK_M1": 32,
|
||||
"BLOCK_N1": 64,
|
||||
"BLOCK_M2": 64,
|
||||
"BLOCK_N2": 32,
|
||||
},
|
||||
)
|
||||
|
||||
x_out = rearrange(x_out, "b n s d -> b s n d")
|
||||
return x_out
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def init_mask(
|
||||
latent_shape,
|
||||
action_shape,
|
||||
padded_length,
|
||||
chunk_size,
|
||||
window_size,
|
||||
patch_size,
|
||||
device,
|
||||
):
|
||||
torch._inductor.config.realize_opcount_threshold = 100
|
||||
B, _, L_F, L_H, L_W = latent_shape
|
||||
_, _, A_F, A_H, A_W = action_shape
|
||||
|
||||
latent_seq_id = (
|
||||
torch.arange(B)[:, None, None, None]
|
||||
.expand(-1, L_F // patch_size[0], L_H // patch_size[1], L_W // patch_size[2])
|
||||
.flatten()
|
||||
)
|
||||
action_seq_id = torch.arange(B)[:, None, None, None].expand(-1, A_F, A_H, A_W).flatten()
|
||||
seq_ids = torch.cat([latent_seq_id] * 2 + [action_seq_id] * 2)
|
||||
|
||||
latent_frame_id = (
|
||||
torch.arange(L_F)[None, :, None, None]
|
||||
.expand(B, -1, L_H // patch_size[1], L_W // patch_size[2])[None]
|
||||
.flatten()
|
||||
)
|
||||
action_frame_id = torch.arange(A_F)[None, :, None, None].expand(B, -1, A_H, A_W)[None].flatten()
|
||||
frame_ids = torch.cat(
|
||||
[latent_frame_id // chunk_size * 2] * 2 + [action_frame_id // chunk_size * 2 + 1] * 2
|
||||
)
|
||||
|
||||
noise_ids = torch.cat(
|
||||
[
|
||||
torch.zeros_like(latent_frame_id),
|
||||
torch.ones_like(latent_frame_id),
|
||||
torch.zeros_like(action_frame_id),
|
||||
torch.ones_like(action_frame_id),
|
||||
]
|
||||
)
|
||||
|
||||
seq_ids = F.pad(seq_ids, (0, padded_length), value=-1)
|
||||
frame_ids = F.pad(frame_ids, (0, padded_length), value=-1)
|
||||
noise_ids = F.pad(noise_ids, (0, padded_length), value=-1)
|
||||
|
||||
mask_mod = FlexAttnFunc._get_mask_mod(
|
||||
seq_ids.long().to(device), frame_ids.long().to(device), noise_ids.long().to(device), window_size
|
||||
)
|
||||
block_mask = FlexAttnFunc.compiled_create_block_mask(
|
||||
mask_mod, 1, 1, len(seq_ids), len(seq_ids), device=device, _compile=True
|
||||
)
|
||||
FlexAttnFunc.attention_mask = block_mask
|
||||
|
||||
text_seq_ids = torch.arange(B)[:, None].expand(-1, 512).flatten()
|
||||
mask_mod_cross = FlexAttnFunc._get_cross_mask_mod(
|
||||
seq_ids.long().to(device), text_seq_ids.long().to(device)
|
||||
)
|
||||
block_mask_cross = FlexAttnFunc.compiled_create_block_mask(
|
||||
mask_mod_cross, 1, 1, len(seq_ids), len(text_seq_ids), device=device, _compile=True
|
||||
)
|
||||
FlexAttnFunc.cross_attention_mask = block_mask_cross
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _get_cross_mask_mod(seq_ids, text_seq_ids):
|
||||
def seq_mask(b, h, q_idx, kv_idx):
|
||||
return (
|
||||
(seq_ids[q_idx] == text_seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (text_seq_ids[kv_idx] >= 0)
|
||||
)
|
||||
|
||||
return seq_mask
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _get_mask_mod(seq_ids, frame_ids, noise_ids, window_size):
|
||||
def seq_mask(b, h, q_idx, kv_idx):
|
||||
return (seq_ids[q_idx] == seq_ids[kv_idx]) & (seq_ids[q_idx] >= 0) & (seq_ids[kv_idx] >= 0)
|
||||
|
||||
def block_causal_mask(b, h, q_idx, kv_idx):
|
||||
return frame_ids[kv_idx] <= frame_ids[q_idx]
|
||||
|
||||
def block_causal_mask_exclude_self(b, h, q_idx, kv_idx):
|
||||
return frame_ids[kv_idx] < frame_ids[q_idx]
|
||||
|
||||
def block_self_mask(b, h, q_idx, kv_idx):
|
||||
return frame_ids[kv_idx] == frame_ids[q_idx]
|
||||
|
||||
def clean2clean_mask(b, h, q_idx, kv_idx):
|
||||
return (noise_ids[q_idx] == 1) & (noise_ids[kv_idx] == 1)
|
||||
|
||||
def noise2clean_mask(b, h, q_idx, kv_idx):
|
||||
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 1)
|
||||
|
||||
def noise2noise_mask(b, h, q_idx, kv_idx):
|
||||
return (noise_ids[q_idx] == 0) & (noise_ids[kv_idx] == 0)
|
||||
|
||||
def block_window_mask(b, h, q_idx, kv_idx, window_size: int):
|
||||
return (frame_ids[q_idx] - frame_ids[kv_idx]).abs() <= window_size
|
||||
|
||||
mask_list = []
|
||||
mask_list.append(and_masks(clean2clean_mask, block_causal_mask))
|
||||
mask_list.append(and_masks(noise2clean_mask, block_causal_mask_exclude_self))
|
||||
mask_list.append(and_masks(noise2noise_mask, block_self_mask))
|
||||
mask = or_masks(*mask_list)
|
||||
mask = and_masks(mask, seq_mask)
|
||||
mask = and_masks(mask, partial(block_window_mask, window_size=window_size))
|
||||
return mask
|
||||
|
||||
|
||||
__all__ = ["FlexAttnFunc"]
|
||||
@@ -1,514 +0,0 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The dual-stream Wan2.2 video-action transformer backbone for LingBot-VA.
|
||||
|
||||
Vendored and lightly adapted from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/modules/model.py``).
|
||||
|
||||
The model keeps the diffusers ``ModelMixin``/``ConfigMixin`` mixins so the original
|
||||
sharded ``transformer/`` checkpoint can be loaded with ``from_pretrained`` during
|
||||
conversion, but in LeRobot it is owned as a plain ``nn.Module`` sub-component of
|
||||
:class:`~lerobot.policies.lingbot_va.modeling_lingbot_va.LingBotVAPolicy`. State-dict
|
||||
parameter names are preserved verbatim so conversion is near-identity.
|
||||
"""
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention import FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
PixArtAlphaTextProjection,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import FP32LayerNorm
|
||||
from einops import rearrange
|
||||
|
||||
from .wan_attention import WanAttention, WanRotaryPosEmbed
|
||||
|
||||
__all__ = ["WanTransformer3DModel", "WanTransformerBlock", "WanTimeTextImageEmbedding"]
|
||||
|
||||
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
time_freq_dim,
|
||||
time_proj_dim,
|
||||
text_embed_dim,
|
||||
pos_embed_seq_len,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(
|
||||
num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
def forward(self, timestep: torch.Tensor, dtype=None):
|
||||
B, L = timestep.shape
|
||||
timestep = timestep.reshape(-1)
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).to(dtype=dtype)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1)
|
||||
|
||||
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
self.attn_mode = attn_mode
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=None,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
rotary_emb,
|
||||
update_cache=0,
|
||||
cache_name="pos",
|
||||
) -> torch.Tensor:
|
||||
temb_scale_shift_table = self.scale_shift_table[None] + temb.float()
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = rearrange(
|
||||
temb_scale_shift_table, "b l n c -> b n l c"
|
||||
).chunk(6, dim=1)
|
||||
shift_msa = shift_msa.squeeze(1)
|
||||
scale_msa = scale_msa.squeeze(1)
|
||||
gate_msa = gate_msa.squeeze(1)
|
||||
c_shift_msa = c_shift_msa.squeeze(1)
|
||||
c_scale_msa = c_scale_msa.squeeze(1)
|
||||
c_gate_msa = c_gate_msa.squeeze(1)
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1.0 + scale_msa) + shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
norm_hidden_states,
|
||||
norm_hidden_states,
|
||||
rotary_emb,
|
||||
update_cache=update_cache,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
update_cache=0,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1.0 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""Dual-stream (video + action) Wan2.2 DiT backbone with autoregressive KV caching."""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = [
|
||||
"patch_embedding_mlp",
|
||||
"condition_embedder",
|
||||
"condition_embedder_action",
|
||||
"norm",
|
||||
]
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = [
|
||||
"time_embedder",
|
||||
"scale_shift_table",
|
||||
"scale_shift_table_action",
|
||||
"norm1",
|
||||
"action_norm1",
|
||||
"text_norm1",
|
||||
"norm2",
|
||||
"action_norm2",
|
||||
"text_norm2",
|
||||
"norm3",
|
||||
"action_norm3",
|
||||
"text_norm3",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
in_channels=48,
|
||||
out_channels=48,
|
||||
action_dim=30,
|
||||
text_dim=4096,
|
||||
freq_dim=256,
|
||||
ffn_dim=14336,
|
||||
num_layers=30,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-06,
|
||||
rope_max_seq_len=1024,
|
||||
pos_embed_seq_len=None,
|
||||
attn_mode="torch",
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding_mlp = nn.Linear(
|
||||
in_channels * patch_size[0] * patch_size[1] * patch_size[2], inner_dim
|
||||
)
|
||||
self.action_embedder = nn.Linear(action_dim, inner_dim)
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
self.condition_embedder_action = deepcopy(self.condition_embedder)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, cross_attn_norm, eps, attn_mode=attn_mode
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.action_proj_out = nn.Linear(inner_dim, action_dim)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# KV-cache management for autoregressive streaming inference
|
||||
# ------------------------------------------------------------------
|
||||
def clear_cache(self, cache_name):
|
||||
for block in self.blocks:
|
||||
block.attn1.clear_cache(cache_name)
|
||||
|
||||
def clear_pred_cache(self, cache_name):
|
||||
for block in self.blocks:
|
||||
block.attn1.clear_pred_cache(cache_name)
|
||||
|
||||
def create_empty_cache(
|
||||
self,
|
||||
cache_name,
|
||||
attn_window,
|
||||
latent_token_per_chunk,
|
||||
action_token_per_chunk,
|
||||
device,
|
||||
dtype,
|
||||
batch_size,
|
||||
):
|
||||
total_tolen = (attn_window // 2) * latent_token_per_chunk + (
|
||||
attn_window // 2
|
||||
) * action_token_per_chunk
|
||||
for block in self.blocks:
|
||||
block.attn1.init_kv_cache(
|
||||
cache_name,
|
||||
total_tolen,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
device,
|
||||
dtype,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Embedding helpers (shared by train + inference paths)
|
||||
# ------------------------------------------------------------------
|
||||
def _input_embed(self, latents, input_type="latent"):
|
||||
if input_type == "latent":
|
||||
hidden_states = rearrange(
|
||||
latents,
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self.patch_size[0],
|
||||
p2=self.patch_size[1],
|
||||
p3=self.patch_size[2],
|
||||
)
|
||||
hidden_states = self.patch_embedding_mlp(hidden_states)
|
||||
elif input_type == "action":
|
||||
hidden_states = rearrange(latents, "b c f h w -> b (f h w) c")
|
||||
hidden_states = self.action_embedder(hidden_states)
|
||||
elif input_type == "text":
|
||||
hidden_states = self.condition_embedder.text_embedder(latents)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {input_type}")
|
||||
return hidden_states
|
||||
|
||||
def _time_embed(self, timesteps, H, W, dtype, action_mode=False):
|
||||
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
latent_time_steps = torch.repeat_interleave(
|
||||
timesteps, (H // pach_scale_h) * (W // pach_scale_w), dim=1
|
||||
)
|
||||
current_condition_embedder = (
|
||||
self.condition_embedder_action if action_mode else self.condition_embedder
|
||||
)
|
||||
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=dtype)
|
||||
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
|
||||
return temb, timestep_proj
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dual-stream training forward (flow matching). Requires attn_mode='flex'.
|
||||
# ------------------------------------------------------------------
|
||||
def forward_train(self, input_dict):
|
||||
from .wan_flex_attention import FlexAttnFunc
|
||||
|
||||
input_dict["latent_dict"]["noisy_latents"] = input_dict["latent_dict"]["noisy_latents"].to(
|
||||
torch.bfloat16
|
||||
)
|
||||
input_dict["latent_dict"]["latent"] = input_dict["latent_dict"]["latent"].to(torch.bfloat16)
|
||||
input_dict["action_dict"]["noisy_latents"] = input_dict["action_dict"]["noisy_latents"].to(
|
||||
torch.bfloat16
|
||||
)
|
||||
input_dict["action_dict"]["latent"] = input_dict["action_dict"]["latent"].to(torch.bfloat16)
|
||||
|
||||
latent_dict = input_dict["latent_dict"]
|
||||
action_dict = input_dict["action_dict"]
|
||||
batch_size = latent_dict["noisy_latents"].shape[0]
|
||||
|
||||
latent_hidden_states = self._input_embed(latent_dict["noisy_latents"], input_type="latent").flatten(
|
||||
0, 1
|
||||
)[None]
|
||||
action_hidden_states = self._input_embed(action_dict["noisy_latents"], input_type="action").flatten(
|
||||
0, 1
|
||||
)[None]
|
||||
text_hidden_states = self._input_embed(latent_dict["text_emb"], input_type="text")
|
||||
|
||||
text_hidden_states = text_hidden_states.flatten(0, 1)[None]
|
||||
|
||||
condition_latent_hidden_states = self._input_embed(
|
||||
latent_dict["latent"], input_type="latent"
|
||||
).flatten(0, 1)[None]
|
||||
condition_action_hidden_states = self._input_embed(
|
||||
action_dict["latent"], input_type="action"
|
||||
).flatten(0, 1)[None]
|
||||
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
latent_hidden_states,
|
||||
condition_latent_hidden_states,
|
||||
action_hidden_states,
|
||||
condition_action_hidden_states,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
latent_grid_id = latent_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
|
||||
action_grid_id = action_dict["grid_id"].permute(1, 0, 2).flatten(1)[None]
|
||||
full_grid_id = torch.cat([latent_grid_id] * 2 + [action_grid_id] * 2, dim=2)
|
||||
|
||||
rotary_emb = self.rope(full_grid_id)[:, :, None]
|
||||
|
||||
latent_time_steps = torch.cat(
|
||||
[latent_dict["timesteps"].flatten(0, 1), latent_dict["cond_timesteps"].flatten(0, 1)]
|
||||
)[None]
|
||||
action_time_steps = torch.cat(
|
||||
[action_dict["timesteps"].flatten(0, 1), action_dict["cond_timesteps"].flatten(0, 1)]
|
||||
)[None]
|
||||
latent_temb, latent_timestep_proj = self._time_embed(
|
||||
latent_time_steps,
|
||||
latent_dict["noisy_latents"].shape[-2],
|
||||
latent_dict["noisy_latents"].shape[-1],
|
||||
dtype=hidden_states.dtype,
|
||||
action_mode=False,
|
||||
)
|
||||
action_temb, action_timestep_proj = self._time_embed(
|
||||
action_time_steps,
|
||||
action_dict["noisy_latents"].shape[-2],
|
||||
action_dict["noisy_latents"].shape[-1],
|
||||
dtype=hidden_states.dtype,
|
||||
action_mode=True,
|
||||
)
|
||||
temb = torch.cat([latent_temb, action_temb], dim=1)
|
||||
timestep_proj = torch.cat([latent_timestep_proj, action_timestep_proj], dim=1)
|
||||
|
||||
total_length = hidden_states.shape[1]
|
||||
padded_length = (128 - total_length % 128) % 128
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, padded_length))
|
||||
rotary_emb = F.pad(rotary_emb, (0, 0, 0, 0, 0, padded_length))
|
||||
temb = F.pad(temb, (0, 0, 0, padded_length))
|
||||
timestep_proj = F.pad(timestep_proj, (0, 0, 0, 0, 0, padded_length))
|
||||
|
||||
split_list = [
|
||||
latent_hidden_states.shape[1],
|
||||
condition_latent_hidden_states.shape[1],
|
||||
action_hidden_states.shape[1],
|
||||
condition_action_hidden_states.shape[1],
|
||||
padded_length,
|
||||
]
|
||||
|
||||
FlexAttnFunc.init_mask(
|
||||
latent_dict["noisy_latents"].shape,
|
||||
action_dict["noisy_latents"].shape,
|
||||
padded_length,
|
||||
input_dict["chunk_size"],
|
||||
window_size=input_dict["window_size"],
|
||||
patch_size=self.patch_size,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states, text_hidden_states, timestep_proj, rotary_emb, update_cache=False
|
||||
)
|
||||
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
|
||||
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
|
||||
shift = shift.to(hidden_states.device).squeeze(1)
|
||||
scale = scale.to(hidden_states.device).squeeze(1)
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1.0 + scale) + shift).type_as(hidden_states)
|
||||
latent_hidden_states, _, action_hidden_states, _, _ = torch.split(hidden_states, split_list, dim=1)
|
||||
latent_hidden_states = self.proj_out(latent_hidden_states)
|
||||
latent_hidden_states = rearrange(
|
||||
latent_hidden_states, "1 (b l) (n c) -> b (l n) c", n=math.prod(self.patch_size), b=batch_size
|
||||
)
|
||||
action_hidden_states = self.action_proj_out(action_hidden_states)
|
||||
action_hidden_states = rearrange(action_hidden_states, "1 (b l) c -> b l c", b=batch_size)
|
||||
|
||||
return latent_hidden_states, action_hidden_states
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Single-stream inference forward (one denoising step for one stream)
|
||||
# ------------------------------------------------------------------
|
||||
def forward(
|
||||
self,
|
||||
input_dict,
|
||||
update_cache=0,
|
||||
cache_name="pos",
|
||||
action_mode=False,
|
||||
train_mode=False,
|
||||
):
|
||||
if train_mode:
|
||||
return self.forward_train(input_dict)
|
||||
if action_mode: # action input emb
|
||||
latent_hidden_states = rearrange(input_dict["noisy_latents"], "b c f h w -> b (f h w) c")
|
||||
latent_hidden_states = self.action_embedder(latent_hidden_states) # B L1 C
|
||||
else: # latent input emb
|
||||
latent_hidden_states = rearrange(
|
||||
input_dict["noisy_latents"],
|
||||
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
||||
p1=self.patch_size[0],
|
||||
p2=self.patch_size[1],
|
||||
p3=self.patch_size[2],
|
||||
)
|
||||
latent_hidden_states = self.patch_embedding_mlp(latent_hidden_states)
|
||||
text_hidden_states = self.condition_embedder.text_embedder(input_dict["text_emb"]) # B L2 C
|
||||
|
||||
latent_grid_id = input_dict["grid_id"]
|
||||
rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
|
||||
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (self.patch_size[1], self.patch_size[2])
|
||||
|
||||
latent_time_steps = torch.repeat_interleave(
|
||||
input_dict["timesteps"],
|
||||
(input_dict["noisy_latents"].shape[-2] // pach_scale_h)
|
||||
* (input_dict["noisy_latents"].shape[-1] // pach_scale_w),
|
||||
dim=1,
|
||||
) # L
|
||||
current_condition_embedder = (
|
||||
self.condition_embedder_action if action_mode else self.condition_embedder
|
||||
)
|
||||
temb, timestep_proj = current_condition_embedder(latent_time_steps, dtype=latent_hidden_states.dtype)
|
||||
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
|
||||
|
||||
for block in self.blocks:
|
||||
latent_hidden_states = block(
|
||||
latent_hidden_states,
|
||||
text_hidden_states,
|
||||
timestep_proj,
|
||||
rotary_emb,
|
||||
update_cache=update_cache,
|
||||
cache_name=cache_name,
|
||||
)
|
||||
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
|
||||
shift, scale = rearrange(temb_scale_shift_table, "b l n c -> b n l c").chunk(2, dim=1)
|
||||
shift = shift.to(latent_hidden_states.device).squeeze(1)
|
||||
scale = scale.to(latent_hidden_states.device).squeeze(1)
|
||||
latent_hidden_states = (self.norm_out(latent_hidden_states.float()) * (1.0 + scale) + shift).type_as(
|
||||
latent_hidden_states
|
||||
)
|
||||
|
||||
if action_mode:
|
||||
latent_hidden_states = self.action_proj_out(latent_hidden_states)
|
||||
else:
|
||||
latent_hidden_states = self.proj_out(latent_hidden_states)
|
||||
latent_hidden_states = rearrange(
|
||||
latent_hidden_states, "b l (n c) -> b (l n) c", n=math.prod(self.patch_size)
|
||||
)
|
||||
|
||||
return latent_hidden_states
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Grid-id / patch utilities for the LingBot-VA autoregressive inference loop.
|
||||
|
||||
Vendored verbatim from the upstream LingBot-VA repository
|
||||
(https://github.com/Robbyant/lingbot-va, ``wan_va/utils/utils.py``).
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["get_mesh_id", "data_seq_to_patch"]
|
||||
|
||||
|
||||
def data_seq_to_patch(patch_size, data_seq, latent_num_frames, latent_height, latent_width, batch_size=1):
|
||||
"""Reshape a flattened patch sequence back into a ``(B, C, F, H, W)`` latent grid."""
|
||||
p_t, p_h, p_w = patch_size
|
||||
post_patch_num_frames = latent_num_frames // p_t
|
||||
post_patch_height = latent_height // p_h
|
||||
post_patch_width = latent_width // p_w
|
||||
|
||||
data_patch = data_seq.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
data_patch = data_patch.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
data_patch = data_patch.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return data_patch
|
||||
|
||||
|
||||
def get_mesh_id(f, h, w, t, f_w=1, f_shift=0, action=False):
|
||||
"""Build the (frame, height, width, stream) grid ids used to index the rotary embedding."""
|
||||
f_idx = torch.arange(f_shift, f + f_shift) * f_w
|
||||
h_idx = torch.arange(h)
|
||||
w_idx = torch.arange(w)
|
||||
ff, hh, ww = torch.meshgrid(f_idx, h_idx, w_idx, indexing="ij")
|
||||
if action:
|
||||
ff_offset = (torch.ones([h]).cumsum(0) / (h + 1)).view(1, -1, 1)
|
||||
ff = ff + ff_offset
|
||||
hh = torch.ones_like(hh) * -1
|
||||
ww = torch.ones_like(ww) * -1
|
||||
|
||||
grid_id = torch.cat([ff.unsqueeze(0), hh.unsqueeze(0), ww.unsqueeze(0)], dim=0).flatten(1)
|
||||
grid_id = torch.cat([grid_id, torch.full_like(grid_id[:1], t)], dim=0)
|
||||
return grid_id
|
||||
@@ -1,120 +0,0 @@
|
||||
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Thin helpers around the stock diffusers ``AutoencoderKLWan`` (Wan2.2, ``z_dim=48``).
|
||||
|
||||
The VAE class itself is NOT vendored — it lives in ``diffusers>=0.36``. This module
|
||||
provides:
|
||||
* loaders for the VAE / text encoder / tokenizer / transformer sub-checkpoints,
|
||||
* the streaming-encoder wrapper used for autoregressive frame-by-frame VAE encoding
|
||||
(it caches the causal-conv state across chunks),
|
||||
* latent (de)normalization helpers using the VAE's ``latents_mean`` / ``latents_std``.
|
||||
|
||||
Vendored and adapted from ``wan_va/modules/utils.py`` upstream.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"WanVAEStreamingWrapper",
|
||||
"load_vae",
|
||||
"load_text_encoder",
|
||||
"load_tokenizer",
|
||||
"normalize_latents",
|
||||
"denormalize_latents",
|
||||
"patchify",
|
||||
]
|
||||
|
||||
|
||||
def load_vae(vae_path, torch_dtype, torch_device):
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(vae_path, torch_dtype=torch_dtype)
|
||||
return vae.to(torch_device)
|
||||
|
||||
|
||||
def load_text_encoder(text_encoder_path, torch_dtype, torch_device):
|
||||
from transformers import UMT5EncoderModel
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(text_encoder_path, torch_dtype=torch_dtype)
|
||||
return text_encoder.to(torch_device)
|
||||
|
||||
|
||||
def load_tokenizer(tokenizer_path):
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
return T5TokenizerFast.from_pretrained(tokenizer_path)
|
||||
|
||||
|
||||
def patchify(x, patch_size):
|
||||
if patch_size is None or patch_size == 1:
|
||||
return x
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
x = x.view(
|
||||
batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size
|
||||
)
|
||||
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
||||
x = x.view(
|
||||
batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Apply ``(x - mean) * std`` channel-wise (note: upstream passes ``1/std`` as ``latents_std``)."""
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
|
||||
latents = ((latents.float() - latents_mean) * latents_std).to(latents)
|
||||
return latents
|
||||
|
||||
|
||||
def denormalize_latents(latents: torch.Tensor, latents_mean, latents_std, z_dim) -> torch.Tensor:
|
||||
"""Inverse of the normalization applied at encode time, for VAE decoding of predicted latents."""
|
||||
mean = torch.tensor(latents_mean).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
inv_std = 1.0 / torch.tensor(latents_std).view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
return latents / inv_std + mean
|
||||
|
||||
|
||||
class WanVAEStreamingWrapper:
|
||||
"""Wraps an ``AutoencoderKLWan`` encoder to support causal streaming encoding across chunks."""
|
||||
|
||||
def __init__(self, vae_model):
|
||||
self.vae = vae_model
|
||||
self.encoder = vae_model.encoder
|
||||
self.quant_conv = vae_model.quant_conv
|
||||
|
||||
if hasattr(self.vae, "_cached_conv_counts"):
|
||||
self.enc_conv_num = self.vae._cached_conv_counts["encoder"]
|
||||
else:
|
||||
count = 0
|
||||
for m in self.encoder.modules():
|
||||
if m.__class__.__name__ == "WanCausalConv3d":
|
||||
count += 1
|
||||
self.enc_conv_num = count
|
||||
|
||||
self.clear_cache()
|
||||
|
||||
def clear_cache(self):
|
||||
self.feat_cache = [None] * self.enc_conv_num
|
||||
|
||||
def encode_chunk(self, x_chunk):
|
||||
if hasattr(self.vae.config, "patch_size") and self.vae.config.patch_size is not None:
|
||||
x_chunk = patchify(x_chunk, self.vae.config.patch_size)
|
||||
feat_idx = [0]
|
||||
out = self.encoder(x_chunk, feat_cache=self.feat_cache, feat_idx=feat_idx)
|
||||
enc = self.quant_conv(out)
|
||||
return enc
|
||||
@@ -76,8 +76,3 @@ def test_validate_features_no_visual_raises() -> None:
|
||||
def test_invalid_attn_mode_raises() -> None:
|
||||
with pytest.raises(ValueError, match="attn_mode"):
|
||||
make_config(attn_mode="banana")
|
||||
|
||||
|
||||
def test_quantile_length_mismatch_raises() -> None:
|
||||
with pytest.raises(ValueError, match="action_q01"):
|
||||
make_config(used_action_channel_ids=[0, 1, 2], action_q01=[0.0, 0.0], action_q99=[1.0, 1.0, 1.0])
|
||||
|
||||
@@ -36,17 +36,3 @@ def test_get_policy_class_resolves_lazily() -> None:
|
||||
cls = get_policy_class("lingbot_va")
|
||||
assert cls.name == "lingbot_va"
|
||||
assert cls.config_class is LingBotVAConfig
|
||||
|
||||
|
||||
def test_convert_build_config_libero() -> None:
|
||||
pytest.importorskip("diffusers")
|
||||
from lerobot.policies.lingbot_va.convert_lingbot_va_checkpoints import build_config
|
||||
|
||||
cfg = build_config("libero", wan_pretrained_path="dummy/path", dtype="float32")
|
||||
assert cfg.height == 128 and cfg.width == 128
|
||||
assert cfg.used_action_channel_ids == list(range(7))
|
||||
# validate_features (called inside build_config) must have populated the action feature.
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
assert cfg.output_features[ACTION].shape == (7,)
|
||||
assert len(cfg.obs_cam_keys) == 2
|
||||
|
||||
@@ -14,14 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pure-torch unit tests for the vendored LingBot-VA helper modules (no diffusers needed)."""
|
||||
"""Unit tests for the vendored LingBot-VA helper code (scheduler + grid utilities)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.lingbot_va.schedulers import FlowMatchScheduler
|
||||
from lerobot.policies.lingbot_va.wan_utils import data_seq_to_patch, get_mesh_id
|
||||
pytest.importorskip("diffusers") # the model code lives in modeling_lingbot_va, which imports diffusers
|
||||
|
||||
from lerobot.policies.lingbot_va.modeling_lingbot_va import ( # noqa: E402
|
||||
FlowMatchScheduler,
|
||||
data_seq_to_patch,
|
||||
get_mesh_id,
|
||||
)
|
||||
|
||||
|
||||
def test_flow_match_scheduler_timesteps_monotone_decreasing() -> None:
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.lingbot_va.configuration_lingbot_va import LingBotVAConfig
|
||||
from lerobot.policies.lingbot_va.processor_lingbot_va import (
|
||||
LIBERO_ACTION_Q01,
|
||||
LingBotVAActionUnnormalizeStep,
|
||||
make_lingbot_va_pre_post_processors,
|
||||
)
|
||||
@@ -75,7 +76,7 @@ def test_make_pre_post_processors_names_and_steps() -> None:
|
||||
def test_postprocessor_applies_unnormalization() -> None:
|
||||
cfg = _make_config()
|
||||
_, post = make_lingbot_va_pre_post_processors(cfg, dataset_stats=None)
|
||||
# A normalized action of all -1 should map back to q01.
|
||||
# A normalized action of all -1 should map back to q01 (the LIBERO 7-DoF default quantiles).
|
||||
normed = torch.full((1, len(cfg.used_action_channel_ids)), -1.0)
|
||||
out = post(normed)
|
||||
assert torch.allclose(out, torch.tensor(cfg.action_q01).unsqueeze(0), atol=1e-4)
|
||||
assert torch.allclose(out, torch.tensor(LIBERO_ACTION_Q01).unsqueeze(0), atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user