diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 3f5d89ec5..a1499d077 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -206,6 +206,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): def push_model_to_hub( self, cfg: TrainPipelineConfig, + peft_model=None, ): api = HfApi() repo_id = api.create_repo( @@ -216,7 +217,14 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: saved_path = Path(tmp) / repo_id - self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors + if peft_model is not None: + # Since PEFT just forwards calls to `push_model_to_hub`, `self` is not the PeftModel wrapper + # but the actual policy which is why we need the PEFT model passed to us to save the adapter. + # That also means that we need to store the policy config ourselves since PEFT can't. + peft_model.save_pretrained(saved_path) + self.config.save_pretrained(saved_path) + else: + self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors card = self.generate_model_card( cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9c85e9dc0..d8822a670 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -524,7 +524,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.policy.push_to_hub: unwrapped_policy = accelerator.unwrap_model(policy) - unwrapped_policy.push_model_to_hub(cfg) + if cfg.policy.use_peft: + unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy) + else: + unwrapped_policy.push_model_to_hub(cfg) preprocessor.push_to_hub(cfg.policy.repo_id) postprocessor.push_to_hub(cfg.policy.repo_id) diff --git a/tests/test_cli.py b/tests/test_cli.py index 2f5ddebe8..3760e4b78 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,6 @@ import importlib -from unittest.mock import patch +import os +from unittest.mock import MagicMock, patch import pytest from safetensors.torch import load_file @@ -19,6 +20,40 @@ def lerobot_record(args): return run_command(cmd="lerobot-record", module="lerobot_record", args=args) +@pytest.mark.parametrize("policy_type", ["smolvla"]) +def test_peft_training_push_to_hub_works(policy_type, tmp_path): + """Ensure that push to hub stores PEFT only the adapter, not the full model weights.""" + output_dir = tmp_path / f"output_{policy_type}" + upload_folder_contents = set() + + def mock_upload_folder(*args, **kwargs): + folder_path = kwargs['folder_path'] + # we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders() + upload_folder_contents.update(os.listdir(folder_path)) + return MagicMock() + + with ( + patch("huggingface_hub.HfApi.create_repo"), + patch("huggingface_hub.HfApi.upload_folder", mock_upload_folder), + ): + lerobot_train( + [ + f"--policy.type={policy_type}", + "--policy.push_to_hub=true", + "--policy.repo_id=foo/bar", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + assert 'adapter_model.safetensors' in upload_folder_contents + assert 'config.json' in upload_folder_contents + assert 'adapter_config.json' in upload_folder_contents + + @pytest.mark.parametrize("policy_type", ["smolvla"]) def test_peft_training_works(policy_type, tmp_path): """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""