mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
big refactor to use models from diffusers and transformers
This commit is contained in:
@@ -219,8 +219,6 @@ eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
fastwam = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[diffusers-dep]",
|
||||
"ftfy>=6.1.1,<7.0.0",
|
||||
"regex>=2024.0.0,<2027.0.0",
|
||||
]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
@@ -118,10 +118,6 @@ def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, Policy
|
||||
return coerced
|
||||
|
||||
|
||||
def _coerce_normalization_mapping(mapping: dict[str, Any]) -> dict[str, Any]:
|
||||
return {key: _coerce_enum(NormalizationMode, value) for key, value in mapping.items()}
|
||||
|
||||
|
||||
def _is_local_model_id(value: str) -> bool:
|
||||
path = Path(value).expanduser()
|
||||
return path.is_absolute() or value.startswith(("./", "../", "~")) or path.exists()
|
||||
@@ -200,7 +196,7 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0})
|
||||
video_dit_config: dict[str, Any] | None = None
|
||||
action_dit_config: dict[str, Any] | None = None
|
||||
normalization_mapping: dict[str, Any] = field(
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
@@ -220,7 +216,6 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
self.input_features = _coerce_policy_features(self.input_features)
|
||||
self.output_features = _coerce_policy_features(self.output_features)
|
||||
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
|
||||
self.normalization_mapping = _coerce_normalization_mapping(self.normalization_mapping)
|
||||
self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim)
|
||||
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
|
||||
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
|
||||
@@ -266,6 +261,38 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def set_dataset_feature_metadata(self, dataset_features: dict[str, Any]) -> None:
|
||||
"""Rebuild visual input features from the dataset's real camera keys.
|
||||
|
||||
FastWAM's `__post_init__` installs a synthetic single-image default
|
||||
(`observation.images.image` at full `image_size` width). For datasets
|
||||
with one or more separately-named cameras (e.g. `observation.images.top`,
|
||||
`observation.images.wrist`), this hook — invoked by `make_policy` once the
|
||||
dataset metadata is known — replaces that default with the actual camera
|
||||
keys, each declared at the policy's native per-camera resolution
|
||||
(`image_size[0]` x `image_size[1] // num_cameras`). The accompanying
|
||||
resize step in `make_fastwam_pre_post_processors` resizes raw frames to
|
||||
match, so heterogeneous source resolutions (e.g. 480x640) are supported.
|
||||
"""
|
||||
image_keys = sorted(
|
||||
key
|
||||
for key, feature in dataset_features.items()
|
||||
if key.startswith("observation.images.")
|
||||
and feature.get("dtype") in ("video", "image")
|
||||
)
|
||||
if not image_keys:
|
||||
return
|
||||
height, total_width = self.image_size
|
||||
per_cam_width = total_width // len(image_keys)
|
||||
new_inputs: dict[str, PolicyFeature] = {
|
||||
key: PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, per_cam_width))
|
||||
for key in image_keys
|
||||
}
|
||||
if self.proprio_dim is not None and OBS_STATE in dataset_features:
|
||||
new_inputs[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,))
|
||||
self.input_features = new_inputs
|
||||
self.validate_features()
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.action_dim <= 0:
|
||||
raise ValueError(f"`action_dim` must be positive, got {self.action_dim}.")
|
||||
|
||||
@@ -14,24 +14,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import warnings
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wan_components import WanCheckpointPaths
|
||||
|
||||
|
||||
class FastWAMPolicy(PreTrainedPolicy):
|
||||
"""LeRobot policy wrapper for FastWAM.
|
||||
@@ -49,98 +43,62 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
self,
|
||||
config: FastWAMConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
skip_wan_init = bool(kwargs.pop("_skip_wan_init", False))
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
suppress_base_init_warning = bool(kwargs.pop("_suppress_base_init_warning", False))
|
||||
if not skip_wan_init and not suppress_base_init_warning:
|
||||
warnings.warn(
|
||||
"FastWAMPolicy(config) initializes from architecture/config and does not load pretrained "
|
||||
"FastWAM weights. For training or evaluation, use `make_policy(config)` or "
|
||||
"`FastWAMPolicy.from_pretrained(...)`.",
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if skip_wan_init:
|
||||
self.model = _build_core_model_from_architecture(config)
|
||||
else:
|
||||
self.model = self._build_core_model(config)
|
||||
self.model = self._build_core_model(config)
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: FastWAMConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> FastWAMPolicy:
|
||||
"""Load FastWAM weights and local Wan components from one HF directory.
|
||||
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
|
||||
"""Shape-aware load that supports cross-embodiment fine-tuning.
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path (str | Path): HF-format policy directory
|
||||
containing `config.json`, `model.safetensors`, local Wan VAE,
|
||||
local UMT5 text encoder safetensors, and tokenizer files.
|
||||
config (FastWAMConfig | None): Optional config override. When
|
||||
omitted, `config.json` is read from `pretrained_name_or_path`.
|
||||
force_download (bool): Forwarded to LeRobot's pretrained loader.
|
||||
resume_download (bool | None): Forwarded to LeRobot's pretrained loader.
|
||||
proxies (dict | None): Forwarded to LeRobot's pretrained loader.
|
||||
token (str | bool | None): Forwarded to LeRobot's pretrained loader.
|
||||
cache_dir (str | Path | None): Forwarded to LeRobot's pretrained loader.
|
||||
local_files_only (bool): Forwarded to LeRobot's pretrained loader.
|
||||
revision (str | None): Forwarded to LeRobot's pretrained loader.
|
||||
strict (bool): Whether safetensors loading should require an exact
|
||||
match between checkpoint keys and policy module keys.
|
||||
**kwargs (Any): Extra constructor arguments forwarded to
|
||||
`FastWAMPolicy`.
|
||||
`safetensors.load_model(strict=False)` ignores missing/unexpected keys but
|
||||
still raises on a shape mismatch for a shared key. When fine-tuning from a
|
||||
checkpoint trained on a different embodiment (e.g. the LIBERO 7-DoF / 8-dim
|
||||
checkpoint adapted to a 6-DoF / 6-dim arm), the action encoder/head and
|
||||
proprio encoder legitimately differ in shape. With `strict=False` we drop
|
||||
only those shape-mismatched tensors — leaving them at their freshly
|
||||
initialized values — and load every compatible tensor. With `strict=True`
|
||||
the standard exact-match loader is used.
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
pretrained_path = _resolve_pretrained_directory(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(pretrained_path)
|
||||
if not isinstance(config, FastWAMConfig):
|
||||
raise TypeError(f"Expected FastWAM config, got {type(config).__name__}.")
|
||||
kwargs["_skip_wan_init"] = True
|
||||
policy = super().from_pretrained(
|
||||
pretrained_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
policy.load_wan_components_from_pretrained(pretrained_path)
|
||||
policy.eval()
|
||||
return policy
|
||||
model_state_dict = model.state_dict()
|
||||
mismatched = []
|
||||
with safe_open(model_file, framework="pt") as f:
|
||||
checkpoint_keys = list(f.keys())
|
||||
for key in checkpoint_keys:
|
||||
if key in model_state_dict and tuple(model_state_dict[key].shape) != tuple(
|
||||
f.get_slice(key).get_shape()
|
||||
):
|
||||
mismatched.append(key)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
super()._save_pretrained(save_directory)
|
||||
_copy_wan_components_from_policy(policy=self, save_directory=save_directory)
|
||||
if not mismatched:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
if strict:
|
||||
raise RuntimeError(
|
||||
f"FastWAM: {len(mismatched)} checkpoint tensors have a shape mismatch under "
|
||||
f"strict=True: {mismatched}"
|
||||
)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
logging.warning(
|
||||
"FastWAM cross-embodiment load: reinitializing %d shape-mismatched tensor(s), keeping "
|
||||
"every compatible weight: %s",
|
||||
len(mismatched),
|
||||
mismatched,
|
||||
)
|
||||
state_dict = load_file(model_file, device="cpu")
|
||||
for key in mismatched:
|
||||
state_dict.pop(key, None)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if map_location and map_location != "cpu":
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict[str, Any]:
|
||||
params = (
|
||||
@@ -151,21 +109,41 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
params.extend(list(proprio_encoder.parameters()))
|
||||
return {"params": [p for p in params if p.requires_grad]}
|
||||
|
||||
def load_wan_components_from_pretrained(self, pretrained_name_or_path: str | Path) -> None:
|
||||
"""Attach local Wan VAE, text encoder, and tokenizer from a HF directory.
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path (str | Path): Directory containing
|
||||
`Wan2.2_VAE.safetensors`, `models_t5_umt5-xxl-enc-bf16.safetensors`,
|
||||
and `google/umt5-xxl/` tokenizer files.
|
||||
"""
|
||||
|
||||
paths = resolve_wan_component_paths(pretrained_name_or_path)
|
||||
_load_wan_components_into_policy(policy=self, paths=paths)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
|
||||
`FastWAM.build_inputs` consumes (`video`, `action`, `context`/`context_mask`,
|
||||
per-frame `proprio`).
|
||||
|
||||
The LeRobot training loop passes raw `observation.images.*`, a single-step
|
||||
`observation.state` `[B, D]`, `action`, and a language `task` string. We do
|
||||
only the translation `build_inputs` can't: stack the camera frames into a
|
||||
video, encode the prompt with the (frozen) text encoder (mirroring inference,
|
||||
so language-conditioned datasets need no precomputed context), and give proprio
|
||||
the per-frame axis `build_inputs` indexes. All shape/presence validation is
|
||||
left to `build_inputs`, the single authority on the contract.
|
||||
"""
|
||||
sample = dict(batch)
|
||||
if "video" not in sample:
|
||||
sample["video"] = _stack_video_from_images(batch, self.config)
|
||||
if "context" not in sample or "context_mask" not in sample:
|
||||
prompt = _prompt_from_batch(batch=batch, config=self.config)
|
||||
if prompt is None:
|
||||
raise KeyError(
|
||||
"FastWAM training requires a `task`/`prompt` to encode text context, "
|
||||
"or precomputed `context`/`context_mask` in the batch."
|
||||
)
|
||||
sample["context"], sample["context_mask"] = self.model.encode_prompt(prompt)
|
||||
if self.config.proprio_dim is not None and "proprio" not in sample:
|
||||
state = sample.get(OBS_STATE)
|
||||
if state is not None:
|
||||
# LeRobot gives a single-step state [B, D]; build_inputs expects
|
||||
# per-frame [B, T, D] and uses frame 0, so add a T=1 axis.
|
||||
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
|
||||
return sample
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Compute FastWAM training loss for a LeRobot batch.
|
||||
|
||||
@@ -180,7 +158,7 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
key required by LeRobot and optional tensor metrics.
|
||||
"""
|
||||
|
||||
sample = _batch_to_training_sample(batch=batch, config=self.config)
|
||||
sample = self._batch_to_training_sample(batch)
|
||||
loss, metrics = self.model.training_loss(sample)
|
||||
output = {"loss": loss}
|
||||
for key, value in (metrics or {}).items():
|
||||
@@ -230,223 +208,57 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def _build_core_model(self, config: FastWAMConfig) -> torch.nn.Module:
|
||||
return _build_core_model_from_wan22(config)
|
||||
"""Build the FastWAM core for training / inference.
|
||||
|
||||
|
||||
def _resolve_pretrained_directory(
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool,
|
||||
token: str | bool | None,
|
||||
cache_dir: str | Path | None,
|
||||
local_files_only: bool,
|
||||
revision: str | None,
|
||||
) -> Path:
|
||||
path = Path(pretrained_name_or_path)
|
||||
if path.is_dir():
|
||||
return path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from .wan_components import (
|
||||
WAN_T5_CHECKPOINT,
|
||||
WAN_T5_TOKENIZER,
|
||||
WAN_VAE_CHECKPOINT,
|
||||
)
|
||||
|
||||
snapshot_path = snapshot_download(
|
||||
repo_id=str(pretrained_name_or_path),
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
allow_patterns=[
|
||||
"config.json",
|
||||
"model.safetensors",
|
||||
WAN_VAE_CHECKPOINT,
|
||||
WAN_T5_CHECKPOINT,
|
||||
f"{WAN_T5_TOKENIZER}/**",
|
||||
],
|
||||
)
|
||||
return Path(snapshot_path)
|
||||
|
||||
|
||||
def resolve_wan_component_paths(pretrained_name_or_path: str | Path) -> WanCheckpointPaths:
|
||||
"""Resolve local Wan component paths stored beside FastWAM HF weights.
|
||||
|
||||
Args:
|
||||
pretrained_name_or_path (str | Path): HF-format FastWAM directory.
|
||||
|
||||
Returns:
|
||||
WanCheckpointPaths: Existing VAE, text encoder, and tokenizer paths.
|
||||
DiT shards are intentionally optional here because FastWAM HF
|
||||
checkpoints store trainable DiT weights in `model.safetensors`.
|
||||
"""
|
||||
|
||||
from .wan_components import resolve_wan_checkpoint_paths
|
||||
|
||||
return resolve_wan_checkpoint_paths(
|
||||
pretrained_name_or_path,
|
||||
load_dit=False,
|
||||
load_text_encoder=True,
|
||||
)
|
||||
|
||||
|
||||
def _load_wan_components_into_policy(policy: FastWAMPolicy, paths: WanCheckpointPaths) -> None:
|
||||
from .wan_components import load_wan_text_encoder, load_wan_tokenizer, load_wan_vae
|
||||
|
||||
if paths.text_encoder is None or paths.tokenizer is None:
|
||||
raise FileNotFoundError("FastWAM HF checkpoint requires Wan text encoder and tokenizer sidecars.")
|
||||
dtype = _dtype_from_name(policy.config.torch_dtype)
|
||||
device = str(policy.config.device)
|
||||
policy.model.vae = load_wan_vae(paths.vae, torch_dtype=dtype, device=device)
|
||||
policy.model.text_encoder = load_wan_text_encoder(paths.text_encoder, torch_dtype=dtype, device=device)
|
||||
policy.model.tokenizer = load_wan_tokenizer(
|
||||
paths.tokenizer,
|
||||
tokenizer_max_len=int(policy.config.tokenizer_max_len),
|
||||
)
|
||||
model_paths = dict(getattr(policy.model, "model_paths", {}) or {})
|
||||
model_paths.update(
|
||||
{
|
||||
"vae": str(paths.vae),
|
||||
"text_encoder": str(paths.text_encoder),
|
||||
"tokenizer": str(paths.tokenizer),
|
||||
}
|
||||
)
|
||||
policy.model.model_paths = model_paths
|
||||
|
||||
|
||||
def _copy_wan_components_from_policy(policy: FastWAMPolicy, save_directory: Path) -> None:
|
||||
model_paths = getattr(policy.model, "model_paths", {}) or {}
|
||||
paths = {
|
||||
"vae": model_paths.get("vae"),
|
||||
"text_encoder": model_paths.get("text_encoder"),
|
||||
"tokenizer": model_paths.get("tokenizer"),
|
||||
}
|
||||
missing = [name for name, path in paths.items() if path is None]
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
"FastWAM save_pretrained requires local Wan component paths for "
|
||||
f"{missing}. Load or initialize the policy with local Wan VAE, text encoder, and tokenizer files."
|
||||
Only the trainable parts (the MoT DiT and the proprio encoder) are
|
||||
materialized empty here and then filled from the policy's
|
||||
`model.safetensors` by the base `from_pretrained`. The *frozen* Wan2.2 VAE
|
||||
and UMT5 text encoder are loaded with their real weights from the
|
||||
`Wan-AI/Wan2.2-TI2V-5B-Diffusers` repo (cached in the HF cache, shared
|
||||
across checkpoints) and are intentionally excluded from `model.safetensors`
|
||||
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
|
||||
"""
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
from .wan_components import (
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
)
|
||||
_copy_component_path(Path(paths["vae"]), save_directory / Path(paths["vae"]).name)
|
||||
_copy_component_path(Path(paths["text_encoder"]), save_directory / Path(paths["text_encoder"]).name)
|
||||
tokenizer_source = Path(paths["tokenizer"])
|
||||
_copy_component_path(tokenizer_source, save_directory / "google" / "umt5-xxl")
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
|
||||
def _copy_component_path(source: Path, destination: Path) -> None:
|
||||
source = source.expanduser()
|
||||
if not source.exists():
|
||||
raise FileNotFoundError(f"FastWAM component path does not exist: {source}")
|
||||
if source.resolve() == destination.resolve():
|
||||
return
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
if source.is_dir():
|
||||
shutil.copytree(source, destination, dirs_exist_ok=True)
|
||||
else:
|
||||
shutil.copy2(source, destination)
|
||||
|
||||
|
||||
def _build_core_model_from_wan22(config: FastWAMConfig) -> torch.nn.Module:
|
||||
from .modular_fastwam import FastWAM
|
||||
|
||||
dtype = _dtype_from_name(config.torch_dtype)
|
||||
return FastWAM.from_wan22_pretrained(
|
||||
device=config.device,
|
||||
torch_dtype=dtype,
|
||||
model_id=config.model_id,
|
||||
tokenizer_model_id=config.tokenizer_model_id,
|
||||
tokenizer_max_len=config.tokenizer_max_len,
|
||||
load_text_encoder=config.load_text_encoder,
|
||||
proprio_dim=config.proprio_dim,
|
||||
video_dit_config=config.video_dit_config,
|
||||
action_dit_config=config.action_dit_config,
|
||||
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
|
||||
video_train_shift=float(config.video_scheduler["train_shift"]),
|
||||
video_infer_shift=float(config.video_scheduler["infer_shift"]),
|
||||
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
|
||||
action_train_shift=float(config.action_scheduler["train_shift"]),
|
||||
action_infer_shift=float(config.action_scheduler["infer_shift"]),
|
||||
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
|
||||
loss_lambda_video=float(config.loss["lambda_video"]),
|
||||
loss_lambda_action=float(config.loss["lambda_action"]),
|
||||
)
|
||||
|
||||
|
||||
def _build_core_model_from_architecture(config: FastWAMConfig) -> torch.nn.Module:
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
dtype = _dtype_from_name(config.torch_dtype)
|
||||
video_expert = WanVideoDiT(**config.video_dit_config).to(device=config.device, dtype=dtype)
|
||||
action_expert = ActionDiT(**config.action_dit_config).to(device=config.device, dtype=dtype)
|
||||
mot = MoT(
|
||||
mixtures={"video": video_expert, "action": action_expert},
|
||||
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
|
||||
)
|
||||
return FastWAM(
|
||||
video_expert=video_expert,
|
||||
action_expert=action_expert,
|
||||
mot=mot,
|
||||
vae=_FastWAMVAEPlaceholder(),
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
text_dim=int(config.video_dit_config["text_dim"]),
|
||||
proprio_dim=config.proprio_dim,
|
||||
device=config.device,
|
||||
torch_dtype=dtype,
|
||||
video_train_shift=float(config.video_scheduler["train_shift"]),
|
||||
video_infer_shift=float(config.video_scheduler["infer_shift"]),
|
||||
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
|
||||
action_train_shift=float(config.action_scheduler["train_shift"]),
|
||||
action_infer_shift=float(config.action_scheduler["infer_shift"]),
|
||||
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
|
||||
loss_lambda_video=float(config.loss["lambda_video"]),
|
||||
loss_lambda_action=float(config.loss["lambda_action"]),
|
||||
)
|
||||
|
||||
|
||||
class _FastWAMVAEPlaceholder(torch.nn.Module):
|
||||
"""Minimal VAE placeholder for checkpoint loading without Wan2.2 VAE.
|
||||
|
||||
Args:
|
||||
temporal_downsample_factor (int): Temporal compression factor expected
|
||||
by FastWAM latent shape logic.
|
||||
upsampling_factor (int): Spatial compression factor expected by FastWAM.
|
||||
z_dim (int): Latent channel count used by Wan2.2 TI2V VAE.
|
||||
"""
|
||||
|
||||
temporal_downsample_factor: int = 4
|
||||
upsampling_factor: int = 8
|
||||
|
||||
def __init__(self, z_dim: int = 48):
|
||||
super().__init__()
|
||||
self.model = type("VAEModelShape", (), {"z_dim": int(z_dim)})()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"FastWAM VAE placeholder cannot encode images; load Wan2.2 VAE for image inference."
|
||||
dtype = _dtype_from_name(config.torch_dtype)
|
||||
device = config.device
|
||||
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
|
||||
action_expert = ActionDiT(**config.action_dit_config).to(device=device, dtype=dtype)
|
||||
mot = MoT(
|
||||
mixtures={"video": video_expert, "action": action_expert},
|
||||
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
|
||||
)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"FastWAM VAE placeholder cannot decode latents; load Wan2.2 VAE for video inference."
|
||||
text_encoder = (
|
||||
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
|
||||
if config.load_text_encoder
|
||||
else None
|
||||
)
|
||||
return FastWAM(
|
||||
video_expert=video_expert,
|
||||
action_expert=action_expert,
|
||||
mot=mot,
|
||||
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
|
||||
text_dim=int(config.video_dit_config["text_dim"]),
|
||||
proprio_dim=config.proprio_dim,
|
||||
device=device,
|
||||
torch_dtype=dtype,
|
||||
video_train_shift=float(config.video_scheduler["train_shift"]),
|
||||
video_infer_shift=float(config.video_scheduler["infer_shift"]),
|
||||
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
|
||||
action_train_shift=float(config.action_scheduler["train_shift"]),
|
||||
action_infer_shift=float(config.action_scheduler["infer_shift"]),
|
||||
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
|
||||
loss_lambda_video=float(config.loss["lambda_video"]),
|
||||
loss_lambda_action=float(config.loss["lambda_action"]),
|
||||
)
|
||||
|
||||
|
||||
def _batch_to_training_sample(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Tensor]:
|
||||
sample = dict(batch)
|
||||
if "video" not in sample:
|
||||
sample["video"] = _stack_video_from_images(batch, config)
|
||||
if "proprio" not in sample and OBS_STATE in batch:
|
||||
sample["proprio"] = batch[OBS_STATE]
|
||||
required = {"video", ACTION, "context", "context_mask"}
|
||||
missing = sorted(required - set(sample))
|
||||
if missing:
|
||||
raise KeyError(f"FastWAM training batch is missing keys: {missing}.")
|
||||
return sample
|
||||
|
||||
|
||||
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
|
||||
|
||||
@@ -24,7 +24,13 @@ import torch.nn as nn
|
||||
import torch.nn.functional as functional
|
||||
from PIL import Image
|
||||
|
||||
from .wan_components import load_wan22_ti2v_5b_components
|
||||
from .wan_components import (
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
load_wan_video_dit,
|
||||
resolve_wan_dit_paths,
|
||||
)
|
||||
from .wan_video_dit import (
|
||||
FastWAMAttentionBlock,
|
||||
WanContinuousFlowMatchScheduler,
|
||||
@@ -846,9 +852,17 @@ class FastWAM(torch.nn.Module):
|
||||
# Keep trainer compatibility: optimizer and freeze logic use `model.dit`.
|
||||
self.dit = self.mot
|
||||
|
||||
self.vae = vae
|
||||
self.text_encoder = text_encoder
|
||||
# Frozen Wan2.2 components: bypass `nn.Module.__setattr__` so they are NOT
|
||||
# registered as submodules. They are therefore excluded from `state_dict()`
|
||||
# (lean checkpoints), `parameters()`, and DDP gradient sync, and are loaded
|
||||
# with their real weights from the diffusers/transformers repos at construction.
|
||||
# Device/dtype moves still reach them via the `_apply` override below.
|
||||
object.__setattr__(self, "vae", vae)
|
||||
object.__setattr__(self, "text_encoder", text_encoder)
|
||||
self.tokenizer = tokenizer
|
||||
vae.requires_grad_(False)
|
||||
if text_encoder is not None:
|
||||
text_encoder.requires_grad_(False)
|
||||
if text_dim is None:
|
||||
if self.text_encoder is None:
|
||||
raise ValueError("`text_dim` is required when `text_encoder` is not loaded.")
|
||||
@@ -913,18 +927,17 @@ class FastWAM(torch.nn.Module):
|
||||
raise ValueError("`video_dit_config` is required for FastWAM.from_wan22_pretrained().")
|
||||
if "text_dim" not in video_dit_config:
|
||||
raise ValueError("`video_dit_config['text_dim']` is required for FastWAM.")
|
||||
del tokenizer_model_id # tokenizer is the stock UMT5 one (google/umt5-xxl)
|
||||
|
||||
components = load_wan22_ti2v_5b_components(
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
model_id=model_id,
|
||||
tokenizer_model_id=tokenizer_model_id,
|
||||
tokenizer_max_len=tokenizer_max_len,
|
||||
# Custom MoT video DiT from the original Wan2.2 repo; frozen VAE / UMT5 from
|
||||
# the diffusers conversion. This is the offline base-creation path; the
|
||||
# weights it loads are then bundled into the FastWAM `model.safetensors`.
|
||||
video_expert = load_wan_video_dit(
|
||||
resolve_wan_dit_paths(model_id),
|
||||
dit_config=video_dit_config,
|
||||
load_text_encoder=load_text_encoder,
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
video_expert = components.dit
|
||||
action_expert = ActionDiT(**action_dit_config).to(device=device, dtype=torch_dtype)
|
||||
if int(action_expert.num_heads) != int(video_expert.num_heads):
|
||||
raise ValueError("ActionDiT `num_heads` must match video expert for MoT mixed attention.")
|
||||
@@ -938,13 +951,21 @@ class FastWAM(torch.nn.Module):
|
||||
mot_checkpoint_mixed_attn=mot_checkpoint_mixed_attn,
|
||||
)
|
||||
|
||||
model = cls(
|
||||
vae = load_pretrained_wan_vae(torch_dtype=torch_dtype, device=device)
|
||||
text_encoder = (
|
||||
load_pretrained_wan_text_encoder(torch_dtype=torch_dtype, device=device)
|
||||
if load_text_encoder
|
||||
else None
|
||||
)
|
||||
tokenizer = build_wan_tokenizer(tokenizer_max_len=tokenizer_max_len)
|
||||
|
||||
return cls(
|
||||
video_expert=video_expert,
|
||||
action_expert=action_expert,
|
||||
mot=mot,
|
||||
vae=components.vae,
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_dim=int(video_dit_config["text_dim"]),
|
||||
proprio_dim=proprio_dim,
|
||||
device=device,
|
||||
@@ -958,20 +979,17 @@ class FastWAM(torch.nn.Module):
|
||||
loss_lambda_video=loss_lambda_video,
|
||||
loss_lambda_action=loss_lambda_action,
|
||||
)
|
||||
model.model_paths = {
|
||||
"video_dit": components.dit_path,
|
||||
"vae": components.vae_path,
|
||||
"text_encoder": components.text_encoder_path,
|
||||
"tokenizer": components.tokenizer_path,
|
||||
}
|
||||
return model
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.mot.to(*args, **kwargs)
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
# `.to()` / `.cuda()` / `.cpu()` and accelerate/DDP device moves all funnel
|
||||
# through `_apply`, and the parent policy reaches us via `child._apply(fn)`
|
||||
# (not `child.to()`). Propagate `fn` to the *unregistered* frozen VAE / text
|
||||
# encoder here so they follow the rest of the model onto the right device,
|
||||
# while staying out of `state_dict()` / `parameters()`.
|
||||
super()._apply(fn, *args, **kwargs)
|
||||
self.vae._apply(fn)
|
||||
if self.text_encoder is not None:
|
||||
self.text_encoder.to(*args, **kwargs)
|
||||
self.vae.to(*args, **kwargs)
|
||||
self.text_encoder._apply(fn)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
@@ -1628,7 +1646,7 @@ class FastWAM(torch.nn.Module):
|
||||
latent_w = width // self.vae.upsampling_factor
|
||||
generator = None if seed is None else torch.Generator(device=rand_device).manual_seed(seed)
|
||||
return torch.randn(
|
||||
(1, self.vae.model.z_dim, latent_t, latent_h, latent_w),
|
||||
(1, self.vae.z_dim, latent_t, latent_h, latent_w),
|
||||
generator=generator,
|
||||
device=rand_device,
|
||||
dtype=torch.float32,
|
||||
|
||||
@@ -24,6 +24,7 @@ from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
@@ -90,9 +91,8 @@ def make_fastwam_pre_post_processors(
|
||||
output processor pipelines discoverable by LeRobot.
|
||||
"""
|
||||
|
||||
normalization_stats: dict[str, dict[str, Any]] = {
|
||||
key: dict(value) for key, value in (dataset_stats or {}).items()
|
||||
}
|
||||
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
|
||||
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
|
||||
for key, feature in config.input_features.items():
|
||||
if feature.type != FeatureType.VISUAL:
|
||||
continue
|
||||
@@ -101,10 +101,23 @@ def make_fastwam_pre_post_processors(
|
||||
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# resize visual inputs to match model expected input size, if necessary
|
||||
visual_shapes = [
|
||||
feature.shape
|
||||
for feature in config.input_features.values()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
resize_steps = []
|
||||
if visual_shapes:
|
||||
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
|
||||
resize_steps.append(ImageCropResizeProcessorStep(resize_size=target_hw))
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
*resize_steps,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
|
||||
@@ -10,17 +10,16 @@ Copied files:
|
||||
|
||||
- `wan/modules/attention.py`
|
||||
- `wan/modules/model.py`
|
||||
- `wan/modules/t5.py`
|
||||
- `wan/modules/tokenizers.py`
|
||||
- `wan/modules/vae2_2.py`
|
||||
- `wan/modules/__init__.py`
|
||||
- `wan/utils/fm_solvers.py`
|
||||
- `wan/utils/__init__.py`
|
||||
|
||||
FastWAM-specific model glue and any larger code adapted from these modules live outside this directory. This keeps the upstream Wan2.2 code reviewable as a vendored reference subset and makes it straightforward to replace this directory with an external Wan2.2 dependency by changing import paths.
|
||||
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
|
||||
UMT5 text encoder, and tokenizer are no longer vendored — they come from
|
||||
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
|
||||
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
|
||||
|
||||
Current FastWAM adapters that directly reuse this vendored subset:
|
||||
|
||||
- `../wan_components.py` instantiates the upstream `wan.modules.t5.umt5_xxl` encoder factory and uses `wan.modules.tokenizers.HuggingfaceTokenizer`.
|
||||
- `../wan_adapters.py` wraps `wan.modules.vae2_2.Wan2VAE` with the FastWAM tensor-batch encode/decode API.
|
||||
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
|
||||
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from .attention import flash_attention
|
||||
from .model import WanModel
|
||||
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
from .vae2_2 import Wan2VAE
|
||||
|
||||
__all__ = [
|
||||
"Wan2VAE",
|
||||
"WanModel",
|
||||
"T5Model",
|
||||
"T5Encoder",
|
||||
"T5Decoder",
|
||||
"T5EncoderModel",
|
||||
"HuggingfaceTokenizer",
|
||||
"flash_attention",
|
||||
]
|
||||
|
||||
@@ -1,489 +0,0 @@
|
||||
# Modified from transformers.models.t5.modeling_t5
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as functional
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
__all__ = [
|
||||
"T5Model",
|
||||
"T5Encoder",
|
||||
"T5Decoder",
|
||||
"T5EncoderModel",
|
||||
]
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
clamp = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp, max=clamp)
|
||||
return x
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, T5LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
elif isinstance(m, T5Model):
|
||||
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
||||
elif isinstance(m, T5FeedForward):
|
||||
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
||||
elif isinstance(m, T5Attention):
|
||||
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
|
||||
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
|
||||
elif isinstance(m, T5RelativeEmbedding):
|
||||
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
||||
assert dim_attn % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
|
||||
# attention bias
|
||||
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
||||
if pos_bias is not None:
|
||||
attn_bias += pos_bias
|
||||
if mask is not None:
|
||||
assert mask.ndim in [2, 3]
|
||||
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
|
||||
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
|
||||
attn = functional.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum("bnij,bjnc->binc", attn, v)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, n * c)
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_ffn, dropout=0.1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) * self.gate(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = (
|
||||
None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5CrossAttention(nn.Module):
|
||||
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm3 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = (
|
||||
None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
|
||||
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def forward(self, lq, lk):
|
||||
device = self.embedding.weight.device
|
||||
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
||||
# torch.arange(lq).unsqueeze(1).to(device)
|
||||
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
|
||||
return rel_pos_embeds.contiguous()
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).long() * num_buckets
|
||||
rel_pos = torch.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = 0
|
||||
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = (
|
||||
max_exact
|
||||
+ (
|
||||
torch.log(rel_pos.float() / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).long()
|
||||
)
|
||||
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
||||
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
||||
return rel_buckets
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(
|
||||
self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = (
|
||||
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Decoder(nn.Module):
|
||||
def __init__(
|
||||
self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = (
|
||||
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
||||
b, s = ids.size()
|
||||
|
||||
# causal mask
|
||||
if mask is None:
|
||||
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
||||
elif mask.ndim == 2:
|
||||
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
||||
|
||||
# layers
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.encoder_layers = encoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.num_buckets = num_buckets
|
||||
|
||||
# layers
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.encoder = T5Encoder(
|
||||
self.token_embedding,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
encoder_layers,
|
||||
num_buckets,
|
||||
shared_pos,
|
||||
dropout,
|
||||
)
|
||||
self.decoder = T5Decoder(
|
||||
self.token_embedding,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
decoder_layers,
|
||||
num_buckets,
|
||||
shared_pos,
|
||||
dropout,
|
||||
)
|
||||
self.head = nn.Linear(dim, vocab_size, bias=False)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
||||
x = self.encoder(encoder_ids, encoder_mask)
|
||||
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _t5(
|
||||
name,
|
||||
encoder_only=False,
|
||||
decoder_only=False,
|
||||
return_tokenizer=False,
|
||||
tokenizer_kwargs=None,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
**kwargs,
|
||||
):
|
||||
# sanity check
|
||||
assert not (encoder_only and decoder_only)
|
||||
tokenizer_kwargs = tokenizer_kwargs or {}
|
||||
|
||||
# params
|
||||
if encoder_only:
|
||||
model_cls = T5Encoder
|
||||
kwargs["vocab"] = kwargs.pop("vocab_size")
|
||||
kwargs["num_layers"] = kwargs.pop("encoder_layers")
|
||||
_ = kwargs.pop("decoder_layers")
|
||||
elif decoder_only:
|
||||
model_cls = T5Decoder
|
||||
kwargs["vocab"] = kwargs.pop("vocab_size")
|
||||
kwargs["num_layers"] = kwargs.pop("decoder_layers")
|
||||
_ = kwargs.pop("encoder_layers")
|
||||
else:
|
||||
model_cls = T5Model
|
||||
|
||||
# init model
|
||||
with torch.device(device):
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# set device
|
||||
model = model.to(dtype=dtype, device=device)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
|
||||
return model, tokenizer
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def umt5_xxl(**kwargs):
|
||||
cfg = {
|
||||
"vocab_size": 256384,
|
||||
"dim": 4096,
|
||||
"dim_attn": 4096,
|
||||
"dim_ffn": 10240,
|
||||
"num_heads": 64,
|
||||
"encoder_layers": 24,
|
||||
"decoder_layers": 24,
|
||||
"num_buckets": 32,
|
||||
"shared_pos": False,
|
||||
"dropout": 0.1,
|
||||
}
|
||||
cfg.update(**kwargs)
|
||||
return _t5("umt5-xxl", **cfg)
|
||||
|
||||
|
||||
class T5EncoderModel:
|
||||
def __init__(
|
||||
self,
|
||||
text_len,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
shard_fn=None,
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
# init model
|
||||
model = (
|
||||
umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device)
|
||||
.eval()
|
||||
.requires_grad_(False)
|
||||
)
|
||||
logging.info(f"loading {checkpoint_path}")
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
|
||||
self.model = model
|
||||
if shard_fn is not None:
|
||||
self.model = shard_fn(self.model, sync_module_states=False)
|
||||
else:
|
||||
self.model.to(self.device)
|
||||
# init tokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
|
||||
|
||||
def __call__(self, texts, device):
|
||||
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
|
||||
ids = ids.to(device)
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
context = self.model(ids, mask)
|
||||
return [u[:v] for u, v in zip(context, seq_lens, strict=False)]
|
||||
@@ -1,78 +0,0 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import html
|
||||
import string
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
__all__ = ["HuggingfaceTokenizer"]
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace("_", " ")
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans("", "", string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string)
|
||||
)
|
||||
else:
|
||||
text = text.translate(str.maketrans("", "", string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, "whitespace", "lower", "canonicalize")
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop("return_mask", False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {"return_tensors": "pt"}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == "whitespace":
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == "lower":
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == "canonicalize":
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,16 +14,24 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from .wan.modules.vae2_2 import Wan2VAE
|
||||
if TYPE_CHECKING:
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
|
||||
class WanVideoVAE38(torch.nn.Module):
|
||||
"""Tensor-batch adapter around the official Wan2.2 VAE wrapper."""
|
||||
"""FastWAM VAE contract over `diffusers.AutoencoderKLWan` (Wan2.2-TI2V-5B).
|
||||
|
||||
16x spatial / 4x temporal compression, 48 latent channels. diffusers'
|
||||
`AutoencoderKLWan` returns *raw* latents (it does not apply `latents_mean`/
|
||||
`latents_std`), so `encode`/`decode` here apply the same standardization the
|
||||
Wan reference uses — `(latents - mean) / std` — done in fp32 for stability.
|
||||
`encode` uses the deterministic posterior mode, matching the original VAE
|
||||
which returned the latent mean `mu`.
|
||||
"""
|
||||
|
||||
upsampling_factor = 16
|
||||
temporal_downsample_factor = 4
|
||||
@@ -31,27 +39,35 @@ class WanVideoVAE38(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_pth: str | Path,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: str | torch.device = "cuda",
|
||||
*,
|
||||
pretrained: AutoencoderKLWan,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.wan_vae = Wan2VAE(vae_pth=str(vae_pth), dtype=dtype, device=str(device))
|
||||
self.model = self.wan_vae.model
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(device)
|
||||
# The Wan2.2 VAE is a fixed pretrained model — it is never trained from scratch,
|
||||
# so a real `AutoencoderKLWan` (with weights) must always be supplied (loaded from
|
||||
# the diffusers repo by `load_pretrained_wan_vae`). No random/offline build path.
|
||||
self.vae = pretrained.to(device=device, dtype=dtype)
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any):
|
||||
super().to(*args, **kwargs)
|
||||
self.model.to(*args, **kwargs)
|
||||
param = next(self.model.parameters())
|
||||
self.device = param.device
|
||||
self.dtype = param.dtype
|
||||
self.wan_vae.device = self.device
|
||||
self.wan_vae.dtype = self.dtype
|
||||
self.wan_vae.scale = [scale.to(device=self.device, dtype=self.dtype) for scale in self.wan_vae.scale]
|
||||
self.wan_vae.model = self.model
|
||||
return self
|
||||
# Read the standardization stats from the VAE's own config (diffusers populates
|
||||
# these from vae/config.json) — single source of truth, no local copy. diffusers'
|
||||
# encode/decode return *raw* latents, so we apply (latent - mean) / std ourselves.
|
||||
# Non-persistent: kept out of state_dict.
|
||||
self.register_buffer(
|
||||
"latents_mean",
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, self.z_dim, 1, 1, 1),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"latents_std",
|
||||
torch.tensor(self.vae.config.latents_std).view(1, self.z_dim, 1, 1, 1),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _device_dtype(self) -> tuple[torch.device, torch.dtype]:
|
||||
param = next(self.vae.parameters())
|
||||
return param.device, param.dtype
|
||||
|
||||
def encode(
|
||||
self,
|
||||
@@ -61,18 +77,16 @@ class WanVideoVAE38(torch.nn.Module):
|
||||
tile_size: tuple[int, int] = (34, 34),
|
||||
tile_stride: tuple[int, int] = (18, 16),
|
||||
) -> torch.Tensor:
|
||||
del tile_size, tile_stride
|
||||
del device, tile_size, tile_stride
|
||||
if tiled:
|
||||
raise NotImplementedError("Tiled Wan2.2 VAE encoding is not supported by the FastWAM adapter.")
|
||||
target_device = self.device if device is None else torch.device(device)
|
||||
if target_device != self.device:
|
||||
self.to(device=target_device)
|
||||
if isinstance(videos, torch.Tensor):
|
||||
videos = list(videos)
|
||||
hidden_states = self.wan_vae.encode([video.to(self.device) for video in videos])
|
||||
if hidden_states is None:
|
||||
raise RuntimeError("Wan2.2 VAE encode failed; expected a list of video tensors.")
|
||||
return torch.stack(hidden_states)
|
||||
if isinstance(videos, (list, tuple)):
|
||||
videos = torch.stack(list(videos))
|
||||
dev, dtype = self._device_dtype()
|
||||
mu = self.vae.encode(videos.to(device=dev, dtype=dtype)).latent_dist.mode().float()
|
||||
mean = self.latents_mean.float().to(mu.device)
|
||||
std = self.latents_std.float().to(mu.device)
|
||||
return (mu - mean) / std
|
||||
|
||||
def decode(
|
||||
self,
|
||||
@@ -82,18 +96,16 @@ class WanVideoVAE38(torch.nn.Module):
|
||||
tile_size: tuple[int, int] = (34, 34),
|
||||
tile_stride: tuple[int, int] = (18, 16),
|
||||
) -> torch.Tensor:
|
||||
del tile_size, tile_stride
|
||||
del device, tile_size, tile_stride
|
||||
if tiled:
|
||||
raise NotImplementedError("Tiled Wan2.2 VAE decoding is not supported by the FastWAM adapter.")
|
||||
target_device = self.device if device is None else torch.device(device)
|
||||
if target_device != self.device:
|
||||
self.to(device=target_device)
|
||||
if isinstance(hidden_states, torch.Tensor):
|
||||
hidden_states = list(hidden_states)
|
||||
videos = self.wan_vae.decode([hidden_state.to(self.device) for hidden_state in hidden_states])
|
||||
if videos is None:
|
||||
raise RuntimeError("Wan2.2 VAE decode failed; expected a list of latent tensors.")
|
||||
return torch.stack(videos)
|
||||
if isinstance(hidden_states, (list, tuple)):
|
||||
hidden_states = torch.stack(list(hidden_states))
|
||||
dev, dtype = self._device_dtype()
|
||||
z = hidden_states.float()
|
||||
z = z * self.latents_std.float().to(z.device) + self.latents_mean.float().to(z.device)
|
||||
out = self.vae.decode(z.to(device=dev, dtype=dtype)).sample
|
||||
return out.float().clamp_(-1.0, 1.0)
|
||||
|
||||
|
||||
__all__ = ["WanVideoVAE38"]
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -24,57 +23,108 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wan.modules.tokenizers import HuggingfaceTokenizer
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
|
||||
# repo as sharded `diffusion_pytorch_model*.safetensors`; the VAE and UMT5 text
|
||||
# encoder come from the diffusers conversion. Tokenizer is the stock UMT5 one.
|
||||
WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
|
||||
WAN_T5_SAFE_CHECKPOINT = "models_t5_umt5-xxl-enc-bf16.safetensors"
|
||||
WAN_T5_CHECKPOINT = WAN_T5_SAFE_CHECKPOINT
|
||||
WAN_T5_TOKENIZER = "google/umt5-xxl"
|
||||
WAN_VAE_SAFE_CHECKPOINT = "Wan2.2_VAE.safetensors"
|
||||
WAN_VAE_CHECKPOINT = WAN_VAE_SAFE_CHECKPOINT
|
||||
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
|
||||
|
||||
class WanTextEncoder(torch.nn.Module):
|
||||
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
|
||||
|
||||
Exposes `.dim` (hidden size) and `forward(ids, mask) -> [B, L, dim]`, matching
|
||||
the call in `FastWAM.encode_prompt`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str | torch.device = "cuda",
|
||||
*,
|
||||
pretrained: torch.nn.Module,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# UMT5-XXL is a fixed pretrained encoder — never trained from scratch, so a real
|
||||
# `UMT5EncoderModel` (with weights) must always be supplied (loaded from the
|
||||
# diffusers repo by `load_pretrained_wan_text_encoder`). No random/offline build.
|
||||
self.model = pretrained.to(device=device, dtype=dtype)
|
||||
self.dim = int(self.model.config.d_model)
|
||||
|
||||
def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
return self.model(input_ids=ids, attention_mask=mask.long()).last_hidden_state
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WanCheckpointPaths:
|
||||
root: Path
|
||||
dit: list[Path]
|
||||
vae: Path
|
||||
text_encoder: Path | None
|
||||
tokenizer: Path | None
|
||||
class WanTokenizer:
|
||||
"""UMT5 tokenizer wrapper returning `(input_ids, attention_mask)` like the
|
||||
FastWAM call site expects."""
|
||||
|
||||
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
self.seq_len = int(seq_len)
|
||||
|
||||
def __call__(
|
||||
self, sequence: str | Sequence[str], return_mask: bool = False, add_special_tokens: bool = True, **_: Any
|
||||
):
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
out = self.tokenizer(
|
||||
list(sequence),
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.seq_len,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if return_mask:
|
||||
return out.input_ids, out.attention_mask
|
||||
return out.input_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wan22LoadedComponents:
|
||||
dit: WanVideoDiT
|
||||
vae: WanVideoVAE38
|
||||
text_encoder: torch.nn.Module | None
|
||||
tokenizer: HuggingfaceTokenizer | None
|
||||
dit_path: list[str]
|
||||
vae_path: str
|
||||
text_encoder_path: str | None
|
||||
tokenizer_path: str | None
|
||||
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
|
||||
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
|
||||
|
||||
|
||||
def resolve_wan_checkpoint_dir(
|
||||
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype
|
||||
)
|
||||
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
|
||||
|
||||
|
||||
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
|
||||
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
|
||||
from transformers import UMT5EncoderModel
|
||||
|
||||
encoder = UMT5EncoderModel.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
|
||||
)
|
||||
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
|
||||
|
||||
|
||||
def resolve_wan_dit_paths(
|
||||
model_id_or_path: str | Path,
|
||||
*,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
) -> Path:
|
||||
"""Return a local Wan2.2 checkpoint directory.
|
||||
|
||||
Local paths are used directly. Hub repos are downloaded with the same fixed
|
||||
component names used by the upstream Wan2.2 inference code.
|
||||
"""
|
||||
|
||||
) -> list[Path]:
|
||||
"""Resolve the custom MoT DiT shards from the original Wan2.2 repo or a local dir."""
|
||||
path = Path(model_id_or_path).expanduser()
|
||||
if path.is_dir():
|
||||
return path
|
||||
return sorted(path.glob(WAN_DIT_PATTERN))
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -83,51 +133,9 @@ def resolve_wan_checkpoint_dir(
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
allow_patterns=[
|
||||
WAN_DIT_PATTERN,
|
||||
WAN_T5_CHECKPOINT,
|
||||
WAN_VAE_CHECKPOINT,
|
||||
f"{WAN_T5_TOKENIZER}/**",
|
||||
],
|
||||
)
|
||||
return Path(snapshot_path)
|
||||
|
||||
|
||||
def resolve_wan_checkpoint_paths(
|
||||
checkpoint_dir: str | Path,
|
||||
*,
|
||||
tokenizer_dir: str | Path | None = None,
|
||||
load_dit: bool = True,
|
||||
load_text_encoder: bool = True,
|
||||
) -> WanCheckpointPaths:
|
||||
root = Path(checkpoint_dir).expanduser()
|
||||
tokenizer_root = Path(tokenizer_dir).expanduser() if tokenizer_dir is not None else root
|
||||
dit = sorted(root.glob(WAN_DIT_PATTERN)) if load_dit else []
|
||||
vae = root / WAN_VAE_SAFE_CHECKPOINT
|
||||
text_encoder = root / WAN_T5_SAFE_CHECKPOINT if load_text_encoder else None
|
||||
tokenizer = tokenizer_root / WAN_T5_TOKENIZER if load_text_encoder else None
|
||||
|
||||
missing = []
|
||||
if load_dit and len(dit) == 0:
|
||||
missing.append(f"DiT ({WAN_DIT_PATTERN})")
|
||||
if not vae.exists():
|
||||
missing.append(f"VAE ({WAN_VAE_SAFE_CHECKPOINT})")
|
||||
if load_text_encoder:
|
||||
if text_encoder is None or not text_encoder.exists():
|
||||
missing.append(f"text encoder ({WAN_T5_SAFE_CHECKPOINT})")
|
||||
if tokenizer is None or not tokenizer.exists():
|
||||
missing.append(f"tokenizer ({WAN_T5_TOKENIZER})")
|
||||
if missing:
|
||||
raise FileNotFoundError(
|
||||
f"Incomplete Wan2.2 checkpoint directory {root}: missing {', '.join(missing)}."
|
||||
)
|
||||
return WanCheckpointPaths(
|
||||
root=root,
|
||||
dit=dit,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
allow_patterns=[WAN_DIT_PATTERN],
|
||||
)
|
||||
return sorted(Path(snapshot_path).glob(WAN_DIT_PATTERN))
|
||||
|
||||
|
||||
def load_wan_video_dit(
|
||||
@@ -145,107 +153,6 @@ def load_wan_video_dit(
|
||||
return model.to(device=device, dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_wan_text_encoder(
|
||||
checkpoint_path: str | Path,
|
||||
*,
|
||||
torch_dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> torch.nn.Module:
|
||||
from .wan.modules.t5 import umt5_xxl
|
||||
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False,
|
||||
dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if checkpoint_path.suffix != ".safetensors":
|
||||
raise ValueError(f"Wan2.2 text encoder checkpoint must be safetensors, got {checkpoint_path}.")
|
||||
state_dict = load_file(checkpoint_path)
|
||||
model.load_state_dict(state_dict)
|
||||
return model.to(device=device, dtype=torch_dtype)
|
||||
|
||||
|
||||
def load_wan_tokenizer(tokenizer_path: str | Path, *, tokenizer_max_len: int) -> HuggingfaceTokenizer:
|
||||
from .wan.modules.tokenizers import HuggingfaceTokenizer
|
||||
|
||||
return HuggingfaceTokenizer(
|
||||
name=str(tokenizer_path),
|
||||
seq_len=int(tokenizer_max_len),
|
||||
clean="whitespace",
|
||||
)
|
||||
|
||||
|
||||
def load_wan_vae(checkpoint_path: str | Path, *, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
|
||||
return WanVideoVAE38(vae_pth=str(checkpoint_path), dtype=torch_dtype, device=device)
|
||||
|
||||
|
||||
def load_wan22_ti2v_5b_components(
|
||||
device: str = "cuda",
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
|
||||
tokenizer_model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
|
||||
tokenizer_max_len: int = 512,
|
||||
dit_config: dict[str, Any] | None = None,
|
||||
load_text_encoder: bool = True,
|
||||
):
|
||||
logger.info("Loading Wan2.2-TI2V-5B components...")
|
||||
start = time.time()
|
||||
|
||||
if dit_config is None:
|
||||
raise ValueError("`dit_config` is required for Wan2.2-TI2V-5B loading.")
|
||||
|
||||
checkpoint_dir = resolve_wan_checkpoint_dir(model_id)
|
||||
tokenizer_dir = (
|
||||
checkpoint_dir if tokenizer_model_id == model_id else resolve_wan_checkpoint_dir(tokenizer_model_id)
|
||||
)
|
||||
paths = resolve_wan_checkpoint_paths(
|
||||
checkpoint_dir,
|
||||
tokenizer_dir=tokenizer_dir,
|
||||
load_text_encoder=load_text_encoder,
|
||||
)
|
||||
|
||||
dit = load_wan_video_dit(
|
||||
paths.dit,
|
||||
dit_config=dit_config,
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
vae = load_wan_vae(paths.vae, torch_dtype=torch_dtype, device=device)
|
||||
|
||||
text_encoder: torch.nn.Module | None = None
|
||||
tokenizer: HuggingfaceTokenizer | None = None
|
||||
if load_text_encoder:
|
||||
if paths.text_encoder is None or paths.tokenizer is None:
|
||||
raise FileNotFoundError("Wan2.2 text encoder/tokenizer paths were not resolved.")
|
||||
text_encoder = load_wan_text_encoder(
|
||||
paths.text_encoder,
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
tokenizer = load_wan_tokenizer(paths.tokenizer, tokenizer_max_len=tokenizer_max_len)
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping pretrained text encoder/tokenizer load (`load_text_encoder=False`); "
|
||||
"training must provide cached `context/context_mask`."
|
||||
)
|
||||
|
||||
logger.info("Finished loading Wan2.2-TI2V-5B components in %.2f seconds.", time.time() - start)
|
||||
return Wan22LoadedComponents(
|
||||
dit=dit,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
dit_path=[str(path) for path in paths.dit],
|
||||
vae_path=str(paths.vae),
|
||||
text_encoder_path=str(paths.text_encoder) if paths.text_encoder is not None else None,
|
||||
tokenizer_path=str(paths.tokenizer) if paths.tokenizer is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
for path in paths:
|
||||
@@ -254,17 +161,14 @@ def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WAN22_DIFFUSERS_MODEL_ID",
|
||||
"WAN_DIT_PATTERN",
|
||||
"WAN_T5_CHECKPOINT",
|
||||
"WAN_T5_TOKENIZER",
|
||||
"WAN_VAE_CHECKPOINT",
|
||||
"Wan22LoadedComponents",
|
||||
"WanCheckpointPaths",
|
||||
"load_wan22_ti2v_5b_components",
|
||||
"load_wan_text_encoder",
|
||||
"load_wan_tokenizer",
|
||||
"load_wan_vae",
|
||||
"WanTextEncoder",
|
||||
"WanTokenizer",
|
||||
"build_wan_tokenizer",
|
||||
"load_pretrained_wan_text_encoder",
|
||||
"load_pretrained_wan_vae",
|
||||
"load_wan_video_dit",
|
||||
"resolve_wan_checkpoint_dir",
|
||||
"resolve_wan_checkpoint_paths",
|
||||
"resolve_wan_dit_paths",
|
||||
]
|
||||
|
||||
@@ -660,11 +660,16 @@ class WanVideoDiT(WanModel):
|
||||
) * timestep.to(dtype=model_dtype).view(batch_size, 1, 1)
|
||||
token_timesteps[:, 0, :] = 0
|
||||
token_timesteps = token_timesteps.reshape(batch_size, -1)
|
||||
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).to(
|
||||
dtype=model_dtype
|
||||
)
|
||||
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
|
||||
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
|
||||
# Wan keeps the time embedding in fp32: the AdaLN modulation in the vendored
|
||||
# Head/Block asserts e.dtype == float32 (numerical stability of the scale/shift).
|
||||
# Upstream guarantees this via an fp32 autocast region, so it holds even when the
|
||||
# model runs in bf16. Mirror that here, then cast the per-block modulation back to
|
||||
# model_dtype so the bf16 attention blocks are not upcast to fp32.
|
||||
with torch.amp.autocast("cuda", dtype=torch.float32):
|
||||
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).float()
|
||||
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
|
||||
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
|
||||
t_mod = t_mod.to(dtype=model_dtype)
|
||||
|
||||
x = self.patchify(x)
|
||||
f, h, w = x.shape[2:]
|
||||
|
||||
@@ -18,21 +18,13 @@ import json
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import save_model
|
||||
from safetensors import safe_open
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
|
||||
from lerobot.policies.fastwam import modeling_fastwam
|
||||
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy, resolve_wan_component_paths
|
||||
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy
|
||||
from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep
|
||||
from lerobot.policies.fastwam.wan_components import (
|
||||
WAN_DIT_PATTERN,
|
||||
WAN_T5_CHECKPOINT,
|
||||
WAN_T5_TOKENIZER,
|
||||
WAN_VAE_CHECKPOINT,
|
||||
resolve_wan_checkpoint_paths,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
@@ -170,8 +162,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
base_model_id=None,
|
||||
)
|
||||
with pytest.warns(RuntimeWarning, match="does not load pretrained FastWAM weights"):
|
||||
policy = FastWAMPolicy(cfg)
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
output = policy.forward(
|
||||
{
|
||||
@@ -207,89 +198,96 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
]
|
||||
|
||||
|
||||
def test_from_pretrained_loads_weights_without_initializing_wan_backbone(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
cfg.save_pretrained(tmp_path)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: FakeFastWAMCore())
|
||||
reference_policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True)
|
||||
save_model(reference_policy, str(tmp_path / "model.safetensors"))
|
||||
class CoreWithFrozenComponents(FakeFastWAMCore):
|
||||
"""Fake core mirroring the real one: frozen VAE / text encoder held as
|
||||
*unregistered* attributes (via `object.__setattr__`) so they are excluded from
|
||||
`state_dict()` and the saved checkpoint, but still moved by the `_apply` override."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "vae", nn.Linear(2, 2))
|
||||
object.__setattr__(self, "text_encoder", nn.Linear(2, 2))
|
||||
self.vae.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
super()._apply(fn, *args, **kwargs)
|
||||
self.vae._apply(fn)
|
||||
self.text_encoder._apply(fn)
|
||||
return self
|
||||
|
||||
|
||||
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
|
||||
def build_core(self, config):
|
||||
core = CoreWithFrozenComponents()
|
||||
with torch.no_grad():
|
||||
core.dit.weight.fill_(0.5)
|
||||
return core
|
||||
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", build_core)
|
||||
|
||||
reference = FastWAMPolicy(cfg)
|
||||
with torch.no_grad():
|
||||
reference.model.dit.weight.fill_(1.25) # a distinctive, trained-looking weight
|
||||
reference.save_pretrained(tmp_path)
|
||||
|
||||
# Building from Wan2.2 must never happen on a checkpoint load.
|
||||
def fail_if_wan_pretrained_is_loaded(*args, **kwargs):
|
||||
raise AssertionError("from_pretrained must not initialize or download Wan2.2 backbone components")
|
||||
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
|
||||
fail_if_wan_pretrained_is_loaded,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
modeling_fastwam,
|
||||
"_build_core_model_from_architecture",
|
||||
lambda config: FakeFastWAMCore(),
|
||||
raising=False,
|
||||
)
|
||||
loaded_components_from = []
|
||||
monkeypatch.setattr(
|
||||
FastWAMPolicy,
|
||||
"load_wan_components_from_pretrained",
|
||||
lambda self, path: loaded_components_from.append(path),
|
||||
)
|
||||
|
||||
policy = FastWAMPolicy.from_pretrained(tmp_path, strict=False)
|
||||
policy = FastWAMPolicy.from_pretrained(tmp_path)
|
||||
|
||||
assert isinstance(policy.model, FakeFastWAMCore)
|
||||
assert loaded_components_from == [tmp_path]
|
||||
assert isinstance(policy.model, CoreWithFrozenComponents)
|
||||
# The bundled checkpoint weights overwrote the freshly built (0.5) DiT weights.
|
||||
assert torch.allclose(policy.model.dit.weight, torch.full_like(policy.model.dit.weight, 1.25))
|
||||
|
||||
|
||||
def test_save_pretrained_copies_required_wan_sidecars(monkeypatch, tmp_path):
|
||||
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
source = tmp_path / "source"
|
||||
tokenizer = source / WAN_T5_TOKENIZER
|
||||
tokenizer.mkdir(parents=True)
|
||||
vae = source / WAN_VAE_CHECKPOINT
|
||||
text_encoder = source / WAN_T5_CHECKPOINT
|
||||
tokenizer_file = tokenizer / "tokenizer.json"
|
||||
vae.write_bytes(b"vae")
|
||||
text_encoder.write_bytes(b"text")
|
||||
tokenizer_file.write_text("{}")
|
||||
core = FakeFastWAMCore()
|
||||
core.model_paths = {
|
||||
"vae": str(vae),
|
||||
"text_encoder": str(text_encoder),
|
||||
"tokenizer": str(tokenizer),
|
||||
}
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: core)
|
||||
policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
save_dir = tmp_path / "saved"
|
||||
policy.save_pretrained(save_dir)
|
||||
|
||||
assert (save_dir / "model.safetensors").is_file()
|
||||
assert (save_dir / WAN_VAE_CHECKPOINT).read_bytes() == b"vae"
|
||||
assert (save_dir / WAN_T5_CHECKPOINT).read_bytes() == b"text"
|
||||
assert (save_dir / WAN_T5_TOKENIZER / "tokenizer.json").read_text() == "{}"
|
||||
# No Wan sidecar files either: the frozen backbone comes from the diffusers repo.
|
||||
assert not (save_dir / "Wan2.2_VAE.safetensors").exists()
|
||||
assert not (save_dir / "google").exists()
|
||||
|
||||
with safe_open(save_dir / "model.safetensors", framework="pt") as f:
|
||||
keys = set(f.keys())
|
||||
# Lean checkpoint: only the trainable DiT is saved; the frozen VAE / UMT5 text
|
||||
# encoder are excluded (loaded from the diffusers/transformers repos at init).
|
||||
assert any(key.startswith("model.dit.") for key in keys)
|
||||
assert not any(key.startswith("model.vae.") for key in keys)
|
||||
assert not any(key.startswith("model.text_encoder.") for key in keys)
|
||||
|
||||
|
||||
def test_wan_component_resolution_uses_fixed_safetensors_layout(tmp_path):
|
||||
tokenizer = tmp_path / WAN_T5_TOKENIZER
|
||||
tokenizer.mkdir(parents=True)
|
||||
(tmp_path / WAN_VAE_CHECKPOINT).touch()
|
||||
(tmp_path / WAN_T5_CHECKPOINT).touch()
|
||||
(tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors").touch()
|
||||
(tokenizer / "tokenizer.json").touch()
|
||||
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
paths = resolve_wan_checkpoint_paths(tmp_path)
|
||||
sidecar_paths = resolve_wan_component_paths(tmp_path)
|
||||
# Unregistered: excluded from state_dict and from the optimizer's parameter set.
|
||||
sd = policy.state_dict()
|
||||
assert not any(k.startswith("model.vae.") or k.startswith("model.text_encoder.") for k in sd)
|
||||
param_names = [n for n, _ in policy.named_parameters()]
|
||||
assert not any("vae" in n or "text_encoder" in n for n in param_names)
|
||||
|
||||
assert paths.dit == [tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors"]
|
||||
assert paths.vae == tmp_path / WAN_VAE_CHECKPOINT
|
||||
assert paths.text_encoder == tmp_path / WAN_T5_CHECKPOINT
|
||||
assert paths.tokenizer == tmp_path / WAN_T5_TOKENIZER
|
||||
assert sidecar_paths.dit == []
|
||||
assert WAN_DIT_PATTERN == "diffusion_pytorch_model*.safetensors"
|
||||
|
||||
(tmp_path / WAN_T5_CHECKPOINT).unlink()
|
||||
with pytest.raises(FileNotFoundError, match="text encoder"):
|
||||
resolve_wan_checkpoint_paths(tmp_path)
|
||||
# ...but the `_apply` override still carries them through `.to()` (dtype stands in
|
||||
# for device on a CPU box), so they never strand off the rest of the model.
|
||||
policy.to(torch.float64)
|
||||
assert policy.model.dit.weight.dtype == torch.float64 # registered
|
||||
assert policy.model.vae.weight.dtype == torch.float64 # unregistered, moved via _apply
|
||||
assert policy.model.text_encoder.weight.dtype == torch.float64
|
||||
|
||||
|
||||
def test_pretrained_config_round_trips_fastwam_features(tmp_path):
|
||||
@@ -302,3 +300,57 @@ def test_pretrained_config_round_trips_fastwam_features(tmp_path):
|
||||
assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL
|
||||
assert loaded.action_feature.shape == (7,)
|
||||
assert loaded.robot_state_feature.shape == (8,)
|
||||
|
||||
|
||||
def test_vae_adapter_empty_build_encode_decode_shapes():
|
||||
"""Offline glue check of the diffusers-backed VAE adapter (random weights).
|
||||
|
||||
Validates the encode/decode contract — 48 latent channels, 16x spatial / 4x
|
||||
temporal compression, list-or-batch input, scaling round-trip — without any
|
||||
weight download. (Numerical fidelity vs the original Wan VAE is a separate,
|
||||
GPU + real-weights verification step.)
|
||||
"""
|
||||
pytest.importorskip("diffusers")
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
|
||||
|
||||
# Production always loads a real pretrained VAE from the diffusers repo; here we
|
||||
# build the same architecture with random weights and dummy standardization stats
|
||||
# to exercise the adapter's shape/scaling contract offline (fidelity is checked
|
||||
# separately, with real weights, on GPU).
|
||||
arch = {
|
||||
"base_dim": 160,
|
||||
"decoder_base_dim": 256,
|
||||
"z_dim": 48,
|
||||
"dim_mult": [1, 2, 4, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_scales": [],
|
||||
"temperal_downsample": [False, True, True],
|
||||
"dropout": 0.0,
|
||||
"is_residual": True,
|
||||
"in_channels": 12,
|
||||
"out_channels": 12,
|
||||
"patch_size": 2,
|
||||
"scale_factor_spatial": 16,
|
||||
"scale_factor_temporal": 4,
|
||||
"clip_output": False,
|
||||
"latents_mean": [0.0] * 48,
|
||||
"latents_std": [1.0] * 48,
|
||||
}
|
||||
raw = AutoencoderKLWan.from_config(arch)
|
||||
vae = WanVideoVAE38(dtype=torch.float32, device="cpu", pretrained=raw)
|
||||
assert vae.z_dim == 48
|
||||
assert vae.upsampling_factor == 16
|
||||
assert vae.temporal_downsample_factor == 4
|
||||
|
||||
video = torch.rand(1, 3, 5, 32, 32) * 2 - 1 # [B,C,T,H,W] in [-1,1]
|
||||
latents = vae.encode(video)
|
||||
assert latents.shape == (1, 48, 2, 2, 2) # T'=(5-1)//4+1, H'=W'=32//16
|
||||
|
||||
decoded = vae.decode(latents)
|
||||
assert decoded.shape[0] == 1 and decoded.shape[1] == 3 and decoded.shape[-2:] == (32, 32)
|
||||
assert decoded.min() >= -1.0 and decoded.max() <= 1.0
|
||||
|
||||
# list input is accepted and equals the batched path
|
||||
assert torch.equal(vae.encode([video[0]]), latents)
|
||||
|
||||
@@ -1636,18 +1636,6 @@ http = [
|
||||
{ name = "aiohttp" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ftfy"
|
||||
version = "6.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "wcwidth" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "future"
|
||||
version = "1.0.0"
|
||||
@@ -2708,7 +2696,6 @@ all = [
|
||||
{ name = "faker" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "feetech-servo-sdk" },
|
||||
{ name = "ftfy" },
|
||||
{ name = "grpcio" },
|
||||
{ name = "grpcio-tools" },
|
||||
{ name = "gym-aloha" },
|
||||
@@ -2747,7 +2734,6 @@ all = [
|
||||
{ name = "pyzmq" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "reachy2-sdk" },
|
||||
{ name = "regex" },
|
||||
{ name = "rerun-sdk" },
|
||||
{ name = "ruff" },
|
||||
{ name = "scikit-image" },
|
||||
@@ -2846,8 +2832,6 @@ evaluation = [
|
||||
]
|
||||
fastwam = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "ftfy" },
|
||||
{ name = "regex" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
feetech = [
|
||||
@@ -3108,7 +3092,6 @@ requires-dist = [
|
||||
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
|
||||
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "ftfy", marker = "extra == 'fastwam'", specifier = ">=6.1.1,<7.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
@@ -3277,7 +3260,6 @@ requires-dist = [
|
||||
{ name = "pyzmq", marker = "extra == 'pyzmq-dep'", specifier = ">=26.2.1,<28.0.0" },
|
||||
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
|
||||
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
|
||||
{ name = "regex", marker = "extra == 'fastwam'", specifier = ">=2024.0.0,<2027.0.0" },
|
||||
{ name = "requests", specifier = ">=2.32.0,<3.0.0" },
|
||||
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.1" },
|
||||
|
||||
Reference in New Issue
Block a user