removing some preprocessors

This commit is contained in:
Maxime Ellerbach
2026-06-17 09:52:23 +00:00
parent abd36f338c
commit d73772b9d0
5 changed files with 68 additions and 68 deletions
@@ -68,7 +68,7 @@ def default_video_dit_config(action_dim: int) -> dict[str, Any]:
"attn_head_dim": 128,
"num_layers": 30,
"eps": 1.0e-6,
"separated_timestep": True,
"seperated_timestep": True,
"use_gradient_checkpointing": False,
"video_attention_mask_mode": "first_frame_causal",
"action_conditioned": False,
@@ -215,7 +215,7 @@ class FastWAMConfig(PreTrainedConfig):
action_dit_config: dict[str, Any] | None = None
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
@@ -145,6 +145,26 @@ class FastWAMPolicy(PreTrainedPolicy):
model.to(map_location)
return model
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
"""Down-cast float tensors to the policy dtype before saving.
FSDP's FULL_STATE_DICT gather returns fp32 master weights, so the default save would
write a fp32 `model.safetensors` (~24 GB) even though FastWAM runs in
`config.torch_dtype` (bf16). That doubles disk/upload and, worse, makes reloading OOM
under FSDP every rank materializes the full fp32 model on GPU before sharding.
Casting float tensors to the configured dtype here halves the checkpoint and keeps
loads within budget; non-float tensors (e.g. integer buffers) pass through unchanged.
The `state_dict is None` path (non-FSDP saves) already holds params at
`config.torch_dtype`, so it needs no cast.
"""
if state_dict is not None:
dtype = _dtype_from_name(self.config.torch_dtype)
state_dict = {
key: (value.to(dtype) if torch.is_floating_point(value) else value)
for key, value in state_dict.items()
}
super()._save_pretrained(save_directory, state_dict)
def get_optim_params(self) -> list[Tensor]:
# Return the trainable tensors directly (a single param group). The optimizer
# builder wraps these in a param group; returning a bare {"params": [...]} dict
@@ -331,7 +351,9 @@ class FastWAMPolicy(PreTrainedPolicy):
def _debug_tensor_to_pil(image: Tensor):
from PIL import Image
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
# `real` is the model input in [0, 1] (VISUAL is IDENTITY; the [-1,1] map lives at the VAE
# encode boundary), so map [0, 1] -> [0, 255] for display.
arr = (image.detach().float().clamp(0.0, 1.0) * 255.0).to(torch.uint8)
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
@staticmethod
@@ -491,13 +513,29 @@ def batch_device(batch: dict[str, Any]) -> torch.device:
return torch.device("cpu")
def _resize_frames(frames: Tensor, size: tuple[int, int]) -> Tensor:
"""Resize a frame tensor to `size` (H, W), tolerating a leading temporal/batch stack.
`interpolate` only accepts a single leading batch dim (`[N, C, H, W]`), but FastWAM camera
tensors arrive as `[B, C, H, W]` (live eval) or `[B, T, C, H, W]` (temporal stack), so flatten
any leading dims into the batch, resize, then restore. A no-op when already at `size`.
"""
if tuple(frames.shape[-2:]) == size:
return frames
lead = frames.shape[:-3]
flat = frames.reshape(-1, *frames.shape[-3:])
flat = torch.nn.functional.interpolate(flat, size=size, mode="bilinear", align_corners=False, antialias=True)
return flat.reshape(*lead, *flat.shape[-3:])
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
if not image_keys:
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
images = [batch[key] for key in image_keys]
per_cam = (int(config.image_size[0]), int(config.image_size[1]) // len(image_keys))
images = [_resize_frames(batch[key], per_cam) for key in image_keys]
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
if image.ndim == 4:
@@ -530,11 +568,8 @@ def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor:
if image.ndim != 4:
raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
target_h, target_w = config.image_size
if tuple(image.shape[-2:]) != (target_h, target_w):
raise ValueError(
"FastWAM policy expects preprocessed image tensors with shape "
f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. "
"Run the FastWAM preprocessor before calling the policy."
)
return image
# Resize to the full configured resolution (no-op when the video path already produced it, but
# also covers a directly-supplied `input_image`). The model owns its input resolution — see
# `_stack_video_from_images` — so we resize rather than assert on a mismatch.
target_h, target_w = int(config.image_size[0]), int(config.image_size[1])
return _resize_frames(image, (target_h, target_w))
@@ -1075,6 +1075,10 @@ class FastWAM(torch.nn.Module):
@torch.no_grad()
def _encode_video_latents(self, video_tensor, tiled=False, tile_size=(30, 52), tile_stride=(15, 26)):
# The Wan VAE expects pixels in [-1, 1]; model inputs arrive in [0, 1] (VISUAL is IDENTITY in
# the preprocessor — see configuration_fastwam.normalization_mapping). Map here, at the single
# video-encode boundary, so it is applied exactly once on every path.
video_tensor = video_tensor * 2.0 - 1.0
z = self.vae.encode(
video_tensor,
device=self.device,
@@ -1094,6 +1098,8 @@ class FastWAM(torch.nn.Module):
raise ValueError(
f"`input_image` must have shape [1,3,H,W] or [3,H,W], got {tuple(input_image.shape)}"
)
# [0, 1] -> [-1, 1] for the Wan VAE (mirrors `_encode_video_latents`); single image-encode boundary.
input_image = input_image * 2.0 - 1.0
image = input_image.to(device=self.device)[0].unsqueeze(1)
z = self.vae.encode(
[image], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
@@ -24,7 +24,6 @@ from lerobot.processor import (
ActionProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
ImageCropResizeProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
@@ -42,39 +41,6 @@ from lerobot.utils.constants import (
from .configuration_fastwam import FastWAMConfig
@dataclass
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
FastWAM loads a per-camera video stack, so image observations arrive as
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
single leading batch dim (resize raises on 5-D input), so we flatten any leading
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
"""
def observation(self, observation: dict) -> dict:
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
# matches on the `"image"` substring, so set these aside and restore them untouched rather
# than letting it try to resize a mask.
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
leads: dict[str, tuple] = {}
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
for key, img in list(flat_input.items()):
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
leads[key] = tuple(img.shape[:-3])
flat_input[key] = img.reshape(-1, *img.shape[-3:])
processed = super().observation(flat_input)
out = dict(processed)
for key, lead in leads.items():
im = processed[key]
out[key] = im.reshape(*lead, *im.shape[-3:])
out.update(pad_keys)
return out
@dataclass
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
@@ -124,32 +90,25 @@ def make_fastwam_pre_post_processors(
output processor pipelines discoverable by LeRobot.
"""
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
# NOTE: no visual normalization here. VISUAL is IDENTITY (see configuration_fastwam.normalization_mapping)
# — images pass through in [0, 1] and the model maps them to the Wan VAE's [-1, 1] at the encode
# boundary. This is deliberate: `lerobot_train.py` overrides the normalizer stats with
# `dataset.meta.stats` when fine-tuning, and a real dataset's per-channel image std is the tiny
# frame-to-frame brightness variance, which would blow images far outside [-1,1] and saturate them.
# STATE/ACTION still normalize with dataset stats below.
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
for key, feature in config.input_features.items():
if feature.type != FeatureType.VISUAL:
continue
channels = int(feature.shape[0])
normalization_stats[key] = {
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
}
# resize visual inputs to match model expected input size, if necessary
visual_shapes = [
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
]
resize_steps = []
if visual_shapes:
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
# NOTE: no resize step here. The model is the single authority on input resolution: it resizes
# each camera to the per-camera target (image_size split across cameras) in
# `_stack_video_from_images` / `_prepare_infer_image`, on every path (train forward, rollout and
# eval select_action). A preprocessor resize step would be both redundant (the model re-resizes
# anyway) and unsafe across fine-tuning: its `resize_size` would be inherited from the base
# checkpoint's camera geometry, not this dataset's, making the concatenation N_cameras x too wide.
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
*resize_steps,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -425,7 +425,7 @@ class WanVideoDiT(WanModel):
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
separated_timestep: bool = False,
seperated_timestep: bool = False,
require_vae_embedding: bool = False,
require_clip_embedding: bool = False,
fuse_vae_embedding_in_latents: bool = True,
@@ -489,7 +489,7 @@ class WanVideoDiT(WanModel):
self.hidden_dim = hidden_dim
self.attn_head_dim = attn_head_dim
self.separated_timestep = separated_timestep
self.seperated_timestep = seperated_timestep
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.video_attention_mask_mode = str(video_attention_mask_mode)
self.action_conditioned = action_conditioned
@@ -647,7 +647,7 @@ class WanVideoDiT(WanModel):
)
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
if not (self.seperated_timestep and fuse_vae_embedding_in_latents):
raise NotImplementedError(
"FastWAM currently requires separated timesteps with fused VAE latents."
)