diff --git a/pyproject.toml b/pyproject.toml index 2a1b86fe3..f11c6d3fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -442,7 +442,8 @@ default.extend-ignore-identifiers-re = [ "is_compileable", "ROBOTIS", "OT_VALUE", - "VanderBilt" + "VanderBilt", + "seperated_timestep", ] # TODO: Uncomment when ready to use diff --git a/src/lerobot/policies/fastwam/configuration_fastwam.py b/src/lerobot/policies/fastwam/configuration_fastwam.py index c78482ace..0e28efb37 100644 --- a/src/lerobot/policies/fastwam/configuration_fastwam.py +++ b/src/lerobot/policies/fastwam/configuration_fastwam.py @@ -28,7 +28,7 @@ from lerobot.optim import AdamWConfig from lerobot.utils.constants import ACTION, OBS_STATE WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B" -FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base" +FASTWAM_BASE_MODEL_ID = "lerobot/fastwam_base" _FASTWAM_VIDEO_BASE_COMPAT_KEYS = ( @@ -130,7 +130,7 @@ def _validate_wan_model_id(value: str, field_name: str) -> str: def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool: - """Return whether `fastwam-base` partial weights can initialize this config.""" + """Return whether `fastwam_base` partial weights can initialize this config.""" default_video_config = default_video_dit_config(config.action_dim) default_action_config = default_action_dit_config(config.action_dim) diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 0c99613f3..1bde87eea 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -16,7 +16,6 @@ from __future__ import annotations import logging from collections import deque -from pathlib import Path from typing import Any import torch @@ -123,26 +122,6 @@ class FastWAMPolicy(PreTrainedPolicy): model.to(map_location) return model - def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None: - """Down-cast float tensors to the policy dtype before saving. - - FSDP's FULL_STATE_DICT gather returns fp32 master weights, so the default save would - write a fp32 `model.safetensors` (~24 GB) even though FastWAM runs in - `config.torch_dtype` (bf16). That doubles disk/upload and, worse, makes reloading OOM - under FSDP — every rank materializes the full fp32 model on GPU before sharding. - Casting float tensors to the configured dtype here halves the checkpoint and keeps - loads within budget; non-float tensors (e.g. integer buffers) pass through unchanged. - The `state_dict is None` path (non-FSDP saves) already holds params at - `config.torch_dtype`, so it needs no cast. - """ - if state_dict is not None: - dtype = _dtype_from_name(self.config.torch_dtype) - state_dict = { - key: (value.to(dtype) if torch.is_floating_point(value) else value) - for key, value in state_dict.items() - } - super()._save_pretrained(save_directory, state_dict) - def get_optim_params(self) -> list[Tensor]: # Return the trainable tensors directly (a single param group). The optimizer # builder wraps these in a param group; returning a bare {"params": [...]} dict @@ -385,7 +364,9 @@ def _resize_frames(frames: Tensor, size: tuple[int, int]) -> Tensor: return frames lead = frames.shape[:-3] flat = frames.reshape(-1, *frames.shape[-3:]) - flat = torch.nn.functional.interpolate(flat, size=size, mode="bilinear", align_corners=False, antialias=True) + flat = torch.nn.functional.interpolate( + flat, size=size, mode="bilinear", align_corners=False, antialias=True + ) return flat.reshape(*lead, *flat.shape[-3:]) diff --git a/src/lerobot/policies/fastwam/processor_fastwam.py b/src/lerobot/policies/fastwam/processor_fastwam.py index f135f52e9..31f3b9277 100644 --- a/src/lerobot/policies/fastwam/processor_fastwam.py +++ b/src/lerobot/policies/fastwam/processor_fastwam.py @@ -19,7 +19,7 @@ from typing import Any import torch -from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor import ( ActionProcessorStep, AddBatchDimensionProcessorStep,