mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 00:07:03 +00:00
preparing for training adding some temporary debug code aswell to visualize model output
This commit is contained in:
@@ -153,11 +153,25 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
proprio_dim (int | None): Number of proprioception channels used as an
|
||||
extra text-context token. `None` disables proprio conditioning.
|
||||
action_horizon (int): Number of actions predicted by one policy call.
|
||||
num_video_frames (int): Number of video frames used by FastWAM rollout.
|
||||
num_video_frames (int): Raw video sampling window (in dataset frames). The
|
||||
model actually operates on `model_video_frames` frames after subsampling
|
||||
by `action_video_freq_ratio`.
|
||||
action_video_freq_ratio (int): Actions are sampled at this multiple of the
|
||||
video frame rate. Video frames are taken every `action_video_freq_ratio`-th
|
||||
raw frame, so the model sees `(num_video_frames - 1) // ratio + 1` frames
|
||||
spanning the same time window as `action_horizon` actions (ratio actions
|
||||
per video frame).
|
||||
image_size (tuple[int, int]): Concatenated image size as `(height, width)`.
|
||||
context_len (int): Maximum text embedding token length.
|
||||
video_dit_config (dict[str, Any] | None): Wan video expert config.
|
||||
action_dit_config (dict[str, Any] | None): Action expert config.
|
||||
use_gradient_checkpointing (bool): Enable activation checkpointing in both DiT
|
||||
experts (trades compute for memory; propagated into the DiT configs).
|
||||
freeze_video_expert (bool): Freeze the ~5B Wan video expert
|
||||
(`model.video_expert`) so only the action expert + proprio encoder train.
|
||||
Cuts the AdamW optimizer footprint substantially; the video expert keeps its
|
||||
pretrained weights. (If enabled, also set `loss.lambda_video=0` to skip the
|
||||
now-gradient-free video loss compute.)
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
@@ -166,6 +180,7 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
action_horizon: int = 32
|
||||
n_action_steps: int = 32
|
||||
num_video_frames: int = 33
|
||||
action_video_freq_ratio: int = 4
|
||||
image_size: tuple[int, int] = (224, 448)
|
||||
context_len: int = 128
|
||||
model_id: str = WAN22_MODEL_ID
|
||||
@@ -186,6 +201,8 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
sigma_shift: float | None = None
|
||||
tiled: bool = False
|
||||
fp32_attention: bool = True
|
||||
use_gradient_checkpointing: bool = False
|
||||
freeze_video_expert: bool = False
|
||||
toggle_action_dimensions: list[int] = field(default_factory=list)
|
||||
video_scheduler: dict[str, float | int] = field(
|
||||
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
|
||||
@@ -220,6 +237,8 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
|
||||
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
|
||||
self.action_dit_config["fp32_attention"] = bool(self.fp32_attention)
|
||||
self.video_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
|
||||
self.action_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
|
||||
if self.input_features is None:
|
||||
height, width = self.image_size
|
||||
self.input_features = {
|
||||
@@ -300,8 +319,28 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
raise ValueError(f"`action_horizon` must be positive, got {self.action_horizon}.")
|
||||
if self.n_action_steps > self.action_horizon:
|
||||
raise ValueError("`n_action_steps` cannot exceed `action_horizon`.")
|
||||
if self.num_video_frames % 4 != 1:
|
||||
raise ValueError(f"`num_video_frames` must satisfy T % 4 == 1, got {self.num_video_frames}.")
|
||||
if self.action_video_freq_ratio <= 0:
|
||||
raise ValueError(
|
||||
f"`action_video_freq_ratio` must be positive, got {self.action_video_freq_ratio}."
|
||||
)
|
||||
# Video frames are subsampled by action_video_freq_ratio; the resulting model frame
|
||||
# count must satisfy T % 4 == 1 for the VAE temporal tokenization (mirrors the
|
||||
# original FastWAM dataset asserts).
|
||||
if (self.num_video_frames - 1) % self.action_video_freq_ratio != 0:
|
||||
raise ValueError(
|
||||
f"`num_video_frames - 1` ({self.num_video_frames - 1}) must be divisible by "
|
||||
f"`action_video_freq_ratio` ({self.action_video_freq_ratio})."
|
||||
)
|
||||
if ((self.num_video_frames - 1) // self.action_video_freq_ratio) % 4 != 0:
|
||||
raise ValueError(
|
||||
f"Subsampled video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio}) "
|
||||
"must be divisible by 4 for VAE tokenization (i.e. model_video_frames % 4 == 1)."
|
||||
)
|
||||
if self.action_horizon % ((self.num_video_frames - 1) // self.action_video_freq_ratio) != 0:
|
||||
raise ValueError(
|
||||
f"`action_horizon` ({self.action_horizon}) must be divisible by the number of "
|
||||
f"video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio})."
|
||||
)
|
||||
if not self.image_features:
|
||||
raise ValueError("FastWAM requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
@@ -333,8 +372,19 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
raise ValueError(f"FastWAM image feature widths must sum to {width}, got {image_width_sum}.")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
def model_video_frames(self) -> int:
|
||||
"""Number of video frames the model actually operates on, after subsampling the
|
||||
raw `num_video_frames` window by `action_video_freq_ratio` (e.g. 33 -> 9)."""
|
||||
return (self.num_video_frames - 1) // self.action_video_freq_ratio + 1
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# Load the video frames the model is supervised on: the future window subsampled by
|
||||
# action_video_freq_ratio (e.g. [0, 4, 8, ..., 32] -> 9 frames). Each video frame is
|
||||
# thus `action_video_freq_ratio` actions apart, while actions load at the full rate
|
||||
# (`action_delta_indices` = range(action_horizon)). Returning None would load only the
|
||||
# current frame, making the video target a static repeat (degenerate supervision).
|
||||
return list(range(0, self.num_video_frames, self.action_video_freq_ratio))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -25,6 +27,23 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
from .wan_components import (
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
)
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first
|
||||
# eval episode's action chunks through `infer_joint` so the predicted video latents are
|
||||
# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path).
|
||||
_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1"
|
||||
# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1
|
||||
# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/
|
||||
# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the
|
||||
# (now-fixed) bug where the model ran on the un-subsampled num_video_frames.
|
||||
_DEBUG_PRED_RATE_DIV = 1
|
||||
|
||||
|
||||
class FastWAMPolicy(PreTrainedPolicy):
|
||||
@@ -43,13 +62,32 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
self,
|
||||
config: FastWAMConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
# `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the
|
||||
# dataset feature metadata is already applied to `config` by make_policy upstream,
|
||||
# so we accept and ignore them, matching the other LeRobot policies.
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.model = self._build_core_model(config)
|
||||
if config.freeze_video_expert and getattr(self.model, "video_expert", None) is not None:
|
||||
# Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad,
|
||||
# so its params drop out of the optimizer (and DDP skips them).
|
||||
self.model.video_expert.requires_grad_(False)
|
||||
self.reset()
|
||||
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
|
||||
# counts only eval-rollout resets (one per episode), not this __init__ one.
|
||||
self._debug_constructed = True
|
||||
self._debug_episode_index = -1
|
||||
self._debug_seen_tasks: set[str] = set()
|
||||
self._debug_capturing = False
|
||||
self._debug_episode_started = False
|
||||
self._debug_episode_task = ""
|
||||
self._debug_step_in_chunk = 0
|
||||
self._debug_last_video: list | None = None
|
||||
self._debug_pairs: list = []
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
|
||||
@@ -100,17 +138,33 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict[str, Any]:
|
||||
def get_optim_params(self) -> list[Tensor]:
|
||||
# Return the trainable tensors directly (a single param group). The optimizer
|
||||
# builder wraps these in a param group; returning a bare {"params": [...]} dict
|
||||
# instead would make `list(...)` yield the key string "params".
|
||||
params = (
|
||||
list(self.model.dit.parameters()) if hasattr(self.model, "dit") else list(self.model.parameters())
|
||||
)
|
||||
proprio_encoder = getattr(self.model, "proprio_encoder", None)
|
||||
if proprio_encoder is not None:
|
||||
params.extend(list(proprio_encoder.parameters()))
|
||||
return {"params": [p for p in params if p.requires_grad]}
|
||||
return [p for p in params if p.requires_grad]
|
||||
|
||||
def reset(self) -> None:
|
||||
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
|
||||
# TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's
|
||||
# true-vs-pred video if it was a captured one (pairs accumulate only while
|
||||
# capturing), then reset per-episode capture state.
|
||||
if getattr(self, "_debug_constructed", False):
|
||||
if _FASTWAM_DECODE_DEBUG and self._debug_pairs:
|
||||
self._save_debug_video()
|
||||
self._debug_episode_index += 1
|
||||
self._debug_capturing = False
|
||||
self._debug_episode_started = False
|
||||
self._debug_episode_task = ""
|
||||
self._debug_step_in_chunk = 0
|
||||
self._debug_last_video = None
|
||||
self._debug_pairs = []
|
||||
|
||||
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
|
||||
@@ -144,7 +198,7 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
|
||||
return sample
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Compute FastWAM training loss for a LeRobot batch.
|
||||
|
||||
Args:
|
||||
@@ -154,19 +208,14 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
`action`, `action_is_pad`).
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: Output dictionary containing the scalar `loss`
|
||||
key required by LeRobot and optional tensor metrics.
|
||||
tuple[Tensor, dict[str, Any]]: The scalar loss to backprop, and a dict of
|
||||
logging metrics (e.g. `loss_video`, `loss_action`) — the `(loss, output_dict)`
|
||||
contract the LeRobot training loop expects.
|
||||
"""
|
||||
|
||||
sample = self._batch_to_training_sample(batch)
|
||||
loss, metrics = self.model.training_loss(sample)
|
||||
output = {"loss": loss}
|
||||
for key, value in (metrics or {}).items():
|
||||
if isinstance(value, Tensor):
|
||||
output[key] = value.to(device=loss.device)
|
||||
else:
|
||||
output[key] = torch.as_tensor(value, device=loss.device)
|
||||
return output
|
||||
return loss, dict(metrics or {})
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **_: Any) -> Tensor:
|
||||
@@ -183,7 +232,21 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
self.eval()
|
||||
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
|
||||
batch_size = _infer_kwargs_batch_size(infer_kwargs)
|
||||
if batch_size == 1:
|
||||
# TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task),
|
||||
# run the joint video+action path so the predicted video is VAE-decoded; stash it
|
||||
# so select_action can pair each predicted frame with the real obs that follows.
|
||||
if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1:
|
||||
out = self.model.infer_joint(
|
||||
**infer_kwargs,
|
||||
num_video_frames=self.config.model_video_frames,
|
||||
test_action_with_infer_action=False,
|
||||
)
|
||||
# The decoded rollout has model_video_frames frames spanning the full
|
||||
# action_horizon (action_video_freq_ratio actions per frame); the per-step
|
||||
# pairing indexes into it, so keep all frames.
|
||||
self._debug_last_video = out["video"]
|
||||
action = _action_from_model_output(out)
|
||||
elif batch_size == 1:
|
||||
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
|
||||
else:
|
||||
action = torch.cat(
|
||||
@@ -202,12 +265,98 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
|
||||
self.eval()
|
||||
# TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide
|
||||
# whether to capture: yes iff this episode's task hasn't been captured yet (so we
|
||||
# get the first episode of every task).
|
||||
if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started:
|
||||
self._debug_episode_started = True
|
||||
task = self._debug_task_name(batch)
|
||||
if task not in self._debug_seen_tasks:
|
||||
self._debug_seen_tasks.add(task)
|
||||
self._debug_capturing = True
|
||||
self._debug_episode_task = task
|
||||
capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
if capturing:
|
||||
self._debug_step_in_chunk = 0 # a fresh chunk was just predicted
|
||||
if capturing:
|
||||
self._debug_capture_pair(batch)
|
||||
self._debug_step_in_chunk += 1
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def _build_core_model(self, config: FastWAMConfig) -> torch.nn.Module:
|
||||
# ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ----
|
||||
@staticmethod
|
||||
def _debug_task_name(batch: dict[str, Any]) -> str:
|
||||
task = batch.get("task")
|
||||
if isinstance(task, (list, tuple)):
|
||||
task = task[0] if task else None
|
||||
return str(task) if task else "no_task"
|
||||
|
||||
def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None:
|
||||
video = getattr(self, "_debug_last_video", None)
|
||||
if not video:
|
||||
return
|
||||
real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1]
|
||||
# Map env-step offset within the chunk to a predicted-frame index. The rollout has
|
||||
# (model_video_frames - 1) transitions over action_horizon actions, so each env step
|
||||
# advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio,
|
||||
# e.g. 8/32 = 0.25 — one predicted frame per ~4 actions).
|
||||
frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon)
|
||||
idx = min(
|
||||
int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)),
|
||||
len(video) - 1,
|
||||
)
|
||||
pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx])
|
||||
self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx)
|
||||
self._debug_pairs.append(pair)
|
||||
|
||||
@staticmethod
|
||||
def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None:
|
||||
from PIL import ImageDraw
|
||||
|
||||
draw = ImageDraw.Draw(pair)
|
||||
draw.text((3, 3), "true", fill=(255, 255, 0))
|
||||
draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0))
|
||||
|
||||
@staticmethod
|
||||
def _debug_tensor_to_pil(image: Tensor):
|
||||
from PIL import Image
|
||||
|
||||
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
|
||||
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
|
||||
|
||||
@staticmethod
|
||||
def _debug_hstack(left, right):
|
||||
from PIL import Image
|
||||
|
||||
if right.height != left.height:
|
||||
right = right.resize((round(right.width * left.height / right.height), left.height))
|
||||
canvas = Image.new("RGB", (left.width + right.width, left.height))
|
||||
canvas.paste(left, (0, 0))
|
||||
canvas.paste(right, (left.width, 0))
|
||||
return canvas
|
||||
|
||||
def _save_debug_video(self) -> None:
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.io_utils import write_video
|
||||
|
||||
pairs = getattr(self, "_debug_pairs", None)
|
||||
if not pairs:
|
||||
return
|
||||
out_dir = Path("outputs/fastwam_debug")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task"
|
||||
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
|
||||
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
|
||||
write_video(path, frames, fps=30)
|
||||
logging.info("FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path)
|
||||
|
||||
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
|
||||
"""Build the FastWAM core for training / inference.
|
||||
|
||||
Only the trainable parts (the MoT DiT and the proprio encoder) are
|
||||
@@ -218,14 +367,6 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
across checkpoints) and are intentionally excluded from `model.safetensors`
|
||||
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
|
||||
"""
|
||||
from .modular_fastwam import ActionDiT, FastWAM, MoT
|
||||
from .wan_components import (
|
||||
build_wan_tokenizer,
|
||||
load_pretrained_wan_text_encoder,
|
||||
load_pretrained_wan_vae,
|
||||
)
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
dtype = _dtype_from_name(config.torch_dtype)
|
||||
device = config.device
|
||||
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
|
||||
@@ -342,15 +483,24 @@ def batch_device(batch: dict[str, Any]) -> torch.device:
|
||||
|
||||
|
||||
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
|
||||
image_keys = sorted(k for k in batch if k.startswith("observation.images."))
|
||||
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
|
||||
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
|
||||
image_keys = sorted(
|
||||
k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad")
|
||||
)
|
||||
if not image_keys:
|
||||
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
|
||||
images = [batch[key] for key in image_keys]
|
||||
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
|
||||
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(2).repeat(1, 1, config.num_video_frames, 1, 1)
|
||||
if image.ndim != 5:
|
||||
raise ValueError(f"Expected image batch [B,C,H,W] or video [B,C,T,H,W], got {tuple(image.shape)}.")
|
||||
# [B, C, H, W]: a single frame (e.g. the live eval observation) -> repeat across time.
|
||||
image = image.unsqueeze(2).repeat(1, 1, config.model_video_frames, 1, 1)
|
||||
elif image.ndim == 5:
|
||||
# [B, T, C, H, W]: temporal stack from delta-timestamp loading -> [B, C, T, H, W].
|
||||
image = image.permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
raise ValueError(f"Expected image batch [B,C,H,W] or temporal [B,T,C,H,W], got {tuple(image.shape)}.")
|
||||
return image
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,35 @@ from lerobot.utils.constants import (
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
|
||||
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
|
||||
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
|
||||
|
||||
FastWAM loads a per-camera video stack, so image observations arrive as
|
||||
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
|
||||
single leading batch dim (resize raises on 5-D input), so we flatten any leading
|
||||
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
|
||||
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
leads: dict[str, tuple] = {}
|
||||
flat_input = dict(observation)
|
||||
for key, img in observation.items():
|
||||
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
|
||||
leads[key] = tuple(img.shape[:-3])
|
||||
flat_input[key] = img.reshape(-1, *img.shape[-3:])
|
||||
processed = super().observation(flat_input)
|
||||
if not leads:
|
||||
return processed
|
||||
out = dict(processed)
|
||||
for key, lead in leads.items():
|
||||
im = processed[key]
|
||||
out[key] = im.reshape(*lead, *im.shape[-3:])
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
|
||||
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
|
||||
@@ -111,7 +140,8 @@ def make_fastwam_pre_post_processors(
|
||||
resize_steps = []
|
||||
if visual_shapes:
|
||||
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
|
||||
resize_steps.append(ImageCropResizeProcessorStep(resize_size=target_hw))
|
||||
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
|
||||
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
|
||||
@@ -20,12 +20,19 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
|
||||
@@ -65,8 +72,6 @@ class WanTokenizer:
|
||||
FastWAM call site expects."""
|
||||
|
||||
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name)
|
||||
self.seq_len = int(seq_len)
|
||||
|
||||
@@ -94,10 +99,6 @@ def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
|
||||
|
||||
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
|
||||
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from .wan_adapters import WanVideoVAE38
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype
|
||||
)
|
||||
@@ -106,8 +107,6 @@ def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVide
|
||||
|
||||
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
|
||||
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
|
||||
from transformers import UMT5EncoderModel
|
||||
|
||||
encoder = UMT5EncoderModel.from_pretrained(
|
||||
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
|
||||
)
|
||||
@@ -126,8 +125,6 @@ def resolve_wan_dit_paths(
|
||||
if path.is_dir():
|
||||
return sorted(path.glob(WAN_DIT_PATTERN))
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_path = snapshot_download(
|
||||
repo_id=str(model_id_or_path),
|
||||
revision=revision,
|
||||
@@ -145,8 +142,6 @@ def load_wan_video_dit(
|
||||
torch_dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> WanVideoDiT:
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
model = WanVideoDiT(**dit_config)
|
||||
state_dict = _read_wan_dit_safetensors(paths)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
@@ -29,6 +29,7 @@ from .wan.modules.model import (
|
||||
rope_params,
|
||||
sinusoidal_embedding_1d,
|
||||
)
|
||||
from .wan.utils.fm_solvers import get_sampling_sigmas
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -94,8 +95,6 @@ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
|
||||
|
||||
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
|
||||
from .wan.utils.fm_solvers import get_sampling_sigmas
|
||||
|
||||
return get_sampling_sigmas(num_inference_steps, shift)
|
||||
|
||||
|
||||
|
||||
@@ -49,6 +49,8 @@ def test_fastwam_is_registered_and_publicly_exported():
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
|
||||
@@ -78,6 +80,8 @@ def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_pa
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
image_size=(2, 2),
|
||||
device="cpu",
|
||||
toggle_action_dimensions=[-1],
|
||||
@@ -154,6 +158,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
image_size=(16, 16),
|
||||
input_features={
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
@@ -164,7 +170,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
)
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
output = policy.forward(
|
||||
loss, metrics = policy.forward(
|
||||
{
|
||||
"observation.images.image": torch.zeros(1, 3, 16, 16),
|
||||
OBS_STATE: torch.zeros(1, 2),
|
||||
@@ -186,8 +192,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
}
|
||||
)
|
||||
|
||||
assert output["loss"].item() == 1.0
|
||||
assert output["loss_action"].item() == 1.0
|
||||
assert loss.item() == 1.0
|
||||
assert metrics["loss_action"] == 1.0
|
||||
assert action.shape == (2, 4, 3)
|
||||
assert action[:, 0, 0].tolist() == [1.0, 2.0]
|
||||
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
|
||||
@@ -218,7 +224,7 @@ class CoreWithFrozenComponents(FakeFastWAMCore):
|
||||
|
||||
|
||||
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
|
||||
def build_core(self, config):
|
||||
core = CoreWithFrozenComponents()
|
||||
@@ -250,7 +256,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
|
||||
|
||||
|
||||
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
@@ -272,7 +278,7 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
|
||||
|
||||
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user