From 35c3302f4ddb8b2b1681d09ddde7682ddf8e66f5 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Mon, 15 Jun 2026 12:11:27 +0000 Subject: [PATCH] re-parenting of some layers to enable proper zero-3 FSDP --- .../policies/fastwam/configuration_fastwam.py | 5 +- .../policies/fastwam/modeling_fastwam.py | 15 +- .../policies/fastwam/modular_fastwam.py | 617 +++++++++--------- .../policies/fastwam/processor_fastwam.py | 4 +- .../policies/fastwam/wan_components.py | 11 +- src/lerobot/policies/fastwam/wan_video_dit.py | 6 +- tests/policies/fastwam/test_fastwam_policy.py | 32 +- 7 files changed, 376 insertions(+), 314 deletions(-) diff --git a/src/lerobot/policies/fastwam/configuration_fastwam.py b/src/lerobot/policies/fastwam/configuration_fastwam.py index c557b9d4f..e6527e20b 100644 --- a/src/lerobot/policies/fastwam/configuration_fastwam.py +++ b/src/lerobot/policies/fastwam/configuration_fastwam.py @@ -68,7 +68,7 @@ def default_video_dit_config(action_dim: int) -> dict[str, Any]: "attn_head_dim": 128, "num_layers": 30, "eps": 1.0e-6, - "seperated_timestep": True, + "separated_timestep": True, "use_gradient_checkpointing": False, "video_attention_mask_mode": "first_frame_causal", "action_conditioned": False, @@ -296,8 +296,7 @@ class FastWAMConfig(PreTrainedConfig): image_keys = sorted( key for key, feature in dataset_features.items() - if key.startswith("observation.images.") - and feature.get("dtype") in ("video", "image") + if key.startswith("observation.images.") and feature.get("dtype") in ("video", "image") ) if not image_keys: return diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 2dcee64d7..9e7124e2e 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -76,6 +76,13 @@ class FastWAMPolicy(PreTrainedPolicy): # Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad, # so its params drop out of the optimizer (and DDP skips them). self.model.video_expert.requires_grad_(False) + # The transformer blocks are re-parented onto the MoTLayers (single FSDP owner), so + # `video_expert.requires_grad_` no longer reaches them — freeze them via the layers. + mot = getattr(self.model, "mot", None) + if mot is not None and getattr(mot, "layers", None) is not None: + for layer in mot.layers: + if "video" in layer.blocks: + layer.blocks["video"].requires_grad_(False) self.reset() # TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()` # counts only eval-rollout resets (one per episode), not this __init__ one. @@ -354,7 +361,9 @@ class FastWAMPolicy(PreTrainedPolicy): path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4" frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB write_video(path, frames, fps=30) - logging.info("FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path) + logging.info( + "FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path + ) def _build_core_model(self, config: FastWAMConfig) -> FastWAM: """Build the FastWAM core for training / inference. @@ -485,9 +494,7 @@ def batch_device(batch: dict[str, Any]) -> torch.device: def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor: # Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside # each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames. - image_keys = sorted( - k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad") - ) + image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad")) if not image_keys: raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.") images = [batch[key] for key in image_keys] diff --git a/src/lerobot/policies/fastwam/modular_fastwam.py b/src/lerobot/policies/fastwam/modular_fastwam.py index 344830c6b..c220a1a73 100644 --- a/src/lerobot/policies/fastwam/modular_fastwam.py +++ b/src/lerobot/policies/fastwam/modular_fastwam.py @@ -16,6 +16,7 @@ from __future__ import annotations import logging import os +import re from collections.abc import Sequence from typing import Any @@ -258,56 +259,39 @@ class ActionDiT(nn.Module): return self.post_dit(x, pre_state) -class MoT(nn.Module): +class MoTLayer(nn.Module): + """A single MoT layer: owns one transformer block per expert and runs the cross-expert + mixed-attention step for that layer. + + This exists as a module — rather than the per-layer work being inlined in ``MoT``'s loop — + so FSDP can wrap each layer as its own unit. FSDP all-gathers a wrapped module's sharded + parameters via a hook on that module's ``forward``/``__call__``. ``MoT`` drives block + submodules directly (the joint mixed attention concatenates Q/K/V across experts, so no + single block's ``forward`` is ever called), so ``MoTLayer.forward`` is the only call + boundary FSDP can hook. All three per-layer paths therefore dispatch through + ``forward(mode=...)`` so each enters via ``__call__``. + """ + def __init__( self, - mixtures: dict[str, nn.Module], - mot_checkpoint_mixed_attn: bool = True, + blocks: dict[str, nn.Module], + experts: dict[str, nn.Module], + num_heads: int, + attn_head_dim: int, + fp32_attention: bool, + mot_checkpoint_mixed_attn: bool, ): super().__init__() - if not mixtures: - raise ValueError("`mixtures` cannot be empty.") - if "video" not in mixtures or "action" not in mixtures: - raise ValueError("`mixtures` must include both 'video' and 'action' experts.") - - self.mixtures = nn.ModuleDict(mixtures) - self.expert_order = list(self.mixtures.keys()) - self.mot_checkpoint_mixed_attn = mot_checkpoint_mixed_attn - if mot_checkpoint_mixed_attn: - logger.info( - "Using gradient checkpointing for mixture attention. This will save memory but use more computation." - ) - - first_expert = self.mixtures[self.expert_order[0]] - self.num_layers = len(first_expert.blocks) - self.num_heads = first_expert.num_heads - self.attn_head_dim = first_expert.attn_head_dim - self.fp32_attention = bool(getattr(first_expert, "fp32_attention", True)) - - for name in self.expert_order[1:]: - expert = self.mixtures[name] - if len(expert.blocks) != self.num_layers: - raise ValueError( - f"All experts must have same number of layers; got {self.num_layers} and {len(expert.blocks)}" - ) - if expert.num_heads != self.num_heads: - raise ValueError( - f"All experts must have same num_heads; got {self.num_heads} and {expert.num_heads}" - ) - if expert.attn_head_dim != self.attn_head_dim: - raise ValueError( - "All experts must have same attn_head_dim; " - f"got {self.attn_head_dim} and {expert.attn_head_dim}" - ) - if bool(getattr(expert, "fp32_attention", True)) != self.fp32_attention: - raise ValueError("All experts must use the same `fp32_attention` setting.") - - logger.info(f"Initialized MoT with experts: {self.expert_order}, num_layers={self.num_layers}") - for name in self.expert_order: - expert = self.mixtures[name] - logger.info( - f" Expert '{name}': num_params={sum(p.numel() for p in expert.parameters()) / 1e9:.2f} B" - ) + # Registered owner of this layer's blocks (one per expert) — the FSDP wrap unit. + self.blocks = nn.ModuleDict(blocks) + self.expert_order = list(blocks.keys()) + # Unregistered back-references to the experts: used only to read the live + # `use_gradient_checkpointing` flag, kept out of parameters()/state_dict(). + object.__setattr__(self, "_experts", dict(experts)) + self.num_heads = num_heads + self.attn_head_dim = attn_head_dim + self.fp32_attention = bool(fp32_attention) + self.mot_checkpoint_mixed_attn = bool(mot_checkpoint_mixed_attn) @staticmethod def _split_modulation(block, t_mod: torch.Tensor): @@ -394,43 +378,17 @@ class MoT(nn.Module): def _build_expert_attention_io( self, - expert, - block, + name: str, x: torch.Tensor, freqs: torch.Tensor | dict[str, torch.Tensor], t_mod: torch.Tensor, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - bool, - ]: - """Build per-expert attention tensors and post-block states. + ): + """Build this expert's attention tensors and post-block states for the layer. - Args: - expert: Expert module that owns this `block`; only used to read - `use_gradient_checkpointing`. - block: Transformer block for current layer (`expert.blocks[layer_idx]`). - x: Current expert tokens, shape [B, S, D]. - freqs: RoPE frequencies aligned with token sequence, shape [S, 1, rope_dim]. - t_mod: Time modulation tensor for this expert/layer. - - Returns: - q: Query after q-proj, RMSNorm, and RoPE, shape [B, S, H*Dh]. - k: Key after k-proj, RMSNorm, and RoPE, shape [B, S, H*Dh]. - v: Value after v-proj, shape [B, S, H*Dh]. - residual_x: Original input `x` for residual path in post block. - gate_msa: Gating tensor for self-attention residual branch. - shift_mlp: Shift tensor for MLP modulation. - scale_mlp: Scale tensor for MLP modulation. - gate_mlp: Gating tensor for MLP residual branch. - use_gradient_checkpointing: Whether this expert enables checkpointing. + Returns (q, k, v, residual_x, gate_msa, shift_mlp, scale_mlp, gate_mlp, use_gc). """ + block = self.blocks[name] + expert = self._experts[name] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self._split_modulation(block, t_mod) attn_input = modulate(_apply_block_norm(block, "norm1", x), shift_msa, scale_msa) @@ -461,25 +419,6 @@ class MoT(nn.Module): mixed_slice: torch.Tensor, context_payload: dict | None, ) -> torch.Tensor: - """Apply post-attention computations, with optional checkpointing. - - Args: - block: Transformer block for current layer. - residual_x: Residual input tokens before attention update, shape [B, S, D]. - gate_msa: Gating tensor used after mixed self-attention. - shift_mlp: Shift tensor for MLP input modulation. - scale_mlp: Scale tensor for MLP input modulation. - gate_mlp: Gating tensor used after MLP. - use_gradient_checkpointing: If True and training, checkpoint this post block. - mixed_slice: Mixed-attention output for this expert, shape [B, S, H*Dh]. - context_payload: Optional dict for cross-attention. - - `context`: encoder states [B, L, D] - - `mask`: attention mask [B, S, L] or [B, 1, S, L] - - Returns: - Updated expert tokens after self-attn residual, optional cross-attn, and MLP. - """ - def _post_fn( _mixed_slice: torch.Tensor, _x: torch.Tensor, @@ -521,6 +460,256 @@ class MoT(nn.Module): gate_mlp, ) + def forward(self, mode: str, **kwargs): + if mode == "joint": + return self._forward_joint(**kwargs) + if mode == "video_prefill": + return self._forward_video_prefill(**kwargs) + if mode == "action_cached": + return self._forward_action_cached(**kwargs) + raise ValueError(f"Unknown MoTLayer forward mode: {mode!r}") + + def _forward_joint( + self, + tokens_all: dict[str, torch.Tensor], + attention_mask: torch.Tensor, + freqs_all: dict[str, torch.Tensor], + context_all: dict[str, dict | None], + t_mod_all: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + q_chunks = [] + k_chunks = [] + v_chunks = [] + cached = {} + seq_lens = [] + + for name in self.expert_order: + ( + q, + k, + v, + residual_x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + use_gradient_checkpointing, + ) = self._build_expert_attention_io(name, tokens_all[name], freqs_all[name], t_mod_all[name]) + + q_chunks.append(q) + k_chunks.append(k) + v_chunks.append(v) + seq_lens.append(tokens_all[name].shape[1]) + cached[name] = { + "residual_x": residual_x, + "gate_msa": gate_msa, + "shift_mlp": shift_mlp, + "scale_mlp": scale_mlp, + "gate_mlp": gate_mlp, + "use_gradient_checkpointing": use_gradient_checkpointing, + } + + q_cat = torch.cat(q_chunks, dim=1) + k_cat = torch.cat(k_chunks, dim=1) + v_cat = torch.cat(v_chunks, dim=1) + + total_seq = q_cat.shape[1] + if attention_mask.shape[0] != total_seq: + raise ValueError( + f"Attention mask seq length mismatch: mask={attention_mask.shape[0]} vs tokens={total_seq}" + ) + + mixed = self._mixed_attention(q_cat=q_cat, k_cat=k_cat, v_cat=v_cat, attention_mask=attention_mask) + + out = {} + start = 0 + for name, seq_len in zip(self.expert_order, seq_lens, strict=True): + end = start + seq_len + mixed_slice = mixed[:, start:end, :] + cached_expert = cached[name] + out[name] = self._apply_post_with_optional_checkpoint( + block=self.blocks[name], + residual_x=cached_expert["residual_x"], + gate_msa=cached_expert["gate_msa"], + shift_mlp=cached_expert["shift_mlp"], + scale_mlp=cached_expert["scale_mlp"], + gate_mlp=cached_expert["gate_mlp"], + use_gradient_checkpointing=cached_expert["use_gradient_checkpointing"], + mixed_slice=mixed_slice, + context_payload=context_all.get(name), + ) + start = end + return out + + def _forward_video_prefill( + self, + x: torch.Tensor, + freqs: torch.Tensor, + t_mod: torch.Tensor, + context_payload: dict | None, + video_attention_mask: torch.Tensor, + ): + ( + q, + k, + v, + residual_x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + use_gradient_checkpointing, + ) = self._build_expert_attention_io("video", x, freqs, t_mod) + # Video prefill uses only video self-attention mask. + mixed = self._mixed_attention(q_cat=q, k_cat=k, v_cat=v, attention_mask=video_attention_mask) + x_out = self._apply_post_with_optional_checkpoint( + block=self.blocks["video"], + residual_x=residual_x, + gate_msa=gate_msa, + shift_mlp=shift_mlp, + scale_mlp=scale_mlp, + gate_mlp=gate_mlp, + use_gradient_checkpointing=use_gradient_checkpointing, + mixed_slice=mixed, + context_payload=context_payload, + ) + return x_out, k, v + + def _forward_action_cached( + self, + x: torch.Tensor, + freqs: torch.Tensor, + t_mod: torch.Tensor, + context_payload: dict | None, + k_video: torch.Tensor, + v_video: torch.Tensor, + action_attention_mask: torch.Tensor, + ) -> torch.Tensor: + ( + q_action, + k_action, + v_action, + residual_x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + use_gradient_checkpointing, + ) = self._build_expert_attention_io("action", x, freqs, t_mod) + # Mixed attention: action queries attend to cached video K/V plus current action K/V. + k_cat = torch.cat([k_video, k_action], dim=1) + v_cat = torch.cat([v_video, v_action], dim=1) + mixed = self._mixed_attention( + q_cat=q_action, k_cat=k_cat, v_cat=v_cat, attention_mask=action_attention_mask + ) + return self._apply_post_with_optional_checkpoint( + block=self.blocks["action"], + residual_x=residual_x, + gate_msa=gate_msa, + shift_mlp=shift_mlp, + scale_mlp=scale_mlp, + gate_mlp=gate_mlp, + use_gradient_checkpointing=use_gradient_checkpointing, + mixed_slice=mixed, + context_payload=context_payload, + ) + + +class MoT(nn.Module): + def __init__( + self, + mixtures: dict[str, nn.Module], + mot_checkpoint_mixed_attn: bool = True, + ): + super().__init__() + if not mixtures: + raise ValueError("`mixtures` cannot be empty.") + if "video" not in mixtures or "action" not in mixtures: + raise ValueError("`mixtures` must include both 'video' and 'action' experts.") + + self.mixtures = nn.ModuleDict(mixtures) + self.expert_order = list(self.mixtures.keys()) + self.mot_checkpoint_mixed_attn = mot_checkpoint_mixed_attn + if mot_checkpoint_mixed_attn: + logger.info( + "Using gradient checkpointing for mixture attention. This will save memory but use more computation." + ) + + first_expert = self.mixtures[self.expert_order[0]] + self.num_layers = len(first_expert.blocks) + self.num_heads = first_expert.num_heads + self.attn_head_dim = first_expert.attn_head_dim + self.fp32_attention = bool(getattr(first_expert, "fp32_attention", True)) + + for name in self.expert_order[1:]: + expert = self.mixtures[name] + if len(expert.blocks) != self.num_layers: + raise ValueError( + f"All experts must have same number of layers; got {self.num_layers} and {len(expert.blocks)}" + ) + if expert.num_heads != self.num_heads: + raise ValueError( + f"All experts must have same num_heads; got {self.num_heads} and {expert.num_heads}" + ) + if expert.attn_head_dim != self.attn_head_dim: + raise ValueError( + "All experts must have same attn_head_dim; " + f"got {self.attn_head_dim} and {expert.attn_head_dim}" + ) + if bool(getattr(expert, "fp32_attention", True)) != self.fp32_attention: + raise ValueError("All experts must use the same `fp32_attention` setting.") + + logger.info(f"Initialized MoT with experts: {self.expert_order}, num_layers={self.num_layers}") + for name in self.expert_order: + expert = self.mixtures[name] + logger.info( + f" Expert '{name}': num_params={sum(p.numel() for p in expert.parameters()) / 1e9:.2f} B" + ) + + # One MoTLayer per layer, each owning that layer's block from every expert. This is the + # FSDP wrap unit: only MoTLayer.forward is ever called (MoT drives block submodules + # directly for the cross-expert mixed attention), so it is the boundary at which FSDP can + # all-gather a layer's params. The blocks are RE-PARENTED into the layers — removed from + # each expert's module registry — so they have a single owner; leaving them registered + # under both the expert and the layer would make FSDP try to manage the same params twice. + self.layers = nn.ModuleList( + [ + MoTLayer( + blocks={name: self.mixtures[name].blocks[layer_idx] for name in self.expert_order}, + experts={name: self.mixtures[name] for name in self.expert_order}, + num_heads=self.num_heads, + attn_head_dim=self.attn_head_dim, + fp32_attention=self.fp32_attention, + mot_checkpoint_mixed_attn=self.mot_checkpoint_mixed_attn, + ) + for layer_idx in range(self.num_layers) + ] + ) + for name in self.expert_order: + expert = self.mixtures[name] + kept_blocks = list(expert.blocks) + del expert._modules["blocks"] + # Keep an UNREGISTERED reference so the (unused) standalone `expert.forward` and any + # `len(expert.blocks)` still work, without re-adding the params to the expert's + # parameters()/state_dict() (which would double-register them with the MoTLayer owner). + object.__setattr__(expert, "blocks", kept_blocks) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # Backward-compat for checkpoints saved before the MoTLayer refactor. Then the per-layer + # blocks were keyed under the experts (`{prefix}mixtures..blocks..`, e.g. + # the released `ZibinDong/fastwam_libero_uncond_2cam224`); now they are owned by the layers + # (`{prefix}layers..blocks..`). Remap legacy keys in place so the recursion + # into `self.layers` finds them and the (now block-less) `self.mixtures` does not flag them. + legacy = re.compile(re.escape(prefix) + r"mixtures\.([^.]+)\.blocks\.(\d+)\.(.+)$") + moved = {} + for key in list(state_dict.keys()): + m = legacy.match(key) + if m is not None: + name, layer_idx, rest = m.group(1), m.group(2), m.group(3) + moved[f"{prefix}layers.{layer_idx}.blocks.{name}.{rest}"] = state_dict.pop(key) + state_dict.update(moved) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + def prefill_video_cache( self, video_tokens: torch.Tensor, @@ -531,20 +720,7 @@ class MoT(nn.Module): ) -> list[dict[str, torch.Tensor]]: """Prefill video branch once and cache per-layer K/V for action denoising. - Args: - video_tokens: Video tokens before layer 0, shape [B, Sv, D]. - video_freqs: Video RoPE frequencies, shape [Sv, 1, rope_dim]. - video_t_mod: Video time modulation tensor. - video_context_payload: Optional dict for video cross-attention. - - `context`: encoder states [B, L, D] - - `mask`: attention mask [B, Sv, L] or [B, 1, Sv, L] - video_attention_mask: Video self-attention mask, shape [Sv, Sv]. - - Returns: - Layer-wise cache list with length `num_layers`. - Each entry contains: - - `k`: video key tensor [B, Sv, H*Dh] - - `v`: video value tensor [B, Sv, H*Dh] + Returns a list of length ``num_layers``, each entry ``{"k": ..., "v": ...}``. """ if "video" not in self.mixtures: raise ValueError("MoT requires `video` expert for `prefill_video_cache`.") @@ -562,47 +738,16 @@ class MoT(nn.Module): f"mask={video_attention_mask.shape[0]} vs tokens={video_tokens.shape[1]}" ) - expert = self.mixtures["video"] x = video_tokens kv_cache: list[dict[str, torch.Tensor]] = [] - for layer_idx in range(self.num_layers): - block = expert.blocks[layer_idx] - # Build video Q/K/V from current layer input tokens. - ( - q, - k, - v, - residual_x, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - use_gradient_checkpointing, - ) = self._build_expert_attention_io( - expert=expert, - block=block, + for layer in self.layers: + x, k, v = layer( + mode="video_prefill", x=x, freqs=video_freqs, t_mod=video_t_mod, - ) - # Video prefill uses only video self-attention mask. - mixed = self._mixed_attention( - q_cat=q, - k_cat=k, - v_cat=v, - attention_mask=video_attention_mask, - ) - # Update video tokens for the next layer and persist current layer K/V. - x = self._apply_post_with_optional_checkpoint( - block=block, - residual_x=residual_x, - gate_msa=gate_msa, - shift_mlp=shift_mlp, - scale_mlp=scale_mlp, - gate_mlp=gate_mlp, - use_gradient_checkpointing=use_gradient_checkpointing, - mixed_slice=mixed, context_payload=video_context_payload, + video_attention_mask=video_attention_mask, ) kv_cache.append({"k": k, "v": v}) return kv_cache @@ -617,22 +762,7 @@ class MoT(nn.Module): attention_mask: torch.Tensor, video_seq_len: int, ) -> torch.Tensor: - """Run action branch with cached video K/V instead of recomputing video tokens. - - Args: - action_tokens: Action tokens before layer 0, shape [B, Sa, D]. - action_freqs: Action RoPE frequencies, shape [Sa, 1, rope_dim]. - action_t_mod: Action time modulation tensor. - action_context_payload: Optional dict for action cross-attention. - - `context`: encoder states [B, L, D] - - `mask`: attention mask [B, Sa, L] or [B, 1, Sa, L] - video_kv_cache: Layer-wise cached video K/V from `prefill_video_cache`. - attention_mask: Joint [video+action] mask, shape [Sv+Sa, Sv+Sa]. - video_seq_len: Video token count `Sv` in the joint sequence prefix. - - Returns: - Updated action tokens after all layers, shape [B, Sa, D]. - """ + """Run action branch with cached video K/V instead of recomputing video tokens.""" if "action" not in self.mixtures: raise ValueError("MoT requires `action` expert for `forward_action_with_video_cache`.") if len(video_kv_cache) != self.num_layers: @@ -654,56 +784,24 @@ class MoT(nn.Module): # Use the action query rows from the joint [video+action] mask. action_attention_mask = attention_mask[video_seq_len:total_seq_len, :total_seq_len] - expert = self.mixtures["action"] x = action_tokens - for layer_idx in range(self.num_layers): - block = expert.blocks[layer_idx] - # Action query/key/value are still step-dependent and must be recomputed each step. - ( - q_action, - k_action, - v_action, - residual_x, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - use_gradient_checkpointing, - ) = self._build_expert_attention_io( - expert=expert, - block=block, - x=x, - freqs=action_freqs, - t_mod=action_t_mod, - ) + for layer_idx, layer in enumerate(self.layers): layer_cache = video_kv_cache[layer_idx] if "k" not in layer_cache or "v" not in layer_cache: raise ValueError(f"`video_kv_cache[{layer_idx}]` must contain `k` and `v`.") - k_video = layer_cache["k"] v_video = layer_cache["v"] if k_video.shape[1] != video_seq_len or v_video.shape[1] != video_seq_len: raise ValueError(f"`video_kv_cache[{layer_idx}]` seq len mismatch, expected {video_seq_len}.") - - # Mixed attention: action queries attend to cached video K/V plus current action K/V. - k_cat = torch.cat([k_video, k_action], dim=1) - v_cat = torch.cat([v_video, v_action], dim=1) - mixed = self._mixed_attention( - q_cat=q_action, - k_cat=k_cat, - v_cat=v_cat, - attention_mask=action_attention_mask, - ) - x = self._apply_post_with_optional_checkpoint( - block=block, - residual_x=residual_x, - gate_msa=gate_msa, - shift_mlp=shift_mlp, - scale_mlp=scale_mlp, - gate_mlp=gate_mlp, - use_gradient_checkpointing=use_gradient_checkpointing, - mixed_slice=mixed, + x = layer( + mode="action_cached", + x=x, + freqs=action_freqs, + t_mod=action_t_mod, context_payload=action_context_payload, + k_video=k_video, + v_video=v_video, + action_attention_mask=action_attention_mask, ) return x @@ -730,94 +828,18 @@ class MoT(nn.Module): if attention_mask.shape[0] != attention_mask.shape[1]: raise ValueError(f"`attention_mask` must be square, got shape {tuple(attention_mask.shape)}") + # Each layer is a MoTLayer module; entering via __call__ lets FSDP all-gather that + # layer's params (the whole point of the per-layer split). tokens_all = dict(embeds_all) - - for layer_idx in range(self.num_layers): - q_chunks = [] - k_chunks = [] - v_chunks = [] - cached = {} - seq_lens = [] - - for name in self.expert_order: - expert = self.mixtures[name] - block = expert.blocks[layer_idx] - x = tokens_all[name] - freqs = freqs_all[name] - t_mod = t_mod_all[name] - - ( - q, - k, - v, - residual_x, - gate_msa, - shift_mlp, - scale_mlp, - gate_mlp, - use_gradient_checkpointing, - ) = self._build_expert_attention_io( - expert=expert, - block=block, - x=x, - freqs=freqs, - t_mod=t_mod, - ) - - q_chunks.append(q) - k_chunks.append(k) - v_chunks.append(v) - seq_lens.append(x.shape[1]) - cached[name] = { - "block": block, - "residual_x": residual_x, - "gate_msa": gate_msa, - "shift_mlp": shift_mlp, - "scale_mlp": scale_mlp, - "gate_mlp": gate_mlp, - "use_gradient_checkpointing": use_gradient_checkpointing, - } - - # 3. concat all tokens for mixed attention - q_cat = torch.cat(q_chunks, dim=1) - k_cat = torch.cat(k_chunks, dim=1) - v_cat = torch.cat(v_chunks, dim=1) - - total_seq = q_cat.shape[1] - if attention_mask.shape[0] != total_seq: - raise ValueError( - "Attention mask seq length mismatch: " - f"mask={attention_mask.shape[0]} vs tokens={total_seq}" - ) - - mixed = self._mixed_attention( - q_cat=q_cat, k_cat=k_cat, v_cat=v_cat, attention_mask=attention_mask + for layer in self.layers: + tokens_all = layer( + mode="joint", + tokens_all=tokens_all, + attention_mask=attention_mask, + freqs_all=freqs_all, + context_all=context_all, + t_mod_all=t_mod_all, ) - - start = 0 - for name, seq_len in zip(self.expert_order, seq_lens, strict=True): - # 4. split mixed attention output and apply post-attention blocks for each expert - end = start + seq_len - mixed_slice = mixed[:, start:end, :] - cached_expert = cached[name] - block = cached_expert["block"] - context_payload = context_all.get(name) - - updated_tokens = self._apply_post_with_optional_checkpoint( - block=block, - residual_x=cached_expert["residual_x"], - gate_msa=cached_expert["gate_msa"], - shift_mlp=cached_expert["shift_mlp"], - scale_mlp=cached_expert["scale_mlp"], - gate_mlp=cached_expert["gate_mlp"], - use_gradient_checkpointing=cached_expert["use_gradient_checkpointing"], - mixed_slice=mixed_slice, - context_payload=context_payload, - ) - - tokens_all[name] = updated_tokens - start = end - return tokens_all @@ -846,11 +868,20 @@ class FastWAM(torch.nn.Module): loss_lambda_action: float = 1.0, ): super().__init__() - self.video_expert = video_expert - self.action_expert = action_expert self.mot = mot - # Keep trainer compatibility: optimizer and freeze logic use `model.dit`. - self.dit = self.mot + # `video_expert` / `action_expert` are the very same module objects as + # `mot.mixtures["video"]` / `["action"]`, and `dit` is an alias of `mot`. Registering + # them as submodules too would give every expert tensor three names in `state_dict()` + # (`video_expert.*`, `mot.mixtures.video.*`, `dit.mixtures.video.*`) — a 3x-bloated + # gathered FSDP checkpoint and a doubled module tree for FSDP to traverse. Hold them as + # plain (unregistered) attributes instead — bypassing `nn.Module.__setattr__`, like the + # frozen vae/text_encoder below — so `mot` is the single registered owner and each tensor + # has one canonical name (`mot.mixtures.*` / `mot.layers.*`, matching the base checkpoint). + # Forward / freeze / optimizer code still reaches them by attribute, and device/dtype moves + # still apply via `mot`. (optimizer + freeze logic use `model.dit`.) + object.__setattr__(self, "video_expert", video_expert) + object.__setattr__(self, "action_expert", action_expert) + object.__setattr__(self, "dit", self.mot) # Frozen Wan2.2 components: bypass `nn.Module.__setattr__` so they are NOT # registered as submodules. They are therefore excluded from `state_dict()` diff --git a/src/lerobot/policies/fastwam/processor_fastwam.py b/src/lerobot/policies/fastwam/processor_fastwam.py index fafc80c9f..080fdb9a4 100644 --- a/src/lerobot/policies/fastwam/processor_fastwam.py +++ b/src/lerobot/policies/fastwam/processor_fastwam.py @@ -133,9 +133,7 @@ def make_fastwam_pre_post_processors( # resize visual inputs to match model expected input size, if necessary visual_shapes = [ - feature.shape - for feature in config.input_features.values() - if feature.type == FeatureType.VISUAL + feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL ] resize_steps = [] if visual_shapes: diff --git a/src/lerobot/policies/fastwam/wan_components.py b/src/lerobot/policies/fastwam/wan_components.py index fd6e3dc52..8a9c9631c 100644 --- a/src/lerobot/policies/fastwam/wan_components.py +++ b/src/lerobot/policies/fastwam/wan_components.py @@ -42,6 +42,7 @@ WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors" WAN_T5_TOKENIZER = "google/umt5-xxl" WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" + class WanTextEncoder(torch.nn.Module): """FastWAM text-encoder contract over `transformers.UMT5EncoderModel`. @@ -76,7 +77,11 @@ class WanTokenizer: self.seq_len = int(seq_len) def __call__( - self, sequence: str | Sequence[str], return_mask: bool = False, add_special_tokens: bool = True, **_: Any + self, + sequence: str | Sequence[str], + return_mask: bool = False, + add_special_tokens: bool = True, + **_: Any, ): if isinstance(sequence, str): sequence = [sequence] @@ -99,9 +104,7 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer: def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38: """Load real Wan2.2 VAE weights from the diffusers repo (offline base creation).""" - vae = AutoencoderKLWan.from_pretrained( - WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype - ) + vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype) return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae) diff --git a/src/lerobot/policies/fastwam/wan_video_dit.py b/src/lerobot/policies/fastwam/wan_video_dit.py index 0b38ad816..7a777e9df 100644 --- a/src/lerobot/policies/fastwam/wan_video_dit.py +++ b/src/lerobot/policies/fastwam/wan_video_dit.py @@ -425,7 +425,7 @@ class WanVideoDiT(WanModel): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, - seperated_timestep: bool = False, + separated_timestep: bool = False, require_vae_embedding: bool = False, require_clip_embedding: bool = False, fuse_vae_embedding_in_latents: bool = True, @@ -489,7 +489,7 @@ class WanVideoDiT(WanModel): self.hidden_dim = hidden_dim self.attn_head_dim = attn_head_dim - self.seperated_timestep = seperated_timestep + self.separated_timestep = separated_timestep self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents self.video_attention_mask_mode = str(video_attention_mask_mode) self.action_conditioned = action_conditioned @@ -647,7 +647,7 @@ class WanVideoDiT(WanModel): ) tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w) - if not (self.seperated_timestep and fuse_vae_embedding_in_latents): + if not (self.separated_timestep and fuse_vae_embedding_in_latents): raise NotImplementedError( "FastWAM currently requires separated timesteps with fused VAE latents." ) diff --git a/tests/policies/fastwam/test_fastwam_policy.py b/tests/policies/fastwam/test_fastwam_policy.py index 68ea6632b..f4abab4a8 100644 --- a/tests/policies/fastwam/test_fastwam_policy.py +++ b/tests/policies/fastwam/test_fastwam_policy.py @@ -224,7 +224,15 @@ class CoreWithFrozenComponents(FakeFastWAMCore): def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None) + cfg = FastWAMConfig( + action_dim=3, + proprio_dim=2, + action_horizon=4, + n_action_steps=2, + num_video_frames=5, + action_video_freq_ratio=1, + base_model_id=None, + ) def build_core(self, config): core = CoreWithFrozenComponents() @@ -256,7 +264,15 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None) + cfg = FastWAMConfig( + action_dim=3, + proprio_dim=2, + action_horizon=4, + n_action_steps=2, + num_video_frames=5, + action_video_freq_ratio=1, + base_model_id=None, + ) monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents()) policy = FastWAMPolicy(cfg) @@ -278,7 +294,15 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path): def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None) + cfg = FastWAMConfig( + action_dim=3, + proprio_dim=2, + action_horizon=4, + n_action_steps=2, + num_video_frames=5, + action_video_freq_ratio=1, + base_model_id=None, + ) monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents()) policy = FastWAMPolicy(cfg) @@ -332,7 +356,7 @@ def test_vae_adapter_empty_build_encode_decode_shapes(): "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], - "temperal_downsample": [False, True, True], + "temporal_downsample": [False, True, True], "dropout": 0.0, "is_residual": True, "in_channels": 12,