preparing for training adding some temporary debug code aswell to visualize model output

This commit is contained in:
Maxime Ellerbach
2026-06-12 15:25:28 +00:00
parent 7c063c3fbc
commit a323ea67b6
6 changed files with 282 additions and 52 deletions
@@ -153,11 +153,25 @@ class FastWAMConfig(PreTrainedConfig):
proprio_dim (int | None): Number of proprioception channels used as an
extra text-context token. `None` disables proprio conditioning.
action_horizon (int): Number of actions predicted by one policy call.
num_video_frames (int): Number of video frames used by FastWAM rollout.
num_video_frames (int): Raw video sampling window (in dataset frames). The
model actually operates on `model_video_frames` frames after subsampling
by `action_video_freq_ratio`.
action_video_freq_ratio (int): Actions are sampled at this multiple of the
video frame rate. Video frames are taken every `action_video_freq_ratio`-th
raw frame, so the model sees `(num_video_frames - 1) // ratio + 1` frames
spanning the same time window as `action_horizon` actions (ratio actions
per video frame).
image_size (tuple[int, int]): Concatenated image size as `(height, width)`.
context_len (int): Maximum text embedding token length.
video_dit_config (dict[str, Any] | None): Wan video expert config.
action_dit_config (dict[str, Any] | None): Action expert config.
use_gradient_checkpointing (bool): Enable activation checkpointing in both DiT
experts (trades compute for memory; propagated into the DiT configs).
freeze_video_expert (bool): Freeze the ~5B Wan video expert
(`model.video_expert`) so only the action expert + proprio encoder train.
Cuts the AdamW optimizer footprint substantially; the video expert keeps its
pretrained weights. (If enabled, also set `loss.lambda_video=0` to skip the
now-gradient-free video loss compute.)
"""
n_obs_steps: int = 1
@@ -166,6 +180,7 @@ class FastWAMConfig(PreTrainedConfig):
action_horizon: int = 32
n_action_steps: int = 32
num_video_frames: int = 33
action_video_freq_ratio: int = 4
image_size: tuple[int, int] = (224, 448)
context_len: int = 128
model_id: str = WAN22_MODEL_ID
@@ -186,6 +201,8 @@ class FastWAMConfig(PreTrainedConfig):
sigma_shift: float | None = None
tiled: bool = False
fp32_attention: bool = True
use_gradient_checkpointing: bool = False
freeze_video_expert: bool = False
toggle_action_dimensions: list[int] = field(default_factory=list)
video_scheduler: dict[str, float | int] = field(
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
@@ -220,6 +237,8 @@ class FastWAMConfig(PreTrainedConfig):
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
self.action_dit_config["fp32_attention"] = bool(self.fp32_attention)
self.video_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
self.action_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
if self.input_features is None:
height, width = self.image_size
self.input_features = {
@@ -300,8 +319,28 @@ class FastWAMConfig(PreTrainedConfig):
raise ValueError(f"`action_horizon` must be positive, got {self.action_horizon}.")
if self.n_action_steps > self.action_horizon:
raise ValueError("`n_action_steps` cannot exceed `action_horizon`.")
if self.num_video_frames % 4 != 1:
raise ValueError(f"`num_video_frames` must satisfy T % 4 == 1, got {self.num_video_frames}.")
if self.action_video_freq_ratio <= 0:
raise ValueError(
f"`action_video_freq_ratio` must be positive, got {self.action_video_freq_ratio}."
)
# Video frames are subsampled by action_video_freq_ratio; the resulting model frame
# count must satisfy T % 4 == 1 for the VAE temporal tokenization (mirrors the
# original FastWAM dataset asserts).
if (self.num_video_frames - 1) % self.action_video_freq_ratio != 0:
raise ValueError(
f"`num_video_frames - 1` ({self.num_video_frames - 1}) must be divisible by "
f"`action_video_freq_ratio` ({self.action_video_freq_ratio})."
)
if ((self.num_video_frames - 1) // self.action_video_freq_ratio) % 4 != 0:
raise ValueError(
f"Subsampled video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio}) "
"must be divisible by 4 for VAE tokenization (i.e. model_video_frames % 4 == 1)."
)
if self.action_horizon % ((self.num_video_frames - 1) // self.action_video_freq_ratio) != 0:
raise ValueError(
f"`action_horizon` ({self.action_horizon}) must be divisible by the number of "
f"video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio})."
)
if not self.image_features:
raise ValueError("FastWAM requires at least one image feature.")
if self.action_feature is None:
@@ -333,8 +372,19 @@ class FastWAMConfig(PreTrainedConfig):
raise ValueError(f"FastWAM image feature widths must sum to {width}, got {image_width_sum}.")
@property
def observation_delta_indices(self) -> None:
return None
def model_video_frames(self) -> int:
"""Number of video frames the model actually operates on, after subsampling the
raw `num_video_frames` window by `action_video_freq_ratio` (e.g. 33 -> 9)."""
return (self.num_video_frames - 1) // self.action_video_freq_ratio + 1
@property
def observation_delta_indices(self) -> list[int]:
# Load the video frames the model is supervised on: the future window subsampled by
# action_video_freq_ratio (e.g. [0, 4, 8, ..., 32] -> 9 frames). Each video frame is
# thus `action_video_freq_ratio` actions apart, while actions load at the full rate
# (`action_delta_indices` = range(action_horizon)). Returning None would load only the
# current frame, making the video target a static repeat (degenerate supervision).
return list(range(0, self.num_video_frames, self.action_video_freq_ratio))
@property
def action_delta_indices(self) -> list[int]:
+176 -26
View File
@@ -15,7 +15,9 @@
from __future__ import annotations
import logging
import os
from collections import deque
from pathlib import Path
from typing import Any
import torch
@@ -25,6 +27,23 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import OBS_STATE
from .configuration_fastwam import FastWAMConfig
from .modular_fastwam import ActionDiT, FastWAM, MoT
from .wan_components import (
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
)
from .wan_video_dit import WanVideoDiT
# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first
# eval episode's action chunks through `infer_joint` so the predicted video latents are
# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path).
_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1"
# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1
# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/
# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the
# (now-fixed) bug where the model ran on the un-subsampled num_video_frames.
_DEBUG_PRED_RATE_DIV = 1
class FastWAMPolicy(PreTrainedPolicy):
@@ -43,13 +62,32 @@ class FastWAMPolicy(PreTrainedPolicy):
self,
config: FastWAMConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
**kwargs: Any,
):
# `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the
# dataset feature metadata is already applied to `config` by make_policy upstream,
# so we accept and ignore them, matching the other LeRobot policies.
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.model = self._build_core_model(config)
if config.freeze_video_expert and getattr(self.model, "video_expert", None) is not None:
# Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad,
# so its params drop out of the optimizer (and DDP skips them).
self.model.video_expert.requires_grad_(False)
self.reset()
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
# counts only eval-rollout resets (one per episode), not this __init__ one.
self._debug_constructed = True
self._debug_episode_index = -1
self._debug_seen_tasks: set[str] = set()
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video: list | None = None
self._debug_pairs: list = []
@classmethod
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
@@ -100,17 +138,33 @@ class FastWAMPolicy(PreTrainedPolicy):
model.to(map_location)
return model
def get_optim_params(self) -> dict[str, Any]:
def get_optim_params(self) -> list[Tensor]:
# Return the trainable tensors directly (a single param group). The optimizer
# builder wraps these in a param group; returning a bare {"params": [...]} dict
# instead would make `list(...)` yield the key string "params".
params = (
list(self.model.dit.parameters()) if hasattr(self.model, "dit") else list(self.model.parameters())
)
proprio_encoder = getattr(self.model, "proprio_encoder", None)
if proprio_encoder is not None:
params.extend(list(proprio_encoder.parameters()))
return {"params": [p for p in params if p.requires_grad]}
return [p for p in params if p.requires_grad]
def reset(self) -> None:
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
# TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's
# true-vs-pred video if it was a captured one (pairs accumulate only while
# capturing), then reset per-episode capture state.
if getattr(self, "_debug_constructed", False):
if _FASTWAM_DECODE_DEBUG and self._debug_pairs:
self._save_debug_video()
self._debug_episode_index += 1
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video = None
self._debug_pairs = []
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
@@ -144,7 +198,7 @@ class FastWAMPolicy(PreTrainedPolicy):
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
return sample
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Compute FastWAM training loss for a LeRobot batch.
Args:
@@ -154,19 +208,14 @@ class FastWAMPolicy(PreTrainedPolicy):
`action`, `action_is_pad`).
Returns:
dict[str, Tensor]: Output dictionary containing the scalar `loss`
key required by LeRobot and optional tensor metrics.
tuple[Tensor, dict[str, Any]]: The scalar loss to backprop, and a dict of
logging metrics (e.g. `loss_video`, `loss_action`) — the `(loss, output_dict)`
contract the LeRobot training loop expects.
"""
sample = self._batch_to_training_sample(batch)
loss, metrics = self.model.training_loss(sample)
output = {"loss": loss}
for key, value in (metrics or {}).items():
if isinstance(value, Tensor):
output[key] = value.to(device=loss.device)
else:
output[key] = torch.as_tensor(value, device=loss.device)
return output
return loss, dict(metrics or {})
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **_: Any) -> Tensor:
@@ -183,7 +232,21 @@ class FastWAMPolicy(PreTrainedPolicy):
self.eval()
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
batch_size = _infer_kwargs_batch_size(infer_kwargs)
if batch_size == 1:
# TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task),
# run the joint video+action path so the predicted video is VAE-decoded; stash it
# so select_action can pair each predicted frame with the real obs that follows.
if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1:
out = self.model.infer_joint(
**infer_kwargs,
num_video_frames=self.config.model_video_frames,
test_action_with_infer_action=False,
)
# The decoded rollout has model_video_frames frames spanning the full
# action_horizon (action_video_freq_ratio actions per frame); the per-step
# pairing indexes into it, so keep all frames.
self._debug_last_video = out["video"]
action = _action_from_model_output(out)
elif batch_size == 1:
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
else:
action = torch.cat(
@@ -202,12 +265,98 @@ class FastWAMPolicy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
self.eval()
# TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide
# whether to capture: yes iff this episode's task hasn't been captured yet (so we
# get the first episode of every task).
if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started:
self._debug_episode_started = True
task = self._debug_task_name(batch)
if task not in self._debug_seen_tasks:
self._debug_seen_tasks.add(task)
self._debug_capturing = True
self._debug_episode_task = task
capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
if capturing:
self._debug_step_in_chunk = 0 # a fresh chunk was just predicted
if capturing:
self._debug_capture_pair(batch)
self._debug_step_in_chunk += 1
return self._action_queue.popleft()
def _build_core_model(self, config: FastWAMConfig) -> torch.nn.Module:
# ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ----
@staticmethod
def _debug_task_name(batch: dict[str, Any]) -> str:
task = batch.get("task")
if isinstance(task, (list, tuple)):
task = task[0] if task else None
return str(task) if task else "no_task"
def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None:
video = getattr(self, "_debug_last_video", None)
if not video:
return
real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1]
# Map env-step offset within the chunk to a predicted-frame index. The rollout has
# (model_video_frames - 1) transitions over action_horizon actions, so each env step
# advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio,
# e.g. 8/32 = 0.25 — one predicted frame per ~4 actions).
frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon)
idx = min(
int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)),
len(video) - 1,
)
pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx])
self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx)
self._debug_pairs.append(pair)
@staticmethod
def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None:
from PIL import ImageDraw
draw = ImageDraw.Draw(pair)
draw.text((3, 3), "true", fill=(255, 255, 0))
draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0))
@staticmethod
def _debug_tensor_to_pil(image: Tensor):
from PIL import Image
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
@staticmethod
def _debug_hstack(left, right):
from PIL import Image
if right.height != left.height:
right = right.resize((round(right.width * left.height / right.height), left.height))
canvas = Image.new("RGB", (left.width + right.width, left.height))
canvas.paste(left, (0, 0))
canvas.paste(right, (left.width, 0))
return canvas
def _save_debug_video(self) -> None:
import re
import numpy as np
from lerobot.utils.io_utils import write_video
pairs = getattr(self, "_debug_pairs", None)
if not pairs:
return
out_dir = Path("outputs/fastwam_debug")
out_dir.mkdir(parents=True, exist_ok=True)
slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task"
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
write_video(path, frames, fps=30)
logging.info("FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path)
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
"""Build the FastWAM core for training / inference.
Only the trainable parts (the MoT DiT and the proprio encoder) are
@@ -218,14 +367,6 @@ class FastWAMPolicy(PreTrainedPolicy):
across checkpoints) and are intentionally excluded from `model.safetensors`
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
"""
from .modular_fastwam import ActionDiT, FastWAM, MoT
from .wan_components import (
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
)
from .wan_video_dit import WanVideoDiT
dtype = _dtype_from_name(config.torch_dtype)
device = config.device
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
@@ -342,15 +483,24 @@ def batch_device(batch: dict[str, Any]) -> torch.device:
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
image_keys = sorted(k for k in batch if k.startswith("observation.images."))
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
image_keys = sorted(
k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad")
)
if not image_keys:
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
images = [batch[key] for key in image_keys]
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
if image.ndim == 4:
image = image.unsqueeze(2).repeat(1, 1, config.num_video_frames, 1, 1)
if image.ndim != 5:
raise ValueError(f"Expected image batch [B,C,H,W] or video [B,C,T,H,W], got {tuple(image.shape)}.")
# [B, C, H, W]: a single frame (e.g. the live eval observation) -> repeat across time.
image = image.unsqueeze(2).repeat(1, 1, config.model_video_frames, 1, 1)
elif image.ndim == 5:
# [B, T, C, H, W]: temporal stack from delta-timestamp loading -> [B, C, T, H, W].
image = image.permute(0, 2, 1, 3, 4)
else:
raise ValueError(f"Expected image batch [B,C,H,W] or temporal [B,T,C,H,W], got {tuple(image.shape)}.")
return image
@@ -42,6 +42,35 @@ from lerobot.utils.constants import (
from .configuration_fastwam import FastWAMConfig
@dataclass
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
FastWAM loads a per-camera video stack, so image observations arrive as
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
single leading batch dim (resize raises on 5-D input), so we flatten any leading
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
"""
def observation(self, observation: dict) -> dict:
leads: dict[str, tuple] = {}
flat_input = dict(observation)
for key, img in observation.items():
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
leads[key] = tuple(img.shape[:-3])
flat_input[key] = img.reshape(-1, *img.shape[-3:])
processed = super().observation(flat_input)
if not leads:
return processed
out = dict(processed)
for key, lead in leads.items():
im = processed[key]
out[key] = im.reshape(*lead, *im.shape[-3:])
return out
@dataclass
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
@@ -111,7 +140,8 @@ def make_fastwam_pre_post_processors(
resize_steps = []
if visual_shapes:
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
resize_steps.append(ImageCropResizeProcessorStep(resize_size=target_hw))
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
+7 -12
View File
@@ -20,12 +20,19 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, UMT5EncoderModel
if TYPE_CHECKING:
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
from diffusers import AutoencoderKLWan
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
logger = logging.getLogger(__name__)
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
@@ -65,8 +72,6 @@ class WanTokenizer:
FastWAM call site expects."""
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name)
self.seq_len = int(seq_len)
@@ -94,10 +99,6 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
from diffusers import AutoencoderKLWan
from .wan_adapters import WanVideoVAE38
vae = AutoencoderKLWan.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype
)
@@ -106,8 +107,6 @@ def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVide
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
from transformers import UMT5EncoderModel
encoder = UMT5EncoderModel.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
)
@@ -126,8 +125,6 @@ def resolve_wan_dit_paths(
if path.is_dir():
return sorted(path.glob(WAN_DIT_PATTERN))
from huggingface_hub import snapshot_download
snapshot_path = snapshot_download(
repo_id=str(model_id_or_path),
revision=revision,
@@ -145,8 +142,6 @@ def load_wan_video_dit(
torch_dtype: torch.dtype,
device: str,
) -> WanVideoDiT:
from .wan_video_dit import WanVideoDiT
model = WanVideoDiT(**dit_config)
state_dict = _read_wan_dit_safetensors(paths)
model.load_state_dict(state_dict, strict=False)
@@ -29,6 +29,7 @@ from .wan.modules.model import (
rope_params,
sinusoidal_embedding_1d,
)
from .wan.utils.fm_solvers import get_sampling_sigmas
logger = logging.getLogger(__name__)
@@ -94,8 +95,6 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
from .wan.utils.fm_solvers import get_sampling_sigmas
return get_sampling_sigmas(num_inference_steps, shift)
+12 -6
View File
@@ -49,6 +49,8 @@ def test_fastwam_is_registered_and_publicly_exported():
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
@@ -78,6 +80,8 @@ def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_pa
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(2, 2),
device="cpu",
toggle_action_dimensions=[-1],
@@ -154,6 +158,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(16, 16),
input_features={
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
@@ -164,7 +170,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
)
policy = FastWAMPolicy(cfg)
output = policy.forward(
loss, metrics = policy.forward(
{
"observation.images.image": torch.zeros(1, 3, 16, 16),
OBS_STATE: torch.zeros(1, 2),
@@ -186,8 +192,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
}
)
assert output["loss"].item() == 1.0
assert output["loss_action"].item() == 1.0
assert loss.item() == 1.0
assert metrics["loss_action"] == 1.0
assert action.shape == (2, 4, 3)
assert action[:, 0, 0].tolist() == [1.0, 2.0]
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
@@ -218,7 +224,7 @@ class CoreWithFrozenComponents(FakeFastWAMCore):
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
def build_core(self, config):
core = CoreWithFrozenComponents()
@@ -250,7 +256,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
@@ -272,7 +278,7 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)