re-parenting of some layers to enable proper zero-3 FSDP

This commit is contained in:
Maxime Ellerbach
2026-06-15 12:11:27 +00:00
parent 56a82e578d
commit 4b6fa7d491
7 changed files with 376 additions and 314 deletions
@@ -68,7 +68,7 @@ def default_video_dit_config(action_dim: int) -> dict[str, Any]:
"attn_head_dim": 128,
"num_layers": 30,
"eps": 1.0e-6,
"seperated_timestep": True,
"separated_timestep": True,
"use_gradient_checkpointing": False,
"video_attention_mask_mode": "first_frame_causal",
"action_conditioned": False,
@@ -296,8 +296,7 @@ class FastWAMConfig(PreTrainedConfig):
image_keys = sorted(
key
for key, feature in dataset_features.items()
if key.startswith("observation.images.")
and feature.get("dtype") in ("video", "image")
if key.startswith("observation.images.") and feature.get("dtype") in ("video", "image")
)
if not image_keys:
return
@@ -76,6 +76,13 @@ class FastWAMPolicy(PreTrainedPolicy):
# Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad,
# so its params drop out of the optimizer (and DDP skips them).
self.model.video_expert.requires_grad_(False)
# The transformer blocks are re-parented onto the MoTLayers (single FSDP owner), so
# `video_expert.requires_grad_` no longer reaches them — freeze them via the layers.
mot = getattr(self.model, "mot", None)
if mot is not None and getattr(mot, "layers", None) is not None:
for layer in mot.layers:
if "video" in layer.blocks:
layer.blocks["video"].requires_grad_(False)
self.reset()
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
# counts only eval-rollout resets (one per episode), not this __init__ one.
@@ -354,7 +361,9 @@ class FastWAMPolicy(PreTrainedPolicy):
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
write_video(path, frames, fps=30)
logging.info("FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path)
logging.info(
"FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path
)
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
"""Build the FastWAM core for training / inference.
@@ -485,9 +494,7 @@ def batch_device(batch: dict[str, Any]) -> torch.device:
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
image_keys = sorted(
k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad")
)
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
if not image_keys:
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
images = [batch[key] for key in image_keys]
+324 -293
View File
@@ -16,6 +16,7 @@ from __future__ import annotations
import logging
import os
import re
from collections.abc import Sequence
from typing import Any
@@ -258,56 +259,39 @@ class ActionDiT(nn.Module):
return self.post_dit(x, pre_state)
class MoT(nn.Module):
class MoTLayer(nn.Module):
"""A single MoT layer: owns one transformer block per expert and runs the cross-expert
mixed-attention step for that layer.
This exists as a module — rather than the per-layer work being inlined in ``MoT``'s loop —
so FSDP can wrap each layer as its own unit. FSDP all-gathers a wrapped module's sharded
parameters via a hook on that module's ``forward``/``__call__``. ``MoT`` drives block
submodules directly (the joint mixed attention concatenates Q/K/V across experts, so no
single block's ``forward`` is ever called), so ``MoTLayer.forward`` is the only call
boundary FSDP can hook. All three per-layer paths therefore dispatch through
``forward(mode=...)`` so each enters via ``__call__``.
"""
def __init__(
self,
mixtures: dict[str, nn.Module],
mot_checkpoint_mixed_attn: bool = True,
blocks: dict[str, nn.Module],
experts: dict[str, nn.Module],
num_heads: int,
attn_head_dim: int,
fp32_attention: bool,
mot_checkpoint_mixed_attn: bool,
):
super().__init__()
if not mixtures:
raise ValueError("`mixtures` cannot be empty.")
if "video" not in mixtures or "action" not in mixtures:
raise ValueError("`mixtures` must include both 'video' and 'action' experts.")
self.mixtures = nn.ModuleDict(mixtures)
self.expert_order = list(self.mixtures.keys())
self.mot_checkpoint_mixed_attn = mot_checkpoint_mixed_attn
if mot_checkpoint_mixed_attn:
logger.info(
"Using gradient checkpointing for mixture attention. This will save memory but use more computation."
)
first_expert = self.mixtures[self.expert_order[0]]
self.num_layers = len(first_expert.blocks)
self.num_heads = first_expert.num_heads
self.attn_head_dim = first_expert.attn_head_dim
self.fp32_attention = bool(getattr(first_expert, "fp32_attention", True))
for name in self.expert_order[1:]:
expert = self.mixtures[name]
if len(expert.blocks) != self.num_layers:
raise ValueError(
f"All experts must have same number of layers; got {self.num_layers} and {len(expert.blocks)}"
)
if expert.num_heads != self.num_heads:
raise ValueError(
f"All experts must have same num_heads; got {self.num_heads} and {expert.num_heads}"
)
if expert.attn_head_dim != self.attn_head_dim:
raise ValueError(
"All experts must have same attn_head_dim; "
f"got {self.attn_head_dim} and {expert.attn_head_dim}"
)
if bool(getattr(expert, "fp32_attention", True)) != self.fp32_attention:
raise ValueError("All experts must use the same `fp32_attention` setting.")
logger.info(f"Initialized MoT with experts: {self.expert_order}, num_layers={self.num_layers}")
for name in self.expert_order:
expert = self.mixtures[name]
logger.info(
f" Expert '{name}': num_params={sum(p.numel() for p in expert.parameters()) / 1e9:.2f} B"
)
# Registered owner of this layer's blocks (one per expert) — the FSDP wrap unit.
self.blocks = nn.ModuleDict(blocks)
self.expert_order = list(blocks.keys())
# Unregistered back-references to the experts: used only to read the live
# `use_gradient_checkpointing` flag, kept out of parameters()/state_dict().
object.__setattr__(self, "_experts", dict(experts))
self.num_heads = num_heads
self.attn_head_dim = attn_head_dim
self.fp32_attention = bool(fp32_attention)
self.mot_checkpoint_mixed_attn = bool(mot_checkpoint_mixed_attn)
@staticmethod
def _split_modulation(block, t_mod: torch.Tensor):
@@ -394,43 +378,17 @@ class MoT(nn.Module):
def _build_expert_attention_io(
self,
expert,
block,
name: str,
x: torch.Tensor,
freqs: torch.Tensor | dict[str, torch.Tensor],
t_mod: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
bool,
]:
"""Build per-expert attention tensors and post-block states.
):
"""Build this expert's attention tensors and post-block states for the layer.
Args:
expert: Expert module that owns this `block`; only used to read
`use_gradient_checkpointing`.
block: Transformer block for current layer (`expert.blocks[layer_idx]`).
x: Current expert tokens, shape [B, S, D].
freqs: RoPE frequencies aligned with token sequence, shape [S, 1, rope_dim].
t_mod: Time modulation tensor for this expert/layer.
Returns:
q: Query after q-proj, RMSNorm, and RoPE, shape [B, S, H*Dh].
k: Key after k-proj, RMSNorm, and RoPE, shape [B, S, H*Dh].
v: Value after v-proj, shape [B, S, H*Dh].
residual_x: Original input `x` for residual path in post block.
gate_msa: Gating tensor for self-attention residual branch.
shift_mlp: Shift tensor for MLP modulation.
scale_mlp: Scale tensor for MLP modulation.
gate_mlp: Gating tensor for MLP residual branch.
use_gradient_checkpointing: Whether this expert enables checkpointing.
Returns (q, k, v, residual_x, gate_msa, shift_mlp, scale_mlp, gate_mlp, use_gc).
"""
block = self.blocks[name]
expert = self._experts[name]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self._split_modulation(block, t_mod)
attn_input = modulate(_apply_block_norm(block, "norm1", x), shift_msa, scale_msa)
@@ -461,25 +419,6 @@ class MoT(nn.Module):
mixed_slice: torch.Tensor,
context_payload: dict | None,
) -> torch.Tensor:
"""Apply post-attention computations, with optional checkpointing.
Args:
block: Transformer block for current layer.
residual_x: Residual input tokens before attention update, shape [B, S, D].
gate_msa: Gating tensor used after mixed self-attention.
shift_mlp: Shift tensor for MLP input modulation.
scale_mlp: Scale tensor for MLP input modulation.
gate_mlp: Gating tensor used after MLP.
use_gradient_checkpointing: If True and training, checkpoint this post block.
mixed_slice: Mixed-attention output for this expert, shape [B, S, H*Dh].
context_payload: Optional dict for cross-attention.
- `context`: encoder states [B, L, D]
- `mask`: attention mask [B, S, L] or [B, 1, S, L]
Returns:
Updated expert tokens after self-attn residual, optional cross-attn, and MLP.
"""
def _post_fn(
_mixed_slice: torch.Tensor,
_x: torch.Tensor,
@@ -521,6 +460,256 @@ class MoT(nn.Module):
gate_mlp,
)
def forward(self, mode: str, **kwargs):
if mode == "joint":
return self._forward_joint(**kwargs)
if mode == "video_prefill":
return self._forward_video_prefill(**kwargs)
if mode == "action_cached":
return self._forward_action_cached(**kwargs)
raise ValueError(f"Unknown MoTLayer forward mode: {mode!r}")
def _forward_joint(
self,
tokens_all: dict[str, torch.Tensor],
attention_mask: torch.Tensor,
freqs_all: dict[str, torch.Tensor],
context_all: dict[str, dict | None],
t_mod_all: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
q_chunks = []
k_chunks = []
v_chunks = []
cached = {}
seq_lens = []
for name in self.expert_order:
(
q,
k,
v,
residual_x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
use_gradient_checkpointing,
) = self._build_expert_attention_io(name, tokens_all[name], freqs_all[name], t_mod_all[name])
q_chunks.append(q)
k_chunks.append(k)
v_chunks.append(v)
seq_lens.append(tokens_all[name].shape[1])
cached[name] = {
"residual_x": residual_x,
"gate_msa": gate_msa,
"shift_mlp": shift_mlp,
"scale_mlp": scale_mlp,
"gate_mlp": gate_mlp,
"use_gradient_checkpointing": use_gradient_checkpointing,
}
q_cat = torch.cat(q_chunks, dim=1)
k_cat = torch.cat(k_chunks, dim=1)
v_cat = torch.cat(v_chunks, dim=1)
total_seq = q_cat.shape[1]
if attention_mask.shape[0] != total_seq:
raise ValueError(
f"Attention mask seq length mismatch: mask={attention_mask.shape[0]} vs tokens={total_seq}"
)
mixed = self._mixed_attention(q_cat=q_cat, k_cat=k_cat, v_cat=v_cat, attention_mask=attention_mask)
out = {}
start = 0
for name, seq_len in zip(self.expert_order, seq_lens, strict=True):
end = start + seq_len
mixed_slice = mixed[:, start:end, :]
cached_expert = cached[name]
out[name] = self._apply_post_with_optional_checkpoint(
block=self.blocks[name],
residual_x=cached_expert["residual_x"],
gate_msa=cached_expert["gate_msa"],
shift_mlp=cached_expert["shift_mlp"],
scale_mlp=cached_expert["scale_mlp"],
gate_mlp=cached_expert["gate_mlp"],
use_gradient_checkpointing=cached_expert["use_gradient_checkpointing"],
mixed_slice=mixed_slice,
context_payload=context_all.get(name),
)
start = end
return out
def _forward_video_prefill(
self,
x: torch.Tensor,
freqs: torch.Tensor,
t_mod: torch.Tensor,
context_payload: dict | None,
video_attention_mask: torch.Tensor,
):
(
q,
k,
v,
residual_x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
use_gradient_checkpointing,
) = self._build_expert_attention_io("video", x, freqs, t_mod)
# Video prefill uses only video self-attention mask.
mixed = self._mixed_attention(q_cat=q, k_cat=k, v_cat=v, attention_mask=video_attention_mask)
x_out = self._apply_post_with_optional_checkpoint(
block=self.blocks["video"],
residual_x=residual_x,
gate_msa=gate_msa,
shift_mlp=shift_mlp,
scale_mlp=scale_mlp,
gate_mlp=gate_mlp,
use_gradient_checkpointing=use_gradient_checkpointing,
mixed_slice=mixed,
context_payload=context_payload,
)
return x_out, k, v
def _forward_action_cached(
self,
x: torch.Tensor,
freqs: torch.Tensor,
t_mod: torch.Tensor,
context_payload: dict | None,
k_video: torch.Tensor,
v_video: torch.Tensor,
action_attention_mask: torch.Tensor,
) -> torch.Tensor:
(
q_action,
k_action,
v_action,
residual_x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
use_gradient_checkpointing,
) = self._build_expert_attention_io("action", x, freqs, t_mod)
# Mixed attention: action queries attend to cached video K/V plus current action K/V.
k_cat = torch.cat([k_video, k_action], dim=1)
v_cat = torch.cat([v_video, v_action], dim=1)
mixed = self._mixed_attention(
q_cat=q_action, k_cat=k_cat, v_cat=v_cat, attention_mask=action_attention_mask
)
return self._apply_post_with_optional_checkpoint(
block=self.blocks["action"],
residual_x=residual_x,
gate_msa=gate_msa,
shift_mlp=shift_mlp,
scale_mlp=scale_mlp,
gate_mlp=gate_mlp,
use_gradient_checkpointing=use_gradient_checkpointing,
mixed_slice=mixed,
context_payload=context_payload,
)
class MoT(nn.Module):
def __init__(
self,
mixtures: dict[str, nn.Module],
mot_checkpoint_mixed_attn: bool = True,
):
super().__init__()
if not mixtures:
raise ValueError("`mixtures` cannot be empty.")
if "video" not in mixtures or "action" not in mixtures:
raise ValueError("`mixtures` must include both 'video' and 'action' experts.")
self.mixtures = nn.ModuleDict(mixtures)
self.expert_order = list(self.mixtures.keys())
self.mot_checkpoint_mixed_attn = mot_checkpoint_mixed_attn
if mot_checkpoint_mixed_attn:
logger.info(
"Using gradient checkpointing for mixture attention. This will save memory but use more computation."
)
first_expert = self.mixtures[self.expert_order[0]]
self.num_layers = len(first_expert.blocks)
self.num_heads = first_expert.num_heads
self.attn_head_dim = first_expert.attn_head_dim
self.fp32_attention = bool(getattr(first_expert, "fp32_attention", True))
for name in self.expert_order[1:]:
expert = self.mixtures[name]
if len(expert.blocks) != self.num_layers:
raise ValueError(
f"All experts must have same number of layers; got {self.num_layers} and {len(expert.blocks)}"
)
if expert.num_heads != self.num_heads:
raise ValueError(
f"All experts must have same num_heads; got {self.num_heads} and {expert.num_heads}"
)
if expert.attn_head_dim != self.attn_head_dim:
raise ValueError(
"All experts must have same attn_head_dim; "
f"got {self.attn_head_dim} and {expert.attn_head_dim}"
)
if bool(getattr(expert, "fp32_attention", True)) != self.fp32_attention:
raise ValueError("All experts must use the same `fp32_attention` setting.")
logger.info(f"Initialized MoT with experts: {self.expert_order}, num_layers={self.num_layers}")
for name in self.expert_order:
expert = self.mixtures[name]
logger.info(
f" Expert '{name}': num_params={sum(p.numel() for p in expert.parameters()) / 1e9:.2f} B"
)
# One MoTLayer per layer, each owning that layer's block from every expert. This is the
# FSDP wrap unit: only MoTLayer.forward is ever called (MoT drives block submodules
# directly for the cross-expert mixed attention), so it is the boundary at which FSDP can
# all-gather a layer's params. The blocks are RE-PARENTED into the layers — removed from
# each expert's module registry — so they have a single owner; leaving them registered
# under both the expert and the layer would make FSDP try to manage the same params twice.
self.layers = nn.ModuleList(
[
MoTLayer(
blocks={name: self.mixtures[name].blocks[layer_idx] for name in self.expert_order},
experts={name: self.mixtures[name] for name in self.expert_order},
num_heads=self.num_heads,
attn_head_dim=self.attn_head_dim,
fp32_attention=self.fp32_attention,
mot_checkpoint_mixed_attn=self.mot_checkpoint_mixed_attn,
)
for layer_idx in range(self.num_layers)
]
)
for name in self.expert_order:
expert = self.mixtures[name]
kept_blocks = list(expert.blocks)
del expert._modules["blocks"]
# Keep an UNREGISTERED reference so the (unused) standalone `expert.forward` and any
# `len(expert.blocks)` still work, without re-adding the params to the expert's
# parameters()/state_dict() (which would double-register them with the MoTLayer owner).
object.__setattr__(expert, "blocks", kept_blocks)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Backward-compat for checkpoints saved before the MoTLayer refactor. Then the per-layer
# blocks were keyed under the experts (`{prefix}mixtures.<name>.blocks.<i>.<rest>`, e.g.
# the released `ZibinDong/fastwam_libero_uncond_2cam224`); now they are owned by the layers
# (`{prefix}layers.<i>.blocks.<name>.<rest>`). Remap legacy keys in place so the recursion
# into `self.layers` finds them and the (now block-less) `self.mixtures` does not flag them.
legacy = re.compile(re.escape(prefix) + r"mixtures\.([^.]+)\.blocks\.(\d+)\.(.+)$")
moved = {}
for key in list(state_dict.keys()):
m = legacy.match(key)
if m is not None:
name, layer_idx, rest = m.group(1), m.group(2), m.group(3)
moved[f"{prefix}layers.{layer_idx}.blocks.{name}.{rest}"] = state_dict.pop(key)
state_dict.update(moved)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def prefill_video_cache(
self,
video_tokens: torch.Tensor,
@@ -531,20 +720,7 @@ class MoT(nn.Module):
) -> list[dict[str, torch.Tensor]]:
"""Prefill video branch once and cache per-layer K/V for action denoising.
Args:
video_tokens: Video tokens before layer 0, shape [B, Sv, D].
video_freqs: Video RoPE frequencies, shape [Sv, 1, rope_dim].
video_t_mod: Video time modulation tensor.
video_context_payload: Optional dict for video cross-attention.
- `context`: encoder states [B, L, D]
- `mask`: attention mask [B, Sv, L] or [B, 1, Sv, L]
video_attention_mask: Video self-attention mask, shape [Sv, Sv].
Returns:
Layer-wise cache list with length `num_layers`.
Each entry contains:
- `k`: video key tensor [B, Sv, H*Dh]
- `v`: video value tensor [B, Sv, H*Dh]
Returns a list of length ``num_layers``, each entry ``{"k": ..., "v": ...}``.
"""
if "video" not in self.mixtures:
raise ValueError("MoT requires `video` expert for `prefill_video_cache`.")
@@ -562,47 +738,16 @@ class MoT(nn.Module):
f"mask={video_attention_mask.shape[0]} vs tokens={video_tokens.shape[1]}"
)
expert = self.mixtures["video"]
x = video_tokens
kv_cache: list[dict[str, torch.Tensor]] = []
for layer_idx in range(self.num_layers):
block = expert.blocks[layer_idx]
# Build video Q/K/V from current layer input tokens.
(
q,
k,
v,
residual_x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
use_gradient_checkpointing,
) = self._build_expert_attention_io(
expert=expert,
block=block,
for layer in self.layers:
x, k, v = layer(
mode="video_prefill",
x=x,
freqs=video_freqs,
t_mod=video_t_mod,
)
# Video prefill uses only video self-attention mask.
mixed = self._mixed_attention(
q_cat=q,
k_cat=k,
v_cat=v,
attention_mask=video_attention_mask,
)
# Update video tokens for the next layer and persist current layer K/V.
x = self._apply_post_with_optional_checkpoint(
block=block,
residual_x=residual_x,
gate_msa=gate_msa,
shift_mlp=shift_mlp,
scale_mlp=scale_mlp,
gate_mlp=gate_mlp,
use_gradient_checkpointing=use_gradient_checkpointing,
mixed_slice=mixed,
context_payload=video_context_payload,
video_attention_mask=video_attention_mask,
)
kv_cache.append({"k": k, "v": v})
return kv_cache
@@ -617,22 +762,7 @@ class MoT(nn.Module):
attention_mask: torch.Tensor,
video_seq_len: int,
) -> torch.Tensor:
"""Run action branch with cached video K/V instead of recomputing video tokens.
Args:
action_tokens: Action tokens before layer 0, shape [B, Sa, D].
action_freqs: Action RoPE frequencies, shape [Sa, 1, rope_dim].
action_t_mod: Action time modulation tensor.
action_context_payload: Optional dict for action cross-attention.
- `context`: encoder states [B, L, D]
- `mask`: attention mask [B, Sa, L] or [B, 1, Sa, L]
video_kv_cache: Layer-wise cached video K/V from `prefill_video_cache`.
attention_mask: Joint [video+action] mask, shape [Sv+Sa, Sv+Sa].
video_seq_len: Video token count `Sv` in the joint sequence prefix.
Returns:
Updated action tokens after all layers, shape [B, Sa, D].
"""
"""Run action branch with cached video K/V instead of recomputing video tokens."""
if "action" not in self.mixtures:
raise ValueError("MoT requires `action` expert for `forward_action_with_video_cache`.")
if len(video_kv_cache) != self.num_layers:
@@ -654,56 +784,24 @@ class MoT(nn.Module):
# Use the action query rows from the joint [video+action] mask.
action_attention_mask = attention_mask[video_seq_len:total_seq_len, :total_seq_len]
expert = self.mixtures["action"]
x = action_tokens
for layer_idx in range(self.num_layers):
block = expert.blocks[layer_idx]
# Action query/key/value are still step-dependent and must be recomputed each step.
(
q_action,
k_action,
v_action,
residual_x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
use_gradient_checkpointing,
) = self._build_expert_attention_io(
expert=expert,
block=block,
x=x,
freqs=action_freqs,
t_mod=action_t_mod,
)
for layer_idx, layer in enumerate(self.layers):
layer_cache = video_kv_cache[layer_idx]
if "k" not in layer_cache or "v" not in layer_cache:
raise ValueError(f"`video_kv_cache[{layer_idx}]` must contain `k` and `v`.")
k_video = layer_cache["k"]
v_video = layer_cache["v"]
if k_video.shape[1] != video_seq_len or v_video.shape[1] != video_seq_len:
raise ValueError(f"`video_kv_cache[{layer_idx}]` seq len mismatch, expected {video_seq_len}.")
# Mixed attention: action queries attend to cached video K/V plus current action K/V.
k_cat = torch.cat([k_video, k_action], dim=1)
v_cat = torch.cat([v_video, v_action], dim=1)
mixed = self._mixed_attention(
q_cat=q_action,
k_cat=k_cat,
v_cat=v_cat,
attention_mask=action_attention_mask,
)
x = self._apply_post_with_optional_checkpoint(
block=block,
residual_x=residual_x,
gate_msa=gate_msa,
shift_mlp=shift_mlp,
scale_mlp=scale_mlp,
gate_mlp=gate_mlp,
use_gradient_checkpointing=use_gradient_checkpointing,
mixed_slice=mixed,
x = layer(
mode="action_cached",
x=x,
freqs=action_freqs,
t_mod=action_t_mod,
context_payload=action_context_payload,
k_video=k_video,
v_video=v_video,
action_attention_mask=action_attention_mask,
)
return x
@@ -730,94 +828,18 @@ class MoT(nn.Module):
if attention_mask.shape[0] != attention_mask.shape[1]:
raise ValueError(f"`attention_mask` must be square, got shape {tuple(attention_mask.shape)}")
# Each layer is a MoTLayer module; entering via __call__ lets FSDP all-gather that
# layer's params (the whole point of the per-layer split).
tokens_all = dict(embeds_all)
for layer_idx in range(self.num_layers):
q_chunks = []
k_chunks = []
v_chunks = []
cached = {}
seq_lens = []
for name in self.expert_order:
expert = self.mixtures[name]
block = expert.blocks[layer_idx]
x = tokens_all[name]
freqs = freqs_all[name]
t_mod = t_mod_all[name]
(
q,
k,
v,
residual_x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
use_gradient_checkpointing,
) = self._build_expert_attention_io(
expert=expert,
block=block,
x=x,
freqs=freqs,
t_mod=t_mod,
)
q_chunks.append(q)
k_chunks.append(k)
v_chunks.append(v)
seq_lens.append(x.shape[1])
cached[name] = {
"block": block,
"residual_x": residual_x,
"gate_msa": gate_msa,
"shift_mlp": shift_mlp,
"scale_mlp": scale_mlp,
"gate_mlp": gate_mlp,
"use_gradient_checkpointing": use_gradient_checkpointing,
}
# 3. concat all tokens for mixed attention
q_cat = torch.cat(q_chunks, dim=1)
k_cat = torch.cat(k_chunks, dim=1)
v_cat = torch.cat(v_chunks, dim=1)
total_seq = q_cat.shape[1]
if attention_mask.shape[0] != total_seq:
raise ValueError(
"Attention mask seq length mismatch: "
f"mask={attention_mask.shape[0]} vs tokens={total_seq}"
)
mixed = self._mixed_attention(
q_cat=q_cat, k_cat=k_cat, v_cat=v_cat, attention_mask=attention_mask
for layer in self.layers:
tokens_all = layer(
mode="joint",
tokens_all=tokens_all,
attention_mask=attention_mask,
freqs_all=freqs_all,
context_all=context_all,
t_mod_all=t_mod_all,
)
start = 0
for name, seq_len in zip(self.expert_order, seq_lens, strict=True):
# 4. split mixed attention output and apply post-attention blocks for each expert
end = start + seq_len
mixed_slice = mixed[:, start:end, :]
cached_expert = cached[name]
block = cached_expert["block"]
context_payload = context_all.get(name)
updated_tokens = self._apply_post_with_optional_checkpoint(
block=block,
residual_x=cached_expert["residual_x"],
gate_msa=cached_expert["gate_msa"],
shift_mlp=cached_expert["shift_mlp"],
scale_mlp=cached_expert["scale_mlp"],
gate_mlp=cached_expert["gate_mlp"],
use_gradient_checkpointing=cached_expert["use_gradient_checkpointing"],
mixed_slice=mixed_slice,
context_payload=context_payload,
)
tokens_all[name] = updated_tokens
start = end
return tokens_all
@@ -846,11 +868,20 @@ class FastWAM(torch.nn.Module):
loss_lambda_action: float = 1.0,
):
super().__init__()
self.video_expert = video_expert
self.action_expert = action_expert
self.mot = mot
# Keep trainer compatibility: optimizer and freeze logic use `model.dit`.
self.dit = self.mot
# `video_expert` / `action_expert` are the very same module objects as
# `mot.mixtures["video"]` / `["action"]`, and `dit` is an alias of `mot`. Registering
# them as submodules too would give every expert tensor three names in `state_dict()`
# (`video_expert.*`, `mot.mixtures.video.*`, `dit.mixtures.video.*`) — a 3x-bloated
# gathered FSDP checkpoint and a doubled module tree for FSDP to traverse. Hold them as
# plain (unregistered) attributes instead — bypassing `nn.Module.__setattr__`, like the
# frozen vae/text_encoder below — so `mot` is the single registered owner and each tensor
# has one canonical name (`mot.mixtures.*` / `mot.layers.*`, matching the base checkpoint).
# Forward / freeze / optimizer code still reaches them by attribute, and device/dtype moves
# still apply via `mot`. (optimizer + freeze logic use `model.dit`.)
object.__setattr__(self, "video_expert", video_expert)
object.__setattr__(self, "action_expert", action_expert)
object.__setattr__(self, "dit", self.mot)
# Frozen Wan2.2 components: bypass `nn.Module.__setattr__` so they are NOT
# registered as submodules. They are therefore excluded from `state_dict()`
@@ -133,9 +133,7 @@ def make_fastwam_pre_post_processors(
# resize visual inputs to match model expected input size, if necessary
visual_shapes = [
feature.shape
for feature in config.input_features.values()
if feature.type == FeatureType.VISUAL
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
]
resize_steps = []
if visual_shapes:
@@ -42,6 +42,7 @@ WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
WAN_T5_TOKENIZER = "google/umt5-xxl"
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
class WanTextEncoder(torch.nn.Module):
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
@@ -76,7 +77,11 @@ class WanTokenizer:
self.seq_len = int(seq_len)
def __call__(
self, sequence: str | Sequence[str], return_mask: bool = False, add_special_tokens: bool = True, **_: Any
self,
sequence: str | Sequence[str],
return_mask: bool = False,
add_special_tokens: bool = True,
**_: Any,
):
if isinstance(sequence, str):
sequence = [sequence]
@@ -99,9 +104,7 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
vae = AutoencoderKLWan.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype
)
vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype)
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
@@ -425,7 +425,7 @@ class WanVideoDiT(WanModel):
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
separated_timestep: bool = False,
require_vae_embedding: bool = False,
require_clip_embedding: bool = False,
fuse_vae_embedding_in_latents: bool = True,
@@ -489,7 +489,7 @@ class WanVideoDiT(WanModel):
self.hidden_dim = hidden_dim
self.attn_head_dim = attn_head_dim
self.seperated_timestep = seperated_timestep
self.separated_timestep = separated_timestep
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.video_attention_mask_mode = str(video_attention_mask_mode)
self.action_conditioned = action_conditioned
@@ -647,7 +647,7 @@ class WanVideoDiT(WanModel):
)
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
if not (self.seperated_timestep and fuse_vae_embedding_in_latents):
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
raise NotImplementedError(
"FastWAM currently requires separated timesteps with fused VAE latents."
)
+28 -4
View File
@@ -224,7 +224,15 @@ class CoreWithFrozenComponents(FakeFastWAMCore):
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
def build_core(self, config):
core = CoreWithFrozenComponents()
@@ -256,7 +264,15 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
@@ -278,7 +294,15 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
@@ -332,7 +356,7 @@ def test_vae_adapter_empty_build_encode_decode_shapes():
"dim_mult": [1, 2, 4, 4],
"num_res_blocks": 2,
"attn_scales": [],
"temperal_downsample": [False, True, True],
"temporal_downsample": [False, True, True],
"dropout": 0.0,
"is_residual": True,
"in_channels": 12,