From 218f8a17352686eb31d99334278ffb4fe69c2751 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Tue, 30 Jun 2026 11:01:33 +0000 Subject: [PATCH] moving and renaming files to have a cleaner file tree --- .../policies/fastwam/modeling_fastwam.py | 6 ++-- src/lerobot/policies/fastwam/wan/README.md | 34 ++++++++++++------- src/lerobot/policies/fastwam/wan/__init__.py | 13 +++++++ .../{wan_adapters.py => wan/adapters.py} | 0 .../fastwam/wan/{modules => }/attention.py | 0 .../{wan_components.py => wan/components.py} | 4 +-- .../fastwam/wan/{modules => }/model.py | 0 .../{modular_fastwam.py => wan/modular.py} | 4 +-- .../policies/fastwam/wan/modules/__init__.py | 8 ----- .../policies/fastwam/wan/utils/__init__.py | 6 ---- .../policies/fastwam/wan/utils/fm_solvers.py | 9 ----- .../{wan_video_dit.py => wan/video_dit.py} | 18 ++++++---- tests/policies/fastwam/test_fastwam_policy.py | 4 +-- 13 files changed, 54 insertions(+), 52 deletions(-) create mode 100644 src/lerobot/policies/fastwam/wan/__init__.py rename src/lerobot/policies/fastwam/{wan_adapters.py => wan/adapters.py} (100%) rename src/lerobot/policies/fastwam/wan/{modules => }/attention.py (100%) rename src/lerobot/policies/fastwam/{wan_components.py => wan/components.py} (98%) rename src/lerobot/policies/fastwam/wan/{modules => }/model.py (100%) rename src/lerobot/policies/fastwam/{modular_fastwam.py => wan/modular.py} (99%) delete mode 100644 src/lerobot/policies/fastwam/wan/modules/__init__.py delete mode 100644 src/lerobot/policies/fastwam/wan/utils/__init__.py delete mode 100644 src/lerobot/policies/fastwam/wan/utils/fm_solvers.py rename src/lerobot/policies/fastwam/{wan_video_dit.py => wan/video_dit.py} (98%) diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 9822c9834..30dda1b3f 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -26,13 +26,13 @@ from lerobot.utils.constants import OBS_STATE from lerobot.utils.import_utils import require_package from .configuration_fastwam import FastWAMConfig -from .modular_fastwam import ActionDiT, FastWAM, MoT -from .wan_components import ( +from .wan.components import ( build_wan_tokenizer, load_pretrained_wan_text_encoder, load_pretrained_wan_vae, ) -from .wan_video_dit import WanVideoDiT +from .wan.modular import ActionDiT, FastWAM, MoT +from .wan.video_dit import WanVideoDiT class FastWAMPolicy(PreTrainedPolicy): diff --git a/src/lerobot/policies/fastwam/wan/README.md b/src/lerobot/policies/fastwam/wan/README.md index dbd56d1f2..552288553 100644 --- a/src/lerobot/policies/fastwam/wan/README.md +++ b/src/lerobot/policies/fastwam/wan/README.md @@ -1,6 +1,10 @@ -# Wan2.2 Upstream Subset +# FastWAM `wan` package -This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM. +This package holds FastWAM's model implementation. It mixes a small **vendored +subset of the official Wan2.2 source tree** with FastWAM's own code, kept flat in +a single directory. + +## Vendored from Wan2.2 - Upstream repository: https://github.com/Wan-Video/Wan2.2 - Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc` @@ -8,18 +12,22 @@ This directory contains the trimmed subset of the official Wan2.2 source tree us Copied files: -- `wan/modules/attention.py` -- `wan/modules/model.py` -- `wan/modules/__init__.py` -- `wan/utils/fm_solvers.py` -- `wan/utils/__init__.py` +- `attention.py` (was `wan/modules/attention.py`) +- `model.py` (was `wan/modules/model.py`) +- `get_sampling_sigmas` in `video_dit.py` (was `wan/utils/fm_solvers.py`), inlined + next to its only caller. -This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE, -UMT5 text encoder, and tokenizer are no longer vendored — they come from +This subset only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE, +UMT5 text encoder, and tokenizer are no longer vendored - they come from `diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and -`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`). +`transformers.AutoTokenizer` (see `components.py` and `adapters.py`). -Current FastWAM adapters that directly reuse this vendored subset: +## FastWAM's own code -- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`. -- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps. +- `video_dit.py` builds on `model` (`sinusoidal_embedding_1d`, `rope_params`, + `rope_apply`, …) and `attention.flash_attention`. Its + `WanContinuousFlowMatchScheduler` uses `get_sampling_sigmas` for Wan-compatible + inference timesteps. +- `components.py` / `adapters.py` load the VAE, text encoder, tokenizer, and the + custom DiT weights. +- `modular.py` defines the FastWAM model (`ActionDiT`, `MoT`, `FastWAM`, …). diff --git a/src/lerobot/policies/fastwam/wan/__init__.py b/src/lerobot/policies/fastwam/wan/__init__.py new file mode 100644 index 000000000..f52df1bd7 --- /dev/null +++ b/src/lerobot/policies/fastwam/wan/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/policies/fastwam/wan_adapters.py b/src/lerobot/policies/fastwam/wan/adapters.py similarity index 100% rename from src/lerobot/policies/fastwam/wan_adapters.py rename to src/lerobot/policies/fastwam/wan/adapters.py diff --git a/src/lerobot/policies/fastwam/wan/modules/attention.py b/src/lerobot/policies/fastwam/wan/attention.py similarity index 100% rename from src/lerobot/policies/fastwam/wan/modules/attention.py rename to src/lerobot/policies/fastwam/wan/attention.py diff --git a/src/lerobot/policies/fastwam/wan_components.py b/src/lerobot/policies/fastwam/wan/components.py similarity index 98% rename from src/lerobot/policies/fastwam/wan_components.py rename to src/lerobot/policies/fastwam/wan/components.py index a69f21fe0..821669975 100644 --- a/src/lerobot/policies/fastwam/wan_components.py +++ b/src/lerobot/policies/fastwam/wan/components.py @@ -36,8 +36,8 @@ if TYPE_CHECKING or _diffusers_available: else: AutoencoderKLWan = None -from .wan_adapters import WanVideoVAE38 -from .wan_video_dit import WanVideoDiT +from .adapters import WanVideoVAE38 +from .video_dit import WanVideoDiT logger = logging.getLogger(__name__) diff --git a/src/lerobot/policies/fastwam/wan/modules/model.py b/src/lerobot/policies/fastwam/wan/model.py similarity index 100% rename from src/lerobot/policies/fastwam/wan/modules/model.py rename to src/lerobot/policies/fastwam/wan/model.py diff --git a/src/lerobot/policies/fastwam/modular_fastwam.py b/src/lerobot/policies/fastwam/wan/modular.py similarity index 99% rename from src/lerobot/policies/fastwam/modular_fastwam.py rename to src/lerobot/policies/fastwam/wan/modular.py index d82c8b69d..00cfb7fc1 100644 --- a/src/lerobot/policies/fastwam/modular_fastwam.py +++ b/src/lerobot/policies/fastwam/wan/modular.py @@ -25,7 +25,7 @@ import torch.nn as nn import torch.nn.functional as functional from PIL import Image -from .wan_components import ( +from .components import ( WAN22_DIFFUSERS_MODEL_ID, WAN_T5_TOKENIZER, build_wan_tokenizer, @@ -34,7 +34,7 @@ from .wan_components import ( load_wan_video_dit, resolve_wan_dit_paths, ) -from .wan_video_dit import ( +from .video_dit import ( FastWAMAttentionBlock, WanContinuousFlowMatchScheduler, fastwam_masked_attention, diff --git a/src/lerobot/policies/fastwam/wan/modules/__init__.py b/src/lerobot/policies/fastwam/wan/modules/__init__.py deleted file mode 100644 index c4c595029..000000000 --- a/src/lerobot/policies/fastwam/wan/modules/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -from .attention import flash_attention -from .model import WanModel - -__all__ = [ - "WanModel", - "flash_attention", -] diff --git a/src/lerobot/policies/fastwam/wan/utils/__init__.py b/src/lerobot/policies/fastwam/wan/utils/__init__.py deleted file mode 100644 index ba223fe65..000000000 --- a/src/lerobot/policies/fastwam/wan/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -from .fm_solvers import get_sampling_sigmas - -__all__ = [ - "get_sampling_sigmas", -] diff --git a/src/lerobot/policies/fastwam/wan/utils/fm_solvers.py b/src/lerobot/policies/fastwam/wan/utils/fm_solvers.py deleted file mode 100644 index 42b453590..000000000 --- a/src/lerobot/policies/fastwam/wan/utils/fm_solvers.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - -import numpy as np - - -def get_sampling_sigmas(sampling_steps, shift): - sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] - sigma = shift * sigma / (1 + (shift - 1) * sigma) - return sigma diff --git a/src/lerobot/policies/fastwam/wan_video_dit.py b/src/lerobot/policies/fastwam/wan/video_dit.py similarity index 98% rename from src/lerobot/policies/fastwam/wan_video_dit.py rename to src/lerobot/policies/fastwam/wan/video_dit.py index 0b38ad816..7602af223 100644 --- a/src/lerobot/policies/fastwam/wan_video_dit.py +++ b/src/lerobot/policies/fastwam/wan/video_dit.py @@ -15,12 +15,13 @@ import logging from typing import Any +import numpy as np import torch import torch.nn as nn import torch.nn.functional as functional from einops import rearrange -from .wan.modules.model import ( +from .model import ( WanAttentionBlock, WanLayerNorm, WanModel, @@ -29,11 +30,18 @@ from .wan.modules.model import ( rope_params, sinusoidal_embedding_1d, ) -from .wan.utils.fm_solvers import get_sampling_sigmas logger = logging.getLogger(__name__) +def get_sampling_sigmas(sampling_steps, shift): + # Vendored from Wan2.2 (formerly wan/utils/fm_solvers.py); computes the + # noise-level (sigma) schedule for Wan-compatible flow-matching inference. + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = shift * sigma / (1 + (shift - 1) * sigma) + return sigma + + def create_custom_forward(module): def custom_forward(*inputs, **kwargs): return module(*inputs, **kwargs) @@ -94,10 +102,6 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): return x * (1 + scale) + shift -def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]: - return get_sampling_sigmas(num_inference_steps, shift) - - class WanContinuousFlowMatchScheduler: """Continuous-time Flow-Matching scheduler with shift-based Wan sampling.""" @@ -173,7 +177,7 @@ class WanContinuousFlowMatchScheduler: raise ValueError(f"`shift` must be positive, got {shift}") sigma_steps = torch.as_tensor( - _get_wan_sampling_sigmas(num_inference_steps, shift), + get_sampling_sigmas(num_inference_steps, shift), device=device, dtype=torch.float32, ) diff --git a/tests/policies/fastwam/test_fastwam_policy.py b/tests/policies/fastwam/test_fastwam_policy.py index 2d132e985..5958f1be3 100644 --- a/tests/policies/fastwam/test_fastwam_policy.py +++ b/tests/policies/fastwam/test_fastwam_policy.py @@ -257,7 +257,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone") monkeypatch.setattr( - "lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained", + "lerobot.policies.fastwam.wan.modular.FastWAM.from_wan22_pretrained", fail_if_wan_pretrained_is_loaded, ) @@ -348,7 +348,7 @@ def test_vae_adapter_empty_build_encode_decode_shapes(): pytest.importorskip("diffusers") from diffusers import AutoencoderKLWan - from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38 + from lerobot.policies.fastwam.wan.adapters import WanVideoVAE38 # Production always loads a real pretrained VAE from the diffusers repo; here we # build the same architecture with random weights and dummy standardization stats