cleaning up

This commit is contained in:
Maxime Ellerbach
2026-06-18 11:33:25 +00:00
parent 2ec82c68b4
commit 8ee60dbd95
4 changed files with 8 additions and 26 deletions
+2 -1
View File
@@ -442,7 +442,8 @@ default.extend-ignore-identifiers-re = [
"is_compileable",
"ROBOTIS",
"OT_VALUE",
"VanderBilt"
"VanderBilt",
"seperated_timestep",
]
# TODO: Uncomment when ready to use
@@ -28,7 +28,7 @@ from lerobot.optim import AdamWConfig
from lerobot.utils.constants import ACTION, OBS_STATE
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base"
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam_base"
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
@@ -130,7 +130,7 @@ def _validate_wan_model_id(value: str, field_name: str) -> str:
def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool:
"""Return whether `fastwam-base` partial weights can initialize this config."""
"""Return whether `fastwam_base` partial weights can initialize this config."""
default_video_config = default_video_dit_config(config.action_dim)
default_action_config = default_action_dit_config(config.action_dim)
@@ -16,7 +16,6 @@ from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import Any
import torch
@@ -123,26 +122,6 @@ class FastWAMPolicy(PreTrainedPolicy):
model.to(map_location)
return model
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
"""Down-cast float tensors to the policy dtype before saving.
FSDP's FULL_STATE_DICT gather returns fp32 master weights, so the default save would
write a fp32 `model.safetensors` (~24 GB) even though FastWAM runs in
`config.torch_dtype` (bf16). That doubles disk/upload and, worse, makes reloading OOM
under FSDP — every rank materializes the full fp32 model on GPU before sharding.
Casting float tensors to the configured dtype here halves the checkpoint and keeps
loads within budget; non-float tensors (e.g. integer buffers) pass through unchanged.
The `state_dict is None` path (non-FSDP saves) already holds params at
`config.torch_dtype`, so it needs no cast.
"""
if state_dict is not None:
dtype = _dtype_from_name(self.config.torch_dtype)
state_dict = {
key: (value.to(dtype) if torch.is_floating_point(value) else value)
for key, value in state_dict.items()
}
super()._save_pretrained(save_directory, state_dict)
def get_optim_params(self) -> list[Tensor]:
# Return the trainable tensors directly (a single param group). The optimizer
# builder wraps these in a param group; returning a bare {"params": [...]} dict
@@ -385,7 +364,9 @@ def _resize_frames(frames: Tensor, size: tuple[int, int]) -> Tensor:
return frames
lead = frames.shape[:-3]
flat = frames.reshape(-1, *frames.shape[-3:])
flat = torch.nn.functional.interpolate(flat, size=size, mode="bilinear", align_corners=False, antialias=True)
flat = torch.nn.functional.interpolate(
flat, size=size, mode="bilinear", align_corners=False, antialias=True
)
return flat.reshape(*lead, *flat.shape[-3:])
@@ -19,7 +19,7 @@ from typing import Any
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
ActionProcessorStep,
AddBatchDimensionProcessorStep,