Compare commits

...

4 Commits

Author SHA1 Message Date
Maxime Ellerbach b42d124007 cleanup 2026-06-15 14:50:23 +00:00
Maxime Ellerbach 3ce50c3468 adding a test for the fsdp checkpoint path 2026-06-15 14:36:22 +00:00
Maxime Ellerbach 44fd3c0a0e adding docs for FSDP 2026-06-15 14:15:09 +00:00
Maxime Ellerbach 0483afc743 feat(train): FSDP checkpoint saving 2026-06-15 14:03:17 +00:00
5 changed files with 126 additions and 7 deletions
+46
View File
@@ -113,6 +113,52 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
--policy=act
```
## Training Large Models with FSDP
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
```bash
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=<your_policy> \
--output_dir=outputs/train/my_policy_fsdp
```
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: bf16
num_machines: 1
num_processes: 4
fsdp_config:
fsdp_version: 1
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
fsdp_state_dict_type: FULL_STATE_DICT
```
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
optimizer before `accelerator.prepare()`.
### FSDP checkpoints
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two thigs to look out for:
- With mixed precision, (`bf16`/`fp16`) FSDP keeps an fp32 master copy, so the checkpoint is fp32
(~2× the bf16 size on disk) and is cast back to the policy dtype on load.
- **Optimizer state is not saved under FSDP**, so **resume-from-checkpoint is not supported**.
Saved weights are fully usable for evaluation and fine-tuning.
## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
+7 -1
View File
@@ -98,6 +98,7 @@ def save_checkpoint(
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
model_state_dict: dict | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -127,9 +128,14 @@ def save_checkpoint(
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
cross-rank collective and passes it here so rank 0 can write it directly. It holds
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
cfg.save_pretrained(pretrained_dir)
if cfg.peft is not None:
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
+36 -3
View File
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
import packaging
import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
@@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
def save_pretrained(
self,
save_directory: str | Path,
*,
state_dict: dict[str, Tensor] | None = None,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict | None = None,
**push_to_hub_kwargs,
) -> str | None:
"""Save the policy to a directory (and optionally push to the Hub).
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
return sharded tensors, so the caller gathers the full state dict via a cross-rank
collective and passes it here for `_save_pretrained` to write directly.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
self._save_pretrained(save_directory, state_dict=state_dict)
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
if state_dict is None:
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
return
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
@classmethod
def from_pretrained(
+13 -3
View File
@@ -189,6 +189,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
require_package("accelerate", extra="training")
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
cfg.validate()
@@ -197,8 +198,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
@@ -558,20 +557,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
# All ranks must call get_state_dict; rank 0 gets the
# full state dict, others get an empty dict.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
model_state_dict = accelerator.get_state_dict(policy)
if is_main_process:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
if is_fsdp:
# TODO(fsdp): sharded optimizer-state save/resume is not wired up yet.
logging.warning(
"FSDP checkpoint: saving model weights only (optimizer state skipped; "
"resume-from-checkpoint not supported under FSDP yet)."
)
save_checkpoint(
checkpoint_dir=checkpoint_dir,
step=step,
cfg=cfg,
policy=accelerator.unwrap_model(policy),
optimizer=optimizer,
optimizer=None if is_fsdp else optimizer,
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
num_processes=accelerator.num_processes,
batch_size=cfg.batch_size,
model_state_dict=model_state_dict,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
+24
View File
@@ -23,6 +23,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from packaging import version
from safetensors.torch import load_file
@@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
policy_cls = get_policy_class("act")
policy_cfg = make_policy_config("act")
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / "fsdp_state_dict"
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
# A single, unsharded safetensors file (no sharded set + index).
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool):
"""