diff --git a/pyproject.toml b/pyproject.toml index 3d2002a06..fa024cedc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,8 +219,6 @@ eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] fastwam = [ "lerobot[transformers-dep]", "lerobot[diffusers-dep]", - "ftfy>=6.1.1,<7.0.0", - "regex>=2024.0.0,<2027.0.0", ] hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"] diff --git a/src/lerobot/policies/fastwam/configuration_fastwam.py b/src/lerobot/policies/fastwam/configuration_fastwam.py index e591f8c78..57ccccb7c 100644 --- a/src/lerobot/policies/fastwam/configuration_fastwam.py +++ b/src/lerobot/policies/fastwam/configuration_fastwam.py @@ -118,10 +118,6 @@ def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, Policy return coerced -def _coerce_normalization_mapping(mapping: dict[str, Any]) -> dict[str, Any]: - return {key: _coerce_enum(NormalizationMode, value) for key, value in mapping.items()} - - def _is_local_model_id(value: str) -> bool: path = Path(value).expanduser() return path.is_absolute() or value.startswith(("./", "../", "~")) or path.exists() @@ -200,7 +196,7 @@ class FastWAMConfig(PreTrainedConfig): loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0}) video_dit_config: dict[str, Any] | None = None action_dit_config: dict[str, Any] | None = None - normalization_mapping: dict[str, Any] = field( + normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, "STATE": NormalizationMode.MEAN_STD, @@ -220,7 +216,6 @@ class FastWAMConfig(PreTrainedConfig): self.input_features = _coerce_policy_features(self.input_features) self.output_features = _coerce_policy_features(self.output_features) self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions] - self.normalization_mapping = _coerce_normalization_mapping(self.normalization_mapping) self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim) self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim) self.video_dit_config["fp32_attention"] = bool(self.fp32_attention) @@ -266,6 +261,38 @@ class FastWAMConfig(PreTrainedConfig): def get_scheduler_preset(self) -> None: return None + def set_dataset_feature_metadata(self, dataset_features: dict[str, Any]) -> None: + """Rebuild visual input features from the dataset's real camera keys. + + FastWAM's `__post_init__` installs a synthetic single-image default + (`observation.images.image` at full `image_size` width). For datasets + with one or more separately-named cameras (e.g. `observation.images.top`, + `observation.images.wrist`), this hook — invoked by `make_policy` once the + dataset metadata is known — replaces that default with the actual camera + keys, each declared at the policy's native per-camera resolution + (`image_size[0]` x `image_size[1] // num_cameras`). The accompanying + resize step in `make_fastwam_pre_post_processors` resizes raw frames to + match, so heterogeneous source resolutions (e.g. 480x640) are supported. + """ + image_keys = sorted( + key + for key, feature in dataset_features.items() + if key.startswith("observation.images.") + and feature.get("dtype") in ("video", "image") + ) + if not image_keys: + return + height, total_width = self.image_size + per_cam_width = total_width // len(image_keys) + new_inputs: dict[str, PolicyFeature] = { + key: PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, per_cam_width)) + for key in image_keys + } + if self.proprio_dim is not None and OBS_STATE in dataset_features: + new_inputs[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,)) + self.input_features = new_inputs + self.validate_features() + def validate_features(self) -> None: if self.action_dim <= 0: raise ValueError(f"`action_dim` must be positive, got {self.action_dim}.") diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 1205acf72..b64d44785 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -14,24 +14,18 @@ from __future__ import annotations -import shutil -import warnings +import logging from collections import deque -from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any import torch from torch import Tensor -from lerobot.configs import PreTrainedConfig from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.constants import OBS_STATE from .configuration_fastwam import FastWAMConfig -if TYPE_CHECKING: - from .wan_components import WanCheckpointPaths - class FastWAMPolicy(PreTrainedPolicy): """LeRobot policy wrapper for FastWAM. @@ -49,98 +43,62 @@ class FastWAMPolicy(PreTrainedPolicy): self, config: FastWAMConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None, - **kwargs: Any, ): - skip_wan_init = bool(kwargs.pop("_skip_wan_init", False)) super().__init__(config, dataset_stats) config.validate_features() self.config = config self.dataset_stats = dataset_stats - suppress_base_init_warning = bool(kwargs.pop("_suppress_base_init_warning", False)) - if not skip_wan_init and not suppress_base_init_warning: - warnings.warn( - "FastWAMPolicy(config) initializes from architecture/config and does not load pretrained " - "FastWAM weights. For training or evaluation, use `make_policy(config)` or " - "`FastWAMPolicy.from_pretrained(...)`.", - RuntimeWarning, - stacklevel=2, - ) - if skip_wan_init: - self.model = _build_core_model_from_architecture(config) - else: - self.model = self._build_core_model(config) + self.model = self._build_core_model(config) self.reset() @classmethod - def from_pretrained( - cls, - pretrained_name_or_path: str | Path, - *, - config: FastWAMConfig | None = None, - force_download: bool = False, - resume_download: bool | None = None, - proxies: dict | None = None, - token: str | bool | None = None, - cache_dir: str | Path | None = None, - local_files_only: bool = False, - revision: str | None = None, - strict: bool = False, - **kwargs: Any, - ) -> FastWAMPolicy: - """Load FastWAM weights and local Wan components from one HF directory. + def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool): + """Shape-aware load that supports cross-embodiment fine-tuning. - Args: - pretrained_name_or_path (str | Path): HF-format policy directory - containing `config.json`, `model.safetensors`, local Wan VAE, - local UMT5 text encoder safetensors, and tokenizer files. - config (FastWAMConfig | None): Optional config override. When - omitted, `config.json` is read from `pretrained_name_or_path`. - force_download (bool): Forwarded to LeRobot's pretrained loader. - resume_download (bool | None): Forwarded to LeRobot's pretrained loader. - proxies (dict | None): Forwarded to LeRobot's pretrained loader. - token (str | bool | None): Forwarded to LeRobot's pretrained loader. - cache_dir (str | Path | None): Forwarded to LeRobot's pretrained loader. - local_files_only (bool): Forwarded to LeRobot's pretrained loader. - revision (str | None): Forwarded to LeRobot's pretrained loader. - strict (bool): Whether safetensors loading should require an exact - match between checkpoint keys and policy module keys. - **kwargs (Any): Extra constructor arguments forwarded to - `FastWAMPolicy`. + `safetensors.load_model(strict=False)` ignores missing/unexpected keys but + still raises on a shape mismatch for a shared key. When fine-tuning from a + checkpoint trained on a different embodiment (e.g. the LIBERO 7-DoF / 8-dim + checkpoint adapted to a 6-DoF / 6-dim arm), the action encoder/head and + proprio encoder legitimately differ in shape. With `strict=False` we drop + only those shape-mismatched tensors — leaving them at their freshly + initialized values — and load every compatible tensor. With `strict=True` + the standard exact-match loader is used. """ + from safetensors import safe_open - pretrained_path = _resolve_pretrained_directory( - pretrained_name_or_path=pretrained_name_or_path, - force_download=force_download, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - if config is None: - config = PreTrainedConfig.from_pretrained(pretrained_path) - if not isinstance(config, FastWAMConfig): - raise TypeError(f"Expected FastWAM config, got {type(config).__name__}.") - kwargs["_skip_wan_init"] = True - policy = super().from_pretrained( - pretrained_path, - config=config, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - strict=strict, - **kwargs, - ) - policy.load_wan_components_from_pretrained(pretrained_path) - policy.eval() - return policy + model_state_dict = model.state_dict() + mismatched = [] + with safe_open(model_file, framework="pt") as f: + checkpoint_keys = list(f.keys()) + for key in checkpoint_keys: + if key in model_state_dict and tuple(model_state_dict[key].shape) != tuple( + f.get_slice(key).get_shape() + ): + mismatched.append(key) - def _save_pretrained(self, save_directory: Path) -> None: - super()._save_pretrained(save_directory) - _copy_wan_components_from_policy(policy=self, save_directory=save_directory) + if not mismatched: + return super()._load_as_safetensor(model, model_file, map_location, strict) + if strict: + raise RuntimeError( + f"FastWAM: {len(mismatched)} checkpoint tensors have a shape mismatch under " + f"strict=True: {mismatched}" + ) + + from safetensors.torch import load_file + + logging.warning( + "FastWAM cross-embodiment load: reinitializing %d shape-mismatched tensor(s), keeping " + "every compatible weight: %s", + len(mismatched), + mismatched, + ) + state_dict = load_file(model_file, device="cpu") + for key in mismatched: + state_dict.pop(key, None) + model.load_state_dict(state_dict, strict=False) + if map_location and map_location != "cpu": + model.to(map_location) + return model def get_optim_params(self) -> dict[str, Any]: params = ( @@ -151,21 +109,41 @@ class FastWAMPolicy(PreTrainedPolicy): params.extend(list(proprio_encoder.parameters())) return {"params": [p for p in params if p.requires_grad]} - def load_wan_components_from_pretrained(self, pretrained_name_or_path: str | Path) -> None: - """Attach local Wan VAE, text encoder, and tokenizer from a HF directory. - - Args: - pretrained_name_or_path (str | Path): Directory containing - `Wan2.2_VAE.safetensors`, `models_t5_umt5-xxl-enc-bf16.safetensors`, - and `google/umt5-xxl/` tokenizer files. - """ - - paths = resolve_wan_component_paths(pretrained_name_or_path) - _load_wan_components_into_policy(policy=self, paths=paths) - def reset(self) -> None: self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps) + def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Adapt a standard LeRobot batch to the FastWAM-native sample that + `FastWAM.build_inputs` consumes (`video`, `action`, `context`/`context_mask`, + per-frame `proprio`). + + The LeRobot training loop passes raw `observation.images.*`, a single-step + `observation.state` `[B, D]`, `action`, and a language `task` string. We do + only the translation `build_inputs` can't: stack the camera frames into a + video, encode the prompt with the (frozen) text encoder (mirroring inference, + so language-conditioned datasets need no precomputed context), and give proprio + the per-frame axis `build_inputs` indexes. All shape/presence validation is + left to `build_inputs`, the single authority on the contract. + """ + sample = dict(batch) + if "video" not in sample: + sample["video"] = _stack_video_from_images(batch, self.config) + if "context" not in sample or "context_mask" not in sample: + prompt = _prompt_from_batch(batch=batch, config=self.config) + if prompt is None: + raise KeyError( + "FastWAM training requires a `task`/`prompt` to encode text context, " + "or precomputed `context`/`context_mask` in the batch." + ) + sample["context"], sample["context_mask"] = self.model.encode_prompt(prompt) + if self.config.proprio_dim is not None and "proprio" not in sample: + state = sample.get(OBS_STATE) + if state is not None: + # LeRobot gives a single-step state [B, D]; build_inputs expects + # per-frame [B, T, D] and uses frame 0, so add a T=1 axis. + sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state + return sample + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Compute FastWAM training loss for a LeRobot batch. @@ -180,7 +158,7 @@ class FastWAMPolicy(PreTrainedPolicy): key required by LeRobot and optional tensor metrics. """ - sample = _batch_to_training_sample(batch=batch, config=self.config) + sample = self._batch_to_training_sample(batch) loss, metrics = self.model.training_loss(sample) output = {"loss": loss} for key, value in (metrics or {}).items(): @@ -230,223 +208,57 @@ class FastWAMPolicy(PreTrainedPolicy): return self._action_queue.popleft() def _build_core_model(self, config: FastWAMConfig) -> torch.nn.Module: - return _build_core_model_from_wan22(config) + """Build the FastWAM core for training / inference. - -def _resolve_pretrained_directory( - pretrained_name_or_path: str | Path, - *, - force_download: bool, - token: str | bool | None, - cache_dir: str | Path | None, - local_files_only: bool, - revision: str | None, -) -> Path: - path = Path(pretrained_name_or_path) - if path.is_dir(): - return path - - from huggingface_hub import snapshot_download - - from .wan_components import ( - WAN_T5_CHECKPOINT, - WAN_T5_TOKENIZER, - WAN_VAE_CHECKPOINT, - ) - - snapshot_path = snapshot_download( - repo_id=str(pretrained_name_or_path), - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - token=token, - local_files_only=local_files_only, - allow_patterns=[ - "config.json", - "model.safetensors", - WAN_VAE_CHECKPOINT, - WAN_T5_CHECKPOINT, - f"{WAN_T5_TOKENIZER}/**", - ], - ) - return Path(snapshot_path) - - -def resolve_wan_component_paths(pretrained_name_or_path: str | Path) -> WanCheckpointPaths: - """Resolve local Wan component paths stored beside FastWAM HF weights. - - Args: - pretrained_name_or_path (str | Path): HF-format FastWAM directory. - - Returns: - WanCheckpointPaths: Existing VAE, text encoder, and tokenizer paths. - DiT shards are intentionally optional here because FastWAM HF - checkpoints store trainable DiT weights in `model.safetensors`. - """ - - from .wan_components import resolve_wan_checkpoint_paths - - return resolve_wan_checkpoint_paths( - pretrained_name_or_path, - load_dit=False, - load_text_encoder=True, - ) - - -def _load_wan_components_into_policy(policy: FastWAMPolicy, paths: WanCheckpointPaths) -> None: - from .wan_components import load_wan_text_encoder, load_wan_tokenizer, load_wan_vae - - if paths.text_encoder is None or paths.tokenizer is None: - raise FileNotFoundError("FastWAM HF checkpoint requires Wan text encoder and tokenizer sidecars.") - dtype = _dtype_from_name(policy.config.torch_dtype) - device = str(policy.config.device) - policy.model.vae = load_wan_vae(paths.vae, torch_dtype=dtype, device=device) - policy.model.text_encoder = load_wan_text_encoder(paths.text_encoder, torch_dtype=dtype, device=device) - policy.model.tokenizer = load_wan_tokenizer( - paths.tokenizer, - tokenizer_max_len=int(policy.config.tokenizer_max_len), - ) - model_paths = dict(getattr(policy.model, "model_paths", {}) or {}) - model_paths.update( - { - "vae": str(paths.vae), - "text_encoder": str(paths.text_encoder), - "tokenizer": str(paths.tokenizer), - } - ) - policy.model.model_paths = model_paths - - -def _copy_wan_components_from_policy(policy: FastWAMPolicy, save_directory: Path) -> None: - model_paths = getattr(policy.model, "model_paths", {}) or {} - paths = { - "vae": model_paths.get("vae"), - "text_encoder": model_paths.get("text_encoder"), - "tokenizer": model_paths.get("tokenizer"), - } - missing = [name for name, path in paths.items() if path is None] - if missing: - raise RuntimeError( - "FastWAM save_pretrained requires local Wan component paths for " - f"{missing}. Load or initialize the policy with local Wan VAE, text encoder, and tokenizer files." + Only the trainable parts (the MoT DiT and the proprio encoder) are + materialized empty here and then filled from the policy's + `model.safetensors` by the base `from_pretrained`. The *frozen* Wan2.2 VAE + and UMT5 text encoder are loaded with their real weights from the + `Wan-AI/Wan2.2-TI2V-5B-Diffusers` repo (cached in the HF cache, shared + across checkpoints) and are intentionally excluded from `model.safetensors` + — see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`. + """ + from .modular_fastwam import ActionDiT, FastWAM, MoT + from .wan_components import ( + build_wan_tokenizer, + load_pretrained_wan_text_encoder, + load_pretrained_wan_vae, ) - _copy_component_path(Path(paths["vae"]), save_directory / Path(paths["vae"]).name) - _copy_component_path(Path(paths["text_encoder"]), save_directory / Path(paths["text_encoder"]).name) - tokenizer_source = Path(paths["tokenizer"]) - _copy_component_path(tokenizer_source, save_directory / "google" / "umt5-xxl") + from .wan_video_dit import WanVideoDiT - -def _copy_component_path(source: Path, destination: Path) -> None: - source = source.expanduser() - if not source.exists(): - raise FileNotFoundError(f"FastWAM component path does not exist: {source}") - if source.resolve() == destination.resolve(): - return - destination.parent.mkdir(parents=True, exist_ok=True) - if source.is_dir(): - shutil.copytree(source, destination, dirs_exist_ok=True) - else: - shutil.copy2(source, destination) - - -def _build_core_model_from_wan22(config: FastWAMConfig) -> torch.nn.Module: - from .modular_fastwam import FastWAM - - dtype = _dtype_from_name(config.torch_dtype) - return FastWAM.from_wan22_pretrained( - device=config.device, - torch_dtype=dtype, - model_id=config.model_id, - tokenizer_model_id=config.tokenizer_model_id, - tokenizer_max_len=config.tokenizer_max_len, - load_text_encoder=config.load_text_encoder, - proprio_dim=config.proprio_dim, - video_dit_config=config.video_dit_config, - action_dit_config=config.action_dit_config, - mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn, - video_train_shift=float(config.video_scheduler["train_shift"]), - video_infer_shift=float(config.video_scheduler["infer_shift"]), - video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]), - action_train_shift=float(config.action_scheduler["train_shift"]), - action_infer_shift=float(config.action_scheduler["infer_shift"]), - action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]), - loss_lambda_video=float(config.loss["lambda_video"]), - loss_lambda_action=float(config.loss["lambda_action"]), - ) - - -def _build_core_model_from_architecture(config: FastWAMConfig) -> torch.nn.Module: - from .modular_fastwam import ActionDiT, FastWAM, MoT - from .wan_video_dit import WanVideoDiT - - dtype = _dtype_from_name(config.torch_dtype) - video_expert = WanVideoDiT(**config.video_dit_config).to(device=config.device, dtype=dtype) - action_expert = ActionDiT(**config.action_dit_config).to(device=config.device, dtype=dtype) - mot = MoT( - mixtures={"video": video_expert, "action": action_expert}, - mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn, - ) - return FastWAM( - video_expert=video_expert, - action_expert=action_expert, - mot=mot, - vae=_FastWAMVAEPlaceholder(), - text_encoder=None, - tokenizer=None, - text_dim=int(config.video_dit_config["text_dim"]), - proprio_dim=config.proprio_dim, - device=config.device, - torch_dtype=dtype, - video_train_shift=float(config.video_scheduler["train_shift"]), - video_infer_shift=float(config.video_scheduler["infer_shift"]), - video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]), - action_train_shift=float(config.action_scheduler["train_shift"]), - action_infer_shift=float(config.action_scheduler["infer_shift"]), - action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]), - loss_lambda_video=float(config.loss["lambda_video"]), - loss_lambda_action=float(config.loss["lambda_action"]), - ) - - -class _FastWAMVAEPlaceholder(torch.nn.Module): - """Minimal VAE placeholder for checkpoint loading without Wan2.2 VAE. - - Args: - temporal_downsample_factor (int): Temporal compression factor expected - by FastWAM latent shape logic. - upsampling_factor (int): Spatial compression factor expected by FastWAM. - z_dim (int): Latent channel count used by Wan2.2 TI2V VAE. - """ - - temporal_downsample_factor: int = 4 - upsampling_factor: int = 8 - - def __init__(self, z_dim: int = 48): - super().__init__() - self.model = type("VAEModelShape", (), {"z_dim": int(z_dim)})() - - def encode(self, *args, **kwargs): - raise RuntimeError( - "FastWAM VAE placeholder cannot encode images; load Wan2.2 VAE for image inference." + dtype = _dtype_from_name(config.torch_dtype) + device = config.device + video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype) + action_expert = ActionDiT(**config.action_dit_config).to(device=device, dtype=dtype) + mot = MoT( + mixtures={"video": video_expert, "action": action_expert}, + mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn, ) - - def decode(self, *args, **kwargs): - raise RuntimeError( - "FastWAM VAE placeholder cannot decode latents; load Wan2.2 VAE for video inference." + text_encoder = ( + load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device) + if config.load_text_encoder + else None + ) + return FastWAM( + video_expert=video_expert, + action_expert=action_expert, + mot=mot, + vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device), + text_encoder=text_encoder, + tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len), + text_dim=int(config.video_dit_config["text_dim"]), + proprio_dim=config.proprio_dim, + device=device, + torch_dtype=dtype, + video_train_shift=float(config.video_scheduler["train_shift"]), + video_infer_shift=float(config.video_scheduler["infer_shift"]), + video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]), + action_train_shift=float(config.action_scheduler["train_shift"]), + action_infer_shift=float(config.action_scheduler["infer_shift"]), + action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]), + loss_lambda_video=float(config.loss["lambda_video"]), + loss_lambda_action=float(config.loss["lambda_action"]), ) - - -def _batch_to_training_sample(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Tensor]: - sample = dict(batch) - if "video" not in sample: - sample["video"] = _stack_video_from_images(batch, config) - if "proprio" not in sample and OBS_STATE in batch: - sample["proprio"] = batch[OBS_STATE] - required = {"video", ACTION, "context", "context_mask"} - missing = sorted(required - set(sample)) - if missing: - raise KeyError(f"FastWAM training batch is missing keys: {missing}.") - return sample def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]: diff --git a/src/lerobot/policies/fastwam/modular_fastwam.py b/src/lerobot/policies/fastwam/modular_fastwam.py index d9481064e..344830c6b 100644 --- a/src/lerobot/policies/fastwam/modular_fastwam.py +++ b/src/lerobot/policies/fastwam/modular_fastwam.py @@ -24,7 +24,13 @@ import torch.nn as nn import torch.nn.functional as functional from PIL import Image -from .wan_components import load_wan22_ti2v_5b_components +from .wan_components import ( + build_wan_tokenizer, + load_pretrained_wan_text_encoder, + load_pretrained_wan_vae, + load_wan_video_dit, + resolve_wan_dit_paths, +) from .wan_video_dit import ( FastWAMAttentionBlock, WanContinuousFlowMatchScheduler, @@ -846,9 +852,17 @@ class FastWAM(torch.nn.Module): # Keep trainer compatibility: optimizer and freeze logic use `model.dit`. self.dit = self.mot - self.vae = vae - self.text_encoder = text_encoder + # Frozen Wan2.2 components: bypass `nn.Module.__setattr__` so they are NOT + # registered as submodules. They are therefore excluded from `state_dict()` + # (lean checkpoints), `parameters()`, and DDP gradient sync, and are loaded + # with their real weights from the diffusers/transformers repos at construction. + # Device/dtype moves still reach them via the `_apply` override below. + object.__setattr__(self, "vae", vae) + object.__setattr__(self, "text_encoder", text_encoder) self.tokenizer = tokenizer + vae.requires_grad_(False) + if text_encoder is not None: + text_encoder.requires_grad_(False) if text_dim is None: if self.text_encoder is None: raise ValueError("`text_dim` is required when `text_encoder` is not loaded.") @@ -913,18 +927,17 @@ class FastWAM(torch.nn.Module): raise ValueError("`video_dit_config` is required for FastWAM.from_wan22_pretrained().") if "text_dim" not in video_dit_config: raise ValueError("`video_dit_config['text_dim']` is required for FastWAM.") + del tokenizer_model_id # tokenizer is the stock UMT5 one (google/umt5-xxl) - components = load_wan22_ti2v_5b_components( - device=device, - torch_dtype=torch_dtype, - model_id=model_id, - tokenizer_model_id=tokenizer_model_id, - tokenizer_max_len=tokenizer_max_len, + # Custom MoT video DiT from the original Wan2.2 repo; frozen VAE / UMT5 from + # the diffusers conversion. This is the offline base-creation path; the + # weights it loads are then bundled into the FastWAM `model.safetensors`. + video_expert = load_wan_video_dit( + resolve_wan_dit_paths(model_id), dit_config=video_dit_config, - load_text_encoder=load_text_encoder, + torch_dtype=torch_dtype, + device=device, ) - - video_expert = components.dit action_expert = ActionDiT(**action_dit_config).to(device=device, dtype=torch_dtype) if int(action_expert.num_heads) != int(video_expert.num_heads): raise ValueError("ActionDiT `num_heads` must match video expert for MoT mixed attention.") @@ -938,13 +951,21 @@ class FastWAM(torch.nn.Module): mot_checkpoint_mixed_attn=mot_checkpoint_mixed_attn, ) - model = cls( + vae = load_pretrained_wan_vae(torch_dtype=torch_dtype, device=device) + text_encoder = ( + load_pretrained_wan_text_encoder(torch_dtype=torch_dtype, device=device) + if load_text_encoder + else None + ) + tokenizer = build_wan_tokenizer(tokenizer_max_len=tokenizer_max_len) + + return cls( video_expert=video_expert, action_expert=action_expert, mot=mot, - vae=components.vae, - text_encoder=components.text_encoder, - tokenizer=components.tokenizer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, text_dim=int(video_dit_config["text_dim"]), proprio_dim=proprio_dim, device=device, @@ -958,20 +979,17 @@ class FastWAM(torch.nn.Module): loss_lambda_video=loss_lambda_video, loss_lambda_action=loss_lambda_action, ) - model.model_paths = { - "video_dit": components.dit_path, - "vae": components.vae_path, - "text_encoder": components.text_encoder_path, - "tokenizer": components.tokenizer_path, - } - return model - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - self.mot.to(*args, **kwargs) + def _apply(self, fn, *args, **kwargs): + # `.to()` / `.cuda()` / `.cpu()` and accelerate/DDP device moves all funnel + # through `_apply`, and the parent policy reaches us via `child._apply(fn)` + # (not `child.to()`). Propagate `fn` to the *unregistered* frozen VAE / text + # encoder here so they follow the rest of the model onto the right device, + # while staying out of `state_dict()` / `parameters()`. + super()._apply(fn, *args, **kwargs) + self.vae._apply(fn) if self.text_encoder is not None: - self.text_encoder.to(*args, **kwargs) - self.vae.to(*args, **kwargs) + self.text_encoder._apply(fn) return self @staticmethod @@ -1628,7 +1646,7 @@ class FastWAM(torch.nn.Module): latent_w = width // self.vae.upsampling_factor generator = None if seed is None else torch.Generator(device=rand_device).manual_seed(seed) return torch.randn( - (1, self.vae.model.z_dim, latent_t, latent_h, latent_w), + (1, self.vae.z_dim, latent_t, latent_h, latent_w), generator=generator, device=rand_device, dtype=torch.float32, diff --git a/src/lerobot/policies/fastwam/processor_fastwam.py b/src/lerobot/policies/fastwam/processor_fastwam.py index 56ef7bfc6..8fc61446b 100644 --- a/src/lerobot/policies/fastwam/processor_fastwam.py +++ b/src/lerobot/policies/fastwam/processor_fastwam.py @@ -24,6 +24,7 @@ from lerobot.processor import ( ActionProcessorStep, AddBatchDimensionProcessorStep, DeviceProcessorStep, + ImageCropResizeProcessorStep, NormalizerProcessorStep, PolicyAction, PolicyProcessorPipeline, @@ -90,9 +91,8 @@ def make_fastwam_pre_post_processors( output processor pipelines discoverable by LeRobot. """ - normalization_stats: dict[str, dict[str, Any]] = { - key: dict(value) for key, value in (dataset_stats or {}).items() - } + # force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1] + normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {}) for key, feature in config.input_features.items(): if feature.type != FeatureType.VISUAL: continue @@ -101,10 +101,23 @@ def make_fastwam_pre_post_processors( "mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32), "std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32), } + + # resize visual inputs to match model expected input size, if necessary + visual_shapes = [ + feature.shape + for feature in config.input_features.values() + if feature.type == FeatureType.VISUAL + ] + resize_steps = [] + if visual_shapes: + target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2])) + resize_steps.append(ImageCropResizeProcessorStep(resize_size=target_hw)) + input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), DeviceProcessorStep(device=config.device), + *resize_steps, NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, diff --git a/src/lerobot/policies/fastwam/wan/README.md b/src/lerobot/policies/fastwam/wan/README.md index 7d0a2b169..dbd56d1f2 100644 --- a/src/lerobot/policies/fastwam/wan/README.md +++ b/src/lerobot/policies/fastwam/wan/README.md @@ -10,17 +10,16 @@ Copied files: - `wan/modules/attention.py` - `wan/modules/model.py` -- `wan/modules/t5.py` -- `wan/modules/tokenizers.py` -- `wan/modules/vae2_2.py` - `wan/modules/__init__.py` - `wan/utils/fm_solvers.py` - `wan/utils/__init__.py` -FastWAM-specific model glue and any larger code adapted from these modules live outside this directory. This keeps the upstream Wan2.2 code reviewable as a vendored reference subset and makes it straightforward to replace this directory with an external Wan2.2 dependency by changing import paths. +This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE, +UMT5 text encoder, and tokenizer are no longer vendored — they come from +`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and +`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`). Current FastWAM adapters that directly reuse this vendored subset: -- `../wan_components.py` instantiates the upstream `wan.modules.t5.umt5_xxl` encoder factory and uses `wan.modules.tokenizers.HuggingfaceTokenizer`. -- `../wan_adapters.py` wraps `wan.modules.vae2_2.Wan2VAE` with the FastWAM tensor-batch encode/decode API. +- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`. - `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps. diff --git a/src/lerobot/policies/fastwam/wan/modules/__init__.py b/src/lerobot/policies/fastwam/wan/modules/__init__.py index ecc42a31e..c4c595029 100644 --- a/src/lerobot/policies/fastwam/wan/modules/__init__.py +++ b/src/lerobot/policies/fastwam/wan/modules/__init__.py @@ -1,17 +1,8 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. from .attention import flash_attention from .model import WanModel -from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model -from .tokenizers import HuggingfaceTokenizer -from .vae2_2 import Wan2VAE __all__ = [ - "Wan2VAE", "WanModel", - "T5Model", - "T5Encoder", - "T5Decoder", - "T5EncoderModel", - "HuggingfaceTokenizer", "flash_attention", ] diff --git a/src/lerobot/policies/fastwam/wan/modules/t5.py b/src/lerobot/policies/fastwam/wan/modules/t5.py deleted file mode 100644 index c90fd3dbc..000000000 --- a/src/lerobot/policies/fastwam/wan/modules/t5.py +++ /dev/null @@ -1,489 +0,0 @@ -# Modified from transformers.models.t5.modeling_t5 -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import logging -import math - -import torch -import torch.nn as nn -import torch.nn.functional as functional - -from .tokenizers import HuggingfaceTokenizer - -__all__ = [ - "T5Model", - "T5Encoder", - "T5Decoder", - "T5EncoderModel", -] - - -def fp16_clamp(x): - if x.dtype == torch.float16 and torch.isinf(x).any(): - clamp = torch.finfo(x.dtype).max - 1000 - x = torch.clamp(x, min=-clamp, max=clamp) - return x - - -def init_weights(m): - if isinstance(m, T5LayerNorm): - nn.init.ones_(m.weight) - elif isinstance(m, T5Model): - nn.init.normal_(m.token_embedding.weight, std=1.0) - elif isinstance(m, T5FeedForward): - nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) - nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) - nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) - elif isinstance(m, T5Attention): - nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) - nn.init.normal_(m.k.weight, std=m.dim**-0.5) - nn.init.normal_(m.v.weight, std=m.dim**-0.5) - nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) - elif isinstance(m, T5RelativeEmbedding): - nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) - - -class GELU(nn.Module): - def forward(self, x): - return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) - - -class T5LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-6): - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.type_as(self.weight) - return self.weight * x - - -class T5Attention(nn.Module): - def __init__(self, dim, dim_attn, num_heads, dropout=0.1): - assert dim_attn % num_heads == 0 - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.num_heads = num_heads - self.head_dim = dim_attn // num_heads - - # layers - self.q = nn.Linear(dim, dim_attn, bias=False) - self.k = nn.Linear(dim, dim_attn, bias=False) - self.v = nn.Linear(dim, dim_attn, bias=False) - self.o = nn.Linear(dim_attn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, context=None, mask=None, pos_bias=None): - """ - x: [B, L1, C]. - context: [B, L2, C] or None. - mask: [B, L2] or [B, L1, L2] or None. - """ - # check inputs - context = x if context is None else context - b, n, c = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.q(x).view(b, -1, n, c) - k = self.k(context).view(b, -1, n, c) - v = self.v(context).view(b, -1, n, c) - - # attention bias - attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) - if pos_bias is not None: - attn_bias += pos_bias - if mask is not None: - assert mask.ndim in [2, 3] - mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) - attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) - - # compute attention (T5 does not use scaling) - attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias - attn = functional.softmax(attn.float(), dim=-1).type_as(attn) - x = torch.einsum("bnij,bjnc->binc", attn, v) - - # output - x = x.reshape(b, -1, n * c) - x = self.o(x) - x = self.dropout(x) - return x - - -class T5FeedForward(nn.Module): - def __init__(self, dim, dim_ffn, dropout=0.1): - super().__init__() - self.dim = dim - self.dim_ffn = dim_ffn - - # layers - self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) - self.fc1 = nn.Linear(dim, dim_ffn, bias=False) - self.fc2 = nn.Linear(dim_ffn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.fc1(x) * self.gate(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - - -class T5SelfAttention(nn.Module): - def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.norm1 = T5LayerNorm(dim) - self.attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = ( - None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) - ) - - def forward(self, x, mask=None, pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) - x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.ffn(self.norm2(x))) - return x - - -class T5CrossAttention(nn.Module): - def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.norm1 = T5LayerNorm(dim) - self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm3 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = ( - None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) - ) - - def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) - x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) - x = fp16_clamp(x + self.ffn(self.norm3(x))) - return x - - -class T5RelativeEmbedding(nn.Module): - def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): - super().__init__() - self.num_buckets = num_buckets - self.num_heads = num_heads - self.bidirectional = bidirectional - self.max_dist = max_dist - - # layers - self.embedding = nn.Embedding(num_buckets, num_heads) - - def forward(self, lq, lk): - device = self.embedding.weight.device - # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ - # torch.arange(lq).unsqueeze(1).to(device) - rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) - rel_pos = self._relative_position_bucket(rel_pos) - rel_pos_embeds = self.embedding(rel_pos) - rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] - return rel_pos_embeds.contiguous() - - def _relative_position_bucket(self, rel_pos): - # preprocess - if self.bidirectional: - num_buckets = self.num_buckets // 2 - rel_buckets = (rel_pos > 0).long() * num_buckets - rel_pos = torch.abs(rel_pos) - else: - num_buckets = self.num_buckets - rel_buckets = 0 - rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) - - # embeddings for small and large positions - max_exact = num_buckets // 2 - rel_pos_large = ( - max_exact - + ( - torch.log(rel_pos.float() / max_exact) - / math.log(self.max_dist / max_exact) - * (num_buckets - max_exact) - ).long() - ) - rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) - rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) - return rel_buckets - - -class T5Encoder(nn.Module): - def __init__( - self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1 - ): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) - self.pos_embedding = ( - T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None - ) - self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList( - [ - T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) - for _ in range(num_layers) - ] - ) - self.norm = T5LayerNorm(dim) - - # initialize weights - self.apply(init_weights) - - def forward(self, ids, mask=None): - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Decoder(nn.Module): - def __init__( - self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1 - ): - super().__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) - self.pos_embedding = ( - T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None - ) - self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList( - [ - T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) - for _ in range(num_layers) - ] - ) - self.norm = T5LayerNorm(dim) - - # initialize weights - self.apply(init_weights) - - def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): - b, s = ids.size() - - # causal mask - if mask is None: - mask = torch.tril(torch.ones(1, s, s).to(ids.device)) - elif mask.ndim == 2: - mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) - - # layers - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Model(nn.Module): - def __init__( - self, - vocab_size, - dim, - dim_attn, - dim_ffn, - num_heads, - encoder_layers, - decoder_layers, - num_buckets, - shared_pos=True, - dropout=0.1, - ): - super().__init__() - self.vocab_size = vocab_size - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.encoder_layers = encoder_layers - self.decoder_layers = decoder_layers - self.num_buckets = num_buckets - - # layers - self.token_embedding = nn.Embedding(vocab_size, dim) - self.encoder = T5Encoder( - self.token_embedding, - dim, - dim_attn, - dim_ffn, - num_heads, - encoder_layers, - num_buckets, - shared_pos, - dropout, - ) - self.decoder = T5Decoder( - self.token_embedding, - dim, - dim_attn, - dim_ffn, - num_heads, - decoder_layers, - num_buckets, - shared_pos, - dropout, - ) - self.head = nn.Linear(dim, vocab_size, bias=False) - - # initialize weights - self.apply(init_weights) - - def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): - x = self.encoder(encoder_ids, encoder_mask) - x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) - x = self.head(x) - return x - - -def _t5( - name, - encoder_only=False, - decoder_only=False, - return_tokenizer=False, - tokenizer_kwargs=None, - dtype=torch.float32, - device="cpu", - **kwargs, -): - # sanity check - assert not (encoder_only and decoder_only) - tokenizer_kwargs = tokenizer_kwargs or {} - - # params - if encoder_only: - model_cls = T5Encoder - kwargs["vocab"] = kwargs.pop("vocab_size") - kwargs["num_layers"] = kwargs.pop("encoder_layers") - _ = kwargs.pop("decoder_layers") - elif decoder_only: - model_cls = T5Decoder - kwargs["vocab"] = kwargs.pop("vocab_size") - kwargs["num_layers"] = kwargs.pop("decoder_layers") - _ = kwargs.pop("encoder_layers") - else: - model_cls = T5Model - - # init model - with torch.device(device): - model = model_cls(**kwargs) - - # set device - model = model.to(dtype=dtype, device=device) - - # init tokenizer - if return_tokenizer: - from .tokenizers import HuggingfaceTokenizer - - tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) - return model, tokenizer - else: - return model - - -def umt5_xxl(**kwargs): - cfg = { - "vocab_size": 256384, - "dim": 4096, - "dim_attn": 4096, - "dim_ffn": 10240, - "num_heads": 64, - "encoder_layers": 24, - "decoder_layers": 24, - "num_buckets": 32, - "shared_pos": False, - "dropout": 0.1, - } - cfg.update(**kwargs) - return _t5("umt5-xxl", **cfg) - - -class T5EncoderModel: - def __init__( - self, - text_len, - dtype=torch.bfloat16, - device=torch.cuda.current_device(), - checkpoint_path=None, - tokenizer_path=None, - shard_fn=None, - ): - self.text_len = text_len - self.dtype = dtype - self.device = device - self.checkpoint_path = checkpoint_path - self.tokenizer_path = tokenizer_path - - # init model - model = ( - umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device) - .eval() - .requires_grad_(False) - ) - logging.info(f"loading {checkpoint_path}") - model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True)) - self.model = model - if shard_fn is not None: - self.model = shard_fn(self.model, sync_module_states=False) - else: - self.model.to(self.device) - # init tokenizer - self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") - - def __call__(self, texts, device): - ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) - ids = ids.to(device) - mask = mask.to(device) - seq_lens = mask.gt(0).sum(dim=1).long() - context = self.model(ids, mask) - return [u[:v] for u, v in zip(context, seq_lens, strict=False)] diff --git a/src/lerobot/policies/fastwam/wan/modules/tokenizers.py b/src/lerobot/policies/fastwam/wan/modules/tokenizers.py deleted file mode 100644 index ec85c9753..000000000 --- a/src/lerobot/policies/fastwam/wan/modules/tokenizers.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import html -import string - -import ftfy -import regex as re -from transformers import AutoTokenizer - -__all__ = ["HuggingfaceTokenizer"] - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def canonicalize(text, keep_punctuation_exact_string=None): - text = text.replace("_", " ") - if keep_punctuation_exact_string: - text = keep_punctuation_exact_string.join( - part.translate(str.maketrans("", "", string.punctuation)) - for part in text.split(keep_punctuation_exact_string) - ) - else: - text = text.translate(str.maketrans("", "", string.punctuation)) - text = text.lower() - text = re.sub(r"\s+", " ", text) - return text.strip() - - -class HuggingfaceTokenizer: - def __init__(self, name, seq_len=None, clean=None, **kwargs): - assert clean in (None, "whitespace", "lower", "canonicalize") - self.name = name - self.seq_len = seq_len - self.clean = clean - - # init tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) - self.vocab_size = self.tokenizer.vocab_size - - def __call__(self, sequence, **kwargs): - return_mask = kwargs.pop("return_mask", False) - - # arguments - _kwargs = {"return_tensors": "pt"} - if self.seq_len is not None: - _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) - _kwargs.update(**kwargs) - - # tokenization - if isinstance(sequence, str): - sequence = [sequence] - if self.clean: - sequence = [self._clean(u) for u in sequence] - ids = self.tokenizer(sequence, **_kwargs) - - # output - if return_mask: - return ids.input_ids, ids.attention_mask - else: - return ids.input_ids - - def _clean(self, text): - if self.clean == "whitespace": - text = whitespace_clean(basic_clean(text)) - elif self.clean == "lower": - text = whitespace_clean(basic_clean(text)).lower() - elif self.clean == "canonicalize": - text = canonicalize(basic_clean(text)) - return text diff --git a/src/lerobot/policies/fastwam/wan/modules/vae2_2.py b/src/lerobot/policies/fastwam/wan/modules/vae2_2.py deleted file mode 100644 index 94b30ec4a..000000000 --- a/src/lerobot/policies/fastwam/wan/modules/vae2_2.py +++ /dev/null @@ -1,1027 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import logging -from pathlib import Path - -import torch -import torch.cuda.amp as amp -import torch.nn as nn -import torch.nn.functional as functional -from einops import rearrange -from safetensors.torch import load_file - -__all__ = [ - "Wan2VAE", -] - -CACHE_T = 2 - - -class CausalConv3d(nn.Conv3d): - """ - Causal 3d convolusion. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._padding = ( - self.padding[2], - self.padding[2], - self.padding[1], - self.padding[1], - 2 * self.padding[0], - 0, - ) - self.padding = (0, 0, 0) - - def forward(self, x, cache_x=None): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = functional.pad(x, padding) - - return super().forward(x) - - -class RMSNorm(nn.Module): - def __init__(self, dim, channel_first=True, images=True, bias=False): - super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - - self.channel_first = channel_first - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 - - def forward(self, x): - return ( - functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma - + self.bias - ) - - -class Upsample(nn.Upsample): - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) - - -class Resample(nn.Module): - def __init__(self, dim, mode): - assert mode in ( - "none", - "upsample2d", - "upsample3d", - "downsample2d", - "downsample3d", - ) - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == "upsample2d": - self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim, 3, padding=1), - ) - elif mode == "upsample3d": - self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim, 3, padding=1), - # nn.Conv2d(dim, dim//2, 3, padding=1) - ) - self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - elif mode == "downsample2d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - else: - self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=None): - if feat_idx is None: - feat_idx = [0] - b, c, t, h, w = x.size() - if self.mode == "upsample3d" and feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 - else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - # cache last frame of last two chunk - cache_x = torch.cat( - [ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), - cache_x, - ], - dim=2, - ) - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = torch.cat( - [torch.zeros_like(cache_x).to(cache_x.device), cache_x], - dim=2, - ) - x = self.time_conv(x) if feat_cache[idx] == "Rep" else self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.resample(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - - if self.mode == "downsample3d" and feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - return x - - def init_weight(self, conv): - conv_weight = conv.weight.detach().clone() - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix - nn.init.zeros_(conv_weight) - conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 - conv.weight = nn.Parameter(conv_weight) - nn.init.zeros_(conv.bias.data) - - def init_weight2(self, conv): - conv_weight = conv.weight.data.detach().clone() - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - init_matrix = torch.eye(c1 // 2, c2) - conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix - conv.weight = nn.Parameter(conv_weight) - nn.init.zeros_(conv.bias.data) - - -class ResidualBlock(nn.Module): - def __init__(self, in_dim, out_dim, dropout=0.0): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - - # layers - self.residual = nn.Sequential( - RMSNorm(in_dim, images=False), - nn.SiLU(), - CausalConv3d(in_dim, out_dim, 3, padding=1), - RMSNorm(out_dim, images=False), - nn.SiLU(), - nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1), - ) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=None): - if feat_idx is None: - feat_idx = [0] - h = self.shortcut(x) - for layer in self.residual: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), - cache_x, - ], - dim=2, - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + h - - -class AttentionBlock(nn.Module): - """ - Causal self-attention with a single head. - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim - - # layers - self.norm = RMSNorm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) - - # zero out the last layer params - nn.init.zeros_(self.proj.weight) - - def forward(self, x): - identity = x - b, c, t, h, w = x.size() - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.norm(x) - # compute query, key, value - q, k, v = ( - self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) - ) - - # apply attention - x = functional.scaled_dot_product_attention( - q, - k, - v, - ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) - - # output - x = self.proj(x) - x = rearrange(x, "(b t) c h w-> b c t h w", t=t) - return x + identity - - -def patchify(x, patch_size): - if patch_size == 1: - return x - if x.dim() == 4: - x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) - elif x.dim() == 5: - x = rearrange( - x, - "b c f (h q) (w r) -> b (c r q) f h w", - q=patch_size, - r=patch_size, - ) - else: - raise ValueError(f"Invalid input shape: {x.shape}") - - return x - - -def unpatchify(x, patch_size): - if patch_size == 1: - return x - - if x.dim() == 4: - x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) - elif x.dim() == 5: - x = rearrange( - x, - "b (c r q) f h w -> b c f (h q) (w r)", - q=patch_size, - r=patch_size, - ) - return x - - -class AvgDown3D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - factor_t, - factor_s=1, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.factor_t = factor_t - self.factor_s = factor_s - self.factor = self.factor_t * self.factor_s * self.factor_s - - assert in_channels * self.factor % out_channels == 0 - self.group_size = in_channels * self.factor // out_channels - - def forward(self, x: torch.Tensor) -> torch.Tensor: - pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t - pad = (0, 0, 0, 0, pad_t, 0) - x = functional.pad(x, pad) - batch, channels, frames, height, width = x.shape - x = x.view( - batch, - channels, - frames // self.factor_t, - self.factor_t, - height // self.factor_s, - self.factor_s, - width // self.factor_s, - self.factor_s, - ) - x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() - x = x.view( - batch, - channels * self.factor, - frames // self.factor_t, - height // self.factor_s, - width // self.factor_s, - ) - x = x.view( - batch, - self.out_channels, - self.group_size, - frames // self.factor_t, - height // self.factor_s, - width // self.factor_s, - ) - x = x.mean(dim=2) - return x - - -class DupUp3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - factor_t, - factor_s=1, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.factor_t = factor_t - self.factor_s = factor_s - self.factor = self.factor_t * self.factor_s * self.factor_s - - assert out_channels * self.factor % in_channels == 0 - self.repeats = out_channels * self.factor // in_channels - - def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: - x = x.repeat_interleave(self.repeats, dim=1) - x = x.view( - x.size(0), - self.out_channels, - self.factor_t, - self.factor_s, - self.factor_s, - x.size(2), - x.size(3), - x.size(4), - ) - x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() - x = x.view( - x.size(0), - self.out_channels, - x.size(2) * self.factor_t, - x.size(4) * self.factor_s, - x.size(6) * self.factor_s, - ) - if first_chunk: - x = x[:, :, self.factor_t - 1 :, :, :] - return x - - -class DownResidualBlock(nn.Module): - def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False): - super().__init__() - - # Shortcut path with downsample - self.avg_shortcut = AvgDown3D( - in_dim, - out_dim, - factor_t=2 if temperal_downsample else 1, - factor_s=2 if down_flag else 1, - ) - - # Main path with residual blocks and downsample - downsamples = [] - for _ in range(mult): - downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - in_dim = out_dim - - # Add the final downsample block - if down_flag: - mode = "downsample3d" if temperal_downsample else "downsample2d" - downsamples.append(Resample(out_dim, mode=mode)) - - self.downsamples = nn.Sequential(*downsamples) - - def forward(self, x, feat_cache=None, feat_idx=None): - if feat_idx is None: - feat_idx = [0] - x_copy = x.clone() - for module in self.downsamples: - x = module(x, feat_cache, feat_idx) - - return x + self.avg_shortcut(x_copy) - - -class UpResidualBlock(nn.Module): - def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False): - super().__init__() - # Shortcut path with upsample - if up_flag: - self.avg_shortcut = DupUp3D( - in_dim, - out_dim, - factor_t=2 if temperal_upsample else 1, - factor_s=2 if up_flag else 1, - ) - else: - self.avg_shortcut = None - - # Main path with residual blocks and upsample - upsamples = [] - for _ in range(mult): - upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - in_dim = out_dim - - # Add the final upsample block - if up_flag: - mode = "upsample3d" if temperal_upsample else "upsample2d" - upsamples.append(Resample(out_dim, mode=mode)) - - self.upsamples = nn.Sequential(*upsamples) - - def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): - if feat_idx is None: - feat_idx = [0] - x_main = x.clone() - for module in self.upsamples: - x_main = module(x_main, feat_cache, feat_idx) - if self.avg_shortcut is not None: - x_shortcut = self.avg_shortcut(x, first_chunk) - return x_main + x_shortcut - else: - return x_main - - -class Encoder3d(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=None, - num_res_blocks=2, - attn_scales=None, - temperal_downsample=None, - dropout=0.0, - ): - if temperal_downsample is None: - temperal_downsample = [True, True, False] - if attn_scales is None: - attn_scales = [] - if dim_mult is None: - dim_mult = [1, 2, 4, 4] - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - - # dimensions - dims = [dim * u for u in [1] + dim_mult] - scale = 1.0 - - # init block - self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) - - # downsample blocks - downsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=False)): - t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False - downsamples.append( - DownResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - dropout=dropout, - mult=num_res_blocks, - temperal_downsample=t_down_flag, - down_flag=i != len(dim_mult) - 1, - ) - ) - scale /= 2.0 - self.downsamples = nn.Sequential(*downsamples) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), - AttentionBlock(out_dim), - ResidualBlock(out_dim, out_dim, dropout), - ) - - # # output blocks - self.head = nn.Sequential( - RMSNorm(out_dim, images=False), - nn.SiLU(), - CausalConv3d(out_dim, z_dim, 3, padding=1), - ) - - def forward(self, x, feat_cache=None, feat_idx=None): - - if feat_idx is None: - feat_idx = [0] - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - cache_x = torch.cat( - [ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), - cache_x, - ], - dim=2, - ) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - ## downsamples - for layer in self.downsamples: - x = layer(x, feat_cache, feat_idx) if feat_cache is not None else layer(x) - - ## middle - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - cache_x = torch.cat( - [ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), - cache_x, - ], - dim=2, - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - - return x - - -class Decoder3d(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=None, - num_res_blocks=2, - attn_scales=None, - temperal_upsample=None, - dropout=0.0, - ): - if temperal_upsample is None: - temperal_upsample = [False, True, True] - if attn_scales is None: - attn_scales = [] - if dim_mult is None: - dim_mult = [1, 2, 4, 4] - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_upsample = temperal_upsample - - # dimensions - dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - # init block - self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), - AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout), - ) - - # upsample blocks - upsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=False)): - t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False - upsamples.append( - UpResidualBlock( - in_dim=in_dim, - out_dim=out_dim, - dropout=dropout, - mult=num_res_blocks + 1, - temperal_upsample=t_up_flag, - up_flag=i != len(dim_mult) - 1, - ) - ) - self.upsamples = nn.Sequential(*upsamples) - - # output blocks - self.head = nn.Sequential( - RMSNorm(out_dim, images=False), - nn.SiLU(), - CausalConv3d(out_dim, 12, 3, padding=1), - ) - - def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): - if feat_idx is None: - feat_idx = [0] - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - cache_x = torch.cat( - [ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), - cache_x, - ], - dim=2, - ) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: - x = layer(x, feat_cache, feat_idx, first_chunk) if feat_cache is not None else layer(x) - - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - cache_x = torch.cat( - [ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), - cache_x, - ], - dim=2, - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -def count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, CausalConv3d): - count += 1 - return count - - -class WanVAEModel(nn.Module): - def __init__( - self, - dim=160, - dec_dim=256, - z_dim=16, - dim_mult=None, - num_res_blocks=2, - attn_scales=None, - temperal_downsample=None, - dropout=0.0, - ): - if temperal_downsample is None: - temperal_downsample = [True, True, False] - if attn_scales is None: - attn_scales = [] - if dim_mult is None: - dim_mult = [1, 2, 4, 4] - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - self.temperal_upsample = temperal_downsample[::-1] - - # modules - self.encoder = Encoder3d( - dim, - z_dim * 2, - dim_mult, - num_res_blocks, - attn_scales, - self.temperal_downsample, - dropout, - ) - self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) - self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d( - dec_dim, - z_dim, - dim_mult, - num_res_blocks, - attn_scales, - self.temperal_upsample, - dropout, - ) - - def forward(self, x, scale=None): - if scale is None: - scale = [0, 1] - mu = self.encode(x, scale) - x_recon = self.decode(mu, scale) - return x_recon, mu - - def encode(self, x, scale): - self.clear_cache() - x = patchify(x, patch_size=2) - t = x.shape[2] - iter_ = 1 + (t - 1) // 4 - for i in range(iter_): - self._enc_conv_idx = [0] - if i == 0: - out = self.encoder( - x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, - ) - else: - out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, - ) - out = torch.cat([out, out_], 2) - mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) - else: - mu = (mu - scale[0]) * scale[1] - self.clear_cache() - return mu - - def decode(self, z, scale): - self.clear_cache() - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) - else: - z = z / scale[1] + scale[0] - iter_ = z.shape[2] - x = self.conv2(z) - for i in range(iter_): - self._conv_idx = [0] - if i == 0: - out = self.decoder( - x[:, :, i : i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx, - first_chunk=True, - ) - else: - out_ = self.decoder( - x[:, :, i : i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx, - ) - out = torch.cat([out, out_], 2) - out = unpatchify(out, patch_size=2) - self.clear_cache() - return out - - def reparameterize(self, mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) - return eps * std + mu - - def sample(self, imgs, deterministic=False): - mu, log_var = self.encode(imgs) - if deterministic: - return mu - std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) - return mu + std * torch.randn_like(std) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num - - -def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): - # params - cfg = { - "dim": dim, - "z_dim": z_dim, - "dim_mult": [1, 2, 4, 4], - "num_res_blocks": 2, - "attn_scales": [], - "temperal_downsample": [True, True, True], - "dropout": 0.0, - } - cfg.update(**kwargs) - - # init model - with torch.device("meta"): - model = WanVAEModel(**cfg) - - # load checkpoint - logging.info(f"loading {pretrained_path}") - if Path(pretrained_path).suffix != ".safetensors": - raise ValueError(f"Wan2.2 VAE checkpoint must be safetensors, got {pretrained_path}.") - state_dict = load_file(pretrained_path, device=str(device)) - model.load_state_dict(state_dict, assign=True) - - return model - - -class Wan2VAE: - def __init__( - self, - z_dim=48, - c_dim=160, - vae_pth=None, - dim_mult=None, - temperal_downsample=None, - dtype=torch.float, - device="cuda", - ): - - if temperal_downsample is None: - temperal_downsample = [False, True, True] - if dim_mult is None: - dim_mult = [1, 2, 4, 4] - self.dtype = dtype - self.device = device - - mean = torch.tensor( - [ - -0.2289, - -0.0052, - -0.1323, - -0.2339, - -0.2799, - 0.0174, - 0.1838, - 0.1557, - -0.1382, - 0.0542, - 0.2813, - 0.0891, - 0.1570, - -0.0098, - 0.0375, - -0.1825, - -0.2246, - -0.1207, - -0.0698, - 0.5109, - 0.2665, - -0.2108, - -0.2158, - 0.2502, - -0.2055, - -0.0322, - 0.1109, - 0.1567, - -0.0729, - 0.0899, - -0.2799, - -0.1230, - -0.0313, - -0.1649, - 0.0117, - 0.0723, - -0.2839, - -0.2083, - -0.0520, - 0.3748, - 0.0152, - 0.1957, - 0.1433, - -0.2944, - 0.3573, - -0.0548, - -0.1681, - -0.0667, - ], - dtype=dtype, - device=device, - ) - std = torch.tensor( - [ - 0.4765, - 1.0364, - 0.4514, - 1.1677, - 0.5313, - 0.4990, - 0.4818, - 0.5013, - 0.8158, - 1.0344, - 0.5894, - 1.0901, - 0.6885, - 0.6165, - 0.8454, - 0.4978, - 0.5759, - 0.3523, - 0.7135, - 0.6804, - 0.5833, - 1.4146, - 0.8986, - 0.5659, - 0.7069, - 0.5338, - 0.4889, - 0.4917, - 0.4069, - 0.4999, - 0.6866, - 0.4093, - 0.5709, - 0.6065, - 0.6415, - 0.4944, - 0.5726, - 1.2042, - 0.5458, - 1.6887, - 0.3971, - 1.0600, - 0.3943, - 0.5537, - 0.5444, - 0.4089, - 0.7468, - 0.7744, - ], - dtype=dtype, - device=device, - ) - self.scale = [mean, 1.0 / std] - - # init model - self.model = ( - _video_vae( - pretrained_path=vae_pth, - z_dim=z_dim, - dim=c_dim, - dim_mult=dim_mult, - temperal_downsample=temperal_downsample, - ) - .eval() - .requires_grad_(False) - .to(device) - ) - - def encode(self, videos): - try: - if not isinstance(videos, list): - raise TypeError("videos should be a list") - with amp.autocast(dtype=self.dtype): - return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] - except TypeError as e: - logging.info(e) - return None - - def decode(self, zs): - try: - if not isinstance(zs, list): - raise TypeError("zs should be a list") - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs - ] - except TypeError as e: - logging.info(e) - return None diff --git a/src/lerobot/policies/fastwam/wan_adapters.py b/src/lerobot/policies/fastwam/wan_adapters.py index cf267a769..4630c0612 100644 --- a/src/lerobot/policies/fastwam/wan_adapters.py +++ b/src/lerobot/policies/fastwam/wan_adapters.py @@ -14,16 +14,24 @@ from __future__ import annotations -from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING import torch -from .wan.modules.vae2_2 import Wan2VAE +if TYPE_CHECKING: + from diffusers import AutoencoderKLWan class WanVideoVAE38(torch.nn.Module): - """Tensor-batch adapter around the official Wan2.2 VAE wrapper.""" + """FastWAM VAE contract over `diffusers.AutoencoderKLWan` (Wan2.2-TI2V-5B). + + 16x spatial / 4x temporal compression, 48 latent channels. diffusers' + `AutoencoderKLWan` returns *raw* latents (it does not apply `latents_mean`/ + `latents_std`), so `encode`/`decode` here apply the same standardization the + Wan reference uses — `(latents - mean) / std` — done in fp32 for stability. + `encode` uses the deterministic posterior mode, matching the original VAE + which returned the latent mean `mu`. + """ upsampling_factor = 16 temporal_downsample_factor = 4 @@ -31,27 +39,35 @@ class WanVideoVAE38(torch.nn.Module): def __init__( self, - vae_pth: str | Path, dtype: torch.dtype = torch.float32, device: str | torch.device = "cuda", + *, + pretrained: AutoencoderKLWan, ) -> None: super().__init__() - self.wan_vae = Wan2VAE(vae_pth=str(vae_pth), dtype=dtype, device=str(device)) - self.model = self.wan_vae.model - self.dtype = dtype - self.device = torch.device(device) + # The Wan2.2 VAE is a fixed pretrained model — it is never trained from scratch, + # so a real `AutoencoderKLWan` (with weights) must always be supplied (loaded from + # the diffusers repo by `load_pretrained_wan_vae`). No random/offline build path. + self.vae = pretrained.to(device=device, dtype=dtype) - def to(self, *args: Any, **kwargs: Any): - super().to(*args, **kwargs) - self.model.to(*args, **kwargs) - param = next(self.model.parameters()) - self.device = param.device - self.dtype = param.dtype - self.wan_vae.device = self.device - self.wan_vae.dtype = self.dtype - self.wan_vae.scale = [scale.to(device=self.device, dtype=self.dtype) for scale in self.wan_vae.scale] - self.wan_vae.model = self.model - return self + # Read the standardization stats from the VAE's own config (diffusers populates + # these from vae/config.json) — single source of truth, no local copy. diffusers' + # encode/decode return *raw* latents, so we apply (latent - mean) / std ourselves. + # Non-persistent: kept out of state_dict. + self.register_buffer( + "latents_mean", + torch.tensor(self.vae.config.latents_mean).view(1, self.z_dim, 1, 1, 1), + persistent=False, + ) + self.register_buffer( + "latents_std", + torch.tensor(self.vae.config.latents_std).view(1, self.z_dim, 1, 1, 1), + persistent=False, + ) + + def _device_dtype(self) -> tuple[torch.device, torch.dtype]: + param = next(self.vae.parameters()) + return param.device, param.dtype def encode( self, @@ -61,18 +77,16 @@ class WanVideoVAE38(torch.nn.Module): tile_size: tuple[int, int] = (34, 34), tile_stride: tuple[int, int] = (18, 16), ) -> torch.Tensor: - del tile_size, tile_stride + del device, tile_size, tile_stride if tiled: raise NotImplementedError("Tiled Wan2.2 VAE encoding is not supported by the FastWAM adapter.") - target_device = self.device if device is None else torch.device(device) - if target_device != self.device: - self.to(device=target_device) - if isinstance(videos, torch.Tensor): - videos = list(videos) - hidden_states = self.wan_vae.encode([video.to(self.device) for video in videos]) - if hidden_states is None: - raise RuntimeError("Wan2.2 VAE encode failed; expected a list of video tensors.") - return torch.stack(hidden_states) + if isinstance(videos, (list, tuple)): + videos = torch.stack(list(videos)) + dev, dtype = self._device_dtype() + mu = self.vae.encode(videos.to(device=dev, dtype=dtype)).latent_dist.mode().float() + mean = self.latents_mean.float().to(mu.device) + std = self.latents_std.float().to(mu.device) + return (mu - mean) / std def decode( self, @@ -82,18 +96,16 @@ class WanVideoVAE38(torch.nn.Module): tile_size: tuple[int, int] = (34, 34), tile_stride: tuple[int, int] = (18, 16), ) -> torch.Tensor: - del tile_size, tile_stride + del device, tile_size, tile_stride if tiled: raise NotImplementedError("Tiled Wan2.2 VAE decoding is not supported by the FastWAM adapter.") - target_device = self.device if device is None else torch.device(device) - if target_device != self.device: - self.to(device=target_device) - if isinstance(hidden_states, torch.Tensor): - hidden_states = list(hidden_states) - videos = self.wan_vae.decode([hidden_state.to(self.device) for hidden_state in hidden_states]) - if videos is None: - raise RuntimeError("Wan2.2 VAE decode failed; expected a list of latent tensors.") - return torch.stack(videos) + if isinstance(hidden_states, (list, tuple)): + hidden_states = torch.stack(list(hidden_states)) + dev, dtype = self._device_dtype() + z = hidden_states.float() + z = z * self.latents_std.float().to(z.device) + self.latents_mean.float().to(z.device) + out = self.vae.decode(z.to(device=dev, dtype=dtype)).sample + return out.float().clamp_(-1.0, 1.0) __all__ = ["WanVideoVAE38"] diff --git a/src/lerobot/policies/fastwam/wan_components.py b/src/lerobot/policies/fastwam/wan_components.py index 7a321cb33..41c2fdafd 100644 --- a/src/lerobot/policies/fastwam/wan_components.py +++ b/src/lerobot/policies/fastwam/wan_components.py @@ -15,8 +15,7 @@ from __future__ import annotations import logging -import time -from dataclasses import dataclass +from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Any @@ -24,57 +23,108 @@ import torch from safetensors.torch import load_file if TYPE_CHECKING: - from .wan.modules.tokenizers import HuggingfaceTokenizer from .wan_adapters import WanVideoVAE38 from .wan_video_dit import WanVideoDiT logger = logging.getLogger(__name__) +# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2 +# repo as sharded `diffusion_pytorch_model*.safetensors`; the VAE and UMT5 text +# encoder come from the diffusers conversion. Tokenizer is the stock UMT5 one. WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors" -WAN_T5_SAFE_CHECKPOINT = "models_t5_umt5-xxl-enc-bf16.safetensors" -WAN_T5_CHECKPOINT = WAN_T5_SAFE_CHECKPOINT WAN_T5_TOKENIZER = "google/umt5-xxl" -WAN_VAE_SAFE_CHECKPOINT = "Wan2.2_VAE.safetensors" -WAN_VAE_CHECKPOINT = WAN_VAE_SAFE_CHECKPOINT +WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" + +class WanTextEncoder(torch.nn.Module): + """FastWAM text-encoder contract over `transformers.UMT5EncoderModel`. + + Exposes `.dim` (hidden size) and `forward(ids, mask) -> [B, L, dim]`, matching + the call in `FastWAM.encode_prompt`. + """ + + def __init__( + self, + dtype: torch.dtype = torch.bfloat16, + device: str | torch.device = "cuda", + *, + pretrained: torch.nn.Module, + ) -> None: + super().__init__() + # UMT5-XXL is a fixed pretrained encoder — never trained from scratch, so a real + # `UMT5EncoderModel` (with weights) must always be supplied (loaded from the + # diffusers repo by `load_pretrained_wan_text_encoder`). No random/offline build. + self.model = pretrained.to(device=device, dtype=dtype) + self.dim = int(self.model.config.d_model) + + def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + return self.model(input_ids=ids, attention_mask=mask.long()).last_hidden_state -@dataclass(frozen=True) -class WanCheckpointPaths: - root: Path - dit: list[Path] - vae: Path - text_encoder: Path | None - tokenizer: Path | None +class WanTokenizer: + """UMT5 tokenizer wrapper returning `(input_ids, attention_mask)` like the + FastWAM call site expects.""" + + def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None: + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(name) + self.seq_len = int(seq_len) + + def __call__( + self, sequence: str | Sequence[str], return_mask: bool = False, add_special_tokens: bool = True, **_: Any + ): + if isinstance(sequence, str): + sequence = [sequence] + out = self.tokenizer( + list(sequence), + padding="max_length", + truncation=True, + max_length=self.seq_len, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + if return_mask: + return out.input_ids, out.attention_mask + return out.input_ids -@dataclass -class Wan22LoadedComponents: - dit: WanVideoDiT - vae: WanVideoVAE38 - text_encoder: torch.nn.Module | None - tokenizer: HuggingfaceTokenizer | None - dit_path: list[str] - vae_path: str - text_encoder_path: str | None - tokenizer_path: str | None +def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer: + return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len)) -def resolve_wan_checkpoint_dir( +def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38: + """Load real Wan2.2 VAE weights from the diffusers repo (offline base creation).""" + from diffusers import AutoencoderKLWan + + from .wan_adapters import WanVideoVAE38 + + vae = AutoencoderKLWan.from_pretrained( + WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype + ) + return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae) + + +def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder: + """Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation).""" + from transformers import UMT5EncoderModel + + encoder = UMT5EncoderModel.from_pretrained( + WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype + ) + return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder) + + +def resolve_wan_dit_paths( model_id_or_path: str | Path, *, cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, -) -> Path: - """Return a local Wan2.2 checkpoint directory. - - Local paths are used directly. Hub repos are downloaded with the same fixed - component names used by the upstream Wan2.2 inference code. - """ - +) -> list[Path]: + """Resolve the custom MoT DiT shards from the original Wan2.2 repo or a local dir.""" path = Path(model_id_or_path).expanduser() if path.is_dir(): - return path + return sorted(path.glob(WAN_DIT_PATTERN)) from huggingface_hub import snapshot_download @@ -83,51 +133,9 @@ def resolve_wan_checkpoint_dir( revision=revision, cache_dir=cache_dir, local_files_only=local_files_only, - allow_patterns=[ - WAN_DIT_PATTERN, - WAN_T5_CHECKPOINT, - WAN_VAE_CHECKPOINT, - f"{WAN_T5_TOKENIZER}/**", - ], - ) - return Path(snapshot_path) - - -def resolve_wan_checkpoint_paths( - checkpoint_dir: str | Path, - *, - tokenizer_dir: str | Path | None = None, - load_dit: bool = True, - load_text_encoder: bool = True, -) -> WanCheckpointPaths: - root = Path(checkpoint_dir).expanduser() - tokenizer_root = Path(tokenizer_dir).expanduser() if tokenizer_dir is not None else root - dit = sorted(root.glob(WAN_DIT_PATTERN)) if load_dit else [] - vae = root / WAN_VAE_SAFE_CHECKPOINT - text_encoder = root / WAN_T5_SAFE_CHECKPOINT if load_text_encoder else None - tokenizer = tokenizer_root / WAN_T5_TOKENIZER if load_text_encoder else None - - missing = [] - if load_dit and len(dit) == 0: - missing.append(f"DiT ({WAN_DIT_PATTERN})") - if not vae.exists(): - missing.append(f"VAE ({WAN_VAE_SAFE_CHECKPOINT})") - if load_text_encoder: - if text_encoder is None or not text_encoder.exists(): - missing.append(f"text encoder ({WAN_T5_SAFE_CHECKPOINT})") - if tokenizer is None or not tokenizer.exists(): - missing.append(f"tokenizer ({WAN_T5_TOKENIZER})") - if missing: - raise FileNotFoundError( - f"Incomplete Wan2.2 checkpoint directory {root}: missing {', '.join(missing)}." - ) - return WanCheckpointPaths( - root=root, - dit=dit, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, + allow_patterns=[WAN_DIT_PATTERN], ) + return sorted(Path(snapshot_path).glob(WAN_DIT_PATTERN)) def load_wan_video_dit( @@ -145,107 +153,6 @@ def load_wan_video_dit( return model.to(device=device, dtype=torch_dtype) -def load_wan_text_encoder( - checkpoint_path: str | Path, - *, - torch_dtype: torch.dtype, - device: str, -) -> torch.nn.Module: - from .wan.modules.t5 import umt5_xxl - - model = umt5_xxl( - encoder_only=True, - return_tokenizer=False, - dtype=torch_dtype, - device=device, - ) - checkpoint_path = Path(checkpoint_path) - if checkpoint_path.suffix != ".safetensors": - raise ValueError(f"Wan2.2 text encoder checkpoint must be safetensors, got {checkpoint_path}.") - state_dict = load_file(checkpoint_path) - model.load_state_dict(state_dict) - return model.to(device=device, dtype=torch_dtype) - - -def load_wan_tokenizer(tokenizer_path: str | Path, *, tokenizer_max_len: int) -> HuggingfaceTokenizer: - from .wan.modules.tokenizers import HuggingfaceTokenizer - - return HuggingfaceTokenizer( - name=str(tokenizer_path), - seq_len=int(tokenizer_max_len), - clean="whitespace", - ) - - -def load_wan_vae(checkpoint_path: str | Path, *, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38: - from .wan_adapters import WanVideoVAE38 - - return WanVideoVAE38(vae_pth=str(checkpoint_path), dtype=torch_dtype, device=device) - - -def load_wan22_ti2v_5b_components( - device: str = "cuda", - torch_dtype: torch.dtype = torch.bfloat16, - model_id: str = "Wan-AI/Wan2.2-TI2V-5B", - tokenizer_model_id: str = "Wan-AI/Wan2.2-TI2V-5B", - tokenizer_max_len: int = 512, - dit_config: dict[str, Any] | None = None, - load_text_encoder: bool = True, -): - logger.info("Loading Wan2.2-TI2V-5B components...") - start = time.time() - - if dit_config is None: - raise ValueError("`dit_config` is required for Wan2.2-TI2V-5B loading.") - - checkpoint_dir = resolve_wan_checkpoint_dir(model_id) - tokenizer_dir = ( - checkpoint_dir if tokenizer_model_id == model_id else resolve_wan_checkpoint_dir(tokenizer_model_id) - ) - paths = resolve_wan_checkpoint_paths( - checkpoint_dir, - tokenizer_dir=tokenizer_dir, - load_text_encoder=load_text_encoder, - ) - - dit = load_wan_video_dit( - paths.dit, - dit_config=dit_config, - torch_dtype=torch_dtype, - device=device, - ) - vae = load_wan_vae(paths.vae, torch_dtype=torch_dtype, device=device) - - text_encoder: torch.nn.Module | None = None - tokenizer: HuggingfaceTokenizer | None = None - if load_text_encoder: - if paths.text_encoder is None or paths.tokenizer is None: - raise FileNotFoundError("Wan2.2 text encoder/tokenizer paths were not resolved.") - text_encoder = load_wan_text_encoder( - paths.text_encoder, - torch_dtype=torch_dtype, - device=device, - ) - tokenizer = load_wan_tokenizer(paths.tokenizer, tokenizer_max_len=tokenizer_max_len) - else: - logger.info( - "Skipping pretrained text encoder/tokenizer load (`load_text_encoder=False`); " - "training must provide cached `context/context_mask`." - ) - - logger.info("Finished loading Wan2.2-TI2V-5B components in %.2f seconds.", time.time() - start) - return Wan22LoadedComponents( - dit=dit, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - dit_path=[str(path) for path in paths.dit], - vae_path=str(paths.vae), - text_encoder_path=str(paths.text_encoder) if paths.text_encoder is not None else None, - tokenizer_path=str(paths.tokenizer) if paths.tokenizer is not None else None, - ) - - def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor]: state_dict = {} for path in paths: @@ -254,17 +161,14 @@ def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor __all__ = [ + "WAN22_DIFFUSERS_MODEL_ID", "WAN_DIT_PATTERN", - "WAN_T5_CHECKPOINT", "WAN_T5_TOKENIZER", - "WAN_VAE_CHECKPOINT", - "Wan22LoadedComponents", - "WanCheckpointPaths", - "load_wan22_ti2v_5b_components", - "load_wan_text_encoder", - "load_wan_tokenizer", - "load_wan_vae", + "WanTextEncoder", + "WanTokenizer", + "build_wan_tokenizer", + "load_pretrained_wan_text_encoder", + "load_pretrained_wan_vae", "load_wan_video_dit", - "resolve_wan_checkpoint_dir", - "resolve_wan_checkpoint_paths", + "resolve_wan_dit_paths", ] diff --git a/src/lerobot/policies/fastwam/wan_video_dit.py b/src/lerobot/policies/fastwam/wan_video_dit.py index 6407d8d9a..d5350ea90 100644 --- a/src/lerobot/policies/fastwam/wan_video_dit.py +++ b/src/lerobot/policies/fastwam/wan_video_dit.py @@ -660,11 +660,16 @@ class WanVideoDiT(WanModel): ) * timestep.to(dtype=model_dtype).view(batch_size, 1, 1) token_timesteps[:, 0, :] = 0 token_timesteps = token_timesteps.reshape(batch_size, -1) - token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).to( - dtype=model_dtype - ) - t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim) - t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim)) + # Wan keeps the time embedding in fp32: the AdaLN modulation in the vendored + # Head/Block asserts e.dtype == float32 (numerical stability of the scale/shift). + # Upstream guarantees this via an fp32 autocast region, so it holds even when the + # model runs in bf16. Mirror that here, then cast the per-block modulation back to + # model_dtype so the bf16 attention blocks are not upcast to fp32. + with torch.amp.autocast("cuda", dtype=torch.float32): + token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).float() + t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim) + t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim)) + t_mod = t_mod.to(dtype=model_dtype) x = self.patchify(x) f, h, w = x.shape[2:] diff --git a/tests/policies/fastwam/test_fastwam_policy.py b/tests/policies/fastwam/test_fastwam_policy.py index d750fd5f8..c3747c407 100644 --- a/tests/policies/fastwam/test_fastwam_policy.py +++ b/tests/policies/fastwam/test_fastwam_policy.py @@ -18,21 +18,13 @@ import json import pytest import torch -from safetensors.torch import save_model +from safetensors import safe_open from torch import nn from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors -from lerobot.policies.fastwam import modeling_fastwam -from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy, resolve_wan_component_paths +from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep -from lerobot.policies.fastwam.wan_components import ( - WAN_DIT_PATTERN, - WAN_T5_CHECKPOINT, - WAN_T5_TOKENIZER, - WAN_VAE_CHECKPOINT, - resolve_wan_checkpoint_paths, -) from lerobot.utils.constants import ACTION, OBS_STATE @@ -170,8 +162,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch): output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, base_model_id=None, ) - with pytest.warns(RuntimeWarning, match="does not load pretrained FastWAM weights"): - policy = FastWAMPolicy(cfg) + policy = FastWAMPolicy(cfg) output = policy.forward( { @@ -207,89 +198,96 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch): ] -def test_from_pretrained_loads_weights_without_initializing_wan_backbone(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) - cfg.save_pretrained(tmp_path) - monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: FakeFastWAMCore()) - reference_policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True) - save_model(reference_policy, str(tmp_path / "model.safetensors")) +class CoreWithFrozenComponents(FakeFastWAMCore): + """Fake core mirroring the real one: frozen VAE / text encoder held as + *unregistered* attributes (via `object.__setattr__`) so they are excluded from + `state_dict()` and the saved checkpoint, but still moved by the `_apply` override.""" + def __init__(self): + super().__init__() + object.__setattr__(self, "vae", nn.Linear(2, 2)) + object.__setattr__(self, "text_encoder", nn.Linear(2, 2)) + self.vae.requires_grad_(False) + self.text_encoder.requires_grad_(False) + + def _apply(self, fn, *args, **kwargs): + super()._apply(fn, *args, **kwargs) + self.vae._apply(fn) + self.text_encoder._apply(fn) + return self + + +def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path): + cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) + + def build_core(self, config): + core = CoreWithFrozenComponents() + with torch.no_grad(): + core.dit.weight.fill_(0.5) + return core + + monkeypatch.setattr(FastWAMPolicy, "_build_core_model", build_core) + + reference = FastWAMPolicy(cfg) + with torch.no_grad(): + reference.model.dit.weight.fill_(1.25) # a distinctive, trained-looking weight + reference.save_pretrained(tmp_path) + + # Building from Wan2.2 must never happen on a checkpoint load. def fail_if_wan_pretrained_is_loaded(*args, **kwargs): - raise AssertionError("from_pretrained must not initialize or download Wan2.2 backbone components") + raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone") monkeypatch.setattr( "lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained", fail_if_wan_pretrained_is_loaded, ) - monkeypatch.setattr( - modeling_fastwam, - "_build_core_model_from_architecture", - lambda config: FakeFastWAMCore(), - raising=False, - ) - loaded_components_from = [] - monkeypatch.setattr( - FastWAMPolicy, - "load_wan_components_from_pretrained", - lambda self, path: loaded_components_from.append(path), - ) - policy = FastWAMPolicy.from_pretrained(tmp_path, strict=False) + policy = FastWAMPolicy.from_pretrained(tmp_path) - assert isinstance(policy.model, FakeFastWAMCore) - assert loaded_components_from == [tmp_path] + assert isinstance(policy.model, CoreWithFrozenComponents) + # The bundled checkpoint weights overwrote the freshly built (0.5) DiT weights. + assert torch.allclose(policy.model.dit.weight, torch.full_like(policy.model.dit.weight, 1.25)) -def test_save_pretrained_copies_required_wan_sidecars(monkeypatch, tmp_path): +def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path): cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) - source = tmp_path / "source" - tokenizer = source / WAN_T5_TOKENIZER - tokenizer.mkdir(parents=True) - vae = source / WAN_VAE_CHECKPOINT - text_encoder = source / WAN_T5_CHECKPOINT - tokenizer_file = tokenizer / "tokenizer.json" - vae.write_bytes(b"vae") - text_encoder.write_bytes(b"text") - tokenizer_file.write_text("{}") - core = FakeFastWAMCore() - core.model_paths = { - "vae": str(vae), - "text_encoder": str(text_encoder), - "tokenizer": str(tokenizer), - } - monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: core) - policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True) + monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents()) + policy = FastWAMPolicy(cfg) save_dir = tmp_path / "saved" policy.save_pretrained(save_dir) assert (save_dir / "model.safetensors").is_file() - assert (save_dir / WAN_VAE_CHECKPOINT).read_bytes() == b"vae" - assert (save_dir / WAN_T5_CHECKPOINT).read_bytes() == b"text" - assert (save_dir / WAN_T5_TOKENIZER / "tokenizer.json").read_text() == "{}" + # No Wan sidecar files either: the frozen backbone comes from the diffusers repo. + assert not (save_dir / "Wan2.2_VAE.safetensors").exists() + assert not (save_dir / "google").exists() + + with safe_open(save_dir / "model.safetensors", framework="pt") as f: + keys = set(f.keys()) + # Lean checkpoint: only the trainable DiT is saved; the frozen VAE / UMT5 text + # encoder are excluded (loaded from the diffusers/transformers repos at init). + assert any(key.startswith("model.dit.") for key in keys) + assert not any(key.startswith("model.vae.") for key in keys) + assert not any(key.startswith("model.text_encoder.") for key in keys) -def test_wan_component_resolution_uses_fixed_safetensors_layout(tmp_path): - tokenizer = tmp_path / WAN_T5_TOKENIZER - tokenizer.mkdir(parents=True) - (tmp_path / WAN_VAE_CHECKPOINT).touch() - (tmp_path / WAN_T5_CHECKPOINT).touch() - (tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors").touch() - (tokenizer / "tokenizer.json").touch() +def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch): + cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) + monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents()) + policy = FastWAMPolicy(cfg) - paths = resolve_wan_checkpoint_paths(tmp_path) - sidecar_paths = resolve_wan_component_paths(tmp_path) + # Unregistered: excluded from state_dict and from the optimizer's parameter set. + sd = policy.state_dict() + assert not any(k.startswith("model.vae.") or k.startswith("model.text_encoder.") for k in sd) + param_names = [n for n, _ in policy.named_parameters()] + assert not any("vae" in n or "text_encoder" in n for n in param_names) - assert paths.dit == [tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors"] - assert paths.vae == tmp_path / WAN_VAE_CHECKPOINT - assert paths.text_encoder == tmp_path / WAN_T5_CHECKPOINT - assert paths.tokenizer == tmp_path / WAN_T5_TOKENIZER - assert sidecar_paths.dit == [] - assert WAN_DIT_PATTERN == "diffusion_pytorch_model*.safetensors" - - (tmp_path / WAN_T5_CHECKPOINT).unlink() - with pytest.raises(FileNotFoundError, match="text encoder"): - resolve_wan_checkpoint_paths(tmp_path) + # ...but the `_apply` override still carries them through `.to()` (dtype stands in + # for device on a CPU box), so they never strand off the rest of the model. + policy.to(torch.float64) + assert policy.model.dit.weight.dtype == torch.float64 # registered + assert policy.model.vae.weight.dtype == torch.float64 # unregistered, moved via _apply + assert policy.model.text_encoder.weight.dtype == torch.float64 def test_pretrained_config_round_trips_fastwam_features(tmp_path): @@ -302,3 +300,57 @@ def test_pretrained_config_round_trips_fastwam_features(tmp_path): assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL assert loaded.action_feature.shape == (7,) assert loaded.robot_state_feature.shape == (8,) + + +def test_vae_adapter_empty_build_encode_decode_shapes(): + """Offline glue check of the diffusers-backed VAE adapter (random weights). + + Validates the encode/decode contract — 48 latent channels, 16x spatial / 4x + temporal compression, list-or-batch input, scaling round-trip — without any + weight download. (Numerical fidelity vs the original Wan VAE is a separate, + GPU + real-weights verification step.) + """ + pytest.importorskip("diffusers") + from diffusers import AutoencoderKLWan + + from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38 + + # Production always loads a real pretrained VAE from the diffusers repo; here we + # build the same architecture with random weights and dummy standardization stats + # to exercise the adapter's shape/scaling contract offline (fidelity is checked + # separately, with real weights, on GPU). + arch = { + "base_dim": 160, + "decoder_base_dim": 256, + "z_dim": 48, + "dim_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_scales": [], + "temperal_downsample": [False, True, True], + "dropout": 0.0, + "is_residual": True, + "in_channels": 12, + "out_channels": 12, + "patch_size": 2, + "scale_factor_spatial": 16, + "scale_factor_temporal": 4, + "clip_output": False, + "latents_mean": [0.0] * 48, + "latents_std": [1.0] * 48, + } + raw = AutoencoderKLWan.from_config(arch) + vae = WanVideoVAE38(dtype=torch.float32, device="cpu", pretrained=raw) + assert vae.z_dim == 48 + assert vae.upsampling_factor == 16 + assert vae.temporal_downsample_factor == 4 + + video = torch.rand(1, 3, 5, 32, 32) * 2 - 1 # [B,C,T,H,W] in [-1,1] + latents = vae.encode(video) + assert latents.shape == (1, 48, 2, 2, 2) # T'=(5-1)//4+1, H'=W'=32//16 + + decoded = vae.decode(latents) + assert decoded.shape[0] == 1 and decoded.shape[1] == 3 and decoded.shape[-2:] == (32, 32) + assert decoded.min() >= -1.0 and decoded.max() <= 1.0 + + # list input is accepted and equals the batched path + assert torch.equal(vae.encode([video[0]]), latents) diff --git a/uv.lock b/uv.lock index d041a99db..c5a62dfd8 100644 --- a/uv.lock +++ b/uv.lock @@ -1636,18 +1636,6 @@ http = [ { name = "aiohttp" }, ] -[[package]] -name = "ftfy" -version = "6.3.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "wcwidth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" }, -] - [[package]] name = "future" version = "1.0.0" @@ -2708,7 +2696,6 @@ all = [ { name = "faker" }, { name = "fastapi" }, { name = "feetech-servo-sdk" }, - { name = "ftfy" }, { name = "grpcio" }, { name = "grpcio-tools" }, { name = "gym-aloha" }, @@ -2747,7 +2734,6 @@ all = [ { name = "pyzmq" }, { name = "qwen-vl-utils" }, { name = "reachy2-sdk" }, - { name = "regex" }, { name = "rerun-sdk" }, { name = "ruff" }, { name = "scikit-image" }, @@ -2846,8 +2832,6 @@ evaluation = [ ] fastwam = [ { name = "diffusers" }, - { name = "ftfy" }, - { name = "regex" }, { name = "transformers" }, ] feetech = [ @@ -3108,7 +3092,6 @@ requires-dist = [ { name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" }, { name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" }, { name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" }, - { name = "ftfy", marker = "extra == 'fastwam'", specifier = ">=6.1.1,<7.0.0" }, { name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" }, { name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" }, { name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" }, @@ -3277,7 +3260,6 @@ requires-dist = [ { name = "pyzmq", marker = "extra == 'pyzmq-dep'", specifier = ">=26.2.1,<28.0.0" }, { name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" }, { name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" }, - { name = "regex", marker = "extra == 'fastwam'", specifier = ">=2024.0.0,<2027.0.0" }, { name = "requests", specifier = ">=2.32.0,<3.0.0" }, { name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.1" },