removing some preprocessors

This commit is contained in:
Maxime Ellerbach
2026-06-17 09:52:23 +00:00
parent fb1fadccd7
commit 5752558467
5 changed files with 68 additions and 68 deletions
@@ -68,7 +68,7 @@ def default_video_dit_config(action_dim: int) -> dict[str, Any]:
"attn_head_dim": 128,
"num_layers": 30,
"eps": 1.0e-6,
"separated_timestep": True,
"seperated_timestep": True,
"use_gradient_checkpointing": False,
"video_attention_mask_mode": "first_frame_causal",
"action_conditioned": False,
@@ -215,7 +215,7 @@ class FastWAMConfig(PreTrainedConfig):
action_dit_config: dict[str, Any] | None = None
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
@@ -145,6 +145,26 @@ class FastWAMPolicy(PreTrainedPolicy):
model.to(map_location)
return model
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
"""Down-cast float tensors to the policy dtype before saving.
FSDP's FULL_STATE_DICT gather returns fp32 master weights, so the default save would
write a fp32 `model.safetensors` (~24 GB) even though FastWAM runs in
`config.torch_dtype` (bf16). That doubles disk/upload and, worse, makes reloading OOM
under FSDP — every rank materializes the full fp32 model on GPU before sharding.
Casting float tensors to the configured dtype here halves the checkpoint and keeps
loads within budget; non-float tensors (e.g. integer buffers) pass through unchanged.
The `state_dict is None` path (non-FSDP saves) already holds params at
`config.torch_dtype`, so it needs no cast.
"""
if state_dict is not None:
dtype = _dtype_from_name(self.config.torch_dtype)
state_dict = {
key: (value.to(dtype) if torch.is_floating_point(value) else value)
for key, value in state_dict.items()
}
super()._save_pretrained(save_directory, state_dict)
def get_optim_params(self) -> list[Tensor]:
# Return the trainable tensors directly (a single param group). The optimizer
# builder wraps these in a param group; returning a bare {"params": [...]} dict
@@ -331,7 +351,9 @@ class FastWAMPolicy(PreTrainedPolicy):
def _debug_tensor_to_pil(image: Tensor):
from PIL import Image
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
# `real` is the model input in [0, 1] (VISUAL is IDENTITY; the [-1,1] map lives at the VAE
# encode boundary), so map [0, 1] -> [0, 255] for display.
arr = (image.detach().float().clamp(0.0, 1.0) * 255.0).to(torch.uint8)
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
@staticmethod
@@ -491,13 +513,29 @@ def batch_device(batch: dict[str, Any]) -> torch.device:
return torch.device("cpu")
def _resize_frames(frames: Tensor, size: tuple[int, int]) -> Tensor:
"""Resize a frame tensor to `size` (H, W), tolerating a leading temporal/batch stack.
`interpolate` only accepts a single leading batch dim (`[N, C, H, W]`), but FastWAM camera
tensors arrive as `[B, C, H, W]` (live eval) or `[B, T, C, H, W]` (temporal stack), so flatten
any leading dims into the batch, resize, then restore. A no-op when already at `size`.
"""
if tuple(frames.shape[-2:]) == size:
return frames
lead = frames.shape[:-3]
flat = frames.reshape(-1, *frames.shape[-3:])
flat = torch.nn.functional.interpolate(flat, size=size, mode="bilinear", align_corners=False, antialias=True)
return flat.reshape(*lead, *flat.shape[-3:])
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
if not image_keys:
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
images = [batch[key] for key in image_keys]
per_cam = (int(config.image_size[0]), int(config.image_size[1]) // len(image_keys))
images = [_resize_frames(batch[key], per_cam) for key in image_keys]
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
if image.ndim == 4:
@@ -530,11 +568,8 @@ def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor:
if image.ndim != 4:
raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
target_h, target_w = config.image_size
if tuple(image.shape[-2:]) != (target_h, target_w):
raise ValueError(
"FastWAM policy expects preprocessed image tensors with shape "
f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. "
"Run the FastWAM preprocessor before calling the policy."
)
return image
# Resize to the full configured resolution (no-op when the video path already produced it, but
# also covers a directly-supplied `input_image`). The model owns its input resolution — see
# `_stack_video_from_images` — so we resize rather than assert on a mismatch.
target_h, target_w = int(config.image_size[0]), int(config.image_size[1])
return _resize_frames(image, (target_h, target_w))
@@ -1075,6 +1075,10 @@ class FastWAM(torch.nn.Module):
@torch.no_grad()
def _encode_video_latents(self, video_tensor, tiled=False, tile_size=(30, 52), tile_stride=(15, 26)):
# The Wan VAE expects pixels in [-1, 1]; model inputs arrive in [0, 1] (VISUAL is IDENTITY in
# the preprocessor — see configuration_fastwam.normalization_mapping). Map here, at the single
# video-encode boundary, so it is applied exactly once on every path.
video_tensor = video_tensor * 2.0 - 1.0
z = self.vae.encode(
video_tensor,
device=self.device,
@@ -1094,6 +1098,8 @@ class FastWAM(torch.nn.Module):
raise ValueError(
f"`input_image` must have shape [1,3,H,W] or [3,H,W], got {tuple(input_image.shape)}"
)
# [0, 1] -> [-1, 1] for the Wan VAE (mirrors `_encode_video_latents`); single image-encode boundary.
input_image = input_image * 2.0 - 1.0
image = input_image.to(device=self.device)[0].unsqueeze(1)
z = self.vae.encode(
[image], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
@@ -24,7 +24,6 @@ from lerobot.processor import (
ActionProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
ImageCropResizeProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
@@ -42,39 +41,6 @@ from lerobot.utils.constants import (
from .configuration_fastwam import FastWAMConfig
@dataclass
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
FastWAM loads a per-camera video stack, so image observations arrive as
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
single leading batch dim (resize raises on 5-D input), so we flatten any leading
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
"""
def observation(self, observation: dict) -> dict:
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
# matches on the `"image"` substring, so set these aside and restore them untouched rather
# than letting it try to resize a mask.
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
leads: dict[str, tuple] = {}
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
for key, img in list(flat_input.items()):
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
leads[key] = tuple(img.shape[:-3])
flat_input[key] = img.reshape(-1, *img.shape[-3:])
processed = super().observation(flat_input)
out = dict(processed)
for key, lead in leads.items():
im = processed[key]
out[key] = im.reshape(*lead, *im.shape[-3:])
out.update(pad_keys)
return out
@dataclass
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
@@ -124,32 +90,25 @@ def make_fastwam_pre_post_processors(
output processor pipelines discoverable by LeRobot.
"""
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
# NOTE: no visual normalization here. VISUAL is IDENTITY (see configuration_fastwam.normalization_mapping)
# — images pass through in [0, 1] and the model maps them to the Wan VAE's [-1, 1] at the encode
# boundary. This is deliberate: `lerobot_train.py` overrides the normalizer stats with
# `dataset.meta.stats` when fine-tuning, and a real dataset's per-channel image std is the tiny
# frame-to-frame brightness variance, which would blow images far outside [-1,1] and saturate them.
# STATE/ACTION still normalize with dataset stats below.
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
for key, feature in config.input_features.items():
if feature.type != FeatureType.VISUAL:
continue
channels = int(feature.shape[0])
normalization_stats[key] = {
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
}
# resize visual inputs to match model expected input size, if necessary
visual_shapes = [
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
]
resize_steps = []
if visual_shapes:
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
# NOTE: no resize step here. The model is the single authority on input resolution: it resizes
# each camera to the per-camera target (image_size split across cameras) in
# `_stack_video_from_images` / `_prepare_infer_image`, on every path (train forward, rollout and
# eval select_action). A preprocessor resize step would be both redundant (the model re-resizes
# anyway) and unsafe across fine-tuning: its `resize_size` would be inherited from the base
# checkpoint's camera geometry, not this dataset's, making the concatenation N_cameras x too wide.
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
*resize_steps,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -425,7 +425,7 @@ class WanVideoDiT(WanModel):
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
separated_timestep: bool = False,
seperated_timestep: bool = False,
require_vae_embedding: bool = False,
require_clip_embedding: bool = False,
fuse_vae_embedding_in_latents: bool = True,
@@ -489,7 +489,7 @@ class WanVideoDiT(WanModel):
self.hidden_dim = hidden_dim
self.attn_head_dim = attn_head_dim
self.separated_timestep = separated_timestep
self.seperated_timestep = seperated_timestep
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.video_attention_mask_mode = str(video_attention_mask_mode)
self.action_conditioned = action_conditioned
@@ -647,7 +647,7 @@ class WanVideoDiT(WanModel):
)
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
if not (self.seperated_timestep and fuse_vae_embedding_in_latents):
raise NotImplementedError(
"FastWAM currently requires separated timesteps with fused VAE latents."
)