moving and renaming files to have a cleaner file tree

This commit is contained in:
Maxime Ellerbach
2026-06-30 11:01:33 +00:00
parent 343ab5f99c
commit 218f8a1735
13 changed files with 54 additions and 52 deletions
@@ -26,13 +26,13 @@ from lerobot.utils.constants import OBS_STATE
from lerobot.utils.import_utils import require_package
from .configuration_fastwam import FastWAMConfig
from .modular_fastwam import ActionDiT, FastWAM, MoT
from .wan_components import (
from .wan.components import (
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
)
from .wan_video_dit import WanVideoDiT
from .wan.modular import ActionDiT, FastWAM, MoT
from .wan.video_dit import WanVideoDiT
class FastWAMPolicy(PreTrainedPolicy):
+21 -13
View File
@@ -1,6 +1,10 @@
# Wan2.2 Upstream Subset
# FastWAM `wan` package
This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM.
This package holds FastWAM's model implementation. It mixes a small **vendored
subset of the official Wan2.2 source tree** with FastWAM's own code, kept flat in
a single directory.
## Vendored from Wan2.2
- Upstream repository: https://github.com/Wan-Video/Wan2.2
- Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc`
@@ -8,18 +12,22 @@ This directory contains the trimmed subset of the official Wan2.2 source tree us
Copied files:
- `wan/modules/attention.py`
- `wan/modules/model.py`
- `wan/modules/__init__.py`
- `wan/utils/fm_solvers.py`
- `wan/utils/__init__.py`
- `attention.py` (was `wan/modules/attention.py`)
- `model.py` (was `wan/modules/model.py`)
- `get_sampling_sigmas` in `video_dit.py` (was `wan/utils/fm_solvers.py`), inlined
next to its only caller.
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
UMT5 text encoder, and tokenizer are no longer vendored they come from
This subset only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
UMT5 text encoder, and tokenizer are no longer vendored - they come from
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
`transformers.AutoTokenizer` (see `components.py` and `adapters.py`).
Current FastWAM adapters that directly reuse this vendored subset:
## FastWAM's own code
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
- `video_dit.py` builds on `model` (`sinusoidal_embedding_1d`, `rope_params`,
`rope_apply`, …) and `attention.flash_attention`. Its
`WanContinuousFlowMatchScheduler` uses `get_sampling_sigmas` for Wan-compatible
inference timesteps.
- `components.py` / `adapters.py` load the VAE, text encoder, tokenizer, and the
custom DiT weights.
- `modular.py` defines the FastWAM model (`ActionDiT`, `MoT`, `FastWAM`, …).
@@ -0,0 +1,13 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -36,8 +36,8 @@ if TYPE_CHECKING or _diffusers_available:
else:
AutoencoderKLWan = None
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
from .adapters import WanVideoVAE38
from .video_dit import WanVideoDiT
logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ import torch.nn as nn
import torch.nn.functional as functional
from PIL import Image
from .wan_components import (
from .components import (
WAN22_DIFFUSERS_MODEL_ID,
WAN_T5_TOKENIZER,
build_wan_tokenizer,
@@ -34,7 +34,7 @@ from .wan_components import (
load_wan_video_dit,
resolve_wan_dit_paths,
)
from .wan_video_dit import (
from .video_dit import (
FastWAMAttentionBlock,
WanContinuousFlowMatchScheduler,
fastwam_masked_attention,
@@ -1,8 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .attention import flash_attention
from .model import WanModel
__all__ = [
"WanModel",
"flash_attention",
]
@@ -1,6 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .fm_solvers import get_sampling_sigmas
__all__ = [
"get_sampling_sigmas",
]
@@ -1,9 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = shift * sigma / (1 + (shift - 1) * sigma)
return sigma
@@ -15,12 +15,13 @@
import logging
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as functional
from einops import rearrange
from .wan.modules.model import (
from .model import (
WanAttentionBlock,
WanLayerNorm,
WanModel,
@@ -29,11 +30,18 @@ from .wan.modules.model import (
rope_params,
sinusoidal_embedding_1d,
)
from .wan.utils.fm_solvers import get_sampling_sigmas
logger = logging.getLogger(__name__)
def get_sampling_sigmas(sampling_steps, shift):
# Vendored from Wan2.2 (formerly wan/utils/fm_solvers.py); computes the
# noise-level (sigma) schedule for Wan-compatible flow-matching inference.
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = shift * sigma / (1 + (shift - 1) * sigma)
return sigma
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
@@ -94,10 +102,6 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
return get_sampling_sigmas(num_inference_steps, shift)
class WanContinuousFlowMatchScheduler:
"""Continuous-time Flow-Matching scheduler with shift-based Wan sampling."""
@@ -173,7 +177,7 @@ class WanContinuousFlowMatchScheduler:
raise ValueError(f"`shift` must be positive, got {shift}")
sigma_steps = torch.as_tensor(
_get_wan_sampling_sigmas(num_inference_steps, shift),
get_sampling_sigmas(num_inference_steps, shift),
device=device,
dtype=torch.float32,
)
@@ -257,7 +257,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
monkeypatch.setattr(
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
"lerobot.policies.fastwam.wan.modular.FastWAM.from_wan22_pretrained",
fail_if_wan_pretrained_is_loaded,
)
@@ -348,7 +348,7 @@ def test_vae_adapter_empty_build_encode_decode_shapes():
pytest.importorskip("diffusers")
from diffusers import AutoencoderKLWan
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
from lerobot.policies.fastwam.wan.adapters import WanVideoVAE38
# Production always loads a real pretrained VAE from the diffusers repo; here we
# build the same architecture with random weights and dummy standardization stats