mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
cleaning up
This commit is contained in:
+2
-1
@@ -442,7 +442,8 @@ default.extend-ignore-identifiers-re = [
|
||||
"is_compileable",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE",
|
||||
"VanderBilt"
|
||||
"VanderBilt",
|
||||
"seperated_timestep",
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.optim import AdamWConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
|
||||
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base"
|
||||
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam_base"
|
||||
|
||||
|
||||
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
|
||||
@@ -130,7 +130,7 @@ def _validate_wan_model_id(value: str, field_name: str) -> str:
|
||||
|
||||
|
||||
def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool:
|
||||
"""Return whether `fastwam-base` partial weights can initialize this config."""
|
||||
"""Return whether `fastwam_base` partial weights can initialize this config."""
|
||||
|
||||
default_video_config = default_video_dit_config(config.action_dim)
|
||||
default_action_config = default_action_dit_config(config.action_dim)
|
||||
|
||||
@@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -123,26 +122,6 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
|
||||
"""Down-cast float tensors to the policy dtype before saving.
|
||||
|
||||
FSDP's FULL_STATE_DICT gather returns fp32 master weights, so the default save would
|
||||
write a fp32 `model.safetensors` (~24 GB) even though FastWAM runs in
|
||||
`config.torch_dtype` (bf16). That doubles disk/upload and, worse, makes reloading OOM
|
||||
under FSDP — every rank materializes the full fp32 model on GPU before sharding.
|
||||
Casting float tensors to the configured dtype here halves the checkpoint and keeps
|
||||
loads within budget; non-float tensors (e.g. integer buffers) pass through unchanged.
|
||||
The `state_dict is None` path (non-FSDP saves) already holds params at
|
||||
`config.torch_dtype`, so it needs no cast.
|
||||
"""
|
||||
if state_dict is not None:
|
||||
dtype = _dtype_from_name(self.config.torch_dtype)
|
||||
state_dict = {
|
||||
key: (value.to(dtype) if torch.is_floating_point(value) else value)
|
||||
for key, value in state_dict.items()
|
||||
}
|
||||
super()._save_pretrained(save_directory, state_dict)
|
||||
|
||||
def get_optim_params(self) -> list[Tensor]:
|
||||
# Return the trainable tensors directly (a single param group). The optimizer
|
||||
# builder wraps these in a param group; returning a bare {"params": [...]} dict
|
||||
@@ -385,7 +364,9 @@ def _resize_frames(frames: Tensor, size: tuple[int, int]) -> Tensor:
|
||||
return frames
|
||||
lead = frames.shape[:-3]
|
||||
flat = frames.reshape(-1, *frames.shape[-3:])
|
||||
flat = torch.nn.functional.interpolate(flat, size=size, mode="bilinear", align_corners=False, antialias=True)
|
||||
flat = torch.nn.functional.interpolate(
|
||||
flat, size=size, mode="bilinear", align_corners=False, antialias=True
|
||||
)
|
||||
return flat.reshape(*lead, *flat.shape[-3:])
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
|
||||
Reference in New Issue
Block a user