mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
adding a test for the fsdp checkpoint path
This commit is contained in:
@@ -23,6 +23,7 @@ import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
|
||||
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
|
||||
policy_cls = get_policy_class("act")
|
||||
policy_cfg = make_policy_config("act")
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.to(policy_cfg.device)
|
||||
|
||||
save_dir = tmp_path / "fsdp_state_dict"
|
||||
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
|
||||
|
||||
# A single, unsharded safetensors file (no sharded set + index).
|
||||
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
|
||||
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
|
||||
|
||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user