From 9dca80a3a926ed10be5a2c6f96dd471138f84cb1 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Wed, 17 Jun 2026 09:52:23 +0000 Subject: [PATCH] removing some preprocessors --- .../policies/fastwam/configuration_fastwam.py | 4 +- .../policies/fastwam/modeling_fastwam.py | 55 +++++++++++++--- .../policies/fastwam/modular_fastwam.py | 6 ++ .../policies/fastwam/processor_fastwam.py | 65 ++++--------------- src/lerobot/policies/fastwam/wan_video_dit.py | 6 +- 5 files changed, 68 insertions(+), 68 deletions(-) diff --git a/src/lerobot/policies/fastwam/configuration_fastwam.py b/src/lerobot/policies/fastwam/configuration_fastwam.py index e6527e20b..c78482ace 100644 --- a/src/lerobot/policies/fastwam/configuration_fastwam.py +++ b/src/lerobot/policies/fastwam/configuration_fastwam.py @@ -68,7 +68,7 @@ def default_video_dit_config(action_dim: int) -> dict[str, Any]: "attn_head_dim": 128, "num_layers": 30, "eps": 1.0e-6, - "separated_timestep": True, + "seperated_timestep": True, "use_gradient_checkpointing": False, "video_attention_mask_mode": "first_frame_causal", "action_conditioned": False, @@ -215,7 +215,7 @@ class FastWAMConfig(PreTrainedConfig): action_dit_config: dict[str, Any] | None = None normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { - "VISUAL": NormalizationMode.MEAN_STD, + "VISUAL": NormalizationMode.IDENTITY, "STATE": NormalizationMode.MEAN_STD, "ACTION": NormalizationMode.MEAN_STD, } diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 9e7124e2e..2f4e86229 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -145,6 +145,26 @@ class FastWAMPolicy(PreTrainedPolicy): model.to(map_location) return model + def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None: + """Down-cast float tensors to the policy dtype before saving. + + FSDP's FULL_STATE_DICT gather returns fp32 master weights, so the default save would + write a fp32 `model.safetensors` (~24 GB) even though FastWAM runs in + `config.torch_dtype` (bf16). That doubles disk/upload and, worse, makes reloading OOM + under FSDP — every rank materializes the full fp32 model on GPU before sharding. + Casting float tensors to the configured dtype here halves the checkpoint and keeps + loads within budget; non-float tensors (e.g. integer buffers) pass through unchanged. + The `state_dict is None` path (non-FSDP saves) already holds params at + `config.torch_dtype`, so it needs no cast. + """ + if state_dict is not None: + dtype = _dtype_from_name(self.config.torch_dtype) + state_dict = { + key: (value.to(dtype) if torch.is_floating_point(value) else value) + for key, value in state_dict.items() + } + super()._save_pretrained(save_directory, state_dict) + def get_optim_params(self) -> list[Tensor]: # Return the trainable tensors directly (a single param group). The optimizer # builder wraps these in a param group; returning a bare {"params": [...]} dict @@ -331,7 +351,9 @@ class FastWAMPolicy(PreTrainedPolicy): def _debug_tensor_to_pil(image: Tensor): from PIL import Image - arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) + # `real` is the model input in [0, 1] (VISUAL is IDENTITY; the [-1,1] map lives at the VAE + # encode boundary), so map [0, 1] -> [0, 255] for display. + arr = (image.detach().float().clamp(0.0, 1.0) * 255.0).to(torch.uint8) return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy()) @staticmethod @@ -491,13 +513,29 @@ def batch_device(batch: dict[str, Any]) -> torch.device: return torch.device("cpu") +def _resize_frames(frames: Tensor, size: tuple[int, int]) -> Tensor: + """Resize a frame tensor to `size` (H, W), tolerating a leading temporal/batch stack. + + `interpolate` only accepts a single leading batch dim (`[N, C, H, W]`), but FastWAM camera + tensors arrive as `[B, C, H, W]` (live eval) or `[B, T, C, H, W]` (temporal stack), so flatten + any leading dims into the batch, resize, then restore. A no-op when already at `size`. + """ + if tuple(frames.shape[-2:]) == size: + return frames + lead = frames.shape[:-3] + flat = frames.reshape(-1, *frames.shape[-3:]) + flat = torch.nn.functional.interpolate(flat, size=size, mode="bilinear", align_corners=False, antialias=True) + return flat.reshape(*lead, *flat.shape[-3:]) + + def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor: # Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside # each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames. image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad")) if not image_keys: raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.") - images = [batch[key] for key in image_keys] + per_cam = (int(config.image_size[0]), int(config.image_size[1]) // len(image_keys)) + images = [_resize_frames(batch[key], per_cam) for key in image_keys] # Cameras concatenate along width (last dim) in both the single-frame and temporal case. image = torch.cat(images, dim=-1) if len(images) > 1 else images[0] if image.ndim == 4: @@ -530,11 +568,8 @@ def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor: if image.ndim != 4: raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.") - target_h, target_w = config.image_size - if tuple(image.shape[-2:]) != (target_h, target_w): - raise ValueError( - "FastWAM policy expects preprocessed image tensors with shape " - f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. " - "Run the FastWAM preprocessor before calling the policy." - ) - return image + # Resize to the full configured resolution (no-op when the video path already produced it, but + # also covers a directly-supplied `input_image`). The model owns its input resolution — see + # `_stack_video_from_images` — so we resize rather than assert on a mismatch. + target_h, target_w = int(config.image_size[0]), int(config.image_size[1]) + return _resize_frames(image, (target_h, target_w)) diff --git a/src/lerobot/policies/fastwam/modular_fastwam.py b/src/lerobot/policies/fastwam/modular_fastwam.py index c220a1a73..8d3df9c91 100644 --- a/src/lerobot/policies/fastwam/modular_fastwam.py +++ b/src/lerobot/policies/fastwam/modular_fastwam.py @@ -1075,6 +1075,10 @@ class FastWAM(torch.nn.Module): @torch.no_grad() def _encode_video_latents(self, video_tensor, tiled=False, tile_size=(30, 52), tile_stride=(15, 26)): + # The Wan VAE expects pixels in [-1, 1]; model inputs arrive in [0, 1] (VISUAL is IDENTITY in + # the preprocessor — see configuration_fastwam.normalization_mapping). Map here, at the single + # video-encode boundary, so it is applied exactly once on every path. + video_tensor = video_tensor * 2.0 - 1.0 z = self.vae.encode( video_tensor, device=self.device, @@ -1094,6 +1098,8 @@ class FastWAM(torch.nn.Module): raise ValueError( f"`input_image` must have shape [1,3,H,W] or [3,H,W], got {tuple(input_image.shape)}" ) + # [0, 1] -> [-1, 1] for the Wan VAE (mirrors `_encode_video_latents`); single image-encode boundary. + input_image = input_image * 2.0 - 1.0 image = input_image.to(device=self.device)[0].unsqueeze(1) z = self.vae.encode( [image], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride diff --git a/src/lerobot/policies/fastwam/processor_fastwam.py b/src/lerobot/policies/fastwam/processor_fastwam.py index 9c31543f9..f135f52e9 100644 --- a/src/lerobot/policies/fastwam/processor_fastwam.py +++ b/src/lerobot/policies/fastwam/processor_fastwam.py @@ -24,7 +24,6 @@ from lerobot.processor import ( ActionProcessorStep, AddBatchDimensionProcessorStep, DeviceProcessorStep, - ImageCropResizeProcessorStep, NormalizerProcessorStep, PolicyAction, PolicyProcessorPipeline, @@ -42,39 +41,6 @@ from lerobot.utils.constants import ( from .configuration_fastwam import FastWAMConfig -@dataclass -@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor") -class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep): - """`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack. - - FastWAM loads a per-camera video stack, so image observations arrive as - ``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a - single leading batch dim (resize raises on 5-D input), so we flatten any leading - dims into the batch, apply the base 4-D crop/resize, then restore the leading shape. - Crop/resize params and feature-shape bookkeeping are inherited unchanged. - """ - - def observation(self, observation: dict) -> dict: - # Delta-timestamp video loading adds `_is_pad` boolean masks ([B, T]) that share - # the `observation.images.` prefix but are padding flags, not frames. The base crop/resize - # matches on the `"image"` substring, so set these aside and restore them untouched rather - # than letting it try to resize a mask. - pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key} - leads: dict[str, tuple] = {} - flat_input = {key: value for key, value in observation.items() if key not in pad_keys} - for key, img in list(flat_input.items()): - if "image" in key and torch.is_tensor(img) and img.ndim > 4: - leads[key] = tuple(img.shape[:-3]) - flat_input[key] = img.reshape(-1, *img.shape[-3:]) - processed = super().observation(flat_input) - out = dict(processed) - for key, lead in leads.items(): - im = processed[key] - out[key] = im.reshape(*lead, *im.shape[-3:]) - out.update(pad_keys) - return out - - @dataclass @ProcessorStepRegistry.register(name="fastwam_action_toggle_processor") class FastWAMActionToggleProcessorStep(ActionProcessorStep): @@ -124,32 +90,25 @@ def make_fastwam_pre_post_processors( output processor pipelines discoverable by LeRobot. """ - # force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1] + # NOTE: no visual normalization here. VISUAL is IDENTITY (see configuration_fastwam.normalization_mapping) + # — images pass through in [0, 1] and the model maps them to the Wan VAE's [-1, 1] at the encode + # boundary. This is deliberate: `lerobot_train.py` overrides the normalizer stats with + # `dataset.meta.stats` when fine-tuning, and a real dataset's per-channel image std is the tiny + # frame-to-frame brightness variance, which would blow images far outside [-1,1] and saturate them. + # STATE/ACTION still normalize with dataset stats below. normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {}) - for key, feature in config.input_features.items(): - if feature.type != FeatureType.VISUAL: - continue - channels = int(feature.shape[0]) - normalization_stats[key] = { - "mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32), - "std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32), - } - # resize visual inputs to match model expected input size, if necessary - visual_shapes = [ - feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL - ] - resize_steps = [] - if visual_shapes: - target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2])) - # FastWAM-aware resize: tolerates the leading temporal dim of the video stack. - resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw)) + # NOTE: no resize step here. The model is the single authority on input resolution: it resizes + # each camera to the per-camera target (image_size split across cameras) in + # `_stack_video_from_images` / `_prepare_infer_image`, on every path (train forward, rollout and + # eval select_action). A preprocessor resize step would be both redundant (the model re-resizes + # anyway) and unsafe across fine-tuning: its `resize_size` would be inherited from the base + # checkpoint's camera geometry, not this dataset's, making the concatenation N_cameras x too wide. input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), DeviceProcessorStep(device=config.device), - *resize_steps, NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, diff --git a/src/lerobot/policies/fastwam/wan_video_dit.py b/src/lerobot/policies/fastwam/wan_video_dit.py index 7a777e9df..0b38ad816 100644 --- a/src/lerobot/policies/fastwam/wan_video_dit.py +++ b/src/lerobot/policies/fastwam/wan_video_dit.py @@ -425,7 +425,7 @@ class WanVideoDiT(WanModel): has_ref_conv: bool = False, add_control_adapter: bool = False, in_dim_control_adapter: int = 24, - separated_timestep: bool = False, + seperated_timestep: bool = False, require_vae_embedding: bool = False, require_clip_embedding: bool = False, fuse_vae_embedding_in_latents: bool = True, @@ -489,7 +489,7 @@ class WanVideoDiT(WanModel): self.hidden_dim = hidden_dim self.attn_head_dim = attn_head_dim - self.separated_timestep = separated_timestep + self.seperated_timestep = seperated_timestep self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents self.video_attention_mask_mode = str(video_attention_mask_mode) self.action_conditioned = action_conditioned @@ -647,7 +647,7 @@ class WanVideoDiT(WanModel): ) tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w) - if not (self.separated_timestep and fuse_vae_embedding_in_latents): + if not (self.seperated_timestep and fuse_vae_embedding_in_latents): raise NotImplementedError( "FastWAM currently requires separated timesteps with fused VAE latents." )