From 3ce50c346855a0a52e4fa1cbe212aeb9317b1283 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Mon, 15 Jun 2026 14:36:22 +0000 Subject: [PATCH] adding a test for the fsdp checkpoint path --- tests/policies/test_policies.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index e9388b3ed..285b87d4c 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -23,6 +23,7 @@ import torch pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from packaging import version from safetensors.torch import load_file @@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) +def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path): + """Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict.""" + policy_cls = get_policy_class("act") + policy_cfg = make_policy_config("act") + features = dataset_to_policy_features(dummy_dataset_metadata.features) + policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + policy_cfg.input_features = { + key: ft for key, ft in features.items() if key not in policy_cfg.output_features + } + policy = policy_cls(policy_cfg) + policy.to(policy_cfg.device) + + save_dir = tmp_path / "fsdp_state_dict" + policy.save_pretrained(save_dir, state_dict=policy.state_dict()) + + # A single, unsharded safetensors file (no sharded set + index). + assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file() + assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists() + + loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg) + torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) + + @pytest.mark.parametrize("multikey", [True, False]) def test_multikey_construction(multikey: bool): """