mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
Make sure push_to_hub works
Since PEFT only wraps `push_to_hub` and not `push_model_to_hub`, the reference to `self` in `policy.push_model_to_hub` is the unwrapped policy which, of course, doesn't know anything about PEFT. To make the upload process aware of PEFT, we pass the unwrapped policy down to `push_model_to_hub` as a kwarg. This is not ideal but I think it is the best way for now.
This commit is contained in:
@@ -206,6 +206,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
def push_model_to_hub(
|
def push_model_to_hub(
|
||||||
self,
|
self,
|
||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
|
peft_model=None,
|
||||||
):
|
):
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
repo_id = api.create_repo(
|
repo_id = api.create_repo(
|
||||||
@@ -216,7 +217,14 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||||
saved_path = Path(tmp) / repo_id
|
saved_path = Path(tmp) / repo_id
|
||||||
|
|
||||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
if peft_model is not None:
|
||||||
|
# Since PEFT just forwards calls to `push_model_to_hub`, `self` is not the PeftModel wrapper
|
||||||
|
# but the actual policy which is why we need the PEFT model passed to us to save the adapter.
|
||||||
|
# That also means that we need to store the policy config ourselves since PEFT can't.
|
||||||
|
peft_model.save_pretrained(saved_path)
|
||||||
|
self.config.save_pretrained(saved_path)
|
||||||
|
else:
|
||||||
|
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||||
|
|
||||||
card = self.generate_model_card(
|
card = self.generate_model_card(
|
||||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||||
|
|||||||
@@ -524,7 +524,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if cfg.policy.push_to_hub:
|
||||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||||
unwrapped_policy.push_model_to_hub(cfg)
|
if cfg.policy.use_peft:
|
||||||
|
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
|
||||||
|
else:
|
||||||
|
unwrapped_policy.push_model_to_hub(cfg)
|
||||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
|
|
||||||
|
|||||||
+36
-1
@@ -1,5 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
from unittest.mock import patch
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@@ -19,6 +20,40 @@ def lerobot_record(args):
|
|||||||
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||||
|
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
|
||||||
|
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
|
||||||
|
output_dir = tmp_path / f"output_{policy_type}"
|
||||||
|
upload_folder_contents = set()
|
||||||
|
|
||||||
|
def mock_upload_folder(*args, **kwargs):
|
||||||
|
folder_path = kwargs['folder_path']
|
||||||
|
# we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
|
||||||
|
upload_folder_contents.update(os.listdir(folder_path))
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("huggingface_hub.HfApi.create_repo"),
|
||||||
|
patch("huggingface_hub.HfApi.upload_folder", mock_upload_folder),
|
||||||
|
):
|
||||||
|
lerobot_train(
|
||||||
|
[
|
||||||
|
f"--policy.type={policy_type}",
|
||||||
|
"--policy.push_to_hub=true",
|
||||||
|
"--policy.repo_id=foo/bar",
|
||||||
|
"--peft.method=LORA",
|
||||||
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
|
"--dataset.episodes=[0, 1]",
|
||||||
|
"--steps=1",
|
||||||
|
f"--output_dir={output_dir}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 'adapter_model.safetensors' in upload_folder_contents
|
||||||
|
assert 'config.json' in upload_folder_contents
|
||||||
|
assert 'adapter_config.json' in upload_folder_contents
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||||
def test_peft_training_works(policy_type, tmp_path):
|
def test_peft_training_works(policy_type, tmp_path):
|
||||||
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
||||||
|
|||||||
Reference in New Issue
Block a user