mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| acd31c7de2 | |||
| 240393d238 |
@@ -113,52 +113,6 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
## Training Large Models with FSDP
|
||||
|
||||
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
|
||||
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
|
||||
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
|
||||
|
||||
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=<your_policy> \
|
||||
--output_dir=outputs/train/my_policy_fsdp
|
||||
```
|
||||
|
||||
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: FSDP
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
fsdp_config:
|
||||
fsdp_version: 1
|
||||
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
|
||||
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
```
|
||||
|
||||
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
|
||||
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
|
||||
optimizer before `accelerator.prepare()`.
|
||||
|
||||
### FSDP checkpoints
|
||||
|
||||
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
|
||||
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two thigs to look out for:
|
||||
|
||||
- With mixed precision, (`bf16`/`fp16`) FSDP keeps an fp32 master copy, so the checkpoint is fp32
|
||||
(~2× the bf16 size on disk) and is cast back to the policy dtype on load.
|
||||
- **Optimizer state is not saved under FSDP**, so **resume-from-checkpoint is not supported**.
|
||||
Saved weights are fully usable for evaluation and fine-tuning.
|
||||
|
||||
## Notes
|
||||
|
||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||
|
||||
@@ -98,7 +98,6 @@ def save_checkpoint(
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
model_state_dict: dict | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -128,14 +127,9 @@ def save_checkpoint(
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
|
||||
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
|
||||
cross-rank collective and passes it here so rank 0 can write it directly. It holds
|
||||
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
|
||||
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if cfg.peft is not None:
|
||||
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
||||
|
||||
@@ -73,6 +73,8 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||
use_async_envs: bool = True
|
||||
# Whether to record eval rollouts as a LeRobot v3.0 dataset on disk.
|
||||
recording: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size == 0:
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
@@ -129,43 +129,10 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | Path,
|
||||
*,
|
||||
state_dict: dict[str, Tensor] | None = None,
|
||||
repo_id: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
card_kwargs: dict | None = None,
|
||||
**push_to_hub_kwargs,
|
||||
) -> str | None:
|
||||
"""Save the policy to a directory (and optionally push to the Hub).
|
||||
|
||||
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
|
||||
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
|
||||
return sharded tensors, so the caller gathers the full state dict via a cross-rank
|
||||
collective and passes it here for `_save_pretrained` to write directly.
|
||||
"""
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._save_pretrained(save_directory, state_dict=state_dict)
|
||||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
if state_dict is None:
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
return
|
||||
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
|
||||
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
|
||||
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
|
||||
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
|
||||
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
||||
@@ -72,8 +72,9 @@ from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs import FeatureType, parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs import (
|
||||
check_env_attributes_and_types,
|
||||
close_envs,
|
||||
@@ -84,7 +85,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -95,6 +96,81 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
|
||||
|
||||
If raw_obs is provided, visual feature shapes are inferred from the actual observation
|
||||
to avoid mismatches between the env config and the real observation resolution.
|
||||
"""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
elif raw_obs is not None and "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
|
||||
shape = img.shape[1:] # strip batch dim
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
shape = pixels.shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
return features
|
||||
|
||||
|
||||
def _build_raw_frame(
|
||||
raw_obs: dict,
|
||||
env_idx: int,
|
||||
action: np.ndarray,
|
||||
reward: float,
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
raw_key = key
|
||||
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
|
||||
val = raw_obs[raw_key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -105,6 +181,7 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -145,6 +222,14 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
raw_observation = deepcopy(observation) if recording_dataset is not None else None
|
||||
task_desc = ""
|
||||
if recording_dataset is not None:
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
task_desc = ""
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
@@ -217,6 +302,26 @@ def rollout(
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
if recording_dataset is not None and raw_observation is not None:
|
||||
prev_done = done.copy()
|
||||
for env_idx in range(env.num_envs):
|
||||
if prev_done[env_idx]:
|
||||
continue
|
||||
frame = _build_raw_frame(
|
||||
raw_observation,
|
||||
env_idx,
|
||||
action_numpy[env_idx],
|
||||
reward[env_idx],
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_dataset.features,
|
||||
)
|
||||
recording_dataset.add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_dataset.save_episode()
|
||||
raw_observation = deepcopy(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||
@@ -273,6 +378,7 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -361,6 +467,7 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -563,6 +670,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
|
||||
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
@@ -572,10 +683,13 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=False,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
recording_dir=recording_dir,
|
||||
env_features=cfg.env.features if cfg.eval.recording else None,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
@@ -618,6 +732,7 @@ def eval_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> TaskMetrics:
|
||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||
|
||||
@@ -635,6 +750,7 @@ def eval_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -661,6 +777,8 @@ def run_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Run eval_one for a single (task_group, task_id, env).
|
||||
@@ -672,21 +790,39 @@ def run_one(
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
)
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
recording_dataset = None
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
sample_obs, _ = env.reset()
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
|
||||
recording_dataset = LeRobotDataset.create(
|
||||
repo_id=f"eval_{task_group}_{task_id}",
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=str(task_recording_dir),
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
finally:
|
||||
if recording_dataset is not None:
|
||||
recording_dataset.finalize()
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
@@ -702,6 +838,8 @@ def eval_policy_all(
|
||||
n_episodes: int,
|
||||
*,
|
||||
max_episodes_rendered: int = 0,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
@@ -761,6 +899,8 @@ def eval_policy_all(
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
|
||||
@@ -189,7 +189,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
require_package("accelerate", extra="training")
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
|
||||
|
||||
cfg.validate()
|
||||
|
||||
@@ -198,6 +197,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
# We set find_unused_parameters=True to handle models with conditional computation
|
||||
if accelerator is None:
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
@@ -557,31 +558,20 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
# All ranks must call get_state_dict; rank 0 gets the
|
||||
# full state dict, others get an empty dict.
|
||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||
model_state_dict = accelerator.get_state_dict(policy)
|
||||
if is_main_process:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
if is_fsdp:
|
||||
# TODO(fsdp): sharded optimizer-state save/resume is not wired up yet.
|
||||
logging.warning(
|
||||
"FSDP checkpoint: saving model weights only (optimizer state skipped; "
|
||||
"resume-from-checkpoint not supported under FSDP yet)."
|
||||
)
|
||||
save_checkpoint(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
step=step,
|
||||
cfg=cfg,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
optimizer=None if is_fsdp else optimizer,
|
||||
optimizer=optimizer,
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
batch_size=cfg.batch_size,
|
||||
model_state_dict=model_state_dict,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -301,29 +300,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
|
||||
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
|
||||
policy_cls = get_policy_class("act")
|
||||
policy_cfg = make_policy_config("act")
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.to(policy_cfg.device)
|
||||
|
||||
save_dir = tmp_path / "fsdp_state_dict"
|
||||
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
|
||||
|
||||
# A single, unsharded safetensors file (no sharded set + index).
|
||||
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
|
||||
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
|
||||
|
||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user