mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
adding lazy imports
This commit is contained in:
@@ -23,6 +23,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
@@ -52,6 +53,11 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
# FastWAM's Wan2.2 backbone needs transformers (UMT5 text encoder/tokenizer) and
|
||||
# diffusers (Wan VAE), both behind the `fastwam` extra. Fail fast with an actionable
|
||||
# message in base installs rather than deep in Wan component construction.
|
||||
require_package("transformers", extra="fastwam")
|
||||
require_package("diffusers", extra="fastwam")
|
||||
# `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the
|
||||
# dataset feature metadata is already applied to `config` by make_policy upstream,
|
||||
# so we accept and ignore them, matching the other LeRobot policies.
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
else:
|
||||
class ModelMixin:
|
||||
pass
|
||||
|
||||
class ConfigMixin:
|
||||
pass
|
||||
|
||||
def register_to_config(init):
|
||||
return init
|
||||
|
||||
from .attention import flash_attention
|
||||
|
||||
|
||||
@@ -22,14 +22,24 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
else:
|
||||
AutoTokenizer = None
|
||||
UMT5EncoderModel = None
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import AutoencoderKLWan
|
||||
else:
|
||||
AutoencoderKLWan = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
@@ -73,6 +83,7 @@ class WanTokenizer:
|
||||
FastWAM call site expects."""
|
||||
|
||||
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
|
||||
require_package("transformers", extra="fastwam")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
self.seq_len = int(seq_len)
|
||||
|
||||
@@ -104,12 +115,14 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
|
||||
|
||||
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
|
||||
require_package("diffusers", extra="fastwam")
|
||||
vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype)
|
||||
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
|
||||
|
||||
|
||||
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
|
||||
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
|
||||
require_package("transformers", extra="fastwam")
|
||||
encoder = UMT5EncoderModel.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
@@ -21,6 +21,9 @@ import torch
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
|
||||
pytest.importorskip("transformers", reason="fastwam requires the `fastwam` extra (transformers)")
|
||||
pytest.importorskip("diffusers", reason="fastwam requires the `fastwam` extra (diffusers)")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
|
||||
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy
|
||||
|
||||
Reference in New Issue
Block a user