big refactor to use models from diffusers and transformers

This commit is contained in:
Maxime Ellerbach
2026-06-12 08:56:58 +00:00
parent dfc0170b4d
commit 807e364132
15 changed files with 507 additions and 2288 deletions
-2
View File
@@ -225,8 +225,6 @@ eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
fastwam = [
"lerobot[transformers-dep]",
"lerobot[diffusers-dep]",
"ftfy>=6.1.1,<7.0.0",
"regex>=2024.0.0,<2027.0.0",
]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
@@ -118,10 +118,6 @@ def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, Policy
return coerced
def _coerce_normalization_mapping(mapping: dict[str, Any]) -> dict[str, Any]:
return {key: _coerce_enum(NormalizationMode, value) for key, value in mapping.items()}
def _is_local_model_id(value: str) -> bool:
path = Path(value).expanduser()
return path.is_absolute() or value.startswith(("./", "../", "~")) or path.exists()
@@ -200,7 +196,7 @@ class FastWAMConfig(PreTrainedConfig):
loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0})
video_dit_config: dict[str, Any] | None = None
action_dit_config: dict[str, Any] | None = None
normalization_mapping: dict[str, Any] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
@@ -220,7 +216,6 @@ class FastWAMConfig(PreTrainedConfig):
self.input_features = _coerce_policy_features(self.input_features)
self.output_features = _coerce_policy_features(self.output_features)
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
self.normalization_mapping = _coerce_normalization_mapping(self.normalization_mapping)
self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim)
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
@@ -266,6 +261,38 @@ class FastWAMConfig(PreTrainedConfig):
def get_scheduler_preset(self) -> None:
return None
def set_dataset_feature_metadata(self, dataset_features: dict[str, Any]) -> None:
"""Rebuild visual input features from the dataset's real camera keys.
FastWAM's `__post_init__` installs a synthetic single-image default
(`observation.images.image` at full `image_size` width). For datasets
with one or more separately-named cameras (e.g. `observation.images.top`,
`observation.images.wrist`), this hook — invoked by `make_policy` once the
dataset metadata is known — replaces that default with the actual camera
keys, each declared at the policy's native per-camera resolution
(`image_size[0]` x `image_size[1] // num_cameras`). The accompanying
resize step in `make_fastwam_pre_post_processors` resizes raw frames to
match, so heterogeneous source resolutions (e.g. 480x640) are supported.
"""
image_keys = sorted(
key
for key, feature in dataset_features.items()
if key.startswith("observation.images.")
and feature.get("dtype") in ("video", "image")
)
if not image_keys:
return
height, total_width = self.image_size
per_cam_width = total_width // len(image_keys)
new_inputs: dict[str, PolicyFeature] = {
key: PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, per_cam_width))
for key in image_keys
}
if self.proprio_dim is not None and OBS_STATE in dataset_features:
new_inputs[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,))
self.input_features = new_inputs
self.validate_features()
def validate_features(self) -> None:
if self.action_dim <= 0:
raise ValueError(f"`action_dim` must be positive, got {self.action_dim}.")
+126 -314
View File
@@ -14,24 +14,18 @@
from __future__ import annotations
import shutil
import warnings
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import Any
import torch
from torch import Tensor
from lerobot.configs import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.constants import OBS_STATE
from .configuration_fastwam import FastWAMConfig
if TYPE_CHECKING:
from .wan_components import WanCheckpointPaths
class FastWAMPolicy(PreTrainedPolicy):
"""LeRobot policy wrapper for FastWAM.
@@ -49,98 +43,62 @@ class FastWAMPolicy(PreTrainedPolicy):
self,
config: FastWAMConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
**kwargs: Any,
):
skip_wan_init = bool(kwargs.pop("_skip_wan_init", False))
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
suppress_base_init_warning = bool(kwargs.pop("_suppress_base_init_warning", False))
if not skip_wan_init and not suppress_base_init_warning:
warnings.warn(
"FastWAMPolicy(config) initializes from architecture/config and does not load pretrained "
"FastWAM weights. For training or evaluation, use `make_policy(config)` or "
"`FastWAMPolicy.from_pretrained(...)`.",
RuntimeWarning,
stacklevel=2,
)
if skip_wan_init:
self.model = _build_core_model_from_architecture(config)
else:
self.model = self._build_core_model(config)
self.model = self._build_core_model(config)
self.reset()
@classmethod
def from_pretrained(
cls,
pretrained_name_or_path: str | Path,
*,
config: FastWAMConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs: Any,
) -> FastWAMPolicy:
"""Load FastWAM weights and local Wan components from one HF directory.
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
"""Shape-aware load that supports cross-embodiment fine-tuning.
Args:
pretrained_name_or_path (str | Path): HF-format policy directory
containing `config.json`, `model.safetensors`, local Wan VAE,
local UMT5 text encoder safetensors, and tokenizer files.
config (FastWAMConfig | None): Optional config override. When
omitted, `config.json` is read from `pretrained_name_or_path`.
force_download (bool): Forwarded to LeRobot's pretrained loader.
resume_download (bool | None): Forwarded to LeRobot's pretrained loader.
proxies (dict | None): Forwarded to LeRobot's pretrained loader.
token (str | bool | None): Forwarded to LeRobot's pretrained loader.
cache_dir (str | Path | None): Forwarded to LeRobot's pretrained loader.
local_files_only (bool): Forwarded to LeRobot's pretrained loader.
revision (str | None): Forwarded to LeRobot's pretrained loader.
strict (bool): Whether safetensors loading should require an exact
match between checkpoint keys and policy module keys.
**kwargs (Any): Extra constructor arguments forwarded to
`FastWAMPolicy`.
`safetensors.load_model(strict=False)` ignores missing/unexpected keys but
still raises on a shape mismatch for a shared key. When fine-tuning from a
checkpoint trained on a different embodiment (e.g. the LIBERO 7-DoF / 8-dim
checkpoint adapted to a 6-DoF / 6-dim arm), the action encoder/head and
proprio encoder legitimately differ in shape. With `strict=False` we drop
only those shape-mismatched tensors — leaving them at their freshly
initialized values — and load every compatible tensor. With `strict=True`
the standard exact-match loader is used.
"""
from safetensors import safe_open
pretrained_path = _resolve_pretrained_directory(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = PreTrainedConfig.from_pretrained(pretrained_path)
if not isinstance(config, FastWAMConfig):
raise TypeError(f"Expected FastWAM config, got {type(config).__name__}.")
kwargs["_skip_wan_init"] = True
policy = super().from_pretrained(
pretrained_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
policy.load_wan_components_from_pretrained(pretrained_path)
policy.eval()
return policy
model_state_dict = model.state_dict()
mismatched = []
with safe_open(model_file, framework="pt") as f:
checkpoint_keys = list(f.keys())
for key in checkpoint_keys:
if key in model_state_dict and tuple(model_state_dict[key].shape) != tuple(
f.get_slice(key).get_shape()
):
mismatched.append(key)
def _save_pretrained(self, save_directory: Path) -> None:
super()._save_pretrained(save_directory)
_copy_wan_components_from_policy(policy=self, save_directory=save_directory)
if not mismatched:
return super()._load_as_safetensor(model, model_file, map_location, strict)
if strict:
raise RuntimeError(
f"FastWAM: {len(mismatched)} checkpoint tensors have a shape mismatch under "
f"strict=True: {mismatched}"
)
from safetensors.torch import load_file
logging.warning(
"FastWAM cross-embodiment load: reinitializing %d shape-mismatched tensor(s), keeping "
"every compatible weight: %s",
len(mismatched),
mismatched,
)
state_dict = load_file(model_file, device="cpu")
for key in mismatched:
state_dict.pop(key, None)
model.load_state_dict(state_dict, strict=False)
if map_location and map_location != "cpu":
model.to(map_location)
return model
def get_optim_params(self) -> dict[str, Any]:
params = (
@@ -151,21 +109,41 @@ class FastWAMPolicy(PreTrainedPolicy):
params.extend(list(proprio_encoder.parameters()))
return {"params": [p for p in params if p.requires_grad]}
def load_wan_components_from_pretrained(self, pretrained_name_or_path: str | Path) -> None:
"""Attach local Wan VAE, text encoder, and tokenizer from a HF directory.
Args:
pretrained_name_or_path (str | Path): Directory containing
`Wan2.2_VAE.safetensors`, `models_t5_umt5-xxl-enc-bf16.safetensors`,
and `google/umt5-xxl/` tokenizer files.
"""
paths = resolve_wan_component_paths(pretrained_name_or_path)
_load_wan_components_into_policy(policy=self, paths=paths)
def reset(self) -> None:
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
`FastWAM.build_inputs` consumes (`video`, `action`, `context`/`context_mask`,
per-frame `proprio`).
The LeRobot training loop passes raw `observation.images.*`, a single-step
`observation.state` `[B, D]`, `action`, and a language `task` string. We do
only the translation `build_inputs` can't: stack the camera frames into a
video, encode the prompt with the (frozen) text encoder (mirroring inference,
so language-conditioned datasets need no precomputed context), and give proprio
the per-frame axis `build_inputs` indexes. All shape/presence validation is
left to `build_inputs`, the single authority on the contract.
"""
sample = dict(batch)
if "video" not in sample:
sample["video"] = _stack_video_from_images(batch, self.config)
if "context" not in sample or "context_mask" not in sample:
prompt = _prompt_from_batch(batch=batch, config=self.config)
if prompt is None:
raise KeyError(
"FastWAM training requires a `task`/`prompt` to encode text context, "
"or precomputed `context`/`context_mask` in the batch."
)
sample["context"], sample["context_mask"] = self.model.encode_prompt(prompt)
if self.config.proprio_dim is not None and "proprio" not in sample:
state = sample.get(OBS_STATE)
if state is not None:
# LeRobot gives a single-step state [B, D]; build_inputs expects
# per-frame [B, T, D] and uses frame 0, so add a T=1 axis.
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
return sample
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Compute FastWAM training loss for a LeRobot batch.
@@ -180,7 +158,7 @@ class FastWAMPolicy(PreTrainedPolicy):
key required by LeRobot and optional tensor metrics.
"""
sample = _batch_to_training_sample(batch=batch, config=self.config)
sample = self._batch_to_training_sample(batch)
loss, metrics = self.model.training_loss(sample)
output = {"loss": loss}
for key, value in (metrics or {}).items():
@@ -230,223 +208,57 @@ class FastWAMPolicy(PreTrainedPolicy):
return self._action_queue.popleft()
def _build_core_model(self, config: FastWAMConfig) -> torch.nn.Module:
return _build_core_model_from_wan22(config)
"""Build the FastWAM core for training / inference.
def _resolve_pretrained_directory(
pretrained_name_or_path: str | Path,
*,
force_download: bool,
token: str | bool | None,
cache_dir: str | Path | None,
local_files_only: bool,
revision: str | None,
) -> Path:
path = Path(pretrained_name_or_path)
if path.is_dir():
return path
from huggingface_hub import snapshot_download
from .wan_components import (
WAN_T5_CHECKPOINT,
WAN_T5_TOKENIZER,
WAN_VAE_CHECKPOINT,
)
snapshot_path = snapshot_download(
repo_id=str(pretrained_name_or_path),
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
token=token,
local_files_only=local_files_only,
allow_patterns=[
"config.json",
"model.safetensors",
WAN_VAE_CHECKPOINT,
WAN_T5_CHECKPOINT,
f"{WAN_T5_TOKENIZER}/**",
],
)
return Path(snapshot_path)
def resolve_wan_component_paths(pretrained_name_or_path: str | Path) -> WanCheckpointPaths:
"""Resolve local Wan component paths stored beside FastWAM HF weights.
Args:
pretrained_name_or_path (str | Path): HF-format FastWAM directory.
Returns:
WanCheckpointPaths: Existing VAE, text encoder, and tokenizer paths.
DiT shards are intentionally optional here because FastWAM HF
checkpoints store trainable DiT weights in `model.safetensors`.
"""
from .wan_components import resolve_wan_checkpoint_paths
return resolve_wan_checkpoint_paths(
pretrained_name_or_path,
load_dit=False,
load_text_encoder=True,
)
def _load_wan_components_into_policy(policy: FastWAMPolicy, paths: WanCheckpointPaths) -> None:
from .wan_components import load_wan_text_encoder, load_wan_tokenizer, load_wan_vae
if paths.text_encoder is None or paths.tokenizer is None:
raise FileNotFoundError("FastWAM HF checkpoint requires Wan text encoder and tokenizer sidecars.")
dtype = _dtype_from_name(policy.config.torch_dtype)
device = str(policy.config.device)
policy.model.vae = load_wan_vae(paths.vae, torch_dtype=dtype, device=device)
policy.model.text_encoder = load_wan_text_encoder(paths.text_encoder, torch_dtype=dtype, device=device)
policy.model.tokenizer = load_wan_tokenizer(
paths.tokenizer,
tokenizer_max_len=int(policy.config.tokenizer_max_len),
)
model_paths = dict(getattr(policy.model, "model_paths", {}) or {})
model_paths.update(
{
"vae": str(paths.vae),
"text_encoder": str(paths.text_encoder),
"tokenizer": str(paths.tokenizer),
}
)
policy.model.model_paths = model_paths
def _copy_wan_components_from_policy(policy: FastWAMPolicy, save_directory: Path) -> None:
model_paths = getattr(policy.model, "model_paths", {}) or {}
paths = {
"vae": model_paths.get("vae"),
"text_encoder": model_paths.get("text_encoder"),
"tokenizer": model_paths.get("tokenizer"),
}
missing = [name for name, path in paths.items() if path is None]
if missing:
raise RuntimeError(
"FastWAM save_pretrained requires local Wan component paths for "
f"{missing}. Load or initialize the policy with local Wan VAE, text encoder, and tokenizer files."
Only the trainable parts (the MoT DiT and the proprio encoder) are
materialized empty here and then filled from the policy's
`model.safetensors` by the base `from_pretrained`. The *frozen* Wan2.2 VAE
and UMT5 text encoder are loaded with their real weights from the
`Wan-AI/Wan2.2-TI2V-5B-Diffusers` repo (cached in the HF cache, shared
across checkpoints) and are intentionally excluded from `model.safetensors`
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
"""
from .modular_fastwam import ActionDiT, FastWAM, MoT
from .wan_components import (
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
)
_copy_component_path(Path(paths["vae"]), save_directory / Path(paths["vae"]).name)
_copy_component_path(Path(paths["text_encoder"]), save_directory / Path(paths["text_encoder"]).name)
tokenizer_source = Path(paths["tokenizer"])
_copy_component_path(tokenizer_source, save_directory / "google" / "umt5-xxl")
from .wan_video_dit import WanVideoDiT
def _copy_component_path(source: Path, destination: Path) -> None:
source = source.expanduser()
if not source.exists():
raise FileNotFoundError(f"FastWAM component path does not exist: {source}")
if source.resolve() == destination.resolve():
return
destination.parent.mkdir(parents=True, exist_ok=True)
if source.is_dir():
shutil.copytree(source, destination, dirs_exist_ok=True)
else:
shutil.copy2(source, destination)
def _build_core_model_from_wan22(config: FastWAMConfig) -> torch.nn.Module:
from .modular_fastwam import FastWAM
dtype = _dtype_from_name(config.torch_dtype)
return FastWAM.from_wan22_pretrained(
device=config.device,
torch_dtype=dtype,
model_id=config.model_id,
tokenizer_model_id=config.tokenizer_model_id,
tokenizer_max_len=config.tokenizer_max_len,
load_text_encoder=config.load_text_encoder,
proprio_dim=config.proprio_dim,
video_dit_config=config.video_dit_config,
action_dit_config=config.action_dit_config,
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
video_train_shift=float(config.video_scheduler["train_shift"]),
video_infer_shift=float(config.video_scheduler["infer_shift"]),
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
action_train_shift=float(config.action_scheduler["train_shift"]),
action_infer_shift=float(config.action_scheduler["infer_shift"]),
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
loss_lambda_video=float(config.loss["lambda_video"]),
loss_lambda_action=float(config.loss["lambda_action"]),
)
def _build_core_model_from_architecture(config: FastWAMConfig) -> torch.nn.Module:
from .modular_fastwam import ActionDiT, FastWAM, MoT
from .wan_video_dit import WanVideoDiT
dtype = _dtype_from_name(config.torch_dtype)
video_expert = WanVideoDiT(**config.video_dit_config).to(device=config.device, dtype=dtype)
action_expert = ActionDiT(**config.action_dit_config).to(device=config.device, dtype=dtype)
mot = MoT(
mixtures={"video": video_expert, "action": action_expert},
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
)
return FastWAM(
video_expert=video_expert,
action_expert=action_expert,
mot=mot,
vae=_FastWAMVAEPlaceholder(),
text_encoder=None,
tokenizer=None,
text_dim=int(config.video_dit_config["text_dim"]),
proprio_dim=config.proprio_dim,
device=config.device,
torch_dtype=dtype,
video_train_shift=float(config.video_scheduler["train_shift"]),
video_infer_shift=float(config.video_scheduler["infer_shift"]),
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
action_train_shift=float(config.action_scheduler["train_shift"]),
action_infer_shift=float(config.action_scheduler["infer_shift"]),
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
loss_lambda_video=float(config.loss["lambda_video"]),
loss_lambda_action=float(config.loss["lambda_action"]),
)
class _FastWAMVAEPlaceholder(torch.nn.Module):
"""Minimal VAE placeholder for checkpoint loading without Wan2.2 VAE.
Args:
temporal_downsample_factor (int): Temporal compression factor expected
by FastWAM latent shape logic.
upsampling_factor (int): Spatial compression factor expected by FastWAM.
z_dim (int): Latent channel count used by Wan2.2 TI2V VAE.
"""
temporal_downsample_factor: int = 4
upsampling_factor: int = 8
def __init__(self, z_dim: int = 48):
super().__init__()
self.model = type("VAEModelShape", (), {"z_dim": int(z_dim)})()
def encode(self, *args, **kwargs):
raise RuntimeError(
"FastWAM VAE placeholder cannot encode images; load Wan2.2 VAE for image inference."
dtype = _dtype_from_name(config.torch_dtype)
device = config.device
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
action_expert = ActionDiT(**config.action_dit_config).to(device=device, dtype=dtype)
mot = MoT(
mixtures={"video": video_expert, "action": action_expert},
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
)
def decode(self, *args, **kwargs):
raise RuntimeError(
"FastWAM VAE placeholder cannot decode latents; load Wan2.2 VAE for video inference."
text_encoder = (
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
if config.load_text_encoder
else None
)
return FastWAM(
video_expert=video_expert,
action_expert=action_expert,
mot=mot,
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
text_encoder=text_encoder,
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
text_dim=int(config.video_dit_config["text_dim"]),
proprio_dim=config.proprio_dim,
device=device,
torch_dtype=dtype,
video_train_shift=float(config.video_scheduler["train_shift"]),
video_infer_shift=float(config.video_scheduler["infer_shift"]),
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
action_train_shift=float(config.action_scheduler["train_shift"]),
action_infer_shift=float(config.action_scheduler["infer_shift"]),
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
loss_lambda_video=float(config.loss["lambda_video"]),
loss_lambda_action=float(config.loss["lambda_action"]),
)
def _batch_to_training_sample(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Tensor]:
sample = dict(batch)
if "video" not in sample:
sample["video"] = _stack_video_from_images(batch, config)
if "proprio" not in sample and OBS_STATE in batch:
sample["proprio"] = batch[OBS_STATE]
required = {"video", ACTION, "context", "context_mask"}
missing = sorted(required - set(sample))
if missing:
raise KeyError(f"FastWAM training batch is missing keys: {missing}.")
return sample
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
+47 -29
View File
@@ -24,7 +24,13 @@ import torch.nn as nn
import torch.nn.functional as functional
from PIL import Image
from .wan_components import load_wan22_ti2v_5b_components
from .wan_components import (
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
load_wan_video_dit,
resolve_wan_dit_paths,
)
from .wan_video_dit import (
FastWAMAttentionBlock,
WanContinuousFlowMatchScheduler,
@@ -846,9 +852,17 @@ class FastWAM(torch.nn.Module):
# Keep trainer compatibility: optimizer and freeze logic use `model.dit`.
self.dit = self.mot
self.vae = vae
self.text_encoder = text_encoder
# Frozen Wan2.2 components: bypass `nn.Module.__setattr__` so they are NOT
# registered as submodules. They are therefore excluded from `state_dict()`
# (lean checkpoints), `parameters()`, and DDP gradient sync, and are loaded
# with their real weights from the diffusers/transformers repos at construction.
# Device/dtype moves still reach them via the `_apply` override below.
object.__setattr__(self, "vae", vae)
object.__setattr__(self, "text_encoder", text_encoder)
self.tokenizer = tokenizer
vae.requires_grad_(False)
if text_encoder is not None:
text_encoder.requires_grad_(False)
if text_dim is None:
if self.text_encoder is None:
raise ValueError("`text_dim` is required when `text_encoder` is not loaded.")
@@ -913,18 +927,17 @@ class FastWAM(torch.nn.Module):
raise ValueError("`video_dit_config` is required for FastWAM.from_wan22_pretrained().")
if "text_dim" not in video_dit_config:
raise ValueError("`video_dit_config['text_dim']` is required for FastWAM.")
del tokenizer_model_id # tokenizer is the stock UMT5 one (google/umt5-xxl)
components = load_wan22_ti2v_5b_components(
device=device,
torch_dtype=torch_dtype,
model_id=model_id,
tokenizer_model_id=tokenizer_model_id,
tokenizer_max_len=tokenizer_max_len,
# Custom MoT video DiT from the original Wan2.2 repo; frozen VAE / UMT5 from
# the diffusers conversion. This is the offline base-creation path; the
# weights it loads are then bundled into the FastWAM `model.safetensors`.
video_expert = load_wan_video_dit(
resolve_wan_dit_paths(model_id),
dit_config=video_dit_config,
load_text_encoder=load_text_encoder,
torch_dtype=torch_dtype,
device=device,
)
video_expert = components.dit
action_expert = ActionDiT(**action_dit_config).to(device=device, dtype=torch_dtype)
if int(action_expert.num_heads) != int(video_expert.num_heads):
raise ValueError("ActionDiT `num_heads` must match video expert for MoT mixed attention.")
@@ -938,13 +951,21 @@ class FastWAM(torch.nn.Module):
mot_checkpoint_mixed_attn=mot_checkpoint_mixed_attn,
)
model = cls(
vae = load_pretrained_wan_vae(torch_dtype=torch_dtype, device=device)
text_encoder = (
load_pretrained_wan_text_encoder(torch_dtype=torch_dtype, device=device)
if load_text_encoder
else None
)
tokenizer = build_wan_tokenizer(tokenizer_max_len=tokenizer_max_len)
return cls(
video_expert=video_expert,
action_expert=action_expert,
mot=mot,
vae=components.vae,
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_dim=int(video_dit_config["text_dim"]),
proprio_dim=proprio_dim,
device=device,
@@ -958,20 +979,17 @@ class FastWAM(torch.nn.Module):
loss_lambda_video=loss_lambda_video,
loss_lambda_action=loss_lambda_action,
)
model.model_paths = {
"video_dit": components.dit_path,
"vae": components.vae_path,
"text_encoder": components.text_encoder_path,
"tokenizer": components.tokenizer_path,
}
return model
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.mot.to(*args, **kwargs)
def _apply(self, fn, *args, **kwargs):
# `.to()` / `.cuda()` / `.cpu()` and accelerate/DDP device moves all funnel
# through `_apply`, and the parent policy reaches us via `child._apply(fn)`
# (not `child.to()`). Propagate `fn` to the *unregistered* frozen VAE / text
# encoder here so they follow the rest of the model onto the right device,
# while staying out of `state_dict()` / `parameters()`.
super()._apply(fn, *args, **kwargs)
self.vae._apply(fn)
if self.text_encoder is not None:
self.text_encoder.to(*args, **kwargs)
self.vae.to(*args, **kwargs)
self.text_encoder._apply(fn)
return self
@staticmethod
@@ -1628,7 +1646,7 @@ class FastWAM(torch.nn.Module):
latent_w = width // self.vae.upsampling_factor
generator = None if seed is None else torch.Generator(device=rand_device).manual_seed(seed)
return torch.randn(
(1, self.vae.model.z_dim, latent_t, latent_h, latent_w),
(1, self.vae.z_dim, latent_t, latent_h, latent_w),
generator=generator,
device=rand_device,
dtype=torch.float32,
@@ -24,6 +24,7 @@ from lerobot.processor import (
ActionProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
ImageCropResizeProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
@@ -90,9 +91,8 @@ def make_fastwam_pre_post_processors(
output processor pipelines discoverable by LeRobot.
"""
normalization_stats: dict[str, dict[str, Any]] = {
key: dict(value) for key, value in (dataset_stats or {}).items()
}
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
for key, feature in config.input_features.items():
if feature.type != FeatureType.VISUAL:
continue
@@ -101,10 +101,23 @@ def make_fastwam_pre_post_processors(
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
}
# resize visual inputs to match model expected input size, if necessary
visual_shapes = [
feature.shape
for feature in config.input_features.values()
if feature.type == FeatureType.VISUAL
]
resize_steps = []
if visual_shapes:
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
resize_steps.append(ImageCropResizeProcessorStep(resize_size=target_hw))
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
*resize_steps,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
+5 -6
View File
@@ -10,17 +10,16 @@ Copied files:
- `wan/modules/attention.py`
- `wan/modules/model.py`
- `wan/modules/t5.py`
- `wan/modules/tokenizers.py`
- `wan/modules/vae2_2.py`
- `wan/modules/__init__.py`
- `wan/utils/fm_solvers.py`
- `wan/utils/__init__.py`
FastWAM-specific model glue and any larger code adapted from these modules live outside this directory. This keeps the upstream Wan2.2 code reviewable as a vendored reference subset and makes it straightforward to replace this directory with an external Wan2.2 dependency by changing import paths.
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
UMT5 text encoder, and tokenizer are no longer vendored — they come from
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
Current FastWAM adapters that directly reuse this vendored subset:
- `../wan_components.py` instantiates the upstream `wan.modules.t5.umt5_xxl` encoder factory and uses `wan.modules.tokenizers.HuggingfaceTokenizer`.
- `../wan_adapters.py` wraps `wan.modules.vae2_2.Wan2VAE` with the FastWAM tensor-batch encode/decode API.
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
@@ -1,17 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vae2_2 import Wan2VAE
__all__ = [
"Wan2VAE",
"WanModel",
"T5Model",
"T5Encoder",
"T5Decoder",
"T5EncoderModel",
"HuggingfaceTokenizer",
"flash_attention",
]
@@ -1,489 +0,0 @@
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as functional
from .tokenizers import HuggingfaceTokenizer
__all__ = [
"T5Model",
"T5Encoder",
"T5Decoder",
"T5EncoderModel",
]
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp, max=clamp)
return x
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
elif isinstance(m, T5Model):
nn.init.normal_(m.token_embedding.weight, std=1.0)
elif isinstance(m, T5FeedForward):
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
elif isinstance(m, T5Attention):
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# compute attention (T5 does not use scaling)
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
attn = functional.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v)
# output
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = (
None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
)
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def forward(self, lq, lk):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = (
max_exact
+ (
torch.log(rel_pos.float() / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).long()
)
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(
self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1
):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
)
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(
self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1
):
super().__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
)
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList(
[
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
for _ in range(num_layers)
]
)
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.size()
# causal mask
if mask is None:
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
elif mask.ndim == 2:
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(
self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.1,
):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(
self.token_embedding,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
num_buckets,
shared_pos,
dropout,
)
self.decoder = T5Decoder(
self.token_embedding,
dim,
dim_attn,
dim_ffn,
num_heads,
decoder_layers,
num_buckets,
shared_pos,
dropout,
)
self.head = nn.Linear(dim, vocab_size, bias=False)
# initialize weights
self.apply(init_weights)
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def _t5(
name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs=None,
dtype=torch.float32,
device="cpu",
**kwargs,
):
# sanity check
assert not (encoder_only and decoder_only)
tokenizer_kwargs = tokenizer_kwargs or {}
# params
if encoder_only:
model_cls = T5Encoder
kwargs["vocab"] = kwargs.pop("vocab_size")
kwargs["num_layers"] = kwargs.pop("encoder_layers")
_ = kwargs.pop("decoder_layers")
elif decoder_only:
model_cls = T5Decoder
kwargs["vocab"] = kwargs.pop("vocab_size")
kwargs["num_layers"] = kwargs.pop("decoder_layers")
_ = kwargs.pop("encoder_layers")
else:
model_cls = T5Model
# init model
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = {
"vocab_size": 256384,
"dim": 4096,
"dim_attn": 4096,
"dim_ffn": 10240,
"num_heads": 64,
"encoder_layers": 24,
"decoder_layers": 24,
"num_buckets": 32,
"shared_pos": False,
"dropout": 0.1,
}
cfg.update(**kwargs)
return _t5("umt5-xxl", **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
):
self.text_len = text_len
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = (
umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device)
.eval()
.requires_grad_(False)
)
logging.info(f"loading {checkpoint_path}")
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
else:
self.model.to(self.device)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
def __call__(self, texts, device):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(device)
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)
return [u[:v] for u, v in zip(context, seq_lens, strict=False)]
@@ -1,78 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ["HuggingfaceTokenizer"]
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans("", "", string.punctuation))
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop("return_mask", False)
# arguments
_kwargs = {"return_tensors": "pt"}
if self.seq_len is not None:
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == "whitespace":
text = whitespace_clean(basic_clean(text))
elif self.clean == "lower":
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == "canonicalize":
text = canonicalize(basic_clean(text))
return text
File diff suppressed because it is too large Load Diff
+52 -40
View File
@@ -14,16 +14,24 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING
import torch
from .wan.modules.vae2_2 import Wan2VAE
if TYPE_CHECKING:
from diffusers import AutoencoderKLWan
class WanVideoVAE38(torch.nn.Module):
"""Tensor-batch adapter around the official Wan2.2 VAE wrapper."""
"""FastWAM VAE contract over `diffusers.AutoencoderKLWan` (Wan2.2-TI2V-5B).
16x spatial / 4x temporal compression, 48 latent channels. diffusers'
`AutoencoderKLWan` returns *raw* latents (it does not apply `latents_mean`/
`latents_std`), so `encode`/`decode` here apply the same standardization the
Wan reference uses `(latents - mean) / std` done in fp32 for stability.
`encode` uses the deterministic posterior mode, matching the original VAE
which returned the latent mean `mu`.
"""
upsampling_factor = 16
temporal_downsample_factor = 4
@@ -31,27 +39,35 @@ class WanVideoVAE38(torch.nn.Module):
def __init__(
self,
vae_pth: str | Path,
dtype: torch.dtype = torch.float32,
device: str | torch.device = "cuda",
*,
pretrained: AutoencoderKLWan,
) -> None:
super().__init__()
self.wan_vae = Wan2VAE(vae_pth=str(vae_pth), dtype=dtype, device=str(device))
self.model = self.wan_vae.model
self.dtype = dtype
self.device = torch.device(device)
# The Wan2.2 VAE is a fixed pretrained model — it is never trained from scratch,
# so a real `AutoencoderKLWan` (with weights) must always be supplied (loaded from
# the diffusers repo by `load_pretrained_wan_vae`). No random/offline build path.
self.vae = pretrained.to(device=device, dtype=dtype)
def to(self, *args: Any, **kwargs: Any):
super().to(*args, **kwargs)
self.model.to(*args, **kwargs)
param = next(self.model.parameters())
self.device = param.device
self.dtype = param.dtype
self.wan_vae.device = self.device
self.wan_vae.dtype = self.dtype
self.wan_vae.scale = [scale.to(device=self.device, dtype=self.dtype) for scale in self.wan_vae.scale]
self.wan_vae.model = self.model
return self
# Read the standardization stats from the VAE's own config (diffusers populates
# these from vae/config.json) — single source of truth, no local copy. diffusers'
# encode/decode return *raw* latents, so we apply (latent - mean) / std ourselves.
# Non-persistent: kept out of state_dict.
self.register_buffer(
"latents_mean",
torch.tensor(self.vae.config.latents_mean).view(1, self.z_dim, 1, 1, 1),
persistent=False,
)
self.register_buffer(
"latents_std",
torch.tensor(self.vae.config.latents_std).view(1, self.z_dim, 1, 1, 1),
persistent=False,
)
def _device_dtype(self) -> tuple[torch.device, torch.dtype]:
param = next(self.vae.parameters())
return param.device, param.dtype
def encode(
self,
@@ -61,18 +77,16 @@ class WanVideoVAE38(torch.nn.Module):
tile_size: tuple[int, int] = (34, 34),
tile_stride: tuple[int, int] = (18, 16),
) -> torch.Tensor:
del tile_size, tile_stride
del device, tile_size, tile_stride
if tiled:
raise NotImplementedError("Tiled Wan2.2 VAE encoding is not supported by the FastWAM adapter.")
target_device = self.device if device is None else torch.device(device)
if target_device != self.device:
self.to(device=target_device)
if isinstance(videos, torch.Tensor):
videos = list(videos)
hidden_states = self.wan_vae.encode([video.to(self.device) for video in videos])
if hidden_states is None:
raise RuntimeError("Wan2.2 VAE encode failed; expected a list of video tensors.")
return torch.stack(hidden_states)
if isinstance(videos, (list, tuple)):
videos = torch.stack(list(videos))
dev, dtype = self._device_dtype()
mu = self.vae.encode(videos.to(device=dev, dtype=dtype)).latent_dist.mode().float()
mean = self.latents_mean.float().to(mu.device)
std = self.latents_std.float().to(mu.device)
return (mu - mean) / std
def decode(
self,
@@ -82,18 +96,16 @@ class WanVideoVAE38(torch.nn.Module):
tile_size: tuple[int, int] = (34, 34),
tile_stride: tuple[int, int] = (18, 16),
) -> torch.Tensor:
del tile_size, tile_stride
del device, tile_size, tile_stride
if tiled:
raise NotImplementedError("Tiled Wan2.2 VAE decoding is not supported by the FastWAM adapter.")
target_device = self.device if device is None else torch.device(device)
if target_device != self.device:
self.to(device=target_device)
if isinstance(hidden_states, torch.Tensor):
hidden_states = list(hidden_states)
videos = self.wan_vae.decode([hidden_state.to(self.device) for hidden_state in hidden_states])
if videos is None:
raise RuntimeError("Wan2.2 VAE decode failed; expected a list of latent tensors.")
return torch.stack(videos)
if isinstance(hidden_states, (list, tuple)):
hidden_states = torch.stack(list(hidden_states))
dev, dtype = self._device_dtype()
z = hidden_states.float()
z = z * self.latents_std.float().to(z.device) + self.latents_mean.float().to(z.device)
out = self.vae.decode(z.to(device=dev, dtype=dtype)).sample
return out.float().clamp_(-1.0, 1.0)
__all__ = ["WanVideoVAE38"]
+92 -188
View File
@@ -15,8 +15,7 @@
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -24,57 +23,108 @@ import torch
from safetensors.torch import load_file
if TYPE_CHECKING:
from .wan.modules.tokenizers import HuggingfaceTokenizer
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
logger = logging.getLogger(__name__)
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
# repo as sharded `diffusion_pytorch_model*.safetensors`; the VAE and UMT5 text
# encoder come from the diffusers conversion. Tokenizer is the stock UMT5 one.
WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
WAN_T5_SAFE_CHECKPOINT = "models_t5_umt5-xxl-enc-bf16.safetensors"
WAN_T5_CHECKPOINT = WAN_T5_SAFE_CHECKPOINT
WAN_T5_TOKENIZER = "google/umt5-xxl"
WAN_VAE_SAFE_CHECKPOINT = "Wan2.2_VAE.safetensors"
WAN_VAE_CHECKPOINT = WAN_VAE_SAFE_CHECKPOINT
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
class WanTextEncoder(torch.nn.Module):
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
Exposes `.dim` (hidden size) and `forward(ids, mask) -> [B, L, dim]`, matching
the call in `FastWAM.encode_prompt`.
"""
def __init__(
self,
dtype: torch.dtype = torch.bfloat16,
device: str | torch.device = "cuda",
*,
pretrained: torch.nn.Module,
) -> None:
super().__init__()
# UMT5-XXL is a fixed pretrained encoder — never trained from scratch, so a real
# `UMT5EncoderModel` (with weights) must always be supplied (loaded from the
# diffusers repo by `load_pretrained_wan_text_encoder`). No random/offline build.
self.model = pretrained.to(device=device, dtype=dtype)
self.dim = int(self.model.config.d_model)
def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
return self.model(input_ids=ids, attention_mask=mask.long()).last_hidden_state
@dataclass(frozen=True)
class WanCheckpointPaths:
root: Path
dit: list[Path]
vae: Path
text_encoder: Path | None
tokenizer: Path | None
class WanTokenizer:
"""UMT5 tokenizer wrapper returning `(input_ids, attention_mask)` like the
FastWAM call site expects."""
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name)
self.seq_len = int(seq_len)
def __call__(
self, sequence: str | Sequence[str], return_mask: bool = False, add_special_tokens: bool = True, **_: Any
):
if isinstance(sequence, str):
sequence = [sequence]
out = self.tokenizer(
list(sequence),
padding="max_length",
truncation=True,
max_length=self.seq_len,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)
if return_mask:
return out.input_ids, out.attention_mask
return out.input_ids
@dataclass
class Wan22LoadedComponents:
dit: WanVideoDiT
vae: WanVideoVAE38
text_encoder: torch.nn.Module | None
tokenizer: HuggingfaceTokenizer | None
dit_path: list[str]
vae_path: str
text_encoder_path: str | None
tokenizer_path: str | None
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
def resolve_wan_checkpoint_dir(
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
from diffusers import AutoencoderKLWan
from .wan_adapters import WanVideoVAE38
vae = AutoencoderKLWan.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype
)
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
from transformers import UMT5EncoderModel
encoder = UMT5EncoderModel.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
)
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
def resolve_wan_dit_paths(
model_id_or_path: str | Path,
*,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
) -> Path:
"""Return a local Wan2.2 checkpoint directory.
Local paths are used directly. Hub repos are downloaded with the same fixed
component names used by the upstream Wan2.2 inference code.
"""
) -> list[Path]:
"""Resolve the custom MoT DiT shards from the original Wan2.2 repo or a local dir."""
path = Path(model_id_or_path).expanduser()
if path.is_dir():
return path
return sorted(path.glob(WAN_DIT_PATTERN))
from huggingface_hub import snapshot_download
@@ -83,51 +133,9 @@ def resolve_wan_checkpoint_dir(
revision=revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
allow_patterns=[
WAN_DIT_PATTERN,
WAN_T5_CHECKPOINT,
WAN_VAE_CHECKPOINT,
f"{WAN_T5_TOKENIZER}/**",
],
)
return Path(snapshot_path)
def resolve_wan_checkpoint_paths(
checkpoint_dir: str | Path,
*,
tokenizer_dir: str | Path | None = None,
load_dit: bool = True,
load_text_encoder: bool = True,
) -> WanCheckpointPaths:
root = Path(checkpoint_dir).expanduser()
tokenizer_root = Path(tokenizer_dir).expanduser() if tokenizer_dir is not None else root
dit = sorted(root.glob(WAN_DIT_PATTERN)) if load_dit else []
vae = root / WAN_VAE_SAFE_CHECKPOINT
text_encoder = root / WAN_T5_SAFE_CHECKPOINT if load_text_encoder else None
tokenizer = tokenizer_root / WAN_T5_TOKENIZER if load_text_encoder else None
missing = []
if load_dit and len(dit) == 0:
missing.append(f"DiT ({WAN_DIT_PATTERN})")
if not vae.exists():
missing.append(f"VAE ({WAN_VAE_SAFE_CHECKPOINT})")
if load_text_encoder:
if text_encoder is None or not text_encoder.exists():
missing.append(f"text encoder ({WAN_T5_SAFE_CHECKPOINT})")
if tokenizer is None or not tokenizer.exists():
missing.append(f"tokenizer ({WAN_T5_TOKENIZER})")
if missing:
raise FileNotFoundError(
f"Incomplete Wan2.2 checkpoint directory {root}: missing {', '.join(missing)}."
)
return WanCheckpointPaths(
root=root,
dit=dit,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
allow_patterns=[WAN_DIT_PATTERN],
)
return sorted(Path(snapshot_path).glob(WAN_DIT_PATTERN))
def load_wan_video_dit(
@@ -145,107 +153,6 @@ def load_wan_video_dit(
return model.to(device=device, dtype=torch_dtype)
def load_wan_text_encoder(
checkpoint_path: str | Path,
*,
torch_dtype: torch.dtype,
device: str,
) -> torch.nn.Module:
from .wan.modules.t5 import umt5_xxl
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=torch_dtype,
device=device,
)
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.suffix != ".safetensors":
raise ValueError(f"Wan2.2 text encoder checkpoint must be safetensors, got {checkpoint_path}.")
state_dict = load_file(checkpoint_path)
model.load_state_dict(state_dict)
return model.to(device=device, dtype=torch_dtype)
def load_wan_tokenizer(tokenizer_path: str | Path, *, tokenizer_max_len: int) -> HuggingfaceTokenizer:
from .wan.modules.tokenizers import HuggingfaceTokenizer
return HuggingfaceTokenizer(
name=str(tokenizer_path),
seq_len=int(tokenizer_max_len),
clean="whitespace",
)
def load_wan_vae(checkpoint_path: str | Path, *, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
from .wan_adapters import WanVideoVAE38
return WanVideoVAE38(vae_pth=str(checkpoint_path), dtype=torch_dtype, device=device)
def load_wan22_ti2v_5b_components(
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
tokenizer_model_id: str = "Wan-AI/Wan2.2-TI2V-5B",
tokenizer_max_len: int = 512,
dit_config: dict[str, Any] | None = None,
load_text_encoder: bool = True,
):
logger.info("Loading Wan2.2-TI2V-5B components...")
start = time.time()
if dit_config is None:
raise ValueError("`dit_config` is required for Wan2.2-TI2V-5B loading.")
checkpoint_dir = resolve_wan_checkpoint_dir(model_id)
tokenizer_dir = (
checkpoint_dir if tokenizer_model_id == model_id else resolve_wan_checkpoint_dir(tokenizer_model_id)
)
paths = resolve_wan_checkpoint_paths(
checkpoint_dir,
tokenizer_dir=tokenizer_dir,
load_text_encoder=load_text_encoder,
)
dit = load_wan_video_dit(
paths.dit,
dit_config=dit_config,
torch_dtype=torch_dtype,
device=device,
)
vae = load_wan_vae(paths.vae, torch_dtype=torch_dtype, device=device)
text_encoder: torch.nn.Module | None = None
tokenizer: HuggingfaceTokenizer | None = None
if load_text_encoder:
if paths.text_encoder is None or paths.tokenizer is None:
raise FileNotFoundError("Wan2.2 text encoder/tokenizer paths were not resolved.")
text_encoder = load_wan_text_encoder(
paths.text_encoder,
torch_dtype=torch_dtype,
device=device,
)
tokenizer = load_wan_tokenizer(paths.tokenizer, tokenizer_max_len=tokenizer_max_len)
else:
logger.info(
"Skipping pretrained text encoder/tokenizer load (`load_text_encoder=False`); "
"training must provide cached `context/context_mask`."
)
logger.info("Finished loading Wan2.2-TI2V-5B components in %.2f seconds.", time.time() - start)
return Wan22LoadedComponents(
dit=dit,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
dit_path=[str(path) for path in paths.dit],
vae_path=str(paths.vae),
text_encoder_path=str(paths.text_encoder) if paths.text_encoder is not None else None,
tokenizer_path=str(paths.tokenizer) if paths.tokenizer is not None else None,
)
def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor]:
state_dict = {}
for path in paths:
@@ -254,17 +161,14 @@ def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor
__all__ = [
"WAN22_DIFFUSERS_MODEL_ID",
"WAN_DIT_PATTERN",
"WAN_T5_CHECKPOINT",
"WAN_T5_TOKENIZER",
"WAN_VAE_CHECKPOINT",
"Wan22LoadedComponents",
"WanCheckpointPaths",
"load_wan22_ti2v_5b_components",
"load_wan_text_encoder",
"load_wan_tokenizer",
"load_wan_vae",
"WanTextEncoder",
"WanTokenizer",
"build_wan_tokenizer",
"load_pretrained_wan_text_encoder",
"load_pretrained_wan_vae",
"load_wan_video_dit",
"resolve_wan_checkpoint_dir",
"resolve_wan_checkpoint_paths",
"resolve_wan_dit_paths",
]
+10 -5
View File
@@ -660,11 +660,16 @@ class WanVideoDiT(WanModel):
) * timestep.to(dtype=model_dtype).view(batch_size, 1, 1)
token_timesteps[:, 0, :] = 0
token_timesteps = token_timesteps.reshape(batch_size, -1)
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).to(
dtype=model_dtype
)
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
# Wan keeps the time embedding in fp32: the AdaLN modulation in the vendored
# Head/Block asserts e.dtype == float32 (numerical stability of the scale/shift).
# Upstream guarantees this via an fp32 autocast region, so it holds even when the
# model runs in bf16. Mirror that here, then cast the per-block modulation back to
# model_dtype so the bf16 attention blocks are not upcast to fp32.
with torch.amp.autocast("cuda", dtype=torch.float32):
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).float()
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
t_mod = t_mod.to(dtype=model_dtype)
x = self.patchify(x)
f, h, w = x.shape[2:]
+126 -74
View File
@@ -18,21 +18,13 @@ import json
import pytest
import torch
from safetensors.torch import save_model
from safetensors import safe_open
from torch import nn
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
from lerobot.policies.fastwam import modeling_fastwam
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy, resolve_wan_component_paths
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy
from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep
from lerobot.policies.fastwam.wan_components import (
WAN_DIT_PATTERN,
WAN_T5_CHECKPOINT,
WAN_T5_TOKENIZER,
WAN_VAE_CHECKPOINT,
resolve_wan_checkpoint_paths,
)
from lerobot.utils.constants import ACTION, OBS_STATE
@@ -170,8 +162,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
base_model_id=None,
)
with pytest.warns(RuntimeWarning, match="does not load pretrained FastWAM weights"):
policy = FastWAMPolicy(cfg)
policy = FastWAMPolicy(cfg)
output = policy.forward(
{
@@ -207,89 +198,96 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
]
def test_from_pretrained_loads_weights_without_initializing_wan_backbone(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
cfg.save_pretrained(tmp_path)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: FakeFastWAMCore())
reference_policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True)
save_model(reference_policy, str(tmp_path / "model.safetensors"))
class CoreWithFrozenComponents(FakeFastWAMCore):
"""Fake core mirroring the real one: frozen VAE / text encoder held as
*unregistered* attributes (via `object.__setattr__`) so they are excluded from
`state_dict()` and the saved checkpoint, but still moved by the `_apply` override."""
def __init__(self):
super().__init__()
object.__setattr__(self, "vae", nn.Linear(2, 2))
object.__setattr__(self, "text_encoder", nn.Linear(2, 2))
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)
def _apply(self, fn, *args, **kwargs):
super()._apply(fn, *args, **kwargs)
self.vae._apply(fn)
self.text_encoder._apply(fn)
return self
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
def build_core(self, config):
core = CoreWithFrozenComponents()
with torch.no_grad():
core.dit.weight.fill_(0.5)
return core
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", build_core)
reference = FastWAMPolicy(cfg)
with torch.no_grad():
reference.model.dit.weight.fill_(1.25) # a distinctive, trained-looking weight
reference.save_pretrained(tmp_path)
# Building from Wan2.2 must never happen on a checkpoint load.
def fail_if_wan_pretrained_is_loaded(*args, **kwargs):
raise AssertionError("from_pretrained must not initialize or download Wan2.2 backbone components")
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
monkeypatch.setattr(
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
fail_if_wan_pretrained_is_loaded,
)
monkeypatch.setattr(
modeling_fastwam,
"_build_core_model_from_architecture",
lambda config: FakeFastWAMCore(),
raising=False,
)
loaded_components_from = []
monkeypatch.setattr(
FastWAMPolicy,
"load_wan_components_from_pretrained",
lambda self, path: loaded_components_from.append(path),
)
policy = FastWAMPolicy.from_pretrained(tmp_path, strict=False)
policy = FastWAMPolicy.from_pretrained(tmp_path)
assert isinstance(policy.model, FakeFastWAMCore)
assert loaded_components_from == [tmp_path]
assert isinstance(policy.model, CoreWithFrozenComponents)
# The bundled checkpoint weights overwrote the freshly built (0.5) DiT weights.
assert torch.allclose(policy.model.dit.weight, torch.full_like(policy.model.dit.weight, 1.25))
def test_save_pretrained_copies_required_wan_sidecars(monkeypatch, tmp_path):
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
source = tmp_path / "source"
tokenizer = source / WAN_T5_TOKENIZER
tokenizer.mkdir(parents=True)
vae = source / WAN_VAE_CHECKPOINT
text_encoder = source / WAN_T5_CHECKPOINT
tokenizer_file = tokenizer / "tokenizer.json"
vae.write_bytes(b"vae")
text_encoder.write_bytes(b"text")
tokenizer_file.write_text("{}")
core = FakeFastWAMCore()
core.model_paths = {
"vae": str(vae),
"text_encoder": str(text_encoder),
"tokenizer": str(tokenizer),
}
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: core)
policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
save_dir = tmp_path / "saved"
policy.save_pretrained(save_dir)
assert (save_dir / "model.safetensors").is_file()
assert (save_dir / WAN_VAE_CHECKPOINT).read_bytes() == b"vae"
assert (save_dir / WAN_T5_CHECKPOINT).read_bytes() == b"text"
assert (save_dir / WAN_T5_TOKENIZER / "tokenizer.json").read_text() == "{}"
# No Wan sidecar files either: the frozen backbone comes from the diffusers repo.
assert not (save_dir / "Wan2.2_VAE.safetensors").exists()
assert not (save_dir / "google").exists()
with safe_open(save_dir / "model.safetensors", framework="pt") as f:
keys = set(f.keys())
# Lean checkpoint: only the trainable DiT is saved; the frozen VAE / UMT5 text
# encoder are excluded (loaded from the diffusers/transformers repos at init).
assert any(key.startswith("model.dit.") for key in keys)
assert not any(key.startswith("model.vae.") for key in keys)
assert not any(key.startswith("model.text_encoder.") for key in keys)
def test_wan_component_resolution_uses_fixed_safetensors_layout(tmp_path):
tokenizer = tmp_path / WAN_T5_TOKENIZER
tokenizer.mkdir(parents=True)
(tmp_path / WAN_VAE_CHECKPOINT).touch()
(tmp_path / WAN_T5_CHECKPOINT).touch()
(tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors").touch()
(tokenizer / "tokenizer.json").touch()
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
paths = resolve_wan_checkpoint_paths(tmp_path)
sidecar_paths = resolve_wan_component_paths(tmp_path)
# Unregistered: excluded from state_dict and from the optimizer's parameter set.
sd = policy.state_dict()
assert not any(k.startswith("model.vae.") or k.startswith("model.text_encoder.") for k in sd)
param_names = [n for n, _ in policy.named_parameters()]
assert not any("vae" in n or "text_encoder" in n for n in param_names)
assert paths.dit == [tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors"]
assert paths.vae == tmp_path / WAN_VAE_CHECKPOINT
assert paths.text_encoder == tmp_path / WAN_T5_CHECKPOINT
assert paths.tokenizer == tmp_path / WAN_T5_TOKENIZER
assert sidecar_paths.dit == []
assert WAN_DIT_PATTERN == "diffusion_pytorch_model*.safetensors"
(tmp_path / WAN_T5_CHECKPOINT).unlink()
with pytest.raises(FileNotFoundError, match="text encoder"):
resolve_wan_checkpoint_paths(tmp_path)
# ...but the `_apply` override still carries them through `.to()` (dtype stands in
# for device on a CPU box), so they never strand off the rest of the model.
policy.to(torch.float64)
assert policy.model.dit.weight.dtype == torch.float64 # registered
assert policy.model.vae.weight.dtype == torch.float64 # unregistered, moved via _apply
assert policy.model.text_encoder.weight.dtype == torch.float64
def test_pretrained_config_round_trips_fastwam_features(tmp_path):
@@ -302,3 +300,57 @@ def test_pretrained_config_round_trips_fastwam_features(tmp_path):
assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL
assert loaded.action_feature.shape == (7,)
assert loaded.robot_state_feature.shape == (8,)
def test_vae_adapter_empty_build_encode_decode_shapes():
"""Offline glue check of the diffusers-backed VAE adapter (random weights).
Validates the encode/decode contract 48 latent channels, 16x spatial / 4x
temporal compression, list-or-batch input, scaling round-trip without any
weight download. (Numerical fidelity vs the original Wan VAE is a separate,
GPU + real-weights verification step.)
"""
pytest.importorskip("diffusers")
from diffusers import AutoencoderKLWan
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
# Production always loads a real pretrained VAE from the diffusers repo; here we
# build the same architecture with random weights and dummy standardization stats
# to exercise the adapter's shape/scaling contract offline (fidelity is checked
# separately, with real weights, on GPU).
arch = {
"base_dim": 160,
"decoder_base_dim": 256,
"z_dim": 48,
"dim_mult": [1, 2, 4, 4],
"num_res_blocks": 2,
"attn_scales": [],
"temperal_downsample": [False, True, True],
"dropout": 0.0,
"is_residual": True,
"in_channels": 12,
"out_channels": 12,
"patch_size": 2,
"scale_factor_spatial": 16,
"scale_factor_temporal": 4,
"clip_output": False,
"latents_mean": [0.0] * 48,
"latents_std": [1.0] * 48,
}
raw = AutoencoderKLWan.from_config(arch)
vae = WanVideoVAE38(dtype=torch.float32, device="cpu", pretrained=raw)
assert vae.z_dim == 48
assert vae.upsampling_factor == 16
assert vae.temporal_downsample_factor == 4
video = torch.rand(1, 3, 5, 32, 32) * 2 - 1 # [B,C,T,H,W] in [-1,1]
latents = vae.encode(video)
assert latents.shape == (1, 48, 2, 2, 2) # T'=(5-1)//4+1, H'=W'=32//16
decoded = vae.decode(latents)
assert decoded.shape[0] == 1 and decoded.shape[1] == 3 and decoded.shape[-2:] == (32, 32)
assert decoded.min() >= -1.0 and decoded.max() <= 1.0
# list input is accepted and equals the batched path
assert torch.equal(vae.encode([video[0]]), latents)
Generated
-18
View File
@@ -1636,18 +1636,6 @@ http = [
{ name = "aiohttp" },
]
[[package]]
name = "ftfy"
version = "6.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "wcwidth" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" },
]
[[package]]
name = "future"
version = "1.0.0"
@@ -2708,7 +2696,6 @@ all = [
{ name = "faker" },
{ name = "fastapi" },
{ name = "feetech-servo-sdk" },
{ name = "ftfy" },
{ name = "grpcio" },
{ name = "grpcio-tools" },
{ name = "gym-aloha" },
@@ -2747,7 +2734,6 @@ all = [
{ name = "pyzmq" },
{ name = "qwen-vl-utils" },
{ name = "reachy2-sdk" },
{ name = "regex" },
{ name = "rerun-sdk" },
{ name = "ruff" },
{ name = "scikit-image" },
@@ -2846,8 +2832,6 @@ evaluation = [
]
fastwam = [
{ name = "diffusers" },
{ name = "ftfy" },
{ name = "regex" },
{ name = "transformers" },
]
feetech = [
@@ -3108,7 +3092,6 @@ requires-dist = [
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
{ name = "ftfy", marker = "extra == 'fastwam'", specifier = ">=6.1.1,<7.0.0" },
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
@@ -3277,7 +3260,6 @@ requires-dist = [
{ name = "pyzmq", marker = "extra == 'pyzmq-dep'", specifier = ">=26.2.1,<28.0.0" },
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
{ name = "regex", marker = "extra == 'fastwam'", specifier = ">=2024.0.0,<2027.0.0" },
{ name = "requests", specifier = ">=2.32.0,<3.0.0" },
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.1" },