mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Make sure push_to_hub works
Since PEFT only wraps `push_to_hub` and not `push_model_to_hub`, the reference to `self` in `policy.push_model_to_hub` is the unwrapped policy which, of course, doesn't know anything about PEFT. To make the upload process aware of PEFT, we pass the unwrapped policy down to `push_model_to_hub` as a kwarg. This is not ideal but I think it is the best way for now.
This commit is contained in:
+36
-1
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from safetensors.torch import load_file
|
||||
@@ -19,6 +20,40 @@ def lerobot_record(args):
|
||||
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
|
||||
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
|
||||
output_dir = tmp_path / f"output_{policy_type}"
|
||||
upload_folder_contents = set()
|
||||
|
||||
def mock_upload_folder(*args, **kwargs):
|
||||
folder_path = kwargs['folder_path']
|
||||
# we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
|
||||
upload_folder_contents.update(os.listdir(folder_path))
|
||||
return MagicMock()
|
||||
|
||||
with (
|
||||
patch("huggingface_hub.HfApi.create_repo"),
|
||||
patch("huggingface_hub.HfApi.upload_folder", mock_upload_folder),
|
||||
):
|
||||
lerobot_train(
|
||||
[
|
||||
f"--policy.type={policy_type}",
|
||||
"--policy.push_to_hub=true",
|
||||
"--policy.repo_id=foo/bar",
|
||||
"--peft.method=LORA",
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0, 1]",
|
||||
"--steps=1",
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
assert 'adapter_model.safetensors' in upload_folder_contents
|
||||
assert 'config.json' in upload_folder_contents
|
||||
assert 'adapter_config.json' in upload_folder_contents
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||
def test_peft_training_works(policy_type, tmp_path):
|
||||
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
||||
|
||||
Reference in New Issue
Block a user