adding lazy imports

This commit is contained in:
Maxime Ellerbach
2026-06-18 14:17:09 +00:00
parent b72cddeea3
commit 13c88403b5
4 changed files with 41 additions and 5 deletions
@@ -23,6 +23,7 @@ from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.import_utils import require_package
from .configuration_fastwam import FastWAMConfig
from .modular_fastwam import ActionDiT, FastWAM, MoT
@@ -52,6 +53,11 @@ class FastWAMPolicy(PreTrainedPolicy):
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
**kwargs: Any,
):
# FastWAM's Wan2.2 backbone needs transformers (UMT5 text encoder/tokenizer) and
# diffusers (Wan VAE), both behind the `fastwam` extra. Fail fast with an actionable
# message in base installs rather than deep in Wan component construction.
require_package("transformers", extra="fastwam")
require_package("diffusers", extra="fastwam")
# `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the
# dataset feature metadata is already applied to `config` by make_policy upstream,
# so we accept and ignore them, matching the other LeRobot policies.
@@ -1,10 +1,24 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from lerobot.utils.import_utils import _diffusers_available
if TYPE_CHECKING or _diffusers_available:
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
else:
class ModelMixin:
pass
class ConfigMixin:
pass
def register_to_config(init):
return init
from .attention import flash_attention
+16 -3
View File
@@ -22,14 +22,24 @@ from typing import TYPE_CHECKING, Any
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, UMT5EncoderModel
from lerobot.utils.import_utils import _diffusers_available, _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer, UMT5EncoderModel
else:
AutoTokenizer = None
UMT5EncoderModel = None
if TYPE_CHECKING or _diffusers_available:
from diffusers import AutoencoderKLWan
else:
AutoencoderKLWan = None
if TYPE_CHECKING:
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
from diffusers import AutoencoderKLWan
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
@@ -73,6 +83,7 @@ class WanTokenizer:
FastWAM call site expects."""
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
require_package("transformers", extra="fastwam")
self.tokenizer = AutoTokenizer.from_pretrained(name)
self.seq_len = int(seq_len)
@@ -104,12 +115,14 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
require_package("diffusers", extra="fastwam")
vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype)
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
require_package("transformers", extra="fastwam")
encoder = UMT5EncoderModel.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
)
@@ -21,6 +21,9 @@ import torch
from safetensors import safe_open
from torch import nn
pytest.importorskip("transformers", reason="fastwam requires the `fastwam` extra (transformers)")
pytest.importorskip("diffusers", reason="fastwam requires the `fastwam` extra (diffusers)")
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy