From 13c88403b5b3a99756f8bdc6fb115ed481d487e2 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Thu, 18 Jun 2026 14:17:09 +0000 Subject: [PATCH] adding lazy imports --- .../policies/fastwam/modeling_fastwam.py | 6 ++++++ .../policies/fastwam/wan/modules/model.py | 18 ++++++++++++++++-- .../policies/fastwam/wan_components.py | 19 ++++++++++++++++--- tests/policies/fastwam/test_fastwam_policy.py | 3 +++ 4 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 1bde87eea..292d90bd8 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -23,6 +23,7 @@ from torch import Tensor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.constants import OBS_STATE +from lerobot.utils.import_utils import require_package from .configuration_fastwam import FastWAMConfig from .modular_fastwam import ActionDiT, FastWAM, MoT @@ -52,6 +53,11 @@ class FastWAMPolicy(PreTrainedPolicy): dataset_stats: dict[str, dict[str, Tensor]] | None = None, **kwargs: Any, ): + # FastWAM's Wan2.2 backbone needs transformers (UMT5 text encoder/tokenizer) and + # diffusers (Wan VAE), both behind the `fastwam` extra. Fail fast with an actionable + # message in base installs rather than deep in Wan component construction. + require_package("transformers", extra="fastwam") + require_package("diffusers", extra="fastwam") # `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the # dataset feature metadata is already applied to `config` by make_policy upstream, # so we accept and ignore them, matching the other LeRobot policies. diff --git a/src/lerobot/policies/fastwam/wan/modules/model.py b/src/lerobot/policies/fastwam/wan/modules/model.py index 5f3c41e1a..b50480e72 100644 --- a/src/lerobot/policies/fastwam/wan/modules/model.py +++ b/src/lerobot/policies/fastwam/wan/modules/model.py @@ -1,10 +1,24 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math +from typing import TYPE_CHECKING import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models.modeling_utils import ModelMixin + +from lerobot.utils.import_utils import _diffusers_available + +if TYPE_CHECKING or _diffusers_available: + from diffusers.configuration_utils import ConfigMixin, register_to_config + from diffusers.models.modeling_utils import ModelMixin +else: + class ModelMixin: + pass + + class ConfigMixin: + pass + + def register_to_config(init): + return init from .attention import flash_attention diff --git a/src/lerobot/policies/fastwam/wan_components.py b/src/lerobot/policies/fastwam/wan_components.py index 8a9c9631c..6d16cf5c7 100644 --- a/src/lerobot/policies/fastwam/wan_components.py +++ b/src/lerobot/policies/fastwam/wan_components.py @@ -22,14 +22,24 @@ from typing import TYPE_CHECKING, Any import torch from huggingface_hub import snapshot_download from safetensors.torch import load_file -from transformers import AutoTokenizer, UMT5EncoderModel + +from lerobot.utils.import_utils import _diffusers_available, _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoTokenizer, UMT5EncoderModel +else: + AutoTokenizer = None + UMT5EncoderModel = None + +if TYPE_CHECKING or _diffusers_available: + from diffusers import AutoencoderKLWan +else: + AutoencoderKLWan = None if TYPE_CHECKING: from .wan_adapters import WanVideoVAE38 from .wan_video_dit import WanVideoDiT -from diffusers import AutoencoderKLWan - from .wan_adapters import WanVideoVAE38 from .wan_video_dit import WanVideoDiT @@ -73,6 +83,7 @@ class WanTokenizer: FastWAM call site expects.""" def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None: + require_package("transformers", extra="fastwam") self.tokenizer = AutoTokenizer.from_pretrained(name) self.seq_len = int(seq_len) @@ -104,12 +115,14 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer: def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38: """Load real Wan2.2 VAE weights from the diffusers repo (offline base creation).""" + require_package("diffusers", extra="fastwam") vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype) return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae) def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder: """Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation).""" + require_package("transformers", extra="fastwam") encoder = UMT5EncoderModel.from_pretrained( WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype ) diff --git a/tests/policies/fastwam/test_fastwam_policy.py b/tests/policies/fastwam/test_fastwam_policy.py index f4abab4a8..2cb705db7 100644 --- a/tests/policies/fastwam/test_fastwam_policy.py +++ b/tests/policies/fastwam/test_fastwam_policy.py @@ -21,6 +21,9 @@ import torch from safetensors import safe_open from torch import nn +pytest.importorskip("transformers", reason="fastwam requires the `fastwam` extra (transformers)") +pytest.importorskip("diffusers", reason="fastwam requires the `fastwam` extra (diffusers)") + from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy