mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
feat(train): FSDP checkpoint saving
This commit is contained in:
@@ -98,6 +98,7 @@ def save_checkpoint(
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
model_state_dict: dict | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -127,9 +128,14 @@ def save_checkpoint(
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
|
||||
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
|
||||
cross-rank collective and passes it here so rank 0 can write it directly. It holds
|
||||
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
|
||||
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if cfg.peft is not None:
|
||||
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
@@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | Path,
|
||||
*,
|
||||
state_dict: dict[str, Tensor] | None = None,
|
||||
repo_id: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
card_kwargs: dict | None = None,
|
||||
**push_to_hub_kwargs,
|
||||
) -> str | None:
|
||||
"""Save the policy to a directory (and optionally push to the Hub).
|
||||
|
||||
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
|
||||
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
|
||||
return sharded tensors, so the caller gathers the full state dict via a cross-rank
|
||||
collective and passes it here for `_save_pretrained` to write directly.
|
||||
"""
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._save_pretrained(save_directory, state_dict=state_dict)
|
||||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
if state_dict is None:
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
return
|
||||
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
|
||||
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
|
||||
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
|
||||
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
|
||||
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
|
||||
@@ -189,6 +189,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
require_package("accelerate", extra="training")
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
|
||||
|
||||
cfg.validate()
|
||||
|
||||
@@ -197,7 +198,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
# We set find_unused_parameters=True to handle models with conditional computation
|
||||
if accelerator is None:
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
@@ -558,20 +558,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
# All ranks must call get_state_dict; rank 0 gets the
|
||||
# full state dict, others get an empty dict.
|
||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||
model_state_dict = accelerator.get_state_dict(policy)
|
||||
if is_main_process:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
if is_fsdp:
|
||||
# TODO(fsdp): sharded optimizer-state save/resume is not wired up yet.
|
||||
logging.warning(
|
||||
"FSDP checkpoint: saving model weights only (optimizer state skipped; "
|
||||
"resume-from-checkpoint not supported under FSDP yet)."
|
||||
)
|
||||
save_checkpoint(
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
step=step,
|
||||
cfg=cfg,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
optimizer=optimizer,
|
||||
optimizer=None if is_fsdp else optimizer,
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
batch_size=cfg.batch_size,
|
||||
model_state_dict=model_state_dict,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
Reference in New Issue
Block a user