Make sure push_to_hub works

Since PEFT only wraps `push_to_hub` and not `push_model_to_hub`, the reference
to `self` in `policy.push_model_to_hub` is the unwrapped policy which, of course,
doesn't know anything about PEFT.

To make the upload process aware of PEFT, we pass the unwrapped policy down to
`push_model_to_hub` as a kwarg. This is not ideal but I think it is the best way
for now.
This commit is contained in:
nemo
2025-11-24 18:50:02 +01:00
parent e9b3889bd2
commit 83aac1b42e
3 changed files with 49 additions and 3 deletions
+9 -1
View File
@@ -206,6 +206,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
def push_model_to_hub( def push_model_to_hub(
self, self,
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
peft_model=None,
): ):
api = HfApi() api = HfApi()
repo_id = api.create_repo( repo_id = api.create_repo(
@@ -216,7 +217,14 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
saved_path = Path(tmp) / repo_id saved_path = Path(tmp) / repo_id
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors if peft_model is not None:
# Since PEFT just forwards calls to `push_model_to_hub`, `self` is not the PeftModel wrapper
# but the actual policy which is why we need the PEFT model passed to us to save the adapter.
# That also means that we need to store the policy config ourselves since PEFT can't.
peft_model.save_pretrained(saved_path)
self.config.save_pretrained(saved_path)
else:
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
card = self.generate_model_card( card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
+4 -1
View File
@@ -524,7 +524,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.policy.push_to_hub: if cfg.policy.push_to_hub:
unwrapped_policy = accelerator.unwrap_model(policy) unwrapped_policy = accelerator.unwrap_model(policy)
unwrapped_policy.push_model_to_hub(cfg) if cfg.policy.use_peft:
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
else:
unwrapped_policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id) preprocessor.push_to_hub(cfg.policy.repo_id)
postprocessor.push_to_hub(cfg.policy.repo_id) postprocessor.push_to_hub(cfg.policy.repo_id)
+36 -1
View File
@@ -1,5 +1,6 @@
import importlib import importlib
from unittest.mock import patch import os
from unittest.mock import MagicMock, patch
import pytest import pytest
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -19,6 +20,40 @@ def lerobot_record(args):
return run_command(cmd="lerobot-record", module="lerobot_record", args=args) return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
@pytest.mark.parametrize("policy_type", ["smolvla"])
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
output_dir = tmp_path / f"output_{policy_type}"
upload_folder_contents = set()
def mock_upload_folder(*args, **kwargs):
folder_path = kwargs['folder_path']
# we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
upload_folder_contents.update(os.listdir(folder_path))
return MagicMock()
with (
patch("huggingface_hub.HfApi.create_repo"),
patch("huggingface_hub.HfApi.upload_folder", mock_upload_folder),
):
lerobot_train(
[
f"--policy.type={policy_type}",
"--policy.push_to_hub=true",
"--policy.repo_id=foo/bar",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",
"--steps=1",
f"--output_dir={output_dir}",
]
)
assert 'adapter_model.safetensors' in upload_folder_contents
assert 'config.json' in upload_folder_contents
assert 'adapter_config.json' in upload_folder_contents
@pytest.mark.parametrize("policy_type", ["smolvla"]) @pytest.mark.parametrize("policy_type", ["smolvla"])
def test_peft_training_works(policy_type, tmp_path): def test_peft_training_works(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""