mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
re-parenting of some layers to enable proper zero-3 FSDP
This commit is contained in:
@@ -68,7 +68,7 @@ def default_video_dit_config(action_dim: int) -> dict[str, Any]:
|
||||
"attn_head_dim": 128,
|
||||
"num_layers": 30,
|
||||
"eps": 1.0e-6,
|
||||
"seperated_timestep": True,
|
||||
"separated_timestep": True,
|
||||
"use_gradient_checkpointing": False,
|
||||
"video_attention_mask_mode": "first_frame_causal",
|
||||
"action_conditioned": False,
|
||||
@@ -296,8 +296,7 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
image_keys = sorted(
|
||||
key
|
||||
for key, feature in dataset_features.items()
|
||||
if key.startswith("observation.images.")
|
||||
and feature.get("dtype") in ("video", "image")
|
||||
if key.startswith("observation.images.") and feature.get("dtype") in ("video", "image")
|
||||
)
|
||||
if not image_keys:
|
||||
return
|
||||
|
||||
@@ -76,6 +76,13 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
# Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad,
|
||||
# so its params drop out of the optimizer (and DDP skips them).
|
||||
self.model.video_expert.requires_grad_(False)
|
||||
# The transformer blocks are re-parented onto the MoTLayers (single FSDP owner), so
|
||||
# `video_expert.requires_grad_` no longer reaches them — freeze them via the layers.
|
||||
mot = getattr(self.model, "mot", None)
|
||||
if mot is not None and getattr(mot, "layers", None) is not None:
|
||||
for layer in mot.layers:
|
||||
if "video" in layer.blocks:
|
||||
layer.blocks["video"].requires_grad_(False)
|
||||
self.reset()
|
||||
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
|
||||
# counts only eval-rollout resets (one per episode), not this __init__ one.
|
||||
@@ -354,7 +361,9 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
|
||||
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
|
||||
write_video(path, frames, fps=30)
|
||||
logging.info("FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path)
|
||||
logging.info(
|
||||
"FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path
|
||||
)
|
||||
|
||||
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
|
||||
"""Build the FastWAM core for training / inference.
|
||||
@@ -485,9 +494,7 @@ def batch_device(batch: dict[str, Any]) -> torch.device:
|
||||
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
|
||||
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
|
||||
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
|
||||
image_keys = sorted(
|
||||
k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad")
|
||||
)
|
||||
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
|
||||
if not image_keys:
|
||||
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
|
||||
images = [batch[key] for key in image_keys]
|
||||
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
@@ -258,56 +259,39 @@ class ActionDiT(nn.Module):
|
||||
return self.post_dit(x, pre_state)
|
||||
|
||||
|
||||
class MoT(nn.Module):
|
||||
class MoTLayer(nn.Module):
|
||||
"""A single MoT layer: owns one transformer block per expert and runs the cross-expert
|
||||
mixed-attention step for that layer.
|
||||
|
||||
This exists as a module — rather than the per-layer work being inlined in ``MoT``'s loop —
|
||||
so FSDP can wrap each layer as its own unit. FSDP all-gathers a wrapped module's sharded
|
||||
parameters via a hook on that module's ``forward``/``__call__``. ``MoT`` drives block
|
||||
submodules directly (the joint mixed attention concatenates Q/K/V across experts, so no
|
||||
single block's ``forward`` is ever called), so ``MoTLayer.forward`` is the only call
|
||||
boundary FSDP can hook. All three per-layer paths therefore dispatch through
|
||||
``forward(mode=...)`` so each enters via ``__call__``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mixtures: dict[str, nn.Module],
|
||||
mot_checkpoint_mixed_attn: bool = True,
|
||||
blocks: dict[str, nn.Module],
|
||||
experts: dict[str, nn.Module],
|
||||
num_heads: int,
|
||||
attn_head_dim: int,
|
||||
fp32_attention: bool,
|
||||
mot_checkpoint_mixed_attn: bool,
|
||||
):
|
||||
super().__init__()
|
||||
if not mixtures:
|
||||
raise ValueError("`mixtures` cannot be empty.")
|
||||
if "video" not in mixtures or "action" not in mixtures:
|
||||
raise ValueError("`mixtures` must include both 'video' and 'action' experts.")
|
||||
|
||||
self.mixtures = nn.ModuleDict(mixtures)
|
||||
self.expert_order = list(self.mixtures.keys())
|
||||
self.mot_checkpoint_mixed_attn = mot_checkpoint_mixed_attn
|
||||
if mot_checkpoint_mixed_attn:
|
||||
logger.info(
|
||||
"Using gradient checkpointing for mixture attention. This will save memory but use more computation."
|
||||
)
|
||||
|
||||
first_expert = self.mixtures[self.expert_order[0]]
|
||||
self.num_layers = len(first_expert.blocks)
|
||||
self.num_heads = first_expert.num_heads
|
||||
self.attn_head_dim = first_expert.attn_head_dim
|
||||
self.fp32_attention = bool(getattr(first_expert, "fp32_attention", True))
|
||||
|
||||
for name in self.expert_order[1:]:
|
||||
expert = self.mixtures[name]
|
||||
if len(expert.blocks) != self.num_layers:
|
||||
raise ValueError(
|
||||
f"All experts must have same number of layers; got {self.num_layers} and {len(expert.blocks)}"
|
||||
)
|
||||
if expert.num_heads != self.num_heads:
|
||||
raise ValueError(
|
||||
f"All experts must have same num_heads; got {self.num_heads} and {expert.num_heads}"
|
||||
)
|
||||
if expert.attn_head_dim != self.attn_head_dim:
|
||||
raise ValueError(
|
||||
"All experts must have same attn_head_dim; "
|
||||
f"got {self.attn_head_dim} and {expert.attn_head_dim}"
|
||||
)
|
||||
if bool(getattr(expert, "fp32_attention", True)) != self.fp32_attention:
|
||||
raise ValueError("All experts must use the same `fp32_attention` setting.")
|
||||
|
||||
logger.info(f"Initialized MoT with experts: {self.expert_order}, num_layers={self.num_layers}")
|
||||
for name in self.expert_order:
|
||||
expert = self.mixtures[name]
|
||||
logger.info(
|
||||
f" Expert '{name}': num_params={sum(p.numel() for p in expert.parameters()) / 1e9:.2f} B"
|
||||
)
|
||||
# Registered owner of this layer's blocks (one per expert) — the FSDP wrap unit.
|
||||
self.blocks = nn.ModuleDict(blocks)
|
||||
self.expert_order = list(blocks.keys())
|
||||
# Unregistered back-references to the experts: used only to read the live
|
||||
# `use_gradient_checkpointing` flag, kept out of parameters()/state_dict().
|
||||
object.__setattr__(self, "_experts", dict(experts))
|
||||
self.num_heads = num_heads
|
||||
self.attn_head_dim = attn_head_dim
|
||||
self.fp32_attention = bool(fp32_attention)
|
||||
self.mot_checkpoint_mixed_attn = bool(mot_checkpoint_mixed_attn)
|
||||
|
||||
@staticmethod
|
||||
def _split_modulation(block, t_mod: torch.Tensor):
|
||||
@@ -394,43 +378,17 @@ class MoT(nn.Module):
|
||||
|
||||
def _build_expert_attention_io(
|
||||
self,
|
||||
expert,
|
||||
block,
|
||||
name: str,
|
||||
x: torch.Tensor,
|
||||
freqs: torch.Tensor | dict[str, torch.Tensor],
|
||||
t_mod: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
bool,
|
||||
]:
|
||||
"""Build per-expert attention tensors and post-block states.
|
||||
):
|
||||
"""Build this expert's attention tensors and post-block states for the layer.
|
||||
|
||||
Args:
|
||||
expert: Expert module that owns this `block`; only used to read
|
||||
`use_gradient_checkpointing`.
|
||||
block: Transformer block for current layer (`expert.blocks[layer_idx]`).
|
||||
x: Current expert tokens, shape [B, S, D].
|
||||
freqs: RoPE frequencies aligned with token sequence, shape [S, 1, rope_dim].
|
||||
t_mod: Time modulation tensor for this expert/layer.
|
||||
|
||||
Returns:
|
||||
q: Query after q-proj, RMSNorm, and RoPE, shape [B, S, H*Dh].
|
||||
k: Key after k-proj, RMSNorm, and RoPE, shape [B, S, H*Dh].
|
||||
v: Value after v-proj, shape [B, S, H*Dh].
|
||||
residual_x: Original input `x` for residual path in post block.
|
||||
gate_msa: Gating tensor for self-attention residual branch.
|
||||
shift_mlp: Shift tensor for MLP modulation.
|
||||
scale_mlp: Scale tensor for MLP modulation.
|
||||
gate_mlp: Gating tensor for MLP residual branch.
|
||||
use_gradient_checkpointing: Whether this expert enables checkpointing.
|
||||
Returns (q, k, v, residual_x, gate_msa, shift_mlp, scale_mlp, gate_mlp, use_gc).
|
||||
"""
|
||||
block = self.blocks[name]
|
||||
expert = self._experts[name]
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self._split_modulation(block, t_mod)
|
||||
attn_input = modulate(_apply_block_norm(block, "norm1", x), shift_msa, scale_msa)
|
||||
|
||||
@@ -461,25 +419,6 @@ class MoT(nn.Module):
|
||||
mixed_slice: torch.Tensor,
|
||||
context_payload: dict | None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply post-attention computations, with optional checkpointing.
|
||||
|
||||
Args:
|
||||
block: Transformer block for current layer.
|
||||
residual_x: Residual input tokens before attention update, shape [B, S, D].
|
||||
gate_msa: Gating tensor used after mixed self-attention.
|
||||
shift_mlp: Shift tensor for MLP input modulation.
|
||||
scale_mlp: Scale tensor for MLP input modulation.
|
||||
gate_mlp: Gating tensor used after MLP.
|
||||
use_gradient_checkpointing: If True and training, checkpoint this post block.
|
||||
mixed_slice: Mixed-attention output for this expert, shape [B, S, H*Dh].
|
||||
context_payload: Optional dict for cross-attention.
|
||||
- `context`: encoder states [B, L, D]
|
||||
- `mask`: attention mask [B, S, L] or [B, 1, S, L]
|
||||
|
||||
Returns:
|
||||
Updated expert tokens after self-attn residual, optional cross-attn, and MLP.
|
||||
"""
|
||||
|
||||
def _post_fn(
|
||||
_mixed_slice: torch.Tensor,
|
||||
_x: torch.Tensor,
|
||||
@@ -521,6 +460,256 @@ class MoT(nn.Module):
|
||||
gate_mlp,
|
||||
)
|
||||
|
||||
def forward(self, mode: str, **kwargs):
|
||||
if mode == "joint":
|
||||
return self._forward_joint(**kwargs)
|
||||
if mode == "video_prefill":
|
||||
return self._forward_video_prefill(**kwargs)
|
||||
if mode == "action_cached":
|
||||
return self._forward_action_cached(**kwargs)
|
||||
raise ValueError(f"Unknown MoTLayer forward mode: {mode!r}")
|
||||
|
||||
def _forward_joint(
|
||||
self,
|
||||
tokens_all: dict[str, torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
freqs_all: dict[str, torch.Tensor],
|
||||
context_all: dict[str, dict | None],
|
||||
t_mod_all: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
q_chunks = []
|
||||
k_chunks = []
|
||||
v_chunks = []
|
||||
cached = {}
|
||||
seq_lens = []
|
||||
|
||||
for name in self.expert_order:
|
||||
(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
residual_x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
use_gradient_checkpointing,
|
||||
) = self._build_expert_attention_io(name, tokens_all[name], freqs_all[name], t_mod_all[name])
|
||||
|
||||
q_chunks.append(q)
|
||||
k_chunks.append(k)
|
||||
v_chunks.append(v)
|
||||
seq_lens.append(tokens_all[name].shape[1])
|
||||
cached[name] = {
|
||||
"residual_x": residual_x,
|
||||
"gate_msa": gate_msa,
|
||||
"shift_mlp": shift_mlp,
|
||||
"scale_mlp": scale_mlp,
|
||||
"gate_mlp": gate_mlp,
|
||||
"use_gradient_checkpointing": use_gradient_checkpointing,
|
||||
}
|
||||
|
||||
q_cat = torch.cat(q_chunks, dim=1)
|
||||
k_cat = torch.cat(k_chunks, dim=1)
|
||||
v_cat = torch.cat(v_chunks, dim=1)
|
||||
|
||||
total_seq = q_cat.shape[1]
|
||||
if attention_mask.shape[0] != total_seq:
|
||||
raise ValueError(
|
||||
f"Attention mask seq length mismatch: mask={attention_mask.shape[0]} vs tokens={total_seq}"
|
||||
)
|
||||
|
||||
mixed = self._mixed_attention(q_cat=q_cat, k_cat=k_cat, v_cat=v_cat, attention_mask=attention_mask)
|
||||
|
||||
out = {}
|
||||
start = 0
|
||||
for name, seq_len in zip(self.expert_order, seq_lens, strict=True):
|
||||
end = start + seq_len
|
||||
mixed_slice = mixed[:, start:end, :]
|
||||
cached_expert = cached[name]
|
||||
out[name] = self._apply_post_with_optional_checkpoint(
|
||||
block=self.blocks[name],
|
||||
residual_x=cached_expert["residual_x"],
|
||||
gate_msa=cached_expert["gate_msa"],
|
||||
shift_mlp=cached_expert["shift_mlp"],
|
||||
scale_mlp=cached_expert["scale_mlp"],
|
||||
gate_mlp=cached_expert["gate_mlp"],
|
||||
use_gradient_checkpointing=cached_expert["use_gradient_checkpointing"],
|
||||
mixed_slice=mixed_slice,
|
||||
context_payload=context_all.get(name),
|
||||
)
|
||||
start = end
|
||||
return out
|
||||
|
||||
def _forward_video_prefill(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
t_mod: torch.Tensor,
|
||||
context_payload: dict | None,
|
||||
video_attention_mask: torch.Tensor,
|
||||
):
|
||||
(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
residual_x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
use_gradient_checkpointing,
|
||||
) = self._build_expert_attention_io("video", x, freqs, t_mod)
|
||||
# Video prefill uses only video self-attention mask.
|
||||
mixed = self._mixed_attention(q_cat=q, k_cat=k, v_cat=v, attention_mask=video_attention_mask)
|
||||
x_out = self._apply_post_with_optional_checkpoint(
|
||||
block=self.blocks["video"],
|
||||
residual_x=residual_x,
|
||||
gate_msa=gate_msa,
|
||||
shift_mlp=shift_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
gate_mlp=gate_mlp,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
mixed_slice=mixed,
|
||||
context_payload=context_payload,
|
||||
)
|
||||
return x_out, k, v
|
||||
|
||||
def _forward_action_cached(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
t_mod: torch.Tensor,
|
||||
context_payload: dict | None,
|
||||
k_video: torch.Tensor,
|
||||
v_video: torch.Tensor,
|
||||
action_attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(
|
||||
q_action,
|
||||
k_action,
|
||||
v_action,
|
||||
residual_x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
use_gradient_checkpointing,
|
||||
) = self._build_expert_attention_io("action", x, freqs, t_mod)
|
||||
# Mixed attention: action queries attend to cached video K/V plus current action K/V.
|
||||
k_cat = torch.cat([k_video, k_action], dim=1)
|
||||
v_cat = torch.cat([v_video, v_action], dim=1)
|
||||
mixed = self._mixed_attention(
|
||||
q_cat=q_action, k_cat=k_cat, v_cat=v_cat, attention_mask=action_attention_mask
|
||||
)
|
||||
return self._apply_post_with_optional_checkpoint(
|
||||
block=self.blocks["action"],
|
||||
residual_x=residual_x,
|
||||
gate_msa=gate_msa,
|
||||
shift_mlp=shift_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
gate_mlp=gate_mlp,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
mixed_slice=mixed,
|
||||
context_payload=context_payload,
|
||||
)
|
||||
|
||||
|
||||
class MoT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
mixtures: dict[str, nn.Module],
|
||||
mot_checkpoint_mixed_attn: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if not mixtures:
|
||||
raise ValueError("`mixtures` cannot be empty.")
|
||||
if "video" not in mixtures or "action" not in mixtures:
|
||||
raise ValueError("`mixtures` must include both 'video' and 'action' experts.")
|
||||
|
||||
self.mixtures = nn.ModuleDict(mixtures)
|
||||
self.expert_order = list(self.mixtures.keys())
|
||||
self.mot_checkpoint_mixed_attn = mot_checkpoint_mixed_attn
|
||||
if mot_checkpoint_mixed_attn:
|
||||
logger.info(
|
||||
"Using gradient checkpointing for mixture attention. This will save memory but use more computation."
|
||||
)
|
||||
|
||||
first_expert = self.mixtures[self.expert_order[0]]
|
||||
self.num_layers = len(first_expert.blocks)
|
||||
self.num_heads = first_expert.num_heads
|
||||
self.attn_head_dim = first_expert.attn_head_dim
|
||||
self.fp32_attention = bool(getattr(first_expert, "fp32_attention", True))
|
||||
|
||||
for name in self.expert_order[1:]:
|
||||
expert = self.mixtures[name]
|
||||
if len(expert.blocks) != self.num_layers:
|
||||
raise ValueError(
|
||||
f"All experts must have same number of layers; got {self.num_layers} and {len(expert.blocks)}"
|
||||
)
|
||||
if expert.num_heads != self.num_heads:
|
||||
raise ValueError(
|
||||
f"All experts must have same num_heads; got {self.num_heads} and {expert.num_heads}"
|
||||
)
|
||||
if expert.attn_head_dim != self.attn_head_dim:
|
||||
raise ValueError(
|
||||
"All experts must have same attn_head_dim; "
|
||||
f"got {self.attn_head_dim} and {expert.attn_head_dim}"
|
||||
)
|
||||
if bool(getattr(expert, "fp32_attention", True)) != self.fp32_attention:
|
||||
raise ValueError("All experts must use the same `fp32_attention` setting.")
|
||||
|
||||
logger.info(f"Initialized MoT with experts: {self.expert_order}, num_layers={self.num_layers}")
|
||||
for name in self.expert_order:
|
||||
expert = self.mixtures[name]
|
||||
logger.info(
|
||||
f" Expert '{name}': num_params={sum(p.numel() for p in expert.parameters()) / 1e9:.2f} B"
|
||||
)
|
||||
|
||||
# One MoTLayer per layer, each owning that layer's block from every expert. This is the
|
||||
# FSDP wrap unit: only MoTLayer.forward is ever called (MoT drives block submodules
|
||||
# directly for the cross-expert mixed attention), so it is the boundary at which FSDP can
|
||||
# all-gather a layer's params. The blocks are RE-PARENTED into the layers — removed from
|
||||
# each expert's module registry — so they have a single owner; leaving them registered
|
||||
# under both the expert and the layer would make FSDP try to manage the same params twice.
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MoTLayer(
|
||||
blocks={name: self.mixtures[name].blocks[layer_idx] for name in self.expert_order},
|
||||
experts={name: self.mixtures[name] for name in self.expert_order},
|
||||
num_heads=self.num_heads,
|
||||
attn_head_dim=self.attn_head_dim,
|
||||
fp32_attention=self.fp32_attention,
|
||||
mot_checkpoint_mixed_attn=self.mot_checkpoint_mixed_attn,
|
||||
)
|
||||
for layer_idx in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
for name in self.expert_order:
|
||||
expert = self.mixtures[name]
|
||||
kept_blocks = list(expert.blocks)
|
||||
del expert._modules["blocks"]
|
||||
# Keep an UNREGISTERED reference so the (unused) standalone `expert.forward` and any
|
||||
# `len(expert.blocks)` still work, without re-adding the params to the expert's
|
||||
# parameters()/state_dict() (which would double-register them with the MoTLayer owner).
|
||||
object.__setattr__(expert, "blocks", kept_blocks)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
# Backward-compat for checkpoints saved before the MoTLayer refactor. Then the per-layer
|
||||
# blocks were keyed under the experts (`{prefix}mixtures.<name>.blocks.<i>.<rest>`, e.g.
|
||||
# the released `ZibinDong/fastwam_libero_uncond_2cam224`); now they are owned by the layers
|
||||
# (`{prefix}layers.<i>.blocks.<name>.<rest>`). Remap legacy keys in place so the recursion
|
||||
# into `self.layers` finds them and the (now block-less) `self.mixtures` does not flag them.
|
||||
legacy = re.compile(re.escape(prefix) + r"mixtures\.([^.]+)\.blocks\.(\d+)\.(.+)$")
|
||||
moved = {}
|
||||
for key in list(state_dict.keys()):
|
||||
m = legacy.match(key)
|
||||
if m is not None:
|
||||
name, layer_idx, rest = m.group(1), m.group(2), m.group(3)
|
||||
moved[f"{prefix}layers.{layer_idx}.blocks.{name}.{rest}"] = state_dict.pop(key)
|
||||
state_dict.update(moved)
|
||||
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
def prefill_video_cache(
|
||||
self,
|
||||
video_tokens: torch.Tensor,
|
||||
@@ -531,20 +720,7 @@ class MoT(nn.Module):
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
"""Prefill video branch once and cache per-layer K/V for action denoising.
|
||||
|
||||
Args:
|
||||
video_tokens: Video tokens before layer 0, shape [B, Sv, D].
|
||||
video_freqs: Video RoPE frequencies, shape [Sv, 1, rope_dim].
|
||||
video_t_mod: Video time modulation tensor.
|
||||
video_context_payload: Optional dict for video cross-attention.
|
||||
- `context`: encoder states [B, L, D]
|
||||
- `mask`: attention mask [B, Sv, L] or [B, 1, Sv, L]
|
||||
video_attention_mask: Video self-attention mask, shape [Sv, Sv].
|
||||
|
||||
Returns:
|
||||
Layer-wise cache list with length `num_layers`.
|
||||
Each entry contains:
|
||||
- `k`: video key tensor [B, Sv, H*Dh]
|
||||
- `v`: video value tensor [B, Sv, H*Dh]
|
||||
Returns a list of length ``num_layers``, each entry ``{"k": ..., "v": ...}``.
|
||||
"""
|
||||
if "video" not in self.mixtures:
|
||||
raise ValueError("MoT requires `video` expert for `prefill_video_cache`.")
|
||||
@@ -562,47 +738,16 @@ class MoT(nn.Module):
|
||||
f"mask={video_attention_mask.shape[0]} vs tokens={video_tokens.shape[1]}"
|
||||
)
|
||||
|
||||
expert = self.mixtures["video"]
|
||||
x = video_tokens
|
||||
kv_cache: list[dict[str, torch.Tensor]] = []
|
||||
for layer_idx in range(self.num_layers):
|
||||
block = expert.blocks[layer_idx]
|
||||
# Build video Q/K/V from current layer input tokens.
|
||||
(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
residual_x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
use_gradient_checkpointing,
|
||||
) = self._build_expert_attention_io(
|
||||
expert=expert,
|
||||
block=block,
|
||||
for layer in self.layers:
|
||||
x, k, v = layer(
|
||||
mode="video_prefill",
|
||||
x=x,
|
||||
freqs=video_freqs,
|
||||
t_mod=video_t_mod,
|
||||
)
|
||||
# Video prefill uses only video self-attention mask.
|
||||
mixed = self._mixed_attention(
|
||||
q_cat=q,
|
||||
k_cat=k,
|
||||
v_cat=v,
|
||||
attention_mask=video_attention_mask,
|
||||
)
|
||||
# Update video tokens for the next layer and persist current layer K/V.
|
||||
x = self._apply_post_with_optional_checkpoint(
|
||||
block=block,
|
||||
residual_x=residual_x,
|
||||
gate_msa=gate_msa,
|
||||
shift_mlp=shift_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
gate_mlp=gate_mlp,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
mixed_slice=mixed,
|
||||
context_payload=video_context_payload,
|
||||
video_attention_mask=video_attention_mask,
|
||||
)
|
||||
kv_cache.append({"k": k, "v": v})
|
||||
return kv_cache
|
||||
@@ -617,22 +762,7 @@ class MoT(nn.Module):
|
||||
attention_mask: torch.Tensor,
|
||||
video_seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Run action branch with cached video K/V instead of recomputing video tokens.
|
||||
|
||||
Args:
|
||||
action_tokens: Action tokens before layer 0, shape [B, Sa, D].
|
||||
action_freqs: Action RoPE frequencies, shape [Sa, 1, rope_dim].
|
||||
action_t_mod: Action time modulation tensor.
|
||||
action_context_payload: Optional dict for action cross-attention.
|
||||
- `context`: encoder states [B, L, D]
|
||||
- `mask`: attention mask [B, Sa, L] or [B, 1, Sa, L]
|
||||
video_kv_cache: Layer-wise cached video K/V from `prefill_video_cache`.
|
||||
attention_mask: Joint [video+action] mask, shape [Sv+Sa, Sv+Sa].
|
||||
video_seq_len: Video token count `Sv` in the joint sequence prefix.
|
||||
|
||||
Returns:
|
||||
Updated action tokens after all layers, shape [B, Sa, D].
|
||||
"""
|
||||
"""Run action branch with cached video K/V instead of recomputing video tokens."""
|
||||
if "action" not in self.mixtures:
|
||||
raise ValueError("MoT requires `action` expert for `forward_action_with_video_cache`.")
|
||||
if len(video_kv_cache) != self.num_layers:
|
||||
@@ -654,56 +784,24 @@ class MoT(nn.Module):
|
||||
# Use the action query rows from the joint [video+action] mask.
|
||||
action_attention_mask = attention_mask[video_seq_len:total_seq_len, :total_seq_len]
|
||||
|
||||
expert = self.mixtures["action"]
|
||||
x = action_tokens
|
||||
for layer_idx in range(self.num_layers):
|
||||
block = expert.blocks[layer_idx]
|
||||
# Action query/key/value are still step-dependent and must be recomputed each step.
|
||||
(
|
||||
q_action,
|
||||
k_action,
|
||||
v_action,
|
||||
residual_x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
use_gradient_checkpointing,
|
||||
) = self._build_expert_attention_io(
|
||||
expert=expert,
|
||||
block=block,
|
||||
x=x,
|
||||
freqs=action_freqs,
|
||||
t_mod=action_t_mod,
|
||||
)
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
layer_cache = video_kv_cache[layer_idx]
|
||||
if "k" not in layer_cache or "v" not in layer_cache:
|
||||
raise ValueError(f"`video_kv_cache[{layer_idx}]` must contain `k` and `v`.")
|
||||
|
||||
k_video = layer_cache["k"]
|
||||
v_video = layer_cache["v"]
|
||||
if k_video.shape[1] != video_seq_len or v_video.shape[1] != video_seq_len:
|
||||
raise ValueError(f"`video_kv_cache[{layer_idx}]` seq len mismatch, expected {video_seq_len}.")
|
||||
|
||||
# Mixed attention: action queries attend to cached video K/V plus current action K/V.
|
||||
k_cat = torch.cat([k_video, k_action], dim=1)
|
||||
v_cat = torch.cat([v_video, v_action], dim=1)
|
||||
mixed = self._mixed_attention(
|
||||
q_cat=q_action,
|
||||
k_cat=k_cat,
|
||||
v_cat=v_cat,
|
||||
attention_mask=action_attention_mask,
|
||||
)
|
||||
x = self._apply_post_with_optional_checkpoint(
|
||||
block=block,
|
||||
residual_x=residual_x,
|
||||
gate_msa=gate_msa,
|
||||
shift_mlp=shift_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
gate_mlp=gate_mlp,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
mixed_slice=mixed,
|
||||
x = layer(
|
||||
mode="action_cached",
|
||||
x=x,
|
||||
freqs=action_freqs,
|
||||
t_mod=action_t_mod,
|
||||
context_payload=action_context_payload,
|
||||
k_video=k_video,
|
||||
v_video=v_video,
|
||||
action_attention_mask=action_attention_mask,
|
||||
)
|
||||
return x
|
||||
|
||||
@@ -730,94 +828,18 @@ class MoT(nn.Module):
|
||||
if attention_mask.shape[0] != attention_mask.shape[1]:
|
||||
raise ValueError(f"`attention_mask` must be square, got shape {tuple(attention_mask.shape)}")
|
||||
|
||||
# Each layer is a MoTLayer module; entering via __call__ lets FSDP all-gather that
|
||||
# layer's params (the whole point of the per-layer split).
|
||||
tokens_all = dict(embeds_all)
|
||||
|
||||
for layer_idx in range(self.num_layers):
|
||||
q_chunks = []
|
||||
k_chunks = []
|
||||
v_chunks = []
|
||||
cached = {}
|
||||
seq_lens = []
|
||||
|
||||
for name in self.expert_order:
|
||||
expert = self.mixtures[name]
|
||||
block = expert.blocks[layer_idx]
|
||||
x = tokens_all[name]
|
||||
freqs = freqs_all[name]
|
||||
t_mod = t_mod_all[name]
|
||||
|
||||
(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
residual_x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
use_gradient_checkpointing,
|
||||
) = self._build_expert_attention_io(
|
||||
expert=expert,
|
||||
block=block,
|
||||
x=x,
|
||||
freqs=freqs,
|
||||
t_mod=t_mod,
|
||||
)
|
||||
|
||||
q_chunks.append(q)
|
||||
k_chunks.append(k)
|
||||
v_chunks.append(v)
|
||||
seq_lens.append(x.shape[1])
|
||||
cached[name] = {
|
||||
"block": block,
|
||||
"residual_x": residual_x,
|
||||
"gate_msa": gate_msa,
|
||||
"shift_mlp": shift_mlp,
|
||||
"scale_mlp": scale_mlp,
|
||||
"gate_mlp": gate_mlp,
|
||||
"use_gradient_checkpointing": use_gradient_checkpointing,
|
||||
}
|
||||
|
||||
# 3. concat all tokens for mixed attention
|
||||
q_cat = torch.cat(q_chunks, dim=1)
|
||||
k_cat = torch.cat(k_chunks, dim=1)
|
||||
v_cat = torch.cat(v_chunks, dim=1)
|
||||
|
||||
total_seq = q_cat.shape[1]
|
||||
if attention_mask.shape[0] != total_seq:
|
||||
raise ValueError(
|
||||
"Attention mask seq length mismatch: "
|
||||
f"mask={attention_mask.shape[0]} vs tokens={total_seq}"
|
||||
)
|
||||
|
||||
mixed = self._mixed_attention(
|
||||
q_cat=q_cat, k_cat=k_cat, v_cat=v_cat, attention_mask=attention_mask
|
||||
for layer in self.layers:
|
||||
tokens_all = layer(
|
||||
mode="joint",
|
||||
tokens_all=tokens_all,
|
||||
attention_mask=attention_mask,
|
||||
freqs_all=freqs_all,
|
||||
context_all=context_all,
|
||||
t_mod_all=t_mod_all,
|
||||
)
|
||||
|
||||
start = 0
|
||||
for name, seq_len in zip(self.expert_order, seq_lens, strict=True):
|
||||
# 4. split mixed attention output and apply post-attention blocks for each expert
|
||||
end = start + seq_len
|
||||
mixed_slice = mixed[:, start:end, :]
|
||||
cached_expert = cached[name]
|
||||
block = cached_expert["block"]
|
||||
context_payload = context_all.get(name)
|
||||
|
||||
updated_tokens = self._apply_post_with_optional_checkpoint(
|
||||
block=block,
|
||||
residual_x=cached_expert["residual_x"],
|
||||
gate_msa=cached_expert["gate_msa"],
|
||||
shift_mlp=cached_expert["shift_mlp"],
|
||||
scale_mlp=cached_expert["scale_mlp"],
|
||||
gate_mlp=cached_expert["gate_mlp"],
|
||||
use_gradient_checkpointing=cached_expert["use_gradient_checkpointing"],
|
||||
mixed_slice=mixed_slice,
|
||||
context_payload=context_payload,
|
||||
)
|
||||
|
||||
tokens_all[name] = updated_tokens
|
||||
start = end
|
||||
|
||||
return tokens_all
|
||||
|
||||
|
||||
@@ -846,11 +868,20 @@ class FastWAM(torch.nn.Module):
|
||||
loss_lambda_action: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.video_expert = video_expert
|
||||
self.action_expert = action_expert
|
||||
self.mot = mot
|
||||
# Keep trainer compatibility: optimizer and freeze logic use `model.dit`.
|
||||
self.dit = self.mot
|
||||
# `video_expert` / `action_expert` are the very same module objects as
|
||||
# `mot.mixtures["video"]` / `["action"]`, and `dit` is an alias of `mot`. Registering
|
||||
# them as submodules too would give every expert tensor three names in `state_dict()`
|
||||
# (`video_expert.*`, `mot.mixtures.video.*`, `dit.mixtures.video.*`) — a 3x-bloated
|
||||
# gathered FSDP checkpoint and a doubled module tree for FSDP to traverse. Hold them as
|
||||
# plain (unregistered) attributes instead — bypassing `nn.Module.__setattr__`, like the
|
||||
# frozen vae/text_encoder below — so `mot` is the single registered owner and each tensor
|
||||
# has one canonical name (`mot.mixtures.*` / `mot.layers.*`, matching the base checkpoint).
|
||||
# Forward / freeze / optimizer code still reaches them by attribute, and device/dtype moves
|
||||
# still apply via `mot`. (optimizer + freeze logic use `model.dit`.)
|
||||
object.__setattr__(self, "video_expert", video_expert)
|
||||
object.__setattr__(self, "action_expert", action_expert)
|
||||
object.__setattr__(self, "dit", self.mot)
|
||||
|
||||
# Frozen Wan2.2 components: bypass `nn.Module.__setattr__` so they are NOT
|
||||
# registered as submodules. They are therefore excluded from `state_dict()`
|
||||
|
||||
@@ -133,9 +133,7 @@ def make_fastwam_pre_post_processors(
|
||||
|
||||
# resize visual inputs to match model expected input size, if necessary
|
||||
visual_shapes = [
|
||||
feature.shape
|
||||
for feature in config.input_features.values()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
resize_steps = []
|
||||
if visual_shapes:
|
||||
|
||||
@@ -42,6 +42,7 @@ WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
|
||||
WAN_T5_TOKENIZER = "google/umt5-xxl"
|
||||
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
|
||||
|
||||
|
||||
class WanTextEncoder(torch.nn.Module):
|
||||
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
|
||||
|
||||
@@ -76,7 +77,11 @@ class WanTokenizer:
|
||||
self.seq_len = int(seq_len)
|
||||
|
||||
def __call__(
|
||||
self, sequence: str | Sequence[str], return_mask: bool = False, add_special_tokens: bool = True, **_: Any
|
||||
self,
|
||||
sequence: str | Sequence[str],
|
||||
return_mask: bool = False,
|
||||
add_special_tokens: bool = True,
|
||||
**_: Any,
|
||||
):
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
@@ -99,9 +104,7 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
|
||||
|
||||
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype
|
||||
)
|
||||
vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype)
|
||||
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
|
||||
|
||||
|
||||
|
||||
@@ -425,7 +425,7 @@ class WanVideoDiT(WanModel):
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
seperated_timestep: bool = False,
|
||||
separated_timestep: bool = False,
|
||||
require_vae_embedding: bool = False,
|
||||
require_clip_embedding: bool = False,
|
||||
fuse_vae_embedding_in_latents: bool = True,
|
||||
@@ -489,7 +489,7 @@ class WanVideoDiT(WanModel):
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.attn_head_dim = attn_head_dim
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.separated_timestep = separated_timestep
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
self.video_attention_mask_mode = str(video_attention_mask_mode)
|
||||
self.action_conditioned = action_conditioned
|
||||
@@ -647,7 +647,7 @@ class WanVideoDiT(WanModel):
|
||||
)
|
||||
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
|
||||
|
||||
if not (self.seperated_timestep and fuse_vae_embedding_in_latents):
|
||||
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
|
||||
raise NotImplementedError(
|
||||
"FastWAM currently requires separated timesteps with fused VAE latents."
|
||||
)
|
||||
|
||||
@@ -224,7 +224,15 @@ class CoreWithFrozenComponents(FakeFastWAMCore):
|
||||
|
||||
|
||||
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
|
||||
def build_core(self, config):
|
||||
core = CoreWithFrozenComponents()
|
||||
@@ -256,7 +264,15 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
|
||||
|
||||
|
||||
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
@@ -278,7 +294,15 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
|
||||
|
||||
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
cfg = FastWAMConfig(
|
||||
action_dim=3,
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
@@ -332,7 +356,7 @@ def test_vae_adapter_empty_build_encode_decode_shapes():
|
||||
"dim_mult": [1, 2, 4, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_scales": [],
|
||||
"temperal_downsample": [False, True, True],
|
||||
"temporal_downsample": [False, True, True],
|
||||
"dropout": 0.0,
|
||||
"is_residual": True,
|
||||
"in_channels": 12,
|
||||
|
||||
Reference in New Issue
Block a user