From 0483afc7432de609b84faf255ffbdeef05c213e4 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Mon, 15 Jun 2026 14:03:17 +0000 Subject: [PATCH] feat(train): FSDP checkpoint saving --- src/lerobot/common/train_utils.py | 8 +++++- src/lerobot/policies/pretrained.py | 39 +++++++++++++++++++++++++--- src/lerobot/scripts/lerobot_train.py | 15 +++++++++-- 3 files changed, 56 insertions(+), 6 deletions(-) diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 2d23b4003..cd8d43381 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -98,6 +98,7 @@ def save_checkpoint( postprocessor: PolicyProcessorPipeline | None = None, num_processes: int | None = None, batch_size: int | None = None, + model_state_dict: dict | None = None, ) -> None: """This function creates the following directory structure: @@ -127,9 +128,14 @@ def save_checkpoint( resume. Defaults to None (not recorded). batch_size (int | None, optional): Per-process batch size to record for sample-exact resume. Defaults to None (not recorded). + model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP, + where `policy.state_dict()` would return sharded tensors; the caller gathers it via a + cross-rank collective and passes it here so rank 0 can write it directly. It holds + FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on + read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR - policy.save_pretrained(pretrained_dir) + policy.save_pretrained(pretrained_dir, state_dict=model_state_dict) cfg.save_pretrained(pretrained_dir) if cfg.peft is not None: # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index a69487f3f..eedf9d99e 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack import packaging import safetensors -from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download +from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor @@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): if not getattr(cls, "name", None): raise TypeError(f"Class {cls.__name__} must define 'name'") - def _save_pretrained(self, save_directory: Path) -> None: + def save_pretrained( + self, + save_directory: str | Path, + *, + state_dict: dict[str, Tensor] | None = None, + repo_id: str | None = None, + push_to_hub: bool = False, + card_kwargs: dict | None = None, + **push_to_hub_kwargs, + ) -> str | None: + """Save the policy to a directory (and optionally push to the Hub). + + Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring + `transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would + return sharded tensors, so the caller gathers the full state dict via a cross-rank + collective and passes it here for `_save_pretrained` to write directly. + """ + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + self._save_pretrained(save_directory, state_dict=state_dict) + if push_to_hub: + if repo_id is None: + repo_id = save_directory.name + return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs) + return None + + def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None: self.config._save_pretrained(save_directory) model_to_save = self.module if hasattr(self, "module") else self - save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) + if state_dict is None: + save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE)) + return + # A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly. + # `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does; + # pin `max_shard_size` above the total size so the output stays a single `model.safetensors` + total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values()) + save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1)) @classmethod def from_pretrained( diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 70a5e9e9d..f60aae8b6 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -189,6 +189,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): require_package("accelerate", extra="training") from accelerate import Accelerator + from accelerate.utils import DistributedDataParallelKwargs, DistributedType cfg.validate() @@ -197,7 +198,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes # We set find_unused_parameters=True to handle models with conditional computation if accelerator is None: - from accelerate.utils import DistributedDataParallelKwargs ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. @@ -558,20 +558,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: + # All ranks must call get_state_dict; rank 0 gets the + # full state dict, others get an empty dict. + is_fsdp = accelerator.distributed_type == DistributedType.FSDP + model_state_dict = accelerator.get_state_dict(policy) if is_main_process: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) + if is_fsdp: + # TODO(fsdp): sharded optimizer-state save/resume is not wired up yet. + logging.warning( + "FSDP checkpoint: saving model weights only (optimizer state skipped; " + "resume-from-checkpoint not supported under FSDP yet)." + ) save_checkpoint( checkpoint_dir=checkpoint_dir, step=step, cfg=cfg, policy=accelerator.unwrap_model(policy), - optimizer=optimizer, + optimizer=None if is_fsdp else optimizer, scheduler=lr_scheduler, preprocessor=preprocessor, postprocessor=postprocessor, num_processes=accelerator.num_processes, batch_size=cfg.batch_size, + model_state_dict=model_state_dict, ) update_last_checkpoint(checkpoint_dir) if wandb_logger: