mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
removing some preprocessors
This commit is contained in:
@@ -68,7 +68,7 @@ def default_video_dit_config(action_dim: int) -> dict[str, Any]:
|
||||
"attn_head_dim": 128,
|
||||
"num_layers": 30,
|
||||
"eps": 1.0e-6,
|
||||
"separated_timestep": True,
|
||||
"seperated_timestep": True,
|
||||
"use_gradient_checkpointing": False,
|
||||
"video_attention_mask_mode": "first_frame_causal",
|
||||
"action_conditioned": False,
|
||||
@@ -215,7 +215,7 @@ class FastWAMConfig(PreTrainedConfig):
|
||||
action_dit_config: dict[str, Any] | None = None
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
@@ -145,6 +145,26 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
|
||||
"""Down-cast float tensors to the policy dtype before saving.
|
||||
|
||||
FSDP's FULL_STATE_DICT gather returns fp32 master weights, so the default save would
|
||||
write a fp32 `model.safetensors` (~24 GB) even though FastWAM runs in
|
||||
`config.torch_dtype` (bf16). That doubles disk/upload and, worse, makes reloading OOM
|
||||
under FSDP — every rank materializes the full fp32 model on GPU before sharding.
|
||||
Casting float tensors to the configured dtype here halves the checkpoint and keeps
|
||||
loads within budget; non-float tensors (e.g. integer buffers) pass through unchanged.
|
||||
The `state_dict is None` path (non-FSDP saves) already holds params at
|
||||
`config.torch_dtype`, so it needs no cast.
|
||||
"""
|
||||
if state_dict is not None:
|
||||
dtype = _dtype_from_name(self.config.torch_dtype)
|
||||
state_dict = {
|
||||
key: (value.to(dtype) if torch.is_floating_point(value) else value)
|
||||
for key, value in state_dict.items()
|
||||
}
|
||||
super()._save_pretrained(save_directory, state_dict)
|
||||
|
||||
def get_optim_params(self) -> list[Tensor]:
|
||||
# Return the trainable tensors directly (a single param group). The optimizer
|
||||
# builder wraps these in a param group; returning a bare {"params": [...]} dict
|
||||
@@ -331,7 +351,9 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
def _debug_tensor_to_pil(image: Tensor):
|
||||
from PIL import Image
|
||||
|
||||
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
|
||||
# `real` is the model input in [0, 1] (VISUAL is IDENTITY; the [-1,1] map lives at the VAE
|
||||
# encode boundary), so map [0, 1] -> [0, 255] for display.
|
||||
arr = (image.detach().float().clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
||||
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
|
||||
|
||||
@staticmethod
|
||||
@@ -491,13 +513,29 @@ def batch_device(batch: dict[str, Any]) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def _resize_frames(frames: Tensor, size: tuple[int, int]) -> Tensor:
|
||||
"""Resize a frame tensor to `size` (H, W), tolerating a leading temporal/batch stack.
|
||||
|
||||
`interpolate` only accepts a single leading batch dim (`[N, C, H, W]`), but FastWAM camera
|
||||
tensors arrive as `[B, C, H, W]` (live eval) or `[B, T, C, H, W]` (temporal stack), so flatten
|
||||
any leading dims into the batch, resize, then restore. A no-op when already at `size`.
|
||||
"""
|
||||
if tuple(frames.shape[-2:]) == size:
|
||||
return frames
|
||||
lead = frames.shape[:-3]
|
||||
flat = frames.reshape(-1, *frames.shape[-3:])
|
||||
flat = torch.nn.functional.interpolate(flat, size=size, mode="bilinear", align_corners=False, antialias=True)
|
||||
return flat.reshape(*lead, *flat.shape[-3:])
|
||||
|
||||
|
||||
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
|
||||
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
|
||||
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
|
||||
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
|
||||
if not image_keys:
|
||||
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
|
||||
images = [batch[key] for key in image_keys]
|
||||
per_cam = (int(config.image_size[0]), int(config.image_size[1]) // len(image_keys))
|
||||
images = [_resize_frames(batch[key], per_cam) for key in image_keys]
|
||||
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
|
||||
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
|
||||
if image.ndim == 4:
|
||||
@@ -530,11 +568,8 @@ def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor:
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
|
||||
|
||||
target_h, target_w = config.image_size
|
||||
if tuple(image.shape[-2:]) != (target_h, target_w):
|
||||
raise ValueError(
|
||||
"FastWAM policy expects preprocessed image tensors with shape "
|
||||
f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. "
|
||||
"Run the FastWAM preprocessor before calling the policy."
|
||||
)
|
||||
return image
|
||||
# Resize to the full configured resolution (no-op when the video path already produced it, but
|
||||
# also covers a directly-supplied `input_image`). The model owns its input resolution — see
|
||||
# `_stack_video_from_images` — so we resize rather than assert on a mismatch.
|
||||
target_h, target_w = int(config.image_size[0]), int(config.image_size[1])
|
||||
return _resize_frames(image, (target_h, target_w))
|
||||
|
||||
@@ -1075,6 +1075,10 @@ class FastWAM(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_video_latents(self, video_tensor, tiled=False, tile_size=(30, 52), tile_stride=(15, 26)):
|
||||
# The Wan VAE expects pixels in [-1, 1]; model inputs arrive in [0, 1] (VISUAL is IDENTITY in
|
||||
# the preprocessor — see configuration_fastwam.normalization_mapping). Map here, at the single
|
||||
# video-encode boundary, so it is applied exactly once on every path.
|
||||
video_tensor = video_tensor * 2.0 - 1.0
|
||||
z = self.vae.encode(
|
||||
video_tensor,
|
||||
device=self.device,
|
||||
@@ -1094,6 +1098,8 @@ class FastWAM(torch.nn.Module):
|
||||
raise ValueError(
|
||||
f"`input_image` must have shape [1,3,H,W] or [3,H,W], got {tuple(input_image.shape)}"
|
||||
)
|
||||
# [0, 1] -> [-1, 1] for the Wan VAE (mirrors `_encode_video_latents`); single image-encode boundary.
|
||||
input_image = input_image * 2.0 - 1.0
|
||||
image = input_image.to(device=self.device)[0].unsqueeze(1)
|
||||
z = self.vae.encode(
|
||||
[image], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
|
||||
@@ -24,7 +24,6 @@ from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
@@ -42,39 +41,6 @@ from lerobot.utils.constants import (
|
||||
from .configuration_fastwam import FastWAMConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
|
||||
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
|
||||
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
|
||||
|
||||
FastWAM loads a per-camera video stack, so image observations arrive as
|
||||
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
|
||||
single leading batch dim (resize raises on 5-D input), so we flatten any leading
|
||||
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
|
||||
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
|
||||
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
|
||||
# matches on the `"image"` substring, so set these aside and restore them untouched rather
|
||||
# than letting it try to resize a mask.
|
||||
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
|
||||
leads: dict[str, tuple] = {}
|
||||
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
|
||||
for key, img in list(flat_input.items()):
|
||||
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
|
||||
leads[key] = tuple(img.shape[:-3])
|
||||
flat_input[key] = img.reshape(-1, *img.shape[-3:])
|
||||
processed = super().observation(flat_input)
|
||||
out = dict(processed)
|
||||
for key, lead in leads.items():
|
||||
im = processed[key]
|
||||
out[key] = im.reshape(*lead, *im.shape[-3:])
|
||||
out.update(pad_keys)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
|
||||
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
|
||||
@@ -124,32 +90,25 @@ def make_fastwam_pre_post_processors(
|
||||
output processor pipelines discoverable by LeRobot.
|
||||
"""
|
||||
|
||||
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
|
||||
# NOTE: no visual normalization here. VISUAL is IDENTITY (see configuration_fastwam.normalization_mapping)
|
||||
# — images pass through in [0, 1] and the model maps them to the Wan VAE's [-1, 1] at the encode
|
||||
# boundary. This is deliberate: `lerobot_train.py` overrides the normalizer stats with
|
||||
# `dataset.meta.stats` when fine-tuning, and a real dataset's per-channel image std is the tiny
|
||||
# frame-to-frame brightness variance, which would blow images far outside [-1,1] and saturate them.
|
||||
# STATE/ACTION still normalize with dataset stats below.
|
||||
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
|
||||
for key, feature in config.input_features.items():
|
||||
if feature.type != FeatureType.VISUAL:
|
||||
continue
|
||||
channels = int(feature.shape[0])
|
||||
normalization_stats[key] = {
|
||||
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# resize visual inputs to match model expected input size, if necessary
|
||||
visual_shapes = [
|
||||
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
|
||||
]
|
||||
resize_steps = []
|
||||
if visual_shapes:
|
||||
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
|
||||
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
|
||||
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
|
||||
# NOTE: no resize step here. The model is the single authority on input resolution: it resizes
|
||||
# each camera to the per-camera target (image_size split across cameras) in
|
||||
# `_stack_video_from_images` / `_prepare_infer_image`, on every path (train forward, rollout and
|
||||
# eval select_action). A preprocessor resize step would be both redundant (the model re-resizes
|
||||
# anyway) and unsafe across fine-tuning: its `resize_size` would be inherited from the base
|
||||
# checkpoint's camera geometry, not this dataset's, making the concatenation N_cameras x too wide.
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
*resize_steps,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
|
||||
@@ -425,7 +425,7 @@ class WanVideoDiT(WanModel):
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
separated_timestep: bool = False,
|
||||
seperated_timestep: bool = False,
|
||||
require_vae_embedding: bool = False,
|
||||
require_clip_embedding: bool = False,
|
||||
fuse_vae_embedding_in_latents: bool = True,
|
||||
@@ -489,7 +489,7 @@ class WanVideoDiT(WanModel):
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.attn_head_dim = attn_head_dim
|
||||
self.separated_timestep = separated_timestep
|
||||
self.seperated_timestep = seperated_timestep
|
||||
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
|
||||
self.video_attention_mask_mode = str(video_attention_mask_mode)
|
||||
self.action_conditioned = action_conditioned
|
||||
@@ -647,7 +647,7 @@ class WanVideoDiT(WanModel):
|
||||
)
|
||||
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
|
||||
|
||||
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
|
||||
if not (self.seperated_timestep and fuse_vae_embedding_in_latents):
|
||||
raise NotImplementedError(
|
||||
"FastWAM currently requires separated timesteps with fused VAE latents."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user