mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 17:17:01 +00:00
moving and renaming files to have a cleaner file tree
This commit is contained in:
@@ -26,13 +26,13 @@ from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
from .wan_components import (
|
||||
from .wan.components import (
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
)
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
from .wan.modular import ActionDiT, FastWAM, MoT
|
||||
from .wan.video_dit import WanVideoDiT
|
||||
|
||||
|
||||
class FastWAMPolicy(PreTrainedPolicy):
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# Wan2.2 Upstream Subset
|
||||
# FastWAM `wan` package
|
||||
|
||||
This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM.
|
||||
This package holds FastWAM's model implementation. It mixes a small **vendored
|
||||
subset of the official Wan2.2 source tree** with FastWAM's own code, kept flat in
|
||||
a single directory.
|
||||
|
||||
## Vendored from Wan2.2
|
||||
|
||||
- Upstream repository: https://github.com/Wan-Video/Wan2.2
|
||||
- Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc`
|
||||
@@ -8,18 +12,22 @@ This directory contains the trimmed subset of the official Wan2.2 source tree us
|
||||
|
||||
Copied files:
|
||||
|
||||
- `wan/modules/attention.py`
|
||||
- `wan/modules/model.py`
|
||||
- `wan/modules/__init__.py`
|
||||
- `wan/utils/fm_solvers.py`
|
||||
- `wan/utils/__init__.py`
|
||||
- `attention.py` (was `wan/modules/attention.py`)
|
||||
- `model.py` (was `wan/modules/model.py`)
|
||||
- `get_sampling_sigmas` in `video_dit.py` (was `wan/utils/fm_solvers.py`), inlined
|
||||
next to its only caller.
|
||||
|
||||
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
|
||||
UMT5 text encoder, and tokenizer are no longer vendored — they come from
|
||||
This subset only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
|
||||
UMT5 text encoder, and tokenizer are no longer vendored - they come from
|
||||
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
|
||||
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
|
||||
`transformers.AutoTokenizer` (see `components.py` and `adapters.py`).
|
||||
|
||||
Current FastWAM adapters that directly reuse this vendored subset:
|
||||
## FastWAM's own code
|
||||
|
||||
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
|
||||
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
|
||||
- `video_dit.py` builds on `model` (`sinusoidal_embedding_1d`, `rope_params`,
|
||||
`rope_apply`, …) and `attention.flash_attention`. Its
|
||||
`WanContinuousFlowMatchScheduler` uses `get_sampling_sigmas` for Wan-compatible
|
||||
inference timesteps.
|
||||
- `components.py` / `adapters.py` load the VAE, text encoder, tokenizer, and the
|
||||
custom DiT weights.
|
||||
- `modular.py` defines the FastWAM model (`ActionDiT`, `MoT`, `FastWAM`, …).
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
+2
-2
@@ -36,8 +36,8 @@ if TYPE_CHECKING or _diffusers_available:
|
||||
else:
|
||||
AutoencoderKLWan = None
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
from .adapters import WanVideoVAE38
|
||||
from .video_dit import WanVideoDiT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
+2
-2
@@ -25,7 +25,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as functional
|
||||
from PIL import Image
|
||||
|
||||
from .wan_components import (
|
||||
from .components import (
|
||||
WAN22_DIFFUSERS_MODEL_ID,
|
||||
WAN_T5_TOKENIZER,
|
||||
build_wan_tokenizer,
|
||||
@@ -34,7 +34,7 @@ from .wan_components import (
|
||||
load_wan_video_dit,
|
||||
resolve_wan_dit_paths,
|
||||
)
|
||||
from .wan_video_dit import (
|
||||
from .video_dit import (
|
||||
FastWAMAttentionBlock,
|
||||
WanContinuousFlowMatchScheduler,
|
||||
fastwam_masked_attention,
|
||||
@@ -1,8 +0,0 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .attention import flash_attention
|
||||
from .model import WanModel
|
||||
|
||||
__all__ = [
|
||||
"WanModel",
|
||||
"flash_attention",
|
||||
]
|
||||
@@ -1,6 +0,0 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .fm_solvers import get_sampling_sigmas
|
||||
|
||||
__all__ = [
|
||||
"get_sampling_sigmas",
|
||||
]
|
||||
@@ -1,9 +0,0 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_sampling_sigmas(sampling_steps, shift):
|
||||
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
||||
sigma = shift * sigma / (1 + (shift - 1) * sigma)
|
||||
return sigma
|
||||
+11
-7
@@ -15,12 +15,13 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as functional
|
||||
from einops import rearrange
|
||||
|
||||
from .wan.modules.model import (
|
||||
from .model import (
|
||||
WanAttentionBlock,
|
||||
WanLayerNorm,
|
||||
WanModel,
|
||||
@@ -29,11 +30,18 @@ from .wan.modules.model import (
|
||||
rope_params,
|
||||
sinusoidal_embedding_1d,
|
||||
)
|
||||
from .wan.utils.fm_solvers import get_sampling_sigmas
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_sampling_sigmas(sampling_steps, shift):
|
||||
# Vendored from Wan2.2 (formerly wan/utils/fm_solvers.py); computes the
|
||||
# noise-level (sigma) schedule for Wan-compatible flow-matching inference.
|
||||
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
||||
sigma = shift * sigma / (1 + (shift - 1) * sigma)
|
||||
return sigma
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
@@ -94,10 +102,6 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
|
||||
return get_sampling_sigmas(num_inference_steps, shift)
|
||||
|
||||
|
||||
class WanContinuousFlowMatchScheduler:
|
||||
"""Continuous-time Flow-Matching scheduler with shift-based Wan sampling."""
|
||||
|
||||
@@ -173,7 +177,7 @@ class WanContinuousFlowMatchScheduler:
|
||||
raise ValueError(f"`shift` must be positive, got {shift}")
|
||||
|
||||
sigma_steps = torch.as_tensor(
|
||||
_get_wan_sampling_sigmas(num_inference_steps, shift),
|
||||
get_sampling_sigmas(num_inference_steps, shift),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
@@ -257,7 +257,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
|
||||
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
|
||||
"lerobot.policies.fastwam.wan.modular.FastWAM.from_wan22_pretrained",
|
||||
fail_if_wan_pretrained_is_loaded,
|
||||
)
|
||||
|
||||
@@ -348,7 +348,7 @@ def test_vae_adapter_empty_build_encode_decode_shapes():
|
||||
pytest.importorskip("diffusers")
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
|
||||
from lerobot.policies.fastwam.wan.adapters import WanVideoVAE38
|
||||
|
||||
# Production always loads a real pretrained VAE from the diffusers repo; here we
|
||||
# build the same architecture with random weights and dummy standardization stats
|
||||
|
||||
Reference in New Issue
Block a user