diff --git a/src/lerobot/policies/fastwam/configuration_fastwam.py b/src/lerobot/policies/fastwam/configuration_fastwam.py index 57ccccb7c..c557b9d4f 100644 --- a/src/lerobot/policies/fastwam/configuration_fastwam.py +++ b/src/lerobot/policies/fastwam/configuration_fastwam.py @@ -153,11 +153,25 @@ class FastWAMConfig(PreTrainedConfig): proprio_dim (int | None): Number of proprioception channels used as an extra text-context token. `None` disables proprio conditioning. action_horizon (int): Number of actions predicted by one policy call. - num_video_frames (int): Number of video frames used by FastWAM rollout. + num_video_frames (int): Raw video sampling window (in dataset frames). The + model actually operates on `model_video_frames` frames after subsampling + by `action_video_freq_ratio`. + action_video_freq_ratio (int): Actions are sampled at this multiple of the + video frame rate. Video frames are taken every `action_video_freq_ratio`-th + raw frame, so the model sees `(num_video_frames - 1) // ratio + 1` frames + spanning the same time window as `action_horizon` actions (ratio actions + per video frame). image_size (tuple[int, int]): Concatenated image size as `(height, width)`. context_len (int): Maximum text embedding token length. video_dit_config (dict[str, Any] | None): Wan video expert config. action_dit_config (dict[str, Any] | None): Action expert config. + use_gradient_checkpointing (bool): Enable activation checkpointing in both DiT + experts (trades compute for memory; propagated into the DiT configs). + freeze_video_expert (bool): Freeze the ~5B Wan video expert + (`model.video_expert`) so only the action expert + proprio encoder train. + Cuts the AdamW optimizer footprint substantially; the video expert keeps its + pretrained weights. (If enabled, also set `loss.lambda_video=0` to skip the + now-gradient-free video loss compute.) """ n_obs_steps: int = 1 @@ -166,6 +180,7 @@ class FastWAMConfig(PreTrainedConfig): action_horizon: int = 32 n_action_steps: int = 32 num_video_frames: int = 33 + action_video_freq_ratio: int = 4 image_size: tuple[int, int] = (224, 448) context_len: int = 128 model_id: str = WAN22_MODEL_ID @@ -186,6 +201,8 @@ class FastWAMConfig(PreTrainedConfig): sigma_shift: float | None = None tiled: bool = False fp32_attention: bool = True + use_gradient_checkpointing: bool = False + freeze_video_expert: bool = False toggle_action_dimensions: list[int] = field(default_factory=list) video_scheduler: dict[str, float | int] = field( default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000} @@ -220,6 +237,8 @@ class FastWAMConfig(PreTrainedConfig): self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim) self.video_dit_config["fp32_attention"] = bool(self.fp32_attention) self.action_dit_config["fp32_attention"] = bool(self.fp32_attention) + self.video_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing) + self.action_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing) if self.input_features is None: height, width = self.image_size self.input_features = { @@ -300,8 +319,28 @@ class FastWAMConfig(PreTrainedConfig): raise ValueError(f"`action_horizon` must be positive, got {self.action_horizon}.") if self.n_action_steps > self.action_horizon: raise ValueError("`n_action_steps` cannot exceed `action_horizon`.") - if self.num_video_frames % 4 != 1: - raise ValueError(f"`num_video_frames` must satisfy T % 4 == 1, got {self.num_video_frames}.") + if self.action_video_freq_ratio <= 0: + raise ValueError( + f"`action_video_freq_ratio` must be positive, got {self.action_video_freq_ratio}." + ) + # Video frames are subsampled by action_video_freq_ratio; the resulting model frame + # count must satisfy T % 4 == 1 for the VAE temporal tokenization (mirrors the + # original FastWAM dataset asserts). + if (self.num_video_frames - 1) % self.action_video_freq_ratio != 0: + raise ValueError( + f"`num_video_frames - 1` ({self.num_video_frames - 1}) must be divisible by " + f"`action_video_freq_ratio` ({self.action_video_freq_ratio})." + ) + if ((self.num_video_frames - 1) // self.action_video_freq_ratio) % 4 != 0: + raise ValueError( + f"Subsampled video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio}) " + "must be divisible by 4 for VAE tokenization (i.e. model_video_frames % 4 == 1)." + ) + if self.action_horizon % ((self.num_video_frames - 1) // self.action_video_freq_ratio) != 0: + raise ValueError( + f"`action_horizon` ({self.action_horizon}) must be divisible by the number of " + f"video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio})." + ) if not self.image_features: raise ValueError("FastWAM requires at least one image feature.") if self.action_feature is None: @@ -333,8 +372,19 @@ class FastWAMConfig(PreTrainedConfig): raise ValueError(f"FastWAM image feature widths must sum to {width}, got {image_width_sum}.") @property - def observation_delta_indices(self) -> None: - return None + def model_video_frames(self) -> int: + """Number of video frames the model actually operates on, after subsampling the + raw `num_video_frames` window by `action_video_freq_ratio` (e.g. 33 -> 9).""" + return (self.num_video_frames - 1) // self.action_video_freq_ratio + 1 + + @property + def observation_delta_indices(self) -> list[int]: + # Load the video frames the model is supervised on: the future window subsampled by + # action_video_freq_ratio (e.g. [0, 4, 8, ..., 32] -> 9 frames). Each video frame is + # thus `action_video_freq_ratio` actions apart, while actions load at the full rate + # (`action_delta_indices` = range(action_horizon)). Returning None would load only the + # current frame, making the video target a static repeat (degenerate supervision). + return list(range(0, self.num_video_frames, self.action_video_freq_ratio)) @property def action_delta_indices(self) -> list[int]: diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index b64d44785..2dcee64d7 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -15,7 +15,9 @@ from __future__ import annotations import logging +import os from collections import deque +from pathlib import Path from typing import Any import torch @@ -25,6 +27,23 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.constants import OBS_STATE from .configuration_fastwam import FastWAMConfig +from .modular_fastwam import ActionDiT, FastWAM, MoT +from .wan_components import ( + build_wan_tokenizer, + load_pretrained_wan_text_encoder, + load_pretrained_wan_vae, +) +from .wan_video_dit import WanVideoDiT + +# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first +# eval episode's action chunks through `infer_joint` so the predicted video latents are +# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path). +_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1" +# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1 +# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/ +# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the +# (now-fixed) bug where the model ran on the un-subsampled num_video_frames. +_DEBUG_PRED_RATE_DIV = 1 class FastWAMPolicy(PreTrainedPolicy): @@ -43,13 +62,32 @@ class FastWAMPolicy(PreTrainedPolicy): self, config: FastWAMConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None, + **kwargs: Any, ): + # `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the + # dataset feature metadata is already applied to `config` by make_policy upstream, + # so we accept and ignore them, matching the other LeRobot policies. super().__init__(config, dataset_stats) config.validate_features() self.config = config self.dataset_stats = dataset_stats self.model = self._build_core_model(config) + if config.freeze_video_expert and getattr(self.model, "video_expert", None) is not None: + # Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad, + # so its params drop out of the optimizer (and DDP skips them). + self.model.video_expert.requires_grad_(False) self.reset() + # TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()` + # counts only eval-rollout resets (one per episode), not this __init__ one. + self._debug_constructed = True + self._debug_episode_index = -1 + self._debug_seen_tasks: set[str] = set() + self._debug_capturing = False + self._debug_episode_started = False + self._debug_episode_task = "" + self._debug_step_in_chunk = 0 + self._debug_last_video: list | None = None + self._debug_pairs: list = [] @classmethod def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool): @@ -100,17 +138,33 @@ class FastWAMPolicy(PreTrainedPolicy): model.to(map_location) return model - def get_optim_params(self) -> dict[str, Any]: + def get_optim_params(self) -> list[Tensor]: + # Return the trainable tensors directly (a single param group). The optimizer + # builder wraps these in a param group; returning a bare {"params": [...]} dict + # instead would make `list(...)` yield the key string "params". params = ( list(self.model.dit.parameters()) if hasattr(self.model, "dit") else list(self.model.parameters()) ) proprio_encoder = getattr(self.model, "proprio_encoder", None) if proprio_encoder is not None: params.extend(list(proprio_encoder.parameters())) - return {"params": [p for p in params if p.requires_grad]} + return [p for p in params if p.requires_grad] def reset(self) -> None: self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps) + # TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's + # true-vs-pred video if it was a captured one (pairs accumulate only while + # capturing), then reset per-episode capture state. + if getattr(self, "_debug_constructed", False): + if _FASTWAM_DECODE_DEBUG and self._debug_pairs: + self._save_debug_video() + self._debug_episode_index += 1 + self._debug_capturing = False + self._debug_episode_started = False + self._debug_episode_task = "" + self._debug_step_in_chunk = 0 + self._debug_last_video = None + self._debug_pairs = [] def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Adapt a standard LeRobot batch to the FastWAM-native sample that @@ -144,7 +198,7 @@ class FastWAMPolicy(PreTrainedPolicy): sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state return sample - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]: """Compute FastWAM training loss for a LeRobot batch. Args: @@ -154,19 +208,14 @@ class FastWAMPolicy(PreTrainedPolicy): `action`, `action_is_pad`). Returns: - dict[str, Tensor]: Output dictionary containing the scalar `loss` - key required by LeRobot and optional tensor metrics. + tuple[Tensor, dict[str, Any]]: The scalar loss to backprop, and a dict of + logging metrics (e.g. `loss_video`, `loss_action`) — the `(loss, output_dict)` + contract the LeRobot training loop expects. """ sample = self._batch_to_training_sample(batch) loss, metrics = self.model.training_loss(sample) - output = {"loss": loss} - for key, value in (metrics or {}).items(): - if isinstance(value, Tensor): - output[key] = value.to(device=loss.device) - else: - output[key] = torch.as_tensor(value, device=loss.device) - return output + return loss, dict(metrics or {}) @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor], **_: Any) -> Tensor: @@ -183,7 +232,21 @@ class FastWAMPolicy(PreTrainedPolicy): self.eval() infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config) batch_size = _infer_kwargs_batch_size(infer_kwargs) - if batch_size == 1: + # TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task), + # run the joint video+action path so the predicted video is VAE-decoded; stash it + # so select_action can pair each predicted frame with the real obs that follows. + if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1: + out = self.model.infer_joint( + **infer_kwargs, + num_video_frames=self.config.model_video_frames, + test_action_with_infer_action=False, + ) + # The decoded rollout has model_video_frames frames spanning the full + # action_horizon (action_video_freq_ratio actions per frame); the per-step + # pairing indexes into it, so keep all frames. + self._debug_last_video = out["video"] + action = _action_from_model_output(out) + elif batch_size == 1: action = _action_from_model_output(self.model.infer_action(**infer_kwargs)) else: action = torch.cat( @@ -202,12 +265,98 @@ class FastWAMPolicy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor: self.eval() + # TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide + # whether to capture: yes iff this episode's task hasn't been captured yet (so we + # get the first episode of every task). + if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started: + self._debug_episode_started = True + task = self._debug_task_name(batch) + if task not in self._debug_seen_tasks: + self._debug_seen_tasks.add(task) + self._debug_capturing = True + self._debug_episode_task = task + capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing if len(self._action_queue) == 0: actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps] self._action_queue.extend(actions.transpose(0, 1)) + if capturing: + self._debug_step_in_chunk = 0 # a fresh chunk was just predicted + if capturing: + self._debug_capture_pair(batch) + self._debug_step_in_chunk += 1 return self._action_queue.popleft() - def _build_core_model(self, config: FastWAMConfig) -> torch.nn.Module: + # ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ---- + @staticmethod + def _debug_task_name(batch: dict[str, Any]) -> str: + task = batch.get("task") + if isinstance(task, (list, tuple)): + task = task[0] if task else None + return str(task) if task else "no_task" + + def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None: + video = getattr(self, "_debug_last_video", None) + if not video: + return + real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1] + # Map env-step offset within the chunk to a predicted-frame index. The rollout has + # (model_video_frames - 1) transitions over action_horizon actions, so each env step + # advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio, + # e.g. 8/32 = 0.25 — one predicted frame per ~4 actions). + frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon) + idx = min( + int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)), + len(video) - 1, + ) + pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx]) + self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx) + self._debug_pairs.append(pair) + + @staticmethod + def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None: + from PIL import ImageDraw + + draw = ImageDraw.Draw(pair) + draw.text((3, 3), "true", fill=(255, 255, 0)) + draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0)) + + @staticmethod + def _debug_tensor_to_pil(image: Tensor): + from PIL import Image + + arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) + return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy()) + + @staticmethod + def _debug_hstack(left, right): + from PIL import Image + + if right.height != left.height: + right = right.resize((round(right.width * left.height / right.height), left.height)) + canvas = Image.new("RGB", (left.width + right.width, left.height)) + canvas.paste(left, (0, 0)) + canvas.paste(right, (left.width, 0)) + return canvas + + def _save_debug_video(self) -> None: + import re + + import numpy as np + + from lerobot.utils.io_utils import write_video + + pairs = getattr(self, "_debug_pairs", None) + if not pairs: + return + out_dir = Path("outputs/fastwam_debug") + out_dir.mkdir(parents=True, exist_ok=True) + slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task" + path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4" + frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB + write_video(path, frames, fps=30) + logging.info("FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path) + + def _build_core_model(self, config: FastWAMConfig) -> FastWAM: """Build the FastWAM core for training / inference. Only the trainable parts (the MoT DiT and the proprio encoder) are @@ -218,14 +367,6 @@ class FastWAMPolicy(PreTrainedPolicy): across checkpoints) and are intentionally excluded from `model.safetensors` — see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`. """ - from .modular_fastwam import ActionDiT, FastWAM, MoT - from .wan_components import ( - build_wan_tokenizer, - load_pretrained_wan_text_encoder, - load_pretrained_wan_vae, - ) - from .wan_video_dit import WanVideoDiT - dtype = _dtype_from_name(config.torch_dtype) device = config.device video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype) @@ -342,15 +483,24 @@ def batch_device(batch: dict[str, Any]) -> torch.device: def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor: - image_keys = sorted(k for k in batch if k.startswith("observation.images.")) + # Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside + # each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames. + image_keys = sorted( + k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad") + ) if not image_keys: raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.") images = [batch[key] for key in image_keys] + # Cameras concatenate along width (last dim) in both the single-frame and temporal case. image = torch.cat(images, dim=-1) if len(images) > 1 else images[0] if image.ndim == 4: - image = image.unsqueeze(2).repeat(1, 1, config.num_video_frames, 1, 1) - if image.ndim != 5: - raise ValueError(f"Expected image batch [B,C,H,W] or video [B,C,T,H,W], got {tuple(image.shape)}.") + # [B, C, H, W]: a single frame (e.g. the live eval observation) -> repeat across time. + image = image.unsqueeze(2).repeat(1, 1, config.model_video_frames, 1, 1) + elif image.ndim == 5: + # [B, T, C, H, W]: temporal stack from delta-timestamp loading -> [B, C, T, H, W]. + image = image.permute(0, 2, 1, 3, 4) + else: + raise ValueError(f"Expected image batch [B,C,H,W] or temporal [B,T,C,H,W], got {tuple(image.shape)}.") return image diff --git a/src/lerobot/policies/fastwam/processor_fastwam.py b/src/lerobot/policies/fastwam/processor_fastwam.py index 8fc61446b..fafc80c9f 100644 --- a/src/lerobot/policies/fastwam/processor_fastwam.py +++ b/src/lerobot/policies/fastwam/processor_fastwam.py @@ -42,6 +42,35 @@ from lerobot.utils.constants import ( from .configuration_fastwam import FastWAMConfig +@dataclass +@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor") +class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep): + """`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack. + + FastWAM loads a per-camera video stack, so image observations arrive as + ``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a + single leading batch dim (resize raises on 5-D input), so we flatten any leading + dims into the batch, apply the base 4-D crop/resize, then restore the leading shape. + Crop/resize params and feature-shape bookkeeping are inherited unchanged. + """ + + def observation(self, observation: dict) -> dict: + leads: dict[str, tuple] = {} + flat_input = dict(observation) + for key, img in observation.items(): + if "image" in key and torch.is_tensor(img) and img.ndim > 4: + leads[key] = tuple(img.shape[:-3]) + flat_input[key] = img.reshape(-1, *img.shape[-3:]) + processed = super().observation(flat_input) + if not leads: + return processed + out = dict(processed) + for key, lead in leads.items(): + im = processed[key] + out[key] = im.reshape(*lead, *im.shape[-3:]) + return out + + @dataclass @ProcessorStepRegistry.register(name="fastwam_action_toggle_processor") class FastWAMActionToggleProcessorStep(ActionProcessorStep): @@ -111,7 +140,8 @@ def make_fastwam_pre_post_processors( resize_steps = [] if visual_shapes: target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2])) - resize_steps.append(ImageCropResizeProcessorStep(resize_size=target_hw)) + # FastWAM-aware resize: tolerates the leading temporal dim of the video stack. + resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw)) input_steps = [ RenameObservationsProcessorStep(rename_map={}), diff --git a/src/lerobot/policies/fastwam/wan_components.py b/src/lerobot/policies/fastwam/wan_components.py index 41c2fdafd..fd6e3dc52 100644 --- a/src/lerobot/policies/fastwam/wan_components.py +++ b/src/lerobot/policies/fastwam/wan_components.py @@ -20,12 +20,19 @@ from pathlib import Path from typing import TYPE_CHECKING, Any import torch +from huggingface_hub import snapshot_download from safetensors.torch import load_file +from transformers import AutoTokenizer, UMT5EncoderModel if TYPE_CHECKING: from .wan_adapters import WanVideoVAE38 from .wan_video_dit import WanVideoDiT +from diffusers import AutoencoderKLWan + +from .wan_adapters import WanVideoVAE38 +from .wan_video_dit import WanVideoDiT + logger = logging.getLogger(__name__) # The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2 @@ -65,8 +72,6 @@ class WanTokenizer: FastWAM call site expects.""" def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None: - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained(name) self.seq_len = int(seq_len) @@ -94,10 +99,6 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer: def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38: """Load real Wan2.2 VAE weights from the diffusers repo (offline base creation).""" - from diffusers import AutoencoderKLWan - - from .wan_adapters import WanVideoVAE38 - vae = AutoencoderKLWan.from_pretrained( WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype ) @@ -106,8 +107,6 @@ def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVide def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder: """Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation).""" - from transformers import UMT5EncoderModel - encoder = UMT5EncoderModel.from_pretrained( WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype ) @@ -126,8 +125,6 @@ def resolve_wan_dit_paths( if path.is_dir(): return sorted(path.glob(WAN_DIT_PATTERN)) - from huggingface_hub import snapshot_download - snapshot_path = snapshot_download( repo_id=str(model_id_or_path), revision=revision, @@ -145,8 +142,6 @@ def load_wan_video_dit( torch_dtype: torch.dtype, device: str, ) -> WanVideoDiT: - from .wan_video_dit import WanVideoDiT - model = WanVideoDiT(**dit_config) state_dict = _read_wan_dit_safetensors(paths) model.load_state_dict(state_dict, strict=False) diff --git a/src/lerobot/policies/fastwam/wan_video_dit.py b/src/lerobot/policies/fastwam/wan_video_dit.py index d5350ea90..0b38ad816 100644 --- a/src/lerobot/policies/fastwam/wan_video_dit.py +++ b/src/lerobot/policies/fastwam/wan_video_dit.py @@ -29,6 +29,7 @@ from .wan.modules.model import ( rope_params, sinusoidal_embedding_1d, ) +from .wan.utils.fm_solvers import get_sampling_sigmas logger = logging.getLogger(__name__) @@ -94,8 +95,6 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]: - from .wan.utils.fm_solvers import get_sampling_sigmas - return get_sampling_sigmas(num_inference_steps, shift) diff --git a/tests/policies/fastwam/test_fastwam_policy.py b/tests/policies/fastwam/test_fastwam_policy.py index c3747c407..68ea6632b 100644 --- a/tests/policies/fastwam/test_fastwam_policy.py +++ b/tests/policies/fastwam/test_fastwam_policy.py @@ -49,6 +49,8 @@ def test_fastwam_is_registered_and_publicly_exported(): proprio_dim=2, action_horizon=4, n_action_steps=2, + num_video_frames=5, + action_video_freq_ratio=1, base_model_id=None, ) @@ -78,6 +80,8 @@ def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_pa proprio_dim=2, action_horizon=4, n_action_steps=2, + num_video_frames=5, + action_video_freq_ratio=1, image_size=(2, 2), device="cpu", toggle_action_dimensions=[-1], @@ -154,6 +158,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch): proprio_dim=2, action_horizon=4, n_action_steps=2, + num_video_frames=5, + action_video_freq_ratio=1, image_size=(16, 16), input_features={ "observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)), @@ -164,7 +170,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch): ) policy = FastWAMPolicy(cfg) - output = policy.forward( + loss, metrics = policy.forward( { "observation.images.image": torch.zeros(1, 3, 16, 16), OBS_STATE: torch.zeros(1, 2), @@ -186,8 +192,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch): } ) - assert output["loss"].item() == 1.0 - assert output["loss_action"].item() == 1.0 + assert loss.item() == 1.0 + assert metrics["loss_action"] == 1.0 assert action.shape == (2, 4, 3) assert action[:, 0, 0].tolist() == [1.0, 2.0] assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)] @@ -218,7 +224,7 @@ class CoreWithFrozenComponents(FakeFastWAMCore): def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) + cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None) def build_core(self, config): core = CoreWithFrozenComponents() @@ -250,7 +256,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) + cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None) monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents()) policy = FastWAMPolicy(cfg) @@ -272,7 +278,7 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path): def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch): - cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None) + cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None) monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents()) policy = FastWAMPolicy(cfg)