diff --git a/docs/source/fastwam.mdx b/docs/source/fastwam.mdx index 18796053d..f36ed1e4e 100644 --- a/docs/source/fastwam.mdx +++ b/docs/source/fastwam.mdx @@ -128,9 +128,6 @@ evaluation pipeline: --policy.toggle_action_dimensions='[-1]' ``` -`policy.invert_dimensions` remains available for older checkpoints or robot -setups that only need a sign inversion. - ## Results Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224): @@ -145,19 +142,6 @@ Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://hug Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300` (1x H20 140 GB). -## Reproducibility Checklist - -For a PR adding or updating FastWAM results, include: - -- the training dataset repo id -- the LeRobot-format checkpoint repo id -- the exact `lerobot-train` command -- the exact `lerobot-eval` or `lerobot-rollout` command -- the number of evaluation episodes -- the GPU type and count - -The upstream Fast-WAM release provides reference checkpoints and benchmark assets at `yuanty/fastwam`; LeRobot eval numbers should be reported from a converted LeRobot-format checkpoint so reviewers can reproduce them with the commands above. - ## References - [Fast-WAM paper](https://arxiv.org/abs/2603.16666) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 9c420135b..5f20f1da8 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -327,10 +327,6 @@ def make_pre_post_processors( to_output=transition_to_policy_action, ) _reconnect_relative_absolute_steps(preprocessor, postprocessor) - if isinstance(policy_cfg, FastWAMConfig): - from .fastwam.processor_fastwam import migrate_fastwam_postprocessor - - postprocessor = migrate_fastwam_postprocessor(postprocessor, policy_cfg) return preprocessor, postprocessor # Create a new processor based on policy type diff --git a/src/lerobot/policies/fastwam/configuration_fastwam.py b/src/lerobot/policies/fastwam/configuration_fastwam.py index 496a218db..e591f8c78 100644 --- a/src/lerobot/policies/fastwam/configuration_fastwam.py +++ b/src/lerobot/policies/fastwam/configuration_fastwam.py @@ -14,8 +14,7 @@ from __future__ import annotations -import json -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -29,9 +28,34 @@ from lerobot.optim import AdamWConfig from lerobot.utils.constants import ACTION, OBS_STATE WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B" +FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base" -def _default_video_dit_config(action_dim: int) -> dict[str, Any]: +_FASTWAM_VIDEO_BASE_COMPAT_KEYS = ( + "patch_size", + "in_dim", + "hidden_dim", + "ffn_dim", + "freq_dim", + "text_dim", + "out_dim", + "num_heads", + "attn_head_dim", + "num_layers", +) + +_FASTWAM_ACTION_BASE_COMPAT_KEYS = ( + "hidden_dim", + "ffn_dim", + "num_heads", + "attn_head_dim", + "num_layers", + "text_dim", + "freq_dim", +) + + +def default_video_dit_config(action_dim: int) -> dict[str, Any]: return { "patch_size": [1, 2, 2], "in_dim": 48, @@ -50,10 +74,11 @@ def _default_video_dit_config(action_dim: int) -> dict[str, Any]: "action_conditioned": False, "action_dim": action_dim, "action_group_causal_mask_mode": "group_diagonal", + "fp32_attention": True, } -def _default_action_dit_config(action_dim: int) -> dict[str, Any]: +def default_action_dit_config(action_dim: int) -> dict[str, Any]: return { "action_dim": action_dim, "hidden_dim": 1024, @@ -65,6 +90,7 @@ def _default_action_dit_config(action_dim: int) -> dict[str, Any]: "freq_dim": 256, "eps": 1.0e-6, "use_gradient_checkpointing": False, + "fp32_attention": True, } @@ -107,13 +133,18 @@ def _validate_wan_model_id(value: str, field_name: str) -> str: raise ValueError(f"`{field_name}` must be `{WAN22_MODEL_ID}` or an explicit local path, got `{value}`.") -def _coerce_pretrained_tokenizer_model_id(payload: dict[str, Any]) -> None: - tokenizer_model_id = payload.get("tokenizer_model_id") - if tokenizer_model_id is None: - return - if tokenizer_model_id == WAN22_MODEL_ID or _is_local_model_id(tokenizer_model_id): - return - payload["tokenizer_model_id"] = WAN22_MODEL_ID +def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool: + """Return whether `fastwam-base` partial weights can initialize this config.""" + + default_video_config = default_video_dit_config(config.action_dim) + default_action_config = default_action_dit_config(config.action_dim) + return all( + config.video_dit_config.get(key) == default_video_config.get(key) + for key in _FASTWAM_VIDEO_BASE_COMPAT_KEYS + ) and all( + config.action_dit_config.get(key) == default_action_config.get(key) + for key in _FASTWAM_ACTION_BASE_COMPAT_KEYS + ) @PreTrainedConfig.register_subclass("fastwam") @@ -131,8 +162,6 @@ class FastWAMConfig(PreTrainedConfig): context_len (int): Maximum text embedding token length. video_dit_config (dict[str, Any] | None): Wan video expert config. action_dit_config (dict[str, Any] | None): Action expert config. - invert_dimensions (list[int]): Action dimensions to multiply by -1 - during postprocessing. Supports negative indices. """ n_obs_steps: int = 1 @@ -145,6 +174,7 @@ class FastWAMConfig(PreTrainedConfig): context_len: int = 128 model_id: str = WAN22_MODEL_ID tokenizer_model_id: str = WAN22_MODEL_ID + base_model_id: str | None = FASTWAM_BASE_MODEL_ID tokenizer_max_len: int = 128 load_text_encoder: bool = True mot_checkpoint_mixed_attn: bool = False @@ -159,7 +189,8 @@ class FastWAMConfig(PreTrainedConfig): negative_prompt: str = "" sigma_shift: float | None = None tiled: bool = False - toggle_action_dimensions: list[int] = field(default_factory=lambda: [-1]) + fp32_attention: bool = True + toggle_action_dimensions: list[int] = field(default_factory=list) video_scheduler: dict[str, float | int] = field( default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000} ) @@ -169,7 +200,6 @@ class FastWAMConfig(PreTrainedConfig): loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0}) video_dit_config: dict[str, Any] | None = None action_dit_config: dict[str, Any] | None = None - invert_dimensions: list[int] = field(default_factory=list) normalization_mapping: dict[str, Any] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, @@ -183,33 +213,52 @@ class FastWAMConfig(PreTrainedConfig): optimizer_weight_decay: float = 1.0e-2 def __post_init__(self) -> None: - parent_post_init = getattr(super(), "__post_init__", None) - if parent_post_init is not None: - parent_post_init() + super().__post_init__() self.image_size = tuple(self.image_size) self.model_id = _validate_wan_model_id(self.model_id, "model_id") self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id") self.input_features = _coerce_policy_features(self.input_features) self.output_features = _coerce_policy_features(self.output_features) - self.invert_dimensions = [int(dim) for dim in self.invert_dimensions] self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions] self.normalization_mapping = _coerce_normalization_mapping(self.normalization_mapping) - self.video_dit_config = self.video_dit_config or _default_video_dit_config(self.action_dim) - self.action_dit_config = self.action_dit_config or _default_action_dit_config(self.action_dim) - self.input_features = self.input_features or self._default_input_features() - self.output_features = self.output_features or self._default_output_features() + self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim) + self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim) + self.video_dit_config["fp32_attention"] = bool(self.fp32_attention) + self.action_dit_config["fp32_attention"] = bool(self.fp32_attention) + if self.input_features is None: + height, width = self.image_size + self.input_features = { + "observation.images.image": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, height, width), + ) + } + if self.proprio_dim is not None: + self.input_features[OBS_STATE] = PolicyFeature( + type=FeatureType.STATE, + shape=(self.proprio_dim,), + ) + if self.output_features is None: + self.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))} self.validate_features() + if self.pretrained_path or self.use_peft or not self.base_model_id: + return + if not is_fastwam_base_compatible_config(self): + return + self.pretrained_path = Path(self.base_model_id) + self._auto_pretrained_path = True - @classmethod - def from_pretrained(cls, pretrained_name_or_path: str | Path, **_: Any) -> FastWAMConfig: - config_path = Path(pretrained_name_or_path) / "config.json" - with open(config_path, encoding="utf-8") as f: - payload = json.load(f) - payload.pop("type", None) - known_fields = {field.name for field in fields(cls)} - payload = {key: value for key, value in payload.items() if key in known_fields} - _coerce_pretrained_tokenizer_model_id(payload) - return cls(**payload) + def _save_pretrained(self, save_directory: Path) -> None: + if not getattr(self, "_auto_pretrained_path", False): + super()._save_pretrained(save_directory) + return + + pretrained_path = self.pretrained_path + self.pretrained_path = None + try: + super()._save_pretrained(save_directory) + finally: + self.pretrained_path = pretrained_path def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay) @@ -244,9 +293,6 @@ class FastWAMConfig(PreTrainedConfig): raise ValueError( f"FastWAM state feature shape must be ({self.proprio_dim},), got {state_shape}." ) - self._validate_image_feature_shapes() - - def _validate_image_feature_shapes(self) -> None: height, width = self.image_size image_width_sum = 0 for name, feature in self.image_features.items(): @@ -270,15 +316,3 @@ class FastWAMConfig(PreTrainedConfig): @property def reward_delta_indices(self) -> None: return None - - def _default_input_features(self) -> dict[str, PolicyFeature]: - height, width = self.image_size - features = { - "observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, width)) - } - if self.proprio_dim is not None: - features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,)) - return features - - def _default_output_features(self) -> dict[str, PolicyFeature]: - return {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))} diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 90de7a2b4..1205acf72 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -15,14 +15,15 @@ from __future__ import annotations import shutil +import warnings from collections import deque from pathlib import Path from typing import TYPE_CHECKING, Any import torch -import torch.nn.functional as functional from torch import Tensor +from lerobot.configs import PreTrainedConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.constants import ACTION, OBS_STATE @@ -55,6 +56,15 @@ class FastWAMPolicy(PreTrainedPolicy): config.validate_features() self.config = config self.dataset_stats = dataset_stats + suppress_base_init_warning = bool(kwargs.pop("_suppress_base_init_warning", False)) + if not skip_wan_init and not suppress_base_init_warning: + warnings.warn( + "FastWAMPolicy(config) initializes from architecture/config and does not load pretrained " + "FastWAM weights. For training or evaluation, use `make_policy(config)` or " + "`FastWAMPolicy.from_pretrained(...)`.", + RuntimeWarning, + stacklevel=2, + ) if skip_wan_init: self.model = _build_core_model_from_architecture(config) else: @@ -82,7 +92,7 @@ class FastWAMPolicy(PreTrainedPolicy): Args: pretrained_name_or_path (str | Path): HF-format policy directory containing `config.json`, `model.safetensors`, local Wan VAE, - local UMT5 text encoder, and tokenizer files. + local UMT5 text encoder safetensors, and tokenizer files. config (FastWAMConfig | None): Optional config override. When omitted, `config.json` is read from `pretrained_name_or_path`. force_download (bool): Forwarded to LeRobot's pretrained loader. @@ -107,7 +117,9 @@ class FastWAMPolicy(PreTrainedPolicy): revision=revision, ) if config is None: - config = cls.config_class.from_pretrained(pretrained_path) + config = PreTrainedConfig.from_pretrained(pretrained_path) + if not isinstance(config, FastWAMConfig): + raise TypeError(f"Expected FastWAM config, got {type(config).__name__}.") kwargs["_skip_wan_init"] = True policy = super().from_pretrained( pretrained_path, @@ -144,7 +156,7 @@ class FastWAMPolicy(PreTrainedPolicy): Args: pretrained_name_or_path (str | Path): Directory containing - `Wan2.2_VAE.pth`, `models_t5_umt5-xxl-enc-bf16.pth`, + `Wan2.2_VAE.safetensors`, `models_t5_umt5-xxl-enc-bf16.safetensors`, and `google/umt5-xxl/` tokenizer files. """ @@ -171,7 +183,11 @@ class FastWAMPolicy(PreTrainedPolicy): sample = _batch_to_training_sample(batch=batch, config=self.config) loss, metrics = self.model.training_loss(sample) output = {"loss": loss} - output.update(_metrics_to_tensors(metrics=metrics, device=loss.device)) + for key, value in (metrics or {}).items(): + if isinstance(value, Tensor): + output[key] = value.to(device=loss.device) + else: + output[key] = torch.as_tensor(value, device=loss.device) return output @torch.no_grad() @@ -232,6 +248,12 @@ def _resolve_pretrained_directory( from huggingface_hub import snapshot_download + from .wan_components import ( + WAN_T5_CHECKPOINT, + WAN_T5_TOKENIZER, + WAN_VAE_CHECKPOINT, + ) + snapshot_path = snapshot_download( repo_id=str(pretrained_name_or_path), revision=revision, @@ -242,9 +264,9 @@ def _resolve_pretrained_directory( allow_patterns=[ "config.json", "model.safetensors", - "Wan2.2_VAE.pth", - "models_t5_umt5-xxl-enc-bf16.pth", - "google/umt5-xxl/**", + WAN_VAE_CHECKPOINT, + WAN_T5_CHECKPOINT, + f"{WAN_T5_TOKENIZER}/**", ], ) return Path(snapshot_path) @@ -284,10 +306,6 @@ def _load_wan_components_into_policy(policy: FastWAMPolicy, paths: WanCheckpoint paths.tokenizer, tokenizer_max_len=int(policy.config.tokenizer_max_len), ) - _record_wan_component_paths(policy=policy, paths=paths) - - -def _record_wan_component_paths(policy: FastWAMPolicy, paths: WanCheckpointPaths) -> None: model_paths = dict(getattr(policy.model, "model_paths", {}) or {}) model_paths.update( { @@ -504,18 +522,6 @@ def _dtype_from_name(name: str) -> torch.dtype: return dtype_map[name] -def _metrics_to_tensors(metrics: dict[str, Any] | None, device: torch.device) -> dict[str, Tensor]: - if metrics is None: - return {} - tensor_metrics = {} - for key, value in metrics.items(): - if isinstance(value, Tensor): - tensor_metrics[key] = value.to(device=device) - else: - tensor_metrics[key] = torch.as_tensor(value, device=device) - return tensor_metrics - - def batch_device(batch: dict[str, Any]) -> torch.device: for value in batch.values(): if isinstance(value, Tensor): @@ -555,29 +561,11 @@ def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor: if image.ndim != 4: raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.") - if image.dtype == torch.uint8: - image = image.to(dtype=torch.float32).div(255.0).mul(2.0).sub(1.0) - else: - image = image.to(dtype=torch.float32) - image_min = float(image.detach().amin().cpu()) - image_max = float(image.detach().amax().cpu()) - if image_min >= 0.0 and image_max <= 1.0: - image = image.mul(2.0).sub(1.0) - elif image_max > 2.0: - image = image.div(255.0).mul(2.0).sub(1.0) - target_h, target_w = config.image_size if tuple(image.shape[-2:]) != (target_h, target_w): - image = _center_crop_resize(image, target_h=target_h, target_w=target_w) + raise ValueError( + "FastWAM policy expects preprocessed image tensors with shape " + f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. " + "Run the FastWAM preprocessor before calling the policy." + ) return image - - -def _center_crop_resize(image: Tensor, *, target_h: int, target_w: int) -> Tensor: - _, _, height, width = image.shape - scale = max(target_h / height, target_w / width) - resized_h = round(height * scale) - resized_w = round(width * scale) - image = functional.interpolate(image, size=(resized_h, resized_w), mode="bilinear", align_corners=False) - top = max((resized_h - target_h) // 2, 0) - left = max((resized_w - target_w) // 2, 0) - return image[:, :, top : top + target_h, left : left + target_w].contiguous() diff --git a/src/lerobot/policies/fastwam/modular_fastwam.py b/src/lerobot/policies/fastwam/modular_fastwam.py index 59ef76f25..d9481064e 100644 --- a/src/lerobot/policies/fastwam/modular_fastwam.py +++ b/src/lerobot/policies/fastwam/modular_fastwam.py @@ -27,7 +27,9 @@ from PIL import Image from .wan_components import load_wan22_ti2v_5b_components from .wan_video_dit import ( FastWAMAttentionBlock, + WanContinuousFlowMatchScheduler, fastwam_masked_attention, + gradient_checkpoint_forward, modulate, precompute_freqs_cis, sinusoidal_embedding_1d, @@ -62,130 +64,6 @@ def _apply_block_norm(block, name: str, x: torch.Tensor) -> torch.Tensor: return getattr(block, name)(x) -def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]: - from .wan.utils.fm_solvers import get_sampling_sigmas - - return get_sampling_sigmas(num_inference_steps, shift) - - -def create_custom_forward(module): - def custom_forward(*inputs, **kwargs): - return module(*inputs, **kwargs) - - return custom_forward - - -def gradient_checkpoint_forward( - model, - use_gradient_checkpointing, - *args, - **kwargs, -): - if use_gradient_checkpointing: - model_output = torch.utils.checkpoint.checkpoint( - create_custom_forward(model), - *args, - **kwargs, - use_reentrant=False, - ) - else: - model_output = model(*args, **kwargs) - return model_output - - -class WanContinuousFlowMatchScheduler: - """Continuous-time Flow-Matching scheduler with shift-based sampling.""" - - def __init__(self, num_train_timesteps: int = 1000, shift: float = 5.0, eps: float = 1e-10): - if num_train_timesteps <= 0: - raise ValueError(f"`num_train_timesteps` must be positive, got {num_train_timesteps}") - if shift <= 0: - raise ValueError(f"`shift` must be positive, got {shift}") - self.num_train_timesteps = int(num_train_timesteps) - self.shift = float(shift) - self.eps = float(eps) - self._y_min, self._weight_norm_const = self._precompute_training_weight_stats() - - @staticmethod - def _phi(u: torch.Tensor, shift: float) -> torch.Tensor: - return shift * u / (1.0 + (shift - 1.0) * u) - - def _precompute_training_weight_stats(self) -> tuple[float, float]: - steps = self.num_train_timesteps - u_grid = torch.linspace(1.0, 0.0, steps + 1, dtype=torch.float64)[:-1] - t_grid = self._phi(u_grid, self.shift) * float(steps) - y_grid = torch.exp(-2.0 * ((t_grid - (steps / 2.0)) / steps) ** 2) - y_min = float(y_grid.min().item()) - y_shifted_grid = y_grid - y_min - norm_const = float(y_shifted_grid.mean().item()) - return y_min, norm_const - - def sample_training_t(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: - if batch_size <= 0: - raise ValueError(f"`batch_size` must be positive, got {batch_size}") - u = torch.rand((batch_size,), device=device, dtype=torch.float32) - sigma = self._phi(u, self.shift) - timestep = sigma * float(self.num_train_timesteps) - return timestep.to(dtype=dtype) - - def training_weight(self, timestep: torch.Tensor) -> torch.Tensor: - t = timestep.to(dtype=torch.float32) - steps = float(self.num_train_timesteps) - y = torch.exp(-2.0 * ((t - (steps / 2.0)) / steps) ** 2) - y_shifted = y - self._y_min - weight = y_shifted / (self._weight_norm_const + self.eps) - if weight.numel() == 1: - return weight.reshape(()) - return weight - - def add_noise( - self, original_samples: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor - ) -> torch.Tensor: - sigma = (timestep / float(self.num_train_timesteps)).to( - original_samples.device, dtype=original_samples.dtype - ) - if sigma.ndim == 0: - return (1 - sigma) * original_samples + sigma * noise - sigma = sigma.view(-1, *([1] * (original_samples.ndim - 1))) - return (1 - sigma) * original_samples + sigma * noise - - @staticmethod - def training_target(sample: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - del timestep - return noise - sample - - def build_inference_schedule( - self, - num_inference_steps: int, - device: torch.device, - dtype: torch.dtype, - shift_override: float | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if num_inference_steps <= 0: - raise ValueError(f"`num_inference_steps` must be positive, got {num_inference_steps}") - shift = self.shift if shift_override is None else float(shift_override) - if shift <= 0: - raise ValueError(f"`shift` must be positive, got {shift}") - - sigma_steps = torch.as_tensor( - _get_wan_sampling_sigmas(num_inference_steps, shift), - device=device, - dtype=torch.float32, - ) - timesteps = sigma_steps * float(self.num_train_timesteps) - sigma_next = torch.cat([sigma_steps[1:], sigma_steps.new_zeros(1)]) - deltas = sigma_next - sigma_steps - return timesteps.to(dtype=dtype), deltas.to(dtype=dtype) - - @staticmethod - def step(model_output: torch.Tensor, delta: torch.Tensor, sample: torch.Tensor) -> torch.Tensor: - delta = delta.to(sample.device, dtype=sample.dtype) - if delta.ndim == 0: - return sample + model_output * delta - delta = delta.view(-1, *([1] * (sample.ndim - 1))) - return sample + model_output * delta - - class ActionHead(nn.Module): def __init__(self, hidden_dim: int, out_dim: int, eps: float): super().__init__() @@ -213,6 +91,7 @@ class ActionDiT(nn.Module): attn_head_dim: int, num_layers: int, use_gradient_checkpointing: bool = False, + fp32_attention: bool = True, ): super().__init__() self.hidden_dim = hidden_dim @@ -250,12 +129,14 @@ class ActionDiT(nn.Module): num_heads=num_heads, ffn_dim=ffn_dim, eps=eps, + fp32_attention=fp32_attention, ) for _ in range(num_layers) ] ) self.head = nn.Linear(hidden_dim, action_dim) self.freqs = precompute_freqs_cis(attn_head_dim, end=1024) + self.fp32_attention = bool(fp32_attention) self.use_gradient_checkpointing = use_gradient_checkpointing @@ -395,6 +276,7 @@ class MoT(nn.Module): self.num_layers = len(first_expert.blocks) self.num_heads = first_expert.num_heads self.attn_head_dim = first_expert.attn_head_dim + self.fp32_attention = bool(getattr(first_expert, "fp32_attention", True)) for name in self.expert_order[1:]: expert = self.mixtures[name] @@ -411,6 +293,8 @@ class MoT(nn.Module): "All experts must have same attn_head_dim; " f"got {self.attn_head_dim} and {expert.attn_head_dim}" ) + if bool(getattr(expert, "fp32_attention", True)) != self.fp32_attention: + raise ValueError("All experts must use the same `fp32_attention` setting.") logger.info(f"Initialized MoT with experts: {self.expert_order}, num_layers={self.num_layers}") for name in self.expert_order: @@ -450,7 +334,14 @@ class MoT(nn.Module): attn_mask = attention_mask.to(device=q_cat.device) def _forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - return fastwam_masked_attention(q=q, k=k, v=v, num_heads=self.num_heads, ctx_mask=attn_mask) + return fastwam_masked_attention( + q=q, + k=k, + v=v, + num_heads=self.num_heads, + ctx_mask=attn_mask, + fp32_attention=self.fp32_attention, + ) if self.mot_checkpoint_mixed_attn and self.training: return torch.utils.checkpoint.checkpoint( @@ -1107,6 +998,8 @@ class FastWAM(torch.nn.Module): seq_lens = mask.gt(0).sum(dim=1).long() for i, v in enumerate(seq_lens): prompt_emb[i, v:] = 0 + # Match FastWAM/Wan2.2 context semantics: padding embeddings are zeroed, + # while cross-attention still sees a fixed-length context. mask = torch.ones_like(mask) return prompt_emb.to(device=self.device), mask @@ -1822,7 +1715,7 @@ class FastWAM(torch.nn.Module): timestep_video = step_t_video.unsqueeze(0).to(dtype=latents_video.dtype, device=self.device) timestep_action = step_t_action.unsqueeze(0).to(dtype=latents_action.dtype, device=self.device) - pred_video_posi, pred_action_posi = self._predict_joint_noise( + pred_video, pred_action = self._predict_joint_noise( latents_video=latents_video, latents_action=latents_action, timestep_video=timestep_video, @@ -1832,8 +1725,6 @@ class FastWAM(torch.nn.Module): fuse_vae_embedding_in_latents=fuse_flag, gt_action=action, ) - pred_video = pred_video_posi - pred_action = pred_action_posi latents_video = self.infer_video_scheduler.step(pred_video, step_delta_video, latents_video) latents_action = self.infer_action_scheduler.step(pred_action, step_delta_action, latents_action) @@ -1924,7 +1815,7 @@ class FastWAM(torch.nn.Module): for step_t_action, step_delta_action in zip(infer_timesteps_action, infer_deltas_action, strict=True): timestep_action = step_t_action.unsqueeze(0).to(dtype=latents_action.dtype, device=self.device) - pred_action_posi = self._predict_action_noise_with_cache( + pred_action = self._predict_action_noise_with_cache( latents_action=latents_action, timestep_action=timestep_action, context=context, @@ -1933,7 +1824,6 @@ class FastWAM(torch.nn.Module): attention_mask=attention_mask, video_seq_len=video_seq_len, ) - pred_action = pred_action_posi latents_action = self.infer_action_scheduler.step(pred_action, step_delta_action, latents_action) diff --git a/src/lerobot/policies/fastwam/processor_fastwam.py b/src/lerobot/policies/fastwam/processor_fastwam.py index 9ca272be3..56ef7bfc6 100644 --- a/src/lerobot/policies/fastwam/processor_fastwam.py +++ b/src/lerobot/policies/fastwam/processor_fastwam.py @@ -19,7 +19,7 @@ from typing import Any import torch -from lerobot.configs import PipelineFeatureType, PolicyFeature +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import ( ActionProcessorStep, AddBatchDimensionProcessorStep, @@ -41,36 +41,6 @@ from lerobot.utils.constants import ( from .configuration_fastwam import FastWAMConfig -@dataclass -@ProcessorStepRegistry.register(name="fastwam_action_inversion_processor") -class FastWAMActionInversionProcessorStep(ActionProcessorStep): - """Invert configured FastWAM action dimensions during postprocessing.""" - - invert_dimensions: list[int] - - def action(self, action: PolicyAction) -> PolicyAction: - if not self.invert_dimensions: - return action - processed_action = action.clone() - action_dim = int(processed_action.shape[-1]) - for dim in self.invert_dimensions: - resolved_dim = dim if dim >= 0 else action_dim + dim - if resolved_dim < 0 or resolved_dim >= action_dim: - raise ValueError( - f"FastWAM action inversion dimension {dim} is out of bounds for action dim {action_dim}." - ) - processed_action[..., resolved_dim] = -processed_action[..., resolved_dim] - return processed_action - - def get_config(self) -> dict[str, Any]: - return {"invert_dimensions": self.invert_dimensions} - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - return features - - @dataclass @ProcessorStepRegistry.register(name="fastwam_action_toggle_processor") class FastWAMActionToggleProcessorStep(ActionProcessorStep): @@ -120,6 +90,17 @@ def make_fastwam_pre_post_processors( output processor pipelines discoverable by LeRobot. """ + normalization_stats: dict[str, dict[str, Any]] = { + key: dict(value) for key, value in (dataset_stats or {}).items() + } + for key, feature in config.input_features.items(): + if feature.type != FeatureType.VISUAL: + continue + channels = int(feature.shape[0]) + normalization_stats[key] = { + "mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32), + "std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32), + } input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), @@ -127,7 +108,7 @@ def make_fastwam_pre_post_processors( NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, - stats=dataset_stats, + stats=normalization_stats, device=config.device, ), ] @@ -135,43 +116,14 @@ def make_fastwam_pre_post_processors( UnnormalizerProcessorStep( features=config.output_features, norm_map=config.normalization_mapping, - stats=dataset_stats, + stats=normalization_stats, ), ] if config.toggle_action_dimensions: output_steps.append( FastWAMActionToggleProcessorStep(toggle_dimensions=config.toggle_action_dimensions) ) - elif config.invert_dimensions: - output_steps.append(FastWAMActionInversionProcessorStep(invert_dimensions=config.invert_dimensions)) output_steps.append(DeviceProcessorStep(device="cpu")) - return _build_lerobot_pipelines(input_steps=input_steps, output_steps=output_steps) - - -def migrate_fastwam_postprocessor( - postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], - config: FastWAMConfig, -) -> PolicyProcessorPipeline[PolicyAction, PolicyAction]: - """Upgrade old FastWAM postprocessor pipelines to the LIBERO toggle step.""" - - if not config.toggle_action_dimensions: - return postprocessor - - toggle_step = FastWAMActionToggleProcessorStep(toggle_dimensions=config.toggle_action_dimensions) - steps = [ - step - for step in postprocessor.steps - if not isinstance(step, (FastWAMActionInversionProcessorStep, FastWAMActionToggleProcessorStep)) - ] - insert_at = next( - (idx for idx, step in enumerate(steps) if isinstance(step, DeviceProcessorStep)), len(steps) - ) - steps.insert(insert_at, toggle_step) - postprocessor.steps = steps - return postprocessor - - -def _build_lerobot_pipelines(input_steps: list[Any], output_steps: list[Any]) -> tuple[Any, Any]: return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, diff --git a/src/lerobot/policies/fastwam/wan/README.md b/src/lerobot/policies/fastwam/wan/README.md index 4d2234cf9..7d0a2b169 100644 --- a/src/lerobot/policies/fastwam/wan/README.md +++ b/src/lerobot/policies/fastwam/wan/README.md @@ -1,6 +1,6 @@ # Wan2.2 Upstream Subset -This directory contains an unmodified subset of the official Wan2.2 source tree. +This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM. - Upstream repository: https://github.com/Wan-Video/Wan2.2 - Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc` @@ -12,17 +12,15 @@ Copied files: - `wan/modules/model.py` - `wan/modules/t5.py` - `wan/modules/tokenizers.py` -- `wan/modules/vae2_1.py` - `wan/modules/vae2_2.py` - `wan/modules/__init__.py` - `wan/utils/fm_solvers.py` -- `wan/utils/fm_solvers_unipc.py` - `wan/utils/__init__.py` -FastWAM-specific model glue and any code adapted from these modules live outside this directory. This keeps the upstream Wan2.2 code reviewable as a vendored reference subset and makes it straightforward to replace this directory with an external Wan2.2 dependency by changing import paths. +FastWAM-specific model glue and any larger code adapted from these modules live outside this directory. This keeps the upstream Wan2.2 code reviewable as a vendored reference subset and makes it straightforward to replace this directory with an external Wan2.2 dependency by changing import paths. Current FastWAM adapters that directly reuse this vendored subset: - `../wan_components.py` instantiates the upstream `wan.modules.t5.umt5_xxl` encoder factory and uses `wan.modules.tokenizers.HuggingfaceTokenizer`. -- `../wan_adapters.py` wraps `wan.modules.vae2_2.Wan2_2_VAE` with the FastWAM tensor-batch encode/decode API. +- `../wan_adapters.py` wraps `wan.modules.vae2_2.Wan2VAE` with the FastWAM tensor-batch encode/decode API. - `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps. diff --git a/src/lerobot/policies/fastwam/wan/modules/__init__.py b/src/lerobot/policies/fastwam/wan/modules/__init__.py index 57f20b699..ecc42a31e 100644 --- a/src/lerobot/policies/fastwam/wan/modules/__init__.py +++ b/src/lerobot/policies/fastwam/wan/modules/__init__.py @@ -3,12 +3,10 @@ from .attention import flash_attention from .model import WanModel from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model from .tokenizers import HuggingfaceTokenizer -from .vae2_1 import Wan2_1_VAE -from .vae2_2 import Wan2_2_VAE +from .vae2_2 import Wan2VAE __all__ = [ - "Wan2_1_VAE", - "Wan2_2_VAE", + "Wan2VAE", "WanModel", "T5Model", "T5Encoder", diff --git a/src/lerobot/policies/fastwam/wan/modules/attention.py b/src/lerobot/policies/fastwam/wan/modules/attention.py index 7ce667e43..cf7dddea8 100644 --- a/src/lerobot/policies/fastwam/wan/modules/attention.py +++ b/src/lerobot/policies/fastwam/wan/modules/attention.py @@ -66,7 +66,7 @@ def flash_attention( q = half(q.flatten(0, 1)) q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True) else: - q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens, strict=False)])) # preprocess key, value if k_lens is None: @@ -74,8 +74,8 @@ def flash_attention( v = half(v.flatten(0, 1)) k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True) else: - k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) - v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens, strict=False)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens, strict=False)])) q = q.to(v.dtype) k = k.to(v.dtype) @@ -84,7 +84,7 @@ def flash_attention( q = q * q_scale if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: - warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.") + warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.", stacklevel=2) # apply attention if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: @@ -166,7 +166,8 @@ def attention( else: if q_lens is not None or k_lens is not None: warnings.warn( - "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance." + "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.", + stacklevel=2, ) attn_mask = None diff --git a/src/lerobot/policies/fastwam/wan/modules/model.py b/src/lerobot/policies/fastwam/wan/modules/model.py index 75b3381b8..5f3c41e1a 100644 --- a/src/lerobot/policies/fastwam/wan/modules/model.py +++ b/src/lerobot/policies/fastwam/wan/modules/model.py @@ -421,7 +421,7 @@ class WanModel(ModelMixin, ConfigMixin): self.freqs = self.freqs.to(device) if y is not None: - x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y, strict=False)] # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] @@ -450,14 +450,14 @@ class WanModel(ModelMixin, ConfigMixin): ) # arguments - kwargs = dict( - e=e0, - seq_lens=seq_lens, - grid_sizes=grid_sizes, - freqs=self.freqs, - context=context, - context_lens=context_lens, - ) + kwargs = { + "e": e0, + "seq_lens": seq_lens, + "grid_sizes": grid_sizes, + "freqs": self.freqs, + "context": context, + "context_lens": context_lens, + } for block in self.blocks: x = block(x, **kwargs) @@ -487,10 +487,10 @@ class WanModel(ModelMixin, ConfigMixin): c = self.out_dim out = [] - for u, v in zip(x, grid_sizes.tolist()): + for u, v in zip(x, grid_sizes.tolist(), strict=False): u = u[: math.prod(v)].view(*v, *self.patch_size, c) u = torch.einsum("fhwpqrc->cfphqwr", u) - u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size, strict=False)]) out.append(u) return out diff --git a/src/lerobot/policies/fastwam/wan/modules/t5.py b/src/lerobot/policies/fastwam/wan/modules/t5.py index e7227e831..c90fd3dbc 100644 --- a/src/lerobot/policies/fastwam/wan/modules/t5.py +++ b/src/lerobot/policies/fastwam/wan/modules/t5.py @@ -5,7 +5,7 @@ import math import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as functional from .tokenizers import HuggingfaceTokenizer @@ -49,7 +49,7 @@ class GELU(nn.Module): class T5LayerNorm(nn.Module): def __init__(self, dim, eps=1e-6): - super(T5LayerNorm, self).__init__() + super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) @@ -64,7 +64,7 @@ class T5LayerNorm(nn.Module): class T5Attention(nn.Module): def __init__(self, dim, dim_attn, num_heads, dropout=0.1): assert dim_attn % num_heads == 0 - super(T5Attention, self).__init__() + super().__init__() self.dim = dim self.dim_attn = dim_attn self.num_heads = num_heads @@ -103,7 +103,7 @@ class T5Attention(nn.Module): # compute attention (T5 does not use scaling) attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias - attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = functional.softmax(attn.float(), dim=-1).type_as(attn) x = torch.einsum("bnij,bjnc->binc", attn, v) # output @@ -115,7 +115,7 @@ class T5Attention(nn.Module): class T5FeedForward(nn.Module): def __init__(self, dim, dim_ffn, dropout=0.1): - super(T5FeedForward, self).__init__() + super().__init__() self.dim = dim self.dim_ffn = dim_ffn @@ -135,7 +135,7 @@ class T5FeedForward(nn.Module): class T5SelfAttention(nn.Module): def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): - super(T5SelfAttention, self).__init__() + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -161,7 +161,7 @@ class T5SelfAttention(nn.Module): class T5CrossAttention(nn.Module): def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): - super(T5CrossAttention, self).__init__() + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -190,7 +190,7 @@ class T5CrossAttention(nn.Module): class T5RelativeEmbedding(nn.Module): def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): - super(T5RelativeEmbedding, self).__init__() + super().__init__() self.num_buckets = num_buckets self.num_heads = num_heads self.bidirectional = bidirectional @@ -239,7 +239,7 @@ class T5Encoder(nn.Module): def __init__( self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1 ): - super(T5Encoder, self).__init__() + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -280,7 +280,7 @@ class T5Decoder(nn.Module): def __init__( self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1 ): - super(T5Decoder, self).__init__() + super().__init__() self.dim = dim self.dim_attn = dim_attn self.dim_ffn = dim_ffn @@ -340,7 +340,7 @@ class T5Model(nn.Module): shared_pos=True, dropout=0.1, ): - super(T5Model, self).__init__() + super().__init__() self.vocab_size = vocab_size self.dim = dim self.dim_attn = dim_attn @@ -391,13 +391,14 @@ def _t5( encoder_only=False, decoder_only=False, return_tokenizer=False, - tokenizer_kwargs={}, + tokenizer_kwargs=None, dtype=torch.float32, device="cpu", **kwargs, ): # sanity check assert not (encoder_only and decoder_only) + tokenizer_kwargs = tokenizer_kwargs or {} # params if encoder_only: @@ -431,18 +432,18 @@ def _t5( def umt5_xxl(**kwargs): - cfg = dict( - vocab_size=256384, - dim=4096, - dim_attn=4096, - dim_ffn=10240, - num_heads=64, - encoder_layers=24, - decoder_layers=24, - num_buckets=32, - shared_pos=False, - dropout=0.1, - ) + cfg = { + "vocab_size": 256384, + "dim": 4096, + "dim_attn": 4096, + "dim_ffn": 10240, + "num_heads": 64, + "encoder_layers": 24, + "decoder_layers": 24, + "num_buckets": 32, + "shared_pos": False, + "dropout": 0.1, + } cfg.update(**kwargs) return _t5("umt5-xxl", **cfg) @@ -470,7 +471,7 @@ class T5EncoderModel: .requires_grad_(False) ) logging.info(f"loading {checkpoint_path}") - model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True)) self.model = model if shard_fn is not None: self.model = shard_fn(self.model, sync_module_states=False) @@ -485,4 +486,4 @@ class T5EncoderModel: mask = mask.to(device) seq_lens = mask.gt(0).sum(dim=1).long() context = self.model(ids, mask) - return [u[:v] for u, v in zip(context, seq_lens)] + return [u[:v] for u, v in zip(context, seq_lens, strict=False)] diff --git a/src/lerobot/policies/fastwam/wan/modules/vae2_1.py b/src/lerobot/policies/fastwam/wan/modules/vae2_1.py deleted file mode 100644 index b6d99107e..000000000 --- a/src/lerobot/policies/fastwam/wan/modules/vae2_1.py +++ /dev/null @@ -1,665 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import logging - -import torch -import torch.cuda.amp as amp -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - -__all__ = [ - "Wan2_1_VAE", -] - -CACHE_T = 2 - - -class CausalConv3d(nn.Conv3d): - """ - Causal 3d convolusion. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._padding = ( - self.padding[2], - self.padding[2], - self.padding[1], - self.padding[1], - 2 * self.padding[0], - 0, - ) - self.padding = (0, 0, 0) - - def forward(self, x, cache_x=None): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) - - return super().forward(x) - - -class RMS_norm(nn.Module): - def __init__(self, dim, channel_first=True, images=True, bias=False): - super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - - self.channel_first = channel_first - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 - - def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias - - -class Upsample(nn.Upsample): - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) - - -class Resample(nn.Module): - def __init__(self, dim, mode): - assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == "upsample2d": - self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim // 2, 3, padding=1), - ) - elif mode == "upsample3d": - self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), - nn.Conv2d(dim, dim // 2, 3, padding=1), - ) - self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - - elif mode == "downsample2d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - - else: - self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() - if self.mode == "upsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 - else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) - if feat_cache[idx] == "Rep": - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.resample(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - - if self.mode == "downsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - - x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - return x - - def init_weight(self, conv): - conv_weight = conv.weight - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix - nn.init.zeros_(conv_weight) - # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 - conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - def init_weight2(self, conv): - conv_weight = conv.weight.data - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - init_matrix = torch.eye(c1 // 2, c2) - # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) - conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - -class ResidualBlock(nn.Module): - def __init__(self, in_dim, out_dim, dropout=0.0): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - - # layers - self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), - nn.SiLU(), - CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), - nn.SiLU(), - nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1), - ) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) - for layer in self.residual: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + h - - -class AttentionBlock(nn.Module): - """ - Causal self-attention with a single head. - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim - - # layers - self.norm = RMS_norm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) - - # zero out the last layer params - nn.init.zeros_(self.proj.weight) - - def forward(self, x): - identity = x - b, c, t, h, w = x.size() - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.norm(x) - # compute query, key, value - q, k, v = ( - self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) - ) - - # apply attention - x = F.scaled_dot_product_attention( - q, - k, - v, - ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) - - # output - x = self.proj(x) - x = rearrange(x, "(b t) c h w-> b c t h w", t=t) - return x + identity - - -class Encoder3d(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0, - ): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - - # dimensions - dims = [dim * u for u in [1] + dim_mult] - scale = 1.0 - - # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) - - # downsample blocks - downsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - for _ in range(num_res_blocks): - downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - downsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # downsample block - if i != len(dim_mult) - 1: - mode = "downsample3d" if temperal_downsample[i] else "downsample2d" - downsamples.append(Resample(out_dim, mode=mode)) - scale /= 2.0 - self.downsamples = nn.Sequential(*downsamples) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), - AttentionBlock(out_dim), - ResidualBlock(out_dim, out_dim, dropout), - ) - - # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1) - ) - - def forward(self, x, feat_cache=None, feat_idx=[0]): - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - ## downsamples - for layer in self.downsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## middle - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -class Decoder3d(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0, - ): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_upsample = temperal_upsample - - # dimensions - dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2 ** (len(dim_mult) - 2) - - # init block - self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), - AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout), - ) - - # upsample blocks - upsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - if i == 1 or i == 2 or i == 3: - in_dim = in_dim // 2 - for _ in range(num_res_blocks + 1): - upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - upsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # upsample block - if i != len(dim_mult) - 1: - mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - upsamples.append(Resample(out_dim, mode=mode)) - scale *= 2.0 - self.upsamples = nn.Sequential(*upsamples) - - # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1) - ) - - def forward(self, x, feat_cache=None, feat_idx=[0]): - ## conv1 - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - ## middle - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -def count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, CausalConv3d): - count += 1 - return count - - -class WanVAE_(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0, - ): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - self.temperal_upsample = temperal_downsample[::-1] - - # modules - self.encoder = Encoder3d( - dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout - ) - self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) - self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d( - dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout - ) - - def forward(self, x): - mu, log_var = self.encode(x) - z = self.reparameterize(mu, log_var) - x_recon = self.decode(z) - return x_recon, mu, log_var - - def encode(self, x, scale): - self.clear_cache() - ## cache - t = x.shape[2] - iter_ = 1 + (t - 1) // 4 - ## 对encode输入的x,按时间拆分为1、4、4、4.... - for i in range(iter_): - self._enc_conv_idx = [0] - if i == 0: - out = self.encoder( - x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx - ) - else: - out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, - ) - out = torch.cat([out, out_], 2) - mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) - else: - mu = (mu - scale[0]) * scale[1] - self.clear_cache() - return mu - - def decode(self, z, scale): - self.clear_cache() - # z: [b,c,t,h,w] - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) - else: - z = z / scale[1] + scale[0] - iter_ = z.shape[2] - x = self.conv2(z) - for i in range(iter_): - self._conv_idx = [0] - if i == 0: - out = self.decoder( - x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx - ) - else: - out_ = self.decoder( - x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx - ) - out = torch.cat([out, out_], 2) - self.clear_cache() - return out - - def reparameterize(self, mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) - return eps * std + mu - - def sample(self, imgs, deterministic=False): - mu, log_var = self.encode(imgs) - if deterministic: - return mu - std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) - return mu + std * torch.randn_like(std) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num - - -def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): - """ - Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. - """ - # params - cfg = dict( - dim=96, - z_dim=z_dim, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[False, True, True], - dropout=0.0, - ) - cfg.update(**kwargs) - - # init model - with torch.device("meta"): - model = WanVAE_(**cfg) - - # load checkpoint - logging.info(f"loading {pretrained_path}") - model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) - - return model - - -class Wan2_1_VAE: - def __init__(self, z_dim=16, vae_pth="cache/vae_step_411000.pth", dtype=torch.float, device="cuda"): - self.dtype = dtype - self.device = device - - mean = [ - -0.7571, - -0.7089, - -0.9113, - 0.1075, - -0.1745, - 0.9653, - -0.1517, - 1.5508, - 0.4134, - -0.0715, - 0.5517, - -0.3632, - -0.1922, - -0.9497, - 0.2503, - -0.2921, - ] - std = [ - 2.8184, - 1.4541, - 2.3275, - 2.6558, - 1.2196, - 1.7708, - 2.6052, - 2.0743, - 3.2687, - 2.1526, - 2.8652, - 1.5579, - 1.6382, - 1.1253, - 2.8251, - 1.9160, - ] - self.mean = torch.tensor(mean, dtype=dtype, device=device) - self.std = torch.tensor(std, dtype=dtype, device=device) - self.scale = [self.mean, 1.0 / self.std] - - # init model - self.model = ( - _video_vae( - pretrained_path=vae_pth, - z_dim=z_dim, - ) - .eval() - .requires_grad_(False) - .to(device) - ) - - def encode(self, videos): - """ - videos: A list of videos each with shape [C, T, H, W]. - """ - with amp.autocast(dtype=self.dtype): - return [self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos] - - def decode(self, zs): - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs - ] diff --git a/src/lerobot/policies/fastwam/wan/modules/vae2_2.py b/src/lerobot/policies/fastwam/wan/modules/vae2_2.py index 38cf88f3d..94b30ec4a 100644 --- a/src/lerobot/policies/fastwam/wan/modules/vae2_2.py +++ b/src/lerobot/policies/fastwam/wan/modules/vae2_2.py @@ -1,14 +1,16 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import logging +from pathlib import Path import torch import torch.cuda.amp as amp import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as functional from einops import rearrange +from safetensors.torch import load_file __all__ = [ - "Wan2_2_VAE", + "Wan2VAE", ] CACHE_T = 2 @@ -37,12 +39,12 @@ class CausalConv3d(nn.Conv3d): cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) + x = functional.pad(x, padding) return super().forward(x) -class RMS_norm(nn.Module): +class RMSNorm(nn.Module): def __init__(self, dim, channel_first=True, images=True, bias=False): super().__init__() broadcastable_dims = (1, 1, 1) if not images else (1, 1) @@ -54,7 +56,10 @@ class RMS_norm(nn.Module): self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + return ( + functional.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + + self.bias + ) class Upsample(nn.Upsample): @@ -99,55 +104,52 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] b, c, t, h, w = x.size() - if self.mode == "upsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 - else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - # cache last frame of last two chunk - cache_x = torch.cat( - [ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), - cache_x, - ], - dim=2, - ) - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = torch.cat( - [torch.zeros_like(cache_x).to(cache_x.device), cache_x], - dim=2, - ) - if feat_cache[idx] == "Rep": - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) - x = x.reshape(b, c, t * 2, h, w) + if self.mode == "upsample3d" and feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + x = self.time_conv(x) if feat_cache[idx] == "Rep" else self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") x = self.resample(x) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - if self.mode == "downsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + if self.mode == "downsample3d" and feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 return x def init_weight(self, conv): @@ -180,17 +182,19 @@ class ResidualBlock(nn.Module): # layers self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), + RMSNorm(in_dim, images=False), nn.SiLU(), CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), + RMSNorm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), CausalConv3d(out_dim, out_dim, 3, padding=1), ) self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] h = self.shortcut(x) for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: @@ -223,7 +227,7 @@ class AttentionBlock(nn.Module): self.dim = dim # layers - self.norm = RMS_norm(dim) + self.norm = RMSNorm(dim) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) self.proj = nn.Conv2d(dim, dim, 1) @@ -241,7 +245,7 @@ class AttentionBlock(nn.Module): ) # apply attention - x = F.scaled_dot_product_attention( + x = functional.scaled_dot_product_attention( q, k, v, @@ -309,33 +313,33 @@ class AvgDown3D(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t pad = (0, 0, 0, 0, pad_t, 0) - x = F.pad(x, pad) - B, C, T, H, W = x.shape + x = functional.pad(x, pad) + batch, channels, frames, height, width = x.shape x = x.view( - B, - C, - T // self.factor_t, + batch, + channels, + frames // self.factor_t, self.factor_t, - H // self.factor_s, + height // self.factor_s, self.factor_s, - W // self.factor_s, + width // self.factor_s, self.factor_s, ) x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() x = x.view( - B, - C * self.factor, - T // self.factor_t, - H // self.factor_s, - W // self.factor_s, + batch, + channels * self.factor, + frames // self.factor_t, + height // self.factor_s, + width // self.factor_s, ) x = x.view( - B, + batch, self.out_channels, self.group_size, - T // self.factor_t, - H // self.factor_s, - W // self.factor_s, + frames // self.factor_t, + height // self.factor_s, + width // self.factor_s, ) x = x.mean(dim=2) return x @@ -385,7 +389,7 @@ class DupUp3D(nn.Module): return x -class Down_ResidualBlock(nn.Module): +class DownResidualBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False): super().__init__() @@ -410,7 +414,9 @@ class Down_ResidualBlock(nn.Module): self.downsamples = nn.Sequential(*downsamples) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] x_copy = x.clone() for module in self.downsamples: x = module(x, feat_cache, feat_idx) @@ -418,7 +424,7 @@ class Down_ResidualBlock(nn.Module): return x + self.avg_shortcut(x_copy) -class Up_ResidualBlock(nn.Module): +class UpResidualBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False): super().__init__() # Shortcut path with upsample @@ -445,7 +451,9 @@ class Up_ResidualBlock(nn.Module): self.upsamples = nn.Sequential(*upsamples) - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): + if feat_idx is None: + feat_idx = [0] x_main = x.clone() for module in self.upsamples: x_main = module(x_main, feat_cache, feat_idx) @@ -461,12 +469,18 @@ class Encoder3d(nn.Module): self, dim=128, z_dim=4, - dim_mult=[1, 2, 4, 4], + dim_mult=None, num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], + attn_scales=None, + temperal_downsample=None, dropout=0.0, ): + if temperal_downsample is None: + temperal_downsample = [True, True, False] + if attn_scales is None: + attn_scales = [] + if dim_mult is None: + dim_mult = [1, 2, 4, 4] super().__init__() self.dim = dim self.z_dim = z_dim @@ -484,10 +498,10 @@ class Encoder3d(nn.Module): # downsample blocks downsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=False)): t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False downsamples.append( - Down_ResidualBlock( + DownResidualBlock( in_dim=in_dim, out_dim=out_dim, dropout=dropout, @@ -508,13 +522,15 @@ class Encoder3d(nn.Module): # # output blocks self.head = nn.Sequential( - RMS_norm(out_dim, images=False), + RMSNorm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1), ) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() @@ -534,10 +550,7 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) + x = layer(x, feat_cache, feat_idx) if feat_cache is not None else layer(x) ## middle for layer in self.middle: @@ -573,12 +586,18 @@ class Decoder3d(nn.Module): self, dim=128, z_dim=4, - dim_mult=[1, 2, 4, 4], + dim_mult=None, num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], + attn_scales=None, + temperal_upsample=None, dropout=0.0, ): + if temperal_upsample is None: + temperal_upsample = [False, True, True] + if attn_scales is None: + attn_scales = [] + if dim_mult is None: + dim_mult = [1, 2, 4, 4] super().__init__() self.dim = dim self.z_dim = z_dim @@ -589,7 +608,6 @@ class Decoder3d(nn.Module): # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) @@ -602,10 +620,10 @@ class Decoder3d(nn.Module): # upsample blocks upsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=False)): t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False upsamples.append( - Up_ResidualBlock( + UpResidualBlock( in_dim=in_dim, out_dim=out_dim, dropout=dropout, @@ -618,12 +636,14 @@ class Decoder3d(nn.Module): # output blocks self.head = nn.Sequential( - RMS_norm(out_dim, images=False), + RMSNorm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 12, 3, padding=1), ) - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): + if feat_idx is None: + feat_idx = [0] if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() @@ -649,10 +669,7 @@ class Decoder3d(nn.Module): ## upsamples for layer in self.upsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx, first_chunk) - else: - x = layer(x) + x = layer(x, feat_cache, feat_idx, first_chunk) if feat_cache is not None else layer(x) ## head for layer in self.head: @@ -683,18 +700,24 @@ def count_conv3d(model): return count -class WanVAE_(nn.Module): +class WanVAEModel(nn.Module): def __init__( self, dim=160, dec_dim=256, z_dim=16, - dim_mult=[1, 2, 4, 4], + dim_mult=None, num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], + attn_scales=None, + temperal_downsample=None, dropout=0.0, ): + if temperal_downsample is None: + temperal_downsample = [True, True, False] + if attn_scales is None: + attn_scales = [] + if dim_mult is None: + dim_mult = [1, 2, 4, 4] super().__init__() self.dim = dim self.z_dim = z_dim @@ -726,7 +749,9 @@ class WanVAE_(nn.Module): dropout, ) - def forward(self, x, scale=[0, 1]): + def forward(self, x, scale=None): + if scale is None: + scale = [0, 1] mu = self.encode(x, scale) x_recon = self.decode(mu, scale) return x_recon, mu @@ -811,40 +836,47 @@ class WanVAE_(nn.Module): def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): # params - cfg = dict( - dim=dim, - z_dim=z_dim, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, True], - dropout=0.0, - ) + cfg = { + "dim": dim, + "z_dim": z_dim, + "dim_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_scales": [], + "temperal_downsample": [True, True, True], + "dropout": 0.0, + } cfg.update(**kwargs) # init model with torch.device("meta"): - model = WanVAE_(**cfg) + model = WanVAEModel(**cfg) # load checkpoint logging.info(f"loading {pretrained_path}") - model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) + if Path(pretrained_path).suffix != ".safetensors": + raise ValueError(f"Wan2.2 VAE checkpoint must be safetensors, got {pretrained_path}.") + state_dict = load_file(pretrained_path, device=str(device)) + model.load_state_dict(state_dict, assign=True) return model -class Wan2_2_VAE: +class Wan2VAE: def __init__( self, z_dim=48, c_dim=160, vae_pth=None, - dim_mult=[1, 2, 4, 4], - temperal_downsample=[False, True, True], + dim_mult=None, + temperal_downsample=None, dtype=torch.float, device="cuda", ): + if temperal_downsample is None: + temperal_downsample = [False, True, True] + if dim_mult is None: + dim_mult = [1, 2, 4, 4] self.dtype = dtype self.device = device diff --git a/src/lerobot/policies/fastwam/wan/utils/__init__.py b/src/lerobot/policies/fastwam/wan/utils/__init__.py index e08f2ba7f..ba223fe65 100644 --- a/src/lerobot/policies/fastwam/wan/utils/__init__.py +++ b/src/lerobot/policies/fastwam/wan/utils/__init__.py @@ -1,15 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -from .fm_solvers import ( - FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, - retrieve_timesteps, -) -from .fm_solvers_unipc import FlowUniPCMultistepScheduler +from .fm_solvers import get_sampling_sigmas __all__ = [ - "HuggingfaceTokenizer", "get_sampling_sigmas", - "retrieve_timesteps", - "FlowDPMSolverMultistepScheduler", - "FlowUniPCMultistepScheduler", ] diff --git a/src/lerobot/policies/fastwam/wan/utils/fm_solvers.py b/src/lerobot/policies/fastwam/wan/utils/fm_solvers.py index 9bcf05987..42b453590 100644 --- a/src/lerobot/policies/fastwam/wan/utils/fm_solvers.py +++ b/src/lerobot/policies/fastwam/wan/utils/fm_solvers.py @@ -1,837 +1,9 @@ -# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py -# Convert dpm solver for flow matching # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import inspect -import math -from typing import List, Optional, Tuple, Union - import numpy as np -import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import ( - KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput, -) -from diffusers.utils import deprecate, is_scipy_available -from diffusers.utils.torch_utils import randn_tensor - -if is_scipy_available(): - pass def get_sampling_sigmas(sampling_steps, shift): sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] sigma = shift * sigma / (1 + (shift - 1) * sigma) - return sigma - - -def retrieve_timesteps( - scheduler, - num_inference_steps=None, - device=None, - timesteps=None, - sigmas=None, - **kwargs, -): - if timesteps is not None and sigmas is not None: - raise ValueError( - "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" - ) - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. This determines the resolution of the diffusion process. - solver_order (`int`, defaults to 2): - The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored - and used in multistep updates. - prediction_type (`str`, defaults to "flow_prediction"): - Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts - the flow of the diffusion process. - shift (`float`, *optional*, defaults to 1.0): - A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling - process. - use_dynamic_shifting (`bool`, defaults to `False`): - Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is - applied on the fly. - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent - saturation and improve photorealism. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the - [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or - `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - euler_at_final (`bool`, defaults to `False`): - Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail - richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference - steps, but sometimes may result in blurring. - final_sigmas_type (`str`, *optional*, defaults to "zero"): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. - lambda_min_clipped (`float`, defaults to `-inf`): - Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the - cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - solver_order: int = 2, - prediction_type: str = "flow_prediction", - shift: Optional[float] = 1.0, - use_dynamic_shifting=False, - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", - lower_order_final: bool = True, - euler_at_final: bool = False, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" - lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - invert_sigmas: bool = False, - ): - if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) - - # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: - if algorithm_type == "deis": - self.register_to_config(algorithm_type="dpmsolver++") - else: - raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") - - if solver_type not in ["midpoint", "heun"]: - if solver_type in ["logrho", "bh1", "bh2"]: - self.register_to_config(solver_type="midpoint") - else: - raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") - - if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": - raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." - ) - - # setable values - self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() - sigmas = 1.0 - alphas - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore - - self.sigmas = sigmas - self.timesteps = sigmas * num_train_timesteps - - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - self._step_index = None - self._begin_index = None - - # self.sigmas = self.sigmas.to( - # "cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - - # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps - def set_timesteps( - self, - num_inference_steps: Union[int, None] = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[Union[float, None]] = None, - shift: Optional[Union[float, None]] = None, - ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - Args: - num_inference_steps (`int`): - Total number of the spacing of the time steps. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - - if self.config.use_dynamic_shifting and mu is None: - raise ValueError( - " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" - ) - - if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore - - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore - else: - if shift is None: - shift = self.config.shift - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore - - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - - timesteps = sigmas * self.config.num_train_timesteps - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - self._step_index = None - self._begin_index = None - # self.sigmas = self.sigmas.to( - # "cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = ( - torch.clamp(sample, -s, s) / s - ) # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma): - return sigma * self.config.num_train_timesteps - - def _sigma_to_alpha_sigma_t(self, sigma): - return 1 - sigma, sigma - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output - def convert_model_output( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - """ - Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is - designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an - integral of the data prediction model. - - The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise - prediction and data prediction models. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - Returns: - `torch.Tensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError("missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - x0_pred = self._threshold_sample(x0_pred) - - return x0_pred - - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - epsilon = sample - (1 - sigma_t) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - x0_pred = self._threshold_sample(x0_pred) - epsilon = model_output + x0_pred - - return epsilon - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update - def dpm_solver_first_order_update( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - noise: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - One step for the first-order DPMSolver (equivalent to DDIM). - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - x_t = ( - (sigma_t / sigma_s * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ( - (alpha_t / alpha_s) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t # pyright: ignore - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update - def multistep_dpm_solver_second_order_update( - self, - model_output_list: List[torch.Tensor], - *args, - sample: torch.Tensor = None, - noise: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - One step for the second-order multistep DPMSolver. - Args: - model_output_list (`List[torch.Tensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1 = ( - self.sigmas[self.step_index + 1], # pyright: ignore - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], # pyright: ignore - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - - m0, m1 = model_output_list[-1], model_output_list[-2] - - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0 * torch.exp(-h)) * sample - + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 - + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 - + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise - ) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise - ) - return x_t # pyright: ignore - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update - def multistep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.Tensor], - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - """ - One step for the third-order multistep DPMSolver. - Args: - model_output_list (`List[torch.Tensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): - A current instance of a sample created by diffusion process. - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - - timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError(" missing`sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( - self.sigmas[self.step_index + 1], # pyright: ignore - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], # pyright: ignore - self.sigmas[self.step_index - 2], # pyright: ignore - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) - - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - - h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m0 - D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t # pyright: ignore - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - def _init_step_index(self, timestep): - """ - Initialize the step_index counter for the scheduler. - """ - - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - - # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step - def step( - self, - model_output: torch.Tensor, - timestep: Union[int, torch.Tensor], - sample: torch.Tensor, - generator=None, - variance_noise: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep DPMSolver. - Args: - model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - variance_noise (`torch.Tensor`): - Alternative to generating noise with `generator` by directly providing the noise for the variance - itself. Useful for methods such as [`LEdits++`]. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Improve numerical stability for small number of steps - lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final - or (self.config.lower_order_final and len(self.timesteps) < 15) - or self.config.final_sigmas_type == "zero" - ) - lower_order_second = ( - (self.step_index == len(self.timesteps) - 2) - and self.config.lower_order_final - and len(self.timesteps) < 15 - ) - - model_output = self.convert_model_output(model_output, sample=sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - # Upcast to avoid precision issues when computing prev_sample - sample = sample.to(torch.float32) - if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: - noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 - ) - elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: - noise = variance_noise.to(device=model_output.device, dtype=torch.float32) # pyright: ignore - else: - noise = None - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, sample=sample, noise=noise - ) - else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # Cast sample back to expected dtype - prev_sample = prev_sample.to(model_output.dtype) - - # upon completion increase step index by one - self._step_index += 1 # pyright: ignore - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input - def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - Args: - sample (`torch.Tensor`): - The input sample. - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/src/lerobot/policies/fastwam/wan/utils/fm_solvers_unipc.py b/src/lerobot/policies/fastwam/wan/utils/fm_solvers_unipc.py deleted file mode 100644 index c96897704..000000000 --- a/src/lerobot/policies/fastwam/wan/utils/fm_solvers_unipc.py +++ /dev/null @@ -1,765 +0,0 @@ -# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py -# Convert unipc for flow matching -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import ( - KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput, -) -from diffusers.utils import deprecate, is_scipy_available - -if is_scipy_available(): - import scipy.stats - - -class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - solver_order (`int`, default `2`): - The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` - due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for - unconditional sampling. - prediction_type (`str`, defaults to "flow_prediction"): - Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts - the flow of the diffusion process. - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. - predict_x0 (`bool`, defaults to `True`): - Whether to use the updating algorithm on the predicted x0. - solver_type (`str`, default `bh2`): - Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` - otherwise. - lower_order_final (`bool`, default `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - disable_corrector (`list`, default `[]`): - Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` - and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is - usually disabled during the first few steps. - solver_p (`SchedulerMixin`, default `None`): - Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - use_exponential_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps, as required by some model families. - final_sigmas_type (`str`, defaults to `"zero"`): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - solver_order: int = 2, - prediction_type: str = "flow_prediction", - shift: Optional[float] = 1.0, - use_dynamic_shifting=False, - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: List[int] = [], - solver_p: SchedulerMixin = None, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" - ): - - if solver_type not in ["bh1", "bh2"]: - if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh2") - else: - raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") - - self.predict_x0 = predict_x0 - # setable values - self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() - sigmas = 1.0 - alphas - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore - - self.sigmas = sigmas - self.timesteps = sigmas * num_train_timesteps - - self.model_outputs = [None] * solver_order - self.timestep_list = [None] * solver_order - self.lower_order_nums = 0 - self.disable_corrector = disable_corrector - self.solver_p = solver_p - self.last_sample = None - self._step_index = None - self._begin_index = None - - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - - # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps - def set_timesteps( - self, - num_inference_steps: Union[int, None] = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[Union[float, None]] = None, - shift: Optional[Union[float, None]] = None, - ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - Args: - num_inference_steps (`int`): - Total number of the spacing of the time steps. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - - if self.config.use_dynamic_shifting and mu is None: - raise ValueError( - " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" - ) - - if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore - - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore - else: - if shift is None: - shift = self.config.shift - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore - - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - - timesteps = sigmas * self.config.num_train_timesteps - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - self.last_sample = None - if self.solver_p: - self.solver_p.set_timesteps(self.num_inference_steps, device=device) - - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = ( - torch.clamp(sample, -s, s) / s - ) # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma): - return sigma * self.config.num_train_timesteps - - def _sigma_to_alpha_sigma_t(self, sigma): - return 1 - sigma, sigma - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - def convert_model_output( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - r""" - Convert the model output to the corresponding type the UniPC algorithm needs. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.Tensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError("missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - - if self.predict_x0: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - x0_pred = self._threshold_sample(x0_pred) - - return x0_pred - else: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - epsilon = sample - (1 - sigma_t) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - x0_pred = self._threshold_sample(x0_pred) - epsilon = model_output + x0_pred - - return epsilon - - def multistep_uni_p_bh_update( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model at the current timestep. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - order (`int`): - The order of UniP at this timestep (corresponds to the *p* in UniPC-p). - - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if order is None: - if len(args) > 2: - order = args[2] - else: - raise ValueError(" missing `order` as a required keyward argument") - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - model_output_list = self.model_outputs - - s0 = self.timestep_list[-1] - m0 = model_output_list[-1] - x = sample - - if self.solver_p: - x_t = self.solver_p.step(model_output, s0, x).prev_sample - return x_t - - sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - - h = lambda_t - lambda_s0 - device = sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = self.step_index - i # pyright: ignore - mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) - lambda_si = torch.log(alpha_si) - torch.log(sigma_si) - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) # (B, K) - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) - else: - D1s = None - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore - else: - pred_res = 0 - x_t = x_t_ - alpha_t * B_h * pred_res - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore - else: - pred_res = 0 - x_t = x_t_ - sigma_t * B_h * pred_res - - x_t = x_t.to(x.dtype) - return x_t - - def multistep_uni_c_bh_update( - self, - this_model_output: torch.Tensor, - *args, - last_sample: torch.Tensor = None, - this_sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - One step for the UniC (B(h) version). - - Args: - this_model_output (`torch.Tensor`): - The model outputs at `x_t`. - this_timestep (`int`): - The current timestep `t`. - last_sample (`torch.Tensor`): - The generated sample before the last predictor `x_{t-1}`. - this_sample (`torch.Tensor`): - The generated sample after the last predictor `x_{t}`. - order (`int`): - The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. - - Returns: - `torch.Tensor`: - The corrected sample tensor at the current timestep. - """ - this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) - if last_sample is None: - if len(args) > 1: - last_sample = args[1] - else: - raise ValueError(" missing`last_sample` as a required keyward argument") - if this_sample is None: - if len(args) > 2: - this_sample = args[2] - else: - raise ValueError(" missing`this_sample` as a required keyward argument") - if order is None: - if len(args) > 3: - order = args[3] - else: - raise ValueError(" missing`order` as a required keyward argument") - if this_timestep is not None: - deprecate( - "this_timestep", - "1.0.0", - "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - model_output_list = self.model_outputs - - m0 = model_output_list[-1] - x = last_sample - x_t = this_sample - model_t = this_model_output - - sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - - h = lambda_t - lambda_s0 - device = this_sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = self.step_index - (i + 1) # pyright: ignore - mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) - lambda_si = torch.log(alpha_si) - torch.log(sigma_si) - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) - else: - D1s = None - - # for order 1, we use a simplified version - if order == 1: - rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) - x_t = x_t.to(x.dtype) - return x_t - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): - """ - Initialize the step_index counter for the scheduler. - """ - - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - - def step( - self, - model_output: torch.Tensor, - timestep: Union[int, torch.Tensor], - sample: torch.Tensor, - return_dict: bool = True, - generator=None, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep UniPC. - - Args: - model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - use_corrector = ( - self.step_index > 0 - and self.step_index - 1 not in self.disable_corrector - and self.last_sample is not None # pyright: ignore - ) - - model_output_convert = self.convert_model_output(model_output, sample=sample) - if use_corrector: - sample = self.multistep_uni_c_bh_update( - this_model_output=model_output_convert, - last_sample=self.last_sample, - this_sample=sample, - order=self.this_order, - ) - - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.timestep_list[i] = self.timestep_list[i + 1] - - self.model_outputs[-1] = model_output_convert - self.timestep_list[-1] = timestep # pyright: ignore - - if self.config.lower_order_final: - this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore - else: - this_order = self.config.solver_order - - self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep - assert self.this_order > 0 - - self.last_sample = sample - prev_sample = self.multistep_uni_p_bh_update( - model_output=model_output, # pass the original non-converted model output, in case solver-p is used - sample=sample, - order=self.this_order, - ) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # upon completion increase step index by one - self._step_index += 1 # pyright: ignore - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.Tensor`): - The input sample. - - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/src/lerobot/policies/fastwam/wan_adapters.py b/src/lerobot/policies/fastwam/wan_adapters.py index f281464a8..cf267a769 100644 --- a/src/lerobot/policies/fastwam/wan_adapters.py +++ b/src/lerobot/policies/fastwam/wan_adapters.py @@ -19,7 +19,7 @@ from typing import Any import torch -from .wan.modules.vae2_2 import Wan2_2_VAE +from .wan.modules.vae2_2 import Wan2VAE class WanVideoVAE38(torch.nn.Module): @@ -36,7 +36,7 @@ class WanVideoVAE38(torch.nn.Module): device: str | torch.device = "cuda", ) -> None: super().__init__() - self.wan_vae = Wan2_2_VAE(vae_pth=str(vae_pth), dtype=dtype, device=str(device)) + self.wan_vae = Wan2VAE(vae_pth=str(vae_pth), dtype=dtype, device=str(device)) self.model = self.wan_vae.model self.dtype = dtype self.device = torch.device(device) diff --git a/src/lerobot/policies/fastwam/wan_components.py b/src/lerobot/policies/fastwam/wan_components.py index a3041ddc7..7a321cb33 100644 --- a/src/lerobot/policies/fastwam/wan_components.py +++ b/src/lerobot/policies/fastwam/wan_components.py @@ -31,9 +31,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors" -WAN_T5_CHECKPOINT = "models_t5_umt5-xxl-enc-bf16.pth" +WAN_T5_SAFE_CHECKPOINT = "models_t5_umt5-xxl-enc-bf16.safetensors" +WAN_T5_CHECKPOINT = WAN_T5_SAFE_CHECKPOINT WAN_T5_TOKENIZER = "google/umt5-xxl" -WAN_VAE_CHECKPOINT = "Wan2.2_VAE.pth" +WAN_VAE_SAFE_CHECKPOINT = "Wan2.2_VAE.safetensors" +WAN_VAE_CHECKPOINT = WAN_VAE_SAFE_CHECKPOINT @dataclass(frozen=True) @@ -101,25 +103,24 @@ def resolve_wan_checkpoint_paths( root = Path(checkpoint_dir).expanduser() tokenizer_root = Path(tokenizer_dir).expanduser() if tokenizer_dir is not None else root dit = sorted(root.glob(WAN_DIT_PATTERN)) if load_dit else [] - vae = root / WAN_VAE_CHECKPOINT - text_encoder = root / WAN_T5_CHECKPOINT if load_text_encoder else None + vae = root / WAN_VAE_SAFE_CHECKPOINT + text_encoder = root / WAN_T5_SAFE_CHECKPOINT if load_text_encoder else None tokenizer = tokenizer_root / WAN_T5_TOKENIZER if load_text_encoder else None missing = [] if load_dit and len(dit) == 0: missing.append(f"DiT ({WAN_DIT_PATTERN})") if not vae.exists(): - missing.append(f"VAE ({WAN_VAE_CHECKPOINT})") + missing.append(f"VAE ({WAN_VAE_SAFE_CHECKPOINT})") if load_text_encoder: if text_encoder is None or not text_encoder.exists(): - missing.append(f"text encoder ({WAN_T5_CHECKPOINT})") + missing.append(f"text encoder ({WAN_T5_SAFE_CHECKPOINT})") if tokenizer is None or not tokenizer.exists(): missing.append(f"tokenizer ({WAN_T5_TOKENIZER})") if missing: raise FileNotFoundError( f"Incomplete Wan2.2 checkpoint directory {root}: missing {', '.join(missing)}." ) - return WanCheckpointPaths( root=root, dit=dit, @@ -158,7 +159,10 @@ def load_wan_text_encoder( dtype=torch_dtype, device=device, ) - state_dict = torch.load(checkpoint_path, map_location="cpu") + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.suffix != ".safetensors": + raise ValueError(f"Wan2.2 text encoder checkpoint must be safetensors, got {checkpoint_path}.") + state_dict = load_file(checkpoint_path) model.load_state_dict(state_dict) return model.to(device=device, dtype=torch_dtype) diff --git a/src/lerobot/policies/fastwam/wan_video_dit.py b/src/lerobot/policies/fastwam/wan_video_dit.py index 2d18389c7..6407d8d9a 100644 --- a/src/lerobot/policies/fastwam/wan_video_dit.py +++ b/src/lerobot/policies/fastwam/wan_video_dit.py @@ -64,6 +64,7 @@ def fastwam_masked_attention( v: torch.Tensor, num_heads: int, ctx_mask: torch.Tensor | None = None, + fp32_attention: bool = True, ) -> torch.Tensor: """FastWAM masked attention wrapper for MoT masks and CPU test coverage. @@ -77,9 +78,13 @@ def fastwam_masked_attention( q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) - q = q.float() - k = k.float() - v = v.float() + if fp32_attention: + q = q.float() + k = k.float() + v = v.float() + else: + q = q.to(dtype=v.dtype) + k = k.to(dtype=v.dtype) x = functional.scaled_dot_product_attention(q, k, v, attn_mask=ctx_mask) return rearrange(x, "b n s d -> b s (n d)", n=num_heads) @@ -88,13 +93,112 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): return x * (1 + scale) + shift +def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]: + from .wan.utils.fm_solvers import get_sampling_sigmas + + return get_sampling_sigmas(num_inference_steps, shift) + + +class WanContinuousFlowMatchScheduler: + """Continuous-time Flow-Matching scheduler with shift-based Wan sampling.""" + + def __init__(self, num_train_timesteps: int = 1000, shift: float = 5.0, eps: float = 1e-10): + if num_train_timesteps <= 0: + raise ValueError(f"`num_train_timesteps` must be positive, got {num_train_timesteps}") + if shift <= 0: + raise ValueError(f"`shift` must be positive, got {shift}") + self.num_train_timesteps = int(num_train_timesteps) + self.shift = float(shift) + self.eps = float(eps) + self._y_min, self._weight_norm_const = self._precompute_training_weight_stats() + + @staticmethod + def _phi(u: torch.Tensor, shift: float) -> torch.Tensor: + return shift * u / (1.0 + (shift - 1.0) * u) + + def _precompute_training_weight_stats(self) -> tuple[float, float]: + steps = self.num_train_timesteps + u_grid = torch.linspace(1.0, 0.0, steps + 1, dtype=torch.float64)[:-1] + t_grid = self._phi(u_grid, self.shift) * float(steps) + y_grid = torch.exp(-2.0 * ((t_grid - (steps / 2.0)) / steps) ** 2) + y_min = float(y_grid.min().item()) + y_shifted_grid = y_grid - y_min + norm_const = float(y_shifted_grid.mean().item()) + return y_min, norm_const + + def sample_training_t(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + if batch_size <= 0: + raise ValueError(f"`batch_size` must be positive, got {batch_size}") + u = torch.rand((batch_size,), device=device, dtype=torch.float32) + sigma = self._phi(u, self.shift) + timestep = sigma * float(self.num_train_timesteps) + return timestep.to(dtype=dtype) + + def training_weight(self, timestep: torch.Tensor) -> torch.Tensor: + t = timestep.to(dtype=torch.float32) + steps = float(self.num_train_timesteps) + y = torch.exp(-2.0 * ((t - (steps / 2.0)) / steps) ** 2) + y_shifted = y - self._y_min + weight = y_shifted / (self._weight_norm_const + self.eps) + if weight.numel() == 1: + return weight.reshape(()) + return weight + + def add_noise( + self, original_samples: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor + ) -> torch.Tensor: + sigma = (timestep / float(self.num_train_timesteps)).to( + original_samples.device, dtype=original_samples.dtype + ) + if sigma.ndim == 0: + return (1 - sigma) * original_samples + sigma * noise + sigma = sigma.view(-1, *([1] * (original_samples.ndim - 1))) + return (1 - sigma) * original_samples + sigma * noise + + @staticmethod + def training_target(sample: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: + del timestep + return noise - sample + + def build_inference_schedule( + self, + num_inference_steps: int, + device: torch.device, + dtype: torch.dtype, + shift_override: float | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if num_inference_steps <= 0: + raise ValueError(f"`num_inference_steps` must be positive, got {num_inference_steps}") + shift = self.shift if shift_override is None else float(shift_override) + if shift <= 0: + raise ValueError(f"`shift` must be positive, got {shift}") + + sigma_steps = torch.as_tensor( + _get_wan_sampling_sigmas(num_inference_steps, shift), + device=device, + dtype=torch.float32, + ) + timesteps = sigma_steps * float(self.num_train_timesteps) + sigma_next = torch.cat([sigma_steps[1:], sigma_steps.new_zeros(1)]) + deltas = sigma_next - sigma_steps + return timesteps.to(dtype=dtype), deltas.to(dtype=dtype) + + @staticmethod + def step(model_output: torch.Tensor, delta: torch.Tensor, sample: torch.Tensor) -> torch.Tensor: + delta = delta.to(sample.device, dtype=sample.dtype) + if delta.ndim == 0: + return sample + model_output * delta + delta = delta.view(-1, *([1] * (sample.ndim - 1))) + return sample + model_output * delta + + def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): return rope_params(end, dim, theta) def apply_dense_rope(x: torch.Tensor, freqs: torch.Tensor, num_heads: int) -> torch.Tensor: x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) - x_out = torch.view_as_complex(x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)) + x_out = torch.view_as_complex(x.to(torch.float32).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)) freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) @@ -144,7 +248,15 @@ def create_group_causal_attn_mask( class FastWAMAttentionBlock(WanAttentionBlock): """Wan attention block with FastWAM's arbitrary boolean mask support.""" - def __init__(self, hidden_dim: int, attn_head_dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): + def __init__( + self, + hidden_dim: int, + attn_head_dim: int, + num_heads: int, + ffn_dim: int, + eps: float = 1e-6, + fp32_attention: bool = True, + ): attention_dim = attn_head_dim * num_heads if hidden_dim == attention_dim: super().__init__( @@ -177,6 +289,7 @@ class FastWAMAttentionBlock(WanAttentionBlock): ) self.modulation = nn.Parameter(torch.randn(1, 6, hidden_dim) / hidden_dim**0.5) self.attn_head_dim = attn_head_dim + self.fp32_attention = bool(fp32_attention) @staticmethod def split_modulation(block, t_mod: torch.Tensor): @@ -231,7 +344,14 @@ class FastWAMAttentionBlock(WanAttentionBlock): q = attn.norm_q(attn.q(x)).view(b, -1, n * d) k = attn.norm_k(attn.k(context)).view(b, -1, n * d) v = attn.v(context).view(b, -1, n * d) - x = fastwam_masked_attention(q=q, k=k, v=v, num_heads=n, ctx_mask=context_mask) + x = fastwam_masked_attention( + q=q, + k=k, + v=v, + num_heads=n, + ctx_mask=context_mask, + fp32_attention=self.fp32_attention, + ) return attn.o(_linear_input(attn.o, x)) def project_self_attention_output(self, x: torch.Tensor) -> torch.Tensor: @@ -259,7 +379,14 @@ class FastWAMAttentionBlock(WanAttentionBlock): residual_x = x attn_input = modulate(self.apply_norm1(x), shift_msa, scale_msa) q, k, v = self.project_self_attention(attn_input, freqs) - y = fastwam_masked_attention(q=q, k=k, v=v, num_heads=self.num_heads, ctx_mask=self_attn_mask) + y = fastwam_masked_attention( + q=q, + k=k, + v=v, + num_heads=self.num_heads, + ctx_mask=self_attn_mask, + fp32_attention=self.fp32_attention, + ) x = residual_x + gate_msa * self.project_self_attention_output(y) x = x + self.apply_cross_attention(self.apply_norm3(x), context, context_mask=context_mask) mlp_input = modulate(self.apply_norm2(x), shift_mlp, scale_mlp) @@ -308,6 +435,7 @@ class WanVideoDiT(WanModel): action_group_causal_mask_mode="causal", video_attention_mask_mode: str = "bidirectional", use_gradient_checkpointing: bool = False, + fp32_attention: bool = True, ): del in_dim_control_adapter if has_image_input: @@ -353,6 +481,7 @@ class WanVideoDiT(WanModel): num_heads=num_heads, ffn_dim=ffn_dim, eps=eps, + fp32_attention=fp32_attention, ) for _ in range(num_layers) ] @@ -366,6 +495,7 @@ class WanVideoDiT(WanModel): self.video_attention_mask_mode = str(video_attention_mask_mode) self.action_conditioned = action_conditioned self.action_dim = action_dim + self.fp32_attention = bool(fp32_attention) if self.action_conditioned: self.action_embedding = torch.nn.Linear(action_dim, hidden_dim) @@ -667,6 +797,7 @@ class WanVideoDiT(WanModel): __all__ = [ "FastWAMAttentionBlock", + "WanContinuousFlowMatchScheduler", "WanVideoDiT", "apply_dense_rope", "create_group_causal_attn_mask", diff --git a/tests/policies/fastwam/test_fastwam_compliance.py b/tests/policies/fastwam/test_fastwam_compliance.py deleted file mode 100644 index c58ddb98c..000000000 --- a/tests/policies/fastwam/test_fastwam_compliance.py +++ /dev/null @@ -1,254 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import typing -from pathlib import Path - -import pytest -import torch -from torch import nn - -from lerobot.configs import FeatureType, PolicyFeature -from lerobot.policies.fastwam.configuration_fastwam import FastWAMConfig -from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy -from lerobot.policies.fastwam.processor_fastwam import make_fastwam_pre_post_processors -from lerobot.utils.constants import OBS_STATE - -ROOT = Path(__file__).resolve().parents[3] - - -def test_package_init_exports_required_symbols(): - init_source = (ROOT / "src" / "lerobot" / "policies" / "fastwam" / "__init__.py").read_text() - - assert "FastWAMConfig" in init_source - assert "make_fastwam_pre_post_processors" in init_source - - -def test_policy_config_is_exported_from_public_policies_package(): - import lerobot.policies as policies - - assert policies.FastWAMConfig is FastWAMConfig - assert "FastWAMConfig" in policies.__all__ - - -def test_fastwam_policy_docs_are_registered(): - readme_path = ROOT / "src" / "lerobot" / "policies" / "fastwam" / "README.md" - wan_readme_path = ROOT / "src" / "lerobot" / "policies" / "fastwam" / "wan" / "README.md" - policy_readme_path = ROOT / "docs" / "source" / "policy_fastwam_README.md" - guide_path = ROOT / "docs" / "source" / "fastwam.mdx" - toctree_path = ROOT / "docs" / "source" / "_toctree.yml" - - assert readme_path.is_symlink() - assert readme_path.resolve() == policy_readme_path.resolve() - assert wan_readme_path.exists() - wan_readme = wan_readme_path.read_text() - assert "Wan-Video/Wan2.2" in wan_readme - assert "42bf4cfaa384bc21833865abc2f9e6c0e67233dc" in wan_readme - assert policy_readme_path.exists() - assert guide_path.exists() - assert "local: fastwam" in toctree_path.read_text() - - -def test_wan_backbone_code_is_isolated_from_lerobot_adapter(): - wan_dir = ROOT / "src" / "lerobot" / "policies" / "fastwam" / "wan" - - assert (wan_dir / "modules" / "attention.py").exists() - assert (wan_dir / "modules" / "model.py").exists() - assert (wan_dir / "modules" / "t5.py").exists() - assert (wan_dir / "modules" / "tokenizers.py").exists() - assert (wan_dir / "modules" / "vae2_1.py").exists() - assert (wan_dir / "modules" / "vae2_2.py").exists() - assert (wan_dir / "utils" / "fm_solvers.py").exists() - assert (wan_dir / "utils" / "fm_solvers_unipc.py").exists() - - assert (wan_dir.parent / "wan_video_dit.py").exists() - assert (wan_dir.parent / "wan_adapters.py").exists() - assert (wan_dir.parent / "wan_components.py").exists() - assert not (wan_dir / "wan_video_dit.py").exists() - assert not (wan_dir / "wan_adapters.py").exists() - assert not (wan_dir / "wan_components.py").exists() - - -def test_fastwam_text_encoder_uses_upstream_wan_modules_directly(): - fastwam_dir = ROOT / "src" / "lerobot" / "policies" / "fastwam" - modular_source = (fastwam_dir / "modular_fastwam.py").read_text() - components_source = (fastwam_dir / "wan_components.py").read_text() - - assert not (fastwam_dir / "wan_video_text_encoder.py").exists() - assert "from .wan.modules.t5 import umt5_xxl" in components_source - assert "from .wan.modules.tokenizers import HuggingfaceTokenizer" in components_source - assert "WAN_T5_ENCODER_KWARGS" not in components_source - assert "wan_video_text_encoder" not in modular_source - - -def test_fastwam_vae_reuses_upstream_wan_modules(): - fastwam_dir = ROOT / "src" / "lerobot" / "policies" / "fastwam" - vae_source = (fastwam_dir / "wan_adapters.py").read_text() - - assert not (fastwam_dir / "wan_video_vae.py").exists() - assert "from .wan.modules.vae2_2 import Wan2_2_VAE" in vae_source - assert "mean = [" not in vae_source - assert "std = [" not in vae_source - assert "class Encoder3d_38" not in vae_source - assert "class Decoder3d_38" not in vae_source - assert "class VideoVAE38_" not in vae_source - - -def test_fastwam_component_loading_uses_fixed_wan_checkpoint_layout(): - modular_source = (ROOT / "src" / "lerobot" / "policies" / "fastwam" / "modular_fastwam.py").read_text() - modeling_source = (ROOT / "src" / "lerobot" / "policies" / "fastwam" / "modeling_fastwam.py").read_text() - components_source = (ROOT / "src" / "lerobot" / "policies" / "fastwam" / "wan_components.py").read_text() - - assert "class ModelConfig" not in modular_source - assert "def load_state_dict" not in modular_source - assert "WAN22_MODEL_REGISTRY" not in modular_source - assert "class ModelConfig" not in components_source - assert "class WanComponentSource" not in components_source - assert "def load_state_dict" not in components_source - assert "WAN22_MODEL_REGISTRY" not in components_source - assert "hash_model_file" not in components_source - assert "_resolve_component_sources" not in components_source - assert "origin_file_pattern" not in components_source - assert "inspect.signature" not in components_source - assert "class FastWAMWanComponentPaths" not in modeling_source - assert "def _first_existing" not in modeling_source - assert "def _missing_wan_component_names" not in modeling_source - assert "WAN_T5_CHECKPOINT" in components_source - assert "WAN_VAE_CHECKPOINT" in components_source - assert "WAN_DIT_PATTERN" in components_source - - -def test_fastwam_dit_reuses_upstream_wan_primitives(): - dit_source = (ROOT / "src" / "lerobot" / "policies" / "fastwam" / "wan_video_dit.py").read_text() - - assert "from .wan.modules.model import" in dit_source - assert "WanModel" in dit_source - for duplicated_symbol in [ - "def flash_attention(", - "def sinusoidal_embedding_1d(", - "def rope_apply(", - "def unpatchify(", - "def _dense_video_freqs(", - "class RMSNorm(", - "class SelfAttention(", - "class CrossAttention(", - "class Head(", - ]: - assert duplicated_symbol not in dit_source - - -def test_fastwam_inference_schedule_reuses_upstream_wan_sigmas(): - modular_source = (ROOT / "src" / "lerobot" / "policies" / "fastwam" / "modular_fastwam.py").read_text() - - assert "def _get_wan_sampling_sigmas" in modular_source - assert "from .wan.utils.fm_solvers import get_sampling_sigmas" in modular_source - assert "_get_wan_sampling_sigmas(num_inference_steps, shift)" in modular_source - - -def test_policy_config_rejects_missing_required_image_and_action_features(): - with pytest.raises(ValueError, match="image feature"): - FastWAMConfig( - input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))}, - ) - - with pytest.raises(ValueError, match="action"): - FastWAMConfig( - output_features={"not_action": PolicyFeature(type=FeatureType.ACTION, shape=(7,))}, - ) - - -def test_policy_init_calls_validate_features_even_for_prebuilt_configs(monkeypatch): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) - calls = [] - - def record_validate_features(): - calls.append("called") - - monkeypatch.setattr(cfg, "validate_features", record_validate_features) - monkeypatch.setattr( - FastWAMPolicy, - "_build_core_model", - lambda self, config: nn.Linear(1, 1), - ) - FastWAMPolicy(cfg) - - assert calls == ["called"] - - -def test_required_policy_entrypoints_exist_with_discoverable_names(): - assert FastWAMPolicy.config_class is FastWAMConfig - assert FastWAMPolicy.name == "fastwam" - assert callable(FastWAMPolicy.reset) - assert callable(FastWAMPolicy.get_optim_params) - assert callable(FastWAMPolicy.predict_action_chunk) - assert callable(FastWAMPolicy.select_action) - assert callable(FastWAMPolicy.forward) - assert callable(make_fastwam_pre_post_processors) - assert make_fastwam_pre_post_processors.__name__ == "make_fastwam_pre_post_processors" - - -def test_policy_constructor_and_forward_match_byo_template_contract(): - init_signature = inspect.signature(FastWAMPolicy.__init__) - - assert "dataset_stats" in init_signature.parameters - assert "core_model" not in init_signature.parameters - assert typing.get_type_hints(FastWAMPolicy.forward)["return"] == dict[str, torch.Tensor] - - -def test_saved_config_round_trips_policy_features(tmp_path): - cfg = FastWAMConfig(action_dim=7, proprio_dim=8, image_size=(224, 448)) - cfg.save_pretrained(tmp_path) - - loaded = FastWAMConfig.from_pretrained(tmp_path) - - assert loaded.type == "fastwam" - assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL - assert loaded.action_feature.shape == (7,) - assert loaded.robot_state_feature.shape == (8,) - - -def test_config_from_pretrained_ignores_unknown_fields(tmp_path): - cfg = FastWAMConfig() - cfg.save_pretrained(tmp_path) - config_path = tmp_path / "config.json" - payload = config_path.read_text() - payload = payload.replace( - '"torch_dtype": "bfloat16"', - '"torch_dtype": "bfloat16",\n "unknown_fastwam_field": true', - ) - config_path.write_text(payload) - - loaded = FastWAMConfig.from_pretrained(tmp_path) - - assert loaded.type == "fastwam" - assert not hasattr(loaded, "unknown_fastwam_field") - - -def test_config_from_pretrained_does_not_use_non_wan22_tokenizer_repo_id(tmp_path): - cfg = FastWAMConfig() - cfg.save_pretrained(tmp_path) - config_path = tmp_path / "config.json" - payload = config_path.read_text() - payload = payload.replace( - '"tokenizer_model_id": "Wan-AI/Wan2.2-TI2V-5B"', - '"tokenizer_model_id": "somebody/old-tokenizer"', - ) - config_path.write_text(payload) - - loaded = FastWAMConfig.from_pretrained(tmp_path) - - assert loaded.tokenizer_model_id == "Wan-AI/Wan2.2-TI2V-5B" diff --git a/tests/policies/fastwam/test_fastwam_factory.py b/tests/policies/fastwam/test_fastwam_factory.py deleted file mode 100644 index f253f6995..000000000 --- a/tests/policies/fastwam/test_fastwam_factory.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch - -from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors - - -def test_fastwam_is_registered_in_policy_factory(): - from lerobot.policies.fastwam.configuration_fastwam import FastWAMConfig - from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy - - cfg = make_policy_config("fastwam", action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) - - assert isinstance(cfg, FastWAMConfig) - assert cfg.type == "fastwam" - assert get_policy_class("fastwam") is FastWAMPolicy - - -def test_fastwam_pre_post_processors_are_available(): - cfg = make_policy_config("fastwam", action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) - - preprocessor, postprocessor = make_pre_post_processors(cfg) - - assert preprocessor.name == "policy_preprocessor" - assert postprocessor.name == "policy_postprocessor" - - -def test_fastwam_postprocessor_only_adds_action_inversion_when_configured(): - from lerobot.policies.fastwam.processor_fastwam import ( - FastWAMActionInversionProcessorStep, - FastWAMActionToggleProcessorStep, - ) - - default_cfg = make_policy_config( - "fastwam", action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2 - ) - _, default_postprocessor = make_pre_post_processors(default_cfg) - - assert any(isinstance(step, FastWAMActionToggleProcessorStep) for step in default_postprocessor.steps) - assert not any( - isinstance(step, FastWAMActionInversionProcessorStep) for step in default_postprocessor.steps - ) - - inverted_cfg = make_policy_config( - "fastwam", - action_dim=3, - proprio_dim=2, - action_horizon=4, - n_action_steps=2, - toggle_action_dimensions=[], - invert_dimensions=[-1], - ) - _, inverted_postprocessor = make_pre_post_processors(inverted_cfg) - - assert any(isinstance(step, FastWAMActionInversionProcessorStep) for step in inverted_postprocessor.steps) - - -def test_fastwam_action_inversion_processor_flips_configured_dimensions(): - from lerobot.policies.fastwam.processor_fastwam import FastWAMActionInversionProcessorStep - - processor = FastWAMActionInversionProcessorStep(invert_dimensions=[0, -1]) - action = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - - processed = processor.action(action) - - assert torch.equal(processed, torch.tensor([[-1.0, 2.0, -3.0], [-4.0, 5.0, -6.0]])) - assert torch.equal(action, torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - - -def test_fastwam_rejects_non_wan22_hub_model_ids(): - from lerobot.policies.fastwam.configuration_fastwam import FastWAMConfig - - with pytest.raises(ValueError, match="model_id"): - FastWAMConfig(model_id="somebody/other-model") diff --git a/tests/policies/fastwam/test_fastwam_policy.py b/tests/policies/fastwam/test_fastwam_policy.py index b7133674b..d750fd5f8 100644 --- a/tests/policies/fastwam/test_fastwam_policy.py +++ b/tests/policies/fastwam/test_fastwam_policy.py @@ -14,21 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + +import pytest import torch from safetensors.torch import save_model from torch import nn +from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig +from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors from lerobot.policies.fastwam import modeling_fastwam -from lerobot.policies.fastwam.configuration_fastwam import FastWAMConfig -from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy -from lerobot.policies.fastwam.modular_fastwam import ActionDiT, MoT -from lerobot.policies.fastwam.wan_video_dit import ( - FastWAMAttentionBlock, - WanVideoDiT, - fastwam_masked_attention, - precompute_freqs_cis, +from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy, resolve_wan_component_paths +from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep +from lerobot.policies.fastwam.wan_components import ( + WAN_DIT_PATTERN, + WAN_T5_CHECKPOINT, + WAN_T5_TOKENIZER, + WAN_VAE_CHECKPOINT, + resolve_wan_checkpoint_paths, ) -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION, OBS_STATE class FakeFastWAMCore(nn.Module): @@ -39,293 +44,117 @@ class FakeFastWAMCore(nn.Module): def training_loss(self, sample): assert sample["video"].ndim == 5 assert sample["context"].ndim == 3 - return sample["action"].sum() * 0.0 + torch.tensor(1.0), {"loss_action": 1.0} + return sample[ACTION].sum() * 0.0 + torch.tensor(1.0), {"loss_action": 1.0} def infer_action(self, **kwargs): - horizon = kwargs["action_horizon"] - return {"action": torch.ones(horizon, 3)} + return {"action": torch.ones(1, kwargs["action_horizon"], 3)} -def _patch_core_builder(monkeypatch): - monkeypatch.setattr( - FastWAMPolicy, - "_build_core_model", - lambda self, config: FakeFastWAMCore(), - ) - - -def test_action_attention_block_supports_mot_attention_dim_larger_than_hidden_dim(): - block = FastWAMAttentionBlock(hidden_dim=16, attn_head_dim=8, num_heads=4, ffn_dim=32) - x = torch.zeros(1, 2, 16) - context = torch.zeros(1, 3, 16) - t_mod = torch.zeros(1, 6, 16) - freqs = precompute_freqs_cis(8, end=2).view(2, 1, -1) - - output = block(x, context, t_mod, freqs) - - assert output.shape == x.shape - assert block.self_attn.q.out_features == 32 - assert block.self_attn.o.out_features == 16 - - -def test_fastwam_masked_attention_accepts_rope_float32_qk_with_bfloat16_values(): - q = torch.zeros(1, 2, 32, dtype=torch.float32) - k = torch.zeros(1, 2, 32, dtype=torch.float32) - v = torch.zeros(1, 2, 32, dtype=torch.bfloat16) - - out = fastwam_masked_attention(q=q, k=k, v=v, num_heads=4) - - assert out.dtype == torch.float32 - assert out.shape == v.shape - - -def test_fastwam_masked_attention_runs_fp32_when_cache_promotes_keys(): - q = torch.zeros(1, 2, 32, dtype=torch.bfloat16) - k = torch.zeros(1, 4, 32, dtype=torch.float32) - v = torch.zeros(1, 4, 32, dtype=torch.bfloat16) - mask = torch.ones(2, 4, dtype=torch.bool) - - out = fastwam_masked_attention(q=q, k=k, v=v, num_heads=4, ctx_mask=mask) - - assert out.dtype == torch.float32 - assert out.shape == q.shape - - -def test_attention_post_projection_casts_fp32_attention_to_block_dtype(): - block = FastWAMAttentionBlock(hidden_dim=16, attn_head_dim=8, num_heads=4, ffn_dim=32).to( - dtype=torch.bfloat16 - ) - residual = torch.zeros(1, 2, 16, dtype=torch.bfloat16) - mixed_attn = torch.zeros(1, 2, 32, dtype=torch.float32) - gate_msa = torch.ones(1, 16, dtype=torch.bfloat16) - shift_mlp = torch.zeros(1, 16, dtype=torch.bfloat16) - scale_mlp = torch.zeros(1, 16, dtype=torch.bfloat16) - gate_mlp = torch.zeros(1, 16, dtype=torch.bfloat16) - - out = MoT._apply_expert_post_block( - block=block, - residual_x=residual, - mixed_attn_out=mixed_attn, - gate_msa=gate_msa, - shift_mlp=shift_mlp, - scale_mlp=scale_mlp, - gate_mlp=gate_mlp, - context_payload=None, - ) - - assert out.dtype == torch.bfloat16 - assert out.shape == residual.shape - - -def test_attention_cross_projection_casts_fp32_attention_to_block_dtype(): - block = FastWAMAttentionBlock(hidden_dim=16, attn_head_dim=8, num_heads=4, ffn_dim=32).to( - dtype=torch.bfloat16 - ) - x = torch.zeros(1, 2, 16, dtype=torch.bfloat16) - context = torch.zeros(1, 3, 16, dtype=torch.bfloat16) - - out = block.apply_cross_attention(x, context) - - assert out.dtype == torch.bfloat16 - assert out.shape == x.shape - - -def test_attention_norm3_handles_bfloat16_affine_weights(): - block = FastWAMAttentionBlock(hidden_dim=16, attn_head_dim=8, num_heads=4, ffn_dim=32).to( - dtype=torch.bfloat16 - ) - x = torch.zeros(1, 2, 16, dtype=torch.bfloat16) - - out = block.apply_norm3(x) - - assert out.dtype == torch.bfloat16 - assert out.shape == x.shape - - -def test_attention_post_block_handles_bfloat16_cross_attention_norm(): - block = FastWAMAttentionBlock(hidden_dim=16, attn_head_dim=8, num_heads=4, ffn_dim=32).to( - dtype=torch.bfloat16 - ) - residual = torch.zeros(1, 2, 16, dtype=torch.bfloat16) - mixed_attn = torch.zeros(1, 2, 32, dtype=torch.float32) - gate_msa = torch.ones(1, 16, dtype=torch.bfloat16) - shift_mlp = torch.zeros(1, 16, dtype=torch.bfloat16) - scale_mlp = torch.zeros(1, 16, dtype=torch.bfloat16) - gate_mlp = torch.zeros(1, 16, dtype=torch.bfloat16) - context_payload = {"context": torch.zeros(1, 3, 16, dtype=torch.bfloat16), "mask": None} - - out = MoT._apply_expert_post_block( - block=block, - residual_x=residual, - mixed_attn_out=mixed_attn, - gate_msa=gate_msa, - shift_mlp=shift_mlp, - scale_mlp=scale_mlp, - gate_mlp=gate_mlp, - context_payload=context_payload, - ) - - assert out.dtype == torch.bfloat16 - assert out.shape == residual.shape - - -def test_video_dit_pre_dit_casts_double_latents_to_model_dtype(): - model = WanVideoDiT( - hidden_dim=4, - in_dim=48, - ffn_dim=8, - out_dim=48, - text_dim=6, - freq_dim=4, - eps=1e-6, - patch_size=(1, 2, 2), - num_heads=1, - attn_head_dim=4, - num_layers=0, - seperated_timestep=True, - fuse_vae_embedding_in_latents=True, - video_attention_mask_mode="first_frame_causal", - ).to(dtype=torch.bfloat16) - - state = model.pre_dit( - x=torch.zeros(1, 48, 1, 2, 2, dtype=torch.float64), - timestep=torch.zeros(1, dtype=torch.float64), - context=torch.zeros(1, 2, 6, dtype=torch.float64), - fuse_vae_embedding_in_latents=True, - ) - - assert state["tokens"].dtype == torch.bfloat16 - assert state["context"].dtype == torch.bfloat16 - assert state["t_mod"].dtype == torch.bfloat16 - - -def test_action_dit_pre_dit_casts_double_inputs_to_model_dtype(): - model = ActionDiT( - hidden_dim=16, +def test_fastwam_is_registered_and_publicly_exported(): + cfg = make_policy_config( + "fastwam", action_dim=3, - ffn_dim=32, - text_dim=6, - freq_dim=4, - eps=1e-6, - num_heads=4, - attn_head_dim=8, - num_layers=0, - ).to(dtype=torch.bfloat16) - - state = model.pre_dit( - action_tokens=torch.zeros(1, 2, 3, dtype=torch.float64), - timestep=torch.zeros(1, dtype=torch.float64), - context=torch.zeros(1, 2, 6, dtype=torch.float64), + proprio_dim=2, + action_horizon=4, + n_action_steps=2, + base_model_id=None, ) - assert state["tokens"].dtype == torch.bfloat16 - assert state["context"].dtype == torch.bfloat16 - assert state["t_mod"].dtype == torch.bfloat16 + assert isinstance(cfg, FastWAMConfig) + assert cfg.type == "fastwam" + assert get_policy_class("fastwam") is FastWAMPolicy -def test_forward_adapts_lerobot_batch_to_fastwam_sample(monkeypatch): - _patch_core_builder(monkeypatch) - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) - policy = FastWAMPolicy(cfg) - batch = { - "observation.images.image": torch.zeros(1, 3, 16, 16), - "observation.state": torch.zeros(1, 2), - "action": torch.zeros(1, 4, 3), - "context": torch.zeros(1, 5, 4096), - "context_mask": torch.ones(1, 5, dtype=torch.bool), - } +def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path): + cfg = FastWAMConfig() + cfg.save_pretrained(tmp_path) + saved = json.loads((tmp_path / "config.json").read_text()) - output = policy.forward(batch) - - assert set(output) == {"loss", "loss_action"} - assert output["loss"].item() == 1.0 - assert output["loss_action"].item() == 1.0 + assert saved["pretrained_path"] is None + assert cfg.image_features["observation.images.image"].type == FeatureType.VISUAL + assert cfg.action_feature.shape == (7,) + assert cfg.robot_state_feature.shape == (8,) + with pytest.raises(ValueError, match="image feature"): + FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))}) + with pytest.raises(ValueError, match="tokenizer_model_id"): + FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer") -def test_get_optim_params_returns_lerobot_optimizer_dict(monkeypatch): - _patch_core_builder(monkeypatch) - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) - policy = FastWAMPolicy(cfg) - - optim_params = policy.get_optim_params() - - assert isinstance(optim_params, dict) - assert set(optim_params) == {"params"} - assert list(optim_params["params"]) - - -def test_select_action_uses_action_queue(monkeypatch): - _patch_core_builder(monkeypatch) - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) - policy = FastWAMPolicy(cfg) - batch = { - "input_image": torch.zeros(1, 3, 16, 16), - "observation.state": torch.zeros(1, 2), - "context": torch.zeros(1, 5, 4096), - "context_mask": torch.ones(1, 5, dtype=torch.bool), - } - - first = policy.select_action(batch) - second = policy.select_action(batch) - - assert first.shape == (1, 3) - assert second.shape == (1, 3) - - -def test_predict_action_prepares_lerobot_libero_observation(monkeypatch): - captured = {} - - class CapturingCore(FakeFastWAMCore): - def infer_action(self, **kwargs): - captured.update(kwargs) - return {"action": torch.ones(1, 4, 3)} - - monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CapturingCore()) +def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_path): cfg = FastWAMConfig( action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, - image_size=(16, 32), + image_size=(2, 2), + device="cpu", + toggle_action_dimensions=[-1], input_features={ - "observation.images.image": {"type": "VISUAL", "shape": (3, 16, 32)}, - "observation.state": {"type": "STATE", "shape": (2,)}, + "observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 2, 2)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)), }, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + base_model_id=None, ) - policy = FastWAMPolicy(cfg) - batch = { - "observation.images.image": torch.ones(1, 3, 20, 20), - "observation.images.image2": torch.zeros(1, 3, 20, 20), - "observation.state": torch.zeros(1, 2), - "task": ["pick up the bowl"], + dataset_stats = { + "observation.images.image": { + "mean": torch.full((3, 1, 1), 0.2), + "std": torch.full((3, 1, 1), 0.1), + }, + OBS_STATE: { + "mean": torch.tensor([1.0, 3.0]), + "std": torch.tensor([2.0, 4.0]), + }, + ACTION: { + "mean": torch.zeros(3), + "std": torch.ones(3), + }, } - action = policy.predict_action_chunk(batch) + preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats) + processed = preprocessor( + { + "observation.images.image": torch.tensor( + [ + [[0.0, 0.5], [1.0, 0.5]], + [[0.0, 0.5], [1.0, 0.5]], + [[0.0, 0.5], [1.0, 0.5]], + ] + ), + OBS_STATE: torch.tensor([3.0, 7.0]), + } + ) + preprocessor.save_pretrained(tmp_path, config_filename="policy_preprocessor.json") + postprocessor.save_pretrained(tmp_path, config_filename="policy_postprocessor.json") + _, loaded_postprocessor = make_pre_post_processors(cfg, pretrained_path=str(tmp_path)) - assert action.shape == (1, 4, 3) - assert captured["prompt"] == [cfg.prompt_template.format(task="pick up the bowl")] - assert tuple(captured["input_image"].shape) == (1, 3, 16, 32) - assert captured["input_image"].amin().item() == -1.0 - assert captured["input_image"].amax().item() == 1.0 - assert "num_video_frames" not in captured + expected_image = torch.tensor( + [[[[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]]]] + ) + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + assert torch.allclose(processed["observation.images.image"], expected_image) + assert torch.allclose(processed[OBS_STATE], torch.tensor([[1.0, 1.0]])) + assert torch.equal(dataset_stats["observation.images.image"]["mean"], torch.full((3, 1, 1), 0.2)) + assert any(isinstance(step, FastWAMActionToggleProcessorStep) for step in loaded_postprocessor.steps) + assert torch.equal( + loaded_postprocessor(torch.tensor([[0.25, 0.5, 1.0]])), torch.tensor([[0.25, 0.5, -1.0]]) + ) -def test_predict_action_splits_parallel_eval_batch_into_single_infer_calls(monkeypatch): +def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch): captured = [] class CapturingCore(FakeFastWAMCore): def infer_action(self, **kwargs): captured.append( { - "input_image_shape": tuple(kwargs["input_image"].shape), - "input_image_sum": float(kwargs["input_image"].sum()), + "image_shape": tuple(kwargs["input_image"].shape), "proprio_shape": tuple(kwargs["proprio"].shape), - "proprio_sum": float(kwargs["proprio"].sum()), "prompt": kwargs["prompt"], } ) - action = torch.full((1, kwargs["action_horizon"], 3), float(len(captured))) - return {"action": action} + return {"action": torch.full((1, kwargs["action_horizon"], 3), float(len(captured)))} monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CapturingCore()) cfg = FastWAMConfig( @@ -335,42 +164,54 @@ def test_predict_action_splits_parallel_eval_batch_into_single_infer_calls(monke n_action_steps=2, image_size=(16, 16), input_features={ - "observation.images.image": {"type": "VISUAL", "shape": (3, 16, 16)}, - "observation.state": {"type": "STATE", "shape": (2,)}, + "observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)), }, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + base_model_id=None, ) - policy = FastWAMPolicy(cfg) - batch = { - "observation.images.image": torch.stack( - [ - torch.zeros(3, 16, 16), - torch.ones(3, 16, 16), - torch.full((3, 16, 16), 2.0), - ] - ), - "observation.state": torch.tensor([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]), - "task": ["task 0", "task 1", "task 2"], - } + with pytest.warns(RuntimeWarning, match="does not load pretrained FastWAM weights"): + policy = FastWAMPolicy(cfg) - action = policy.predict_action_chunk(batch) + output = policy.forward( + { + "observation.images.image": torch.zeros(1, 3, 16, 16), + OBS_STATE: torch.zeros(1, 2), + ACTION: torch.zeros(1, 4, 3), + "context": torch.zeros(1, 5, 4096), + "context_mask": torch.ones(1, 5, dtype=torch.bool), + } + ) + action = policy.predict_action_chunk( + { + "observation.images.image": torch.stack( + [ + torch.zeros(3, 16, 16), + torch.ones(3, 16, 16), + ] + ), + OBS_STATE: torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + "task": ["task 0", "task 1"], + } + ) - assert action.shape == (3, 4, 3) - assert action[:, 0, 0].tolist() == [1.0, 2.0, 3.0] - assert len(captured) == 3 - assert [item["input_image_shape"] for item in captured] == [(1, 3, 16, 16)] * 3 - assert [item["proprio_shape"] for item in captured] == [(1, 2)] * 3 + assert output["loss"].item() == 1.0 + assert output["loss_action"].item() == 1.0 + assert action.shape == (2, 4, 3) + assert action[:, 0, 0].tolist() == [1.0, 2.0] + assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)] + assert [item["proprio_shape"] for item in captured] == [(1, 2), (1, 2)] assert [item["prompt"] for item in captured] == [ cfg.prompt_template.format(task="task 0"), cfg.prompt_template.format(task="task 1"), - cfg.prompt_template.format(task="task 2"), ] -def test_from_pretrained_does_not_initialize_wan_backbone(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) +def test_from_pretrained_loads_weights_without_initializing_wan_backbone(monkeypatch, tmp_path): + cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) cfg.save_pretrained(tmp_path) - _patch_core_builder(monkeypatch) - reference_policy = FastWAMPolicy(cfg) + monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: FakeFastWAMCore()) + reference_policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True) save_model(reference_policy, str(tmp_path / "model.safetensors")) def fail_if_wan_pretrained_is_loaded(*args, **kwargs): @@ -399,52 +240,13 @@ def test_from_pretrained_does_not_initialize_wan_backbone(monkeypatch, tmp_path) assert loaded_components_from == [tmp_path] -def test_from_pretrained_resolves_hub_repo_to_snapshot_before_loading_sidecars(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) - cfg.save_pretrained(tmp_path) - snapshot_calls = [] - - def fake_snapshot_download(**kwargs): - snapshot_calls.append(kwargs) - return str(tmp_path) - - @classmethod - def fake_base_from_pretrained(cls, pretrained_name_or_path, *, config=None, **kwargs): - assert pretrained_name_or_path == tmp_path - assert kwargs.pop("_skip_wan_init") is True - assert kwargs["strict"] is False - return cls(config, _skip_wan_init=True) - - monkeypatch.setattr("huggingface_hub.snapshot_download", fake_snapshot_download) - monkeypatch.setattr(PreTrainedPolicy, "from_pretrained", fake_base_from_pretrained) - monkeypatch.setattr( - modeling_fastwam, - "_build_core_model_from_architecture", - lambda config: FakeFastWAMCore(), - raising=False, - ) - loaded_components_from = [] - monkeypatch.setattr( - FastWAMPolicy, - "load_wan_components_from_pretrained", - lambda self, path: loaded_components_from.append(path), - ) - - FastWAMPolicy.from_pretrained("org/fastwam", strict=False, local_files_only=True, revision="main") - - assert snapshot_calls[0]["repo_id"] == "org/fastwam" - assert snapshot_calls[0]["local_files_only"] is True - assert snapshot_calls[0]["revision"] == "main" - assert loaded_components_from == [tmp_path] - - -def test_save_pretrained_copies_wan_components(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2) +def test_save_pretrained_copies_required_wan_sidecars(monkeypatch, tmp_path): + cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) source = tmp_path / "source" - tokenizer = source / "google" / "umt5-xxl" + tokenizer = source / WAN_T5_TOKENIZER tokenizer.mkdir(parents=True) - vae = source / "Wan2.2_VAE.pth" - text_encoder = source / "models_t5_umt5-xxl-enc-bf16.pth" + vae = source / WAN_VAE_CHECKPOINT + text_encoder = source / WAN_T5_CHECKPOINT tokenizer_file = tokenizer / "tokenizer.json" vae.write_bytes(b"vae") text_encoder.write_bytes(b"text") @@ -456,12 +258,47 @@ def test_save_pretrained_copies_wan_components(monkeypatch, tmp_path): "tokenizer": str(tokenizer), } monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: core) - policy = FastWAMPolicy(cfg) + policy = FastWAMPolicy(cfg, _suppress_base_init_warning=True) save_dir = tmp_path / "saved" policy.save_pretrained(save_dir) assert (save_dir / "model.safetensors").is_file() - assert (save_dir / "Wan2.2_VAE.pth").read_bytes() == b"vae" - assert (save_dir / "models_t5_umt5-xxl-enc-bf16.pth").read_bytes() == b"text" - assert (save_dir / "google" / "umt5-xxl" / "tokenizer.json").read_text() == "{}" + assert (save_dir / WAN_VAE_CHECKPOINT).read_bytes() == b"vae" + assert (save_dir / WAN_T5_CHECKPOINT).read_bytes() == b"text" + assert (save_dir / WAN_T5_TOKENIZER / "tokenizer.json").read_text() == "{}" + + +def test_wan_component_resolution_uses_fixed_safetensors_layout(tmp_path): + tokenizer = tmp_path / WAN_T5_TOKENIZER + tokenizer.mkdir(parents=True) + (tmp_path / WAN_VAE_CHECKPOINT).touch() + (tmp_path / WAN_T5_CHECKPOINT).touch() + (tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors").touch() + (tokenizer / "tokenizer.json").touch() + + paths = resolve_wan_checkpoint_paths(tmp_path) + sidecar_paths = resolve_wan_component_paths(tmp_path) + + assert paths.dit == [tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors"] + assert paths.vae == tmp_path / WAN_VAE_CHECKPOINT + assert paths.text_encoder == tmp_path / WAN_T5_CHECKPOINT + assert paths.tokenizer == tmp_path / WAN_T5_TOKENIZER + assert sidecar_paths.dit == [] + assert WAN_DIT_PATTERN == "diffusion_pytorch_model*.safetensors" + + (tmp_path / WAN_T5_CHECKPOINT).unlink() + with pytest.raises(FileNotFoundError, match="text encoder"): + resolve_wan_checkpoint_paths(tmp_path) + + +def test_pretrained_config_round_trips_fastwam_features(tmp_path): + cfg = FastWAMConfig(action_dim=7, proprio_dim=8, image_size=(224, 448), base_model_id=None) + cfg.save_pretrained(tmp_path) + + loaded = PreTrainedConfig.from_pretrained(tmp_path) + + assert loaded.type == "fastwam" + assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL + assert loaded.action_feature.shape == (7,) + assert loaded.robot_state_feature.shape == (8,) diff --git a/tests/policies/fastwam/test_fastwam_wan_components.py b/tests/policies/fastwam/test_fastwam_wan_components.py deleted file mode 100644 index c2c835dd2..000000000 --- a/tests/policies/fastwam/test_fastwam_wan_components.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path - -import pytest -from torch import nn - -from lerobot.policies.fastwam import modeling_fastwam -from lerobot.policies.fastwam.configuration_fastwam import FastWAMConfig -from lerobot.policies.fastwam.modeling_fastwam import ( - FastWAMPolicy, - resolve_wan_component_paths, -) -from lerobot.policies.fastwam.wan_components import ( - WAN_DIT_PATTERN, - WAN_T5_CHECKPOINT, - WAN_T5_TOKENIZER, - WAN_VAE_CHECKPOINT, - resolve_wan_checkpoint_paths, -) - - -def _make_wan_component_tree(root: Path) -> None: - tokenizer = root / WAN_T5_TOKENIZER - tokenizer.mkdir(parents=True) - (root / WAN_VAE_CHECKPOINT).touch() - (root / WAN_T5_CHECKPOINT).touch() - (root / "diffusion_pytorch_model-00001-of-00001.safetensors").touch() - (tokenizer / "tokenizer.json").touch() - - -def test_resolve_wan_component_paths_finds_complete_local_directory(tmp_path): - _make_wan_component_tree(tmp_path) - - paths = resolve_wan_component_paths(tmp_path) - - assert paths.vae == tmp_path / WAN_VAE_CHECKPOINT - assert paths.text_encoder == tmp_path / WAN_T5_CHECKPOINT - assert paths.tokenizer == tmp_path / WAN_T5_TOKENIZER - - -def test_resolve_wan_component_paths_does_not_require_original_dit_shards(tmp_path): - _make_wan_component_tree(tmp_path) - for shard in tmp_path.glob(WAN_DIT_PATTERN): - shard.unlink() - - paths = resolve_wan_component_paths(tmp_path) - - assert paths.dit == [] - assert paths.vae == tmp_path / WAN_VAE_CHECKPOINT - assert paths.text_encoder == tmp_path / WAN_T5_CHECKPOINT - assert paths.tokenizer == tmp_path / WAN_T5_TOKENIZER - - -def test_resolve_wan_checkpoint_paths_uses_official_wan_layout(tmp_path): - _make_wan_component_tree(tmp_path) - - paths = resolve_wan_checkpoint_paths(tmp_path) - - assert paths.root == tmp_path - assert paths.dit == [tmp_path / "diffusion_pytorch_model-00001-of-00001.safetensors"] - assert paths.vae == tmp_path / WAN_VAE_CHECKPOINT - assert paths.text_encoder == tmp_path / WAN_T5_CHECKPOINT - assert paths.tokenizer == tmp_path / WAN_T5_TOKENIZER - assert WAN_DIT_PATTERN == "diffusion_pytorch_model*.safetensors" - - -def test_resolve_wan_component_paths_rejects_partial_local_directory(tmp_path): - _make_wan_component_tree(tmp_path) - (tmp_path / WAN_T5_CHECKPOINT).unlink() - - with pytest.raises(FileNotFoundError, match="text encoder"): - resolve_wan_component_paths(tmp_path) - - -def test_policy_config_construction_loads_wan22_backbone_from_config(monkeypatch): - class TinyCore(nn.Module): - def __init__(self): - super().__init__() - self.text_encoder = None - - calls = [] - - def fake_from_wan22_pretrained(**kwargs): - calls.append(kwargs) - return TinyCore() - - monkeypatch.setattr( - "lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained", - fake_from_wan22_pretrained, - ) - - cfg = FastWAMConfig() - policy = FastWAMPolicy(cfg) - - assert policy.model.text_encoder is None - assert calls == [ - { - "device": cfg.device, - "torch_dtype": modeling_fastwam._dtype_from_name(cfg.torch_dtype), - "model_id": "Wan-AI/Wan2.2-TI2V-5B", - "tokenizer_model_id": "Wan-AI/Wan2.2-TI2V-5B", - "tokenizer_max_len": cfg.tokenizer_max_len, - "load_text_encoder": cfg.load_text_encoder, - "proprio_dim": cfg.proprio_dim, - "video_dit_config": cfg.video_dit_config, - "action_dit_config": cfg.action_dit_config, - "mot_checkpoint_mixed_attn": cfg.mot_checkpoint_mixed_attn, - "video_train_shift": float(cfg.video_scheduler["train_shift"]), - "video_infer_shift": float(cfg.video_scheduler["infer_shift"]), - "video_num_train_timesteps": int(cfg.video_scheduler["num_train_timesteps"]), - "action_train_shift": float(cfg.action_scheduler["train_shift"]), - "action_infer_shift": float(cfg.action_scheduler["infer_shift"]), - "action_num_train_timesteps": int(cfg.action_scheduler["num_train_timesteps"]), - "loss_lambda_video": float(cfg.loss["lambda_video"]), - "loss_lambda_action": float(cfg.loss["lambda_action"]), - } - ] - - -def test_explicit_local_wan_path_is_preserved(tmp_path): - cfg = FastWAMConfig(model_id=str(tmp_path), tokenizer_model_id=str(tmp_path)) - - assert cfg.model_id == str(tmp_path) - assert cfg.tokenizer_model_id == str(tmp_path) - - -def test_other_hub_model_ids_are_rejected(): - with pytest.raises(ValueError, match="model_id"): - FastWAMConfig(model_id="somebody/other-model") - - with pytest.raises(ValueError, match="tokenizer_model_id"): - FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer") - - -def test_resolve_wan_checkpoint_paths_can_skip_text_encoder(tmp_path): - _make_wan_component_tree(tmp_path) - (tmp_path / WAN_T5_CHECKPOINT).unlink() - shutil_tokenizer = tmp_path / WAN_T5_TOKENIZER - for child in shutil_tokenizer.iterdir(): - child.unlink() - shutil_tokenizer.rmdir() - shutil_tokenizer.parent.rmdir() - - paths = resolve_wan_checkpoint_paths(tmp_path, load_text_encoder=False) - - assert paths.text_encoder is None - assert paths.tokenizer is None