Disallow PEFT training on non-pretrained policies

At first I thought it would make sense to have this feature
in case you want to fine-tune a pre-trained section but in the
end it makes more trouble than it's worth.

It's still possible to allow this in the future when a concrete
need arises.
This commit is contained in:
nemo
2025-12-19 11:50:29 +01:00
parent e01927f641
commit 54c25a4400
2 changed files with 33 additions and 10 deletions
+29 -7
View File
@@ -20,12 +20,22 @@ def lerobot_record(args):
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
def resolve_model_id_for_peft_training(policy_type):
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
if policy_type == "smolvla":
return "lerobot/smolvla_base"
raise ValueError(f"No pretrained model known for {policy_type}. PEFT training will not work.")
@pytest.mark.parametrize("policy_type", ["smolvla"])
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
output_dir = tmp_path / f"output_{policy_type}"
upload_folder_contents = set()
model_id = resolve_model_id_for_peft_training(policy_type)
def mock_upload_folder(*args, **kwargs):
folder_path = kwargs["folder_path"]
# we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
@@ -38,9 +48,11 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path):
):
lerobot_train(
[
f"--policy.type={policy_type}",
f"--policy.path={model_id}",
"--policy.push_to_hub=true",
"--policy.repo_id=foo/bar",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",
@@ -58,11 +70,14 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path):
def test_peft_training_works(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
lerobot_train(
[
f"--policy.type={policy_type}",
f"--policy.path={model_id}",
"--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",
@@ -81,7 +96,7 @@ def test_peft_training_works(policy_type, tmp_path):
# so these must be part of the trained state dict.
state_dict = load_file(policy_dir / "adapter_model.safetensors")
fully_trained_keys = [
adapted_keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
@@ -91,18 +106,19 @@ def test_peft_training_works(policy_type, tmp_path):
found_keys = [
module_key
for module_key in fully_trained_keys
for module_key in adapted_keys
for state_dict_key in state_dict
if f".{module_key}." in state_dict_key
]
assert set(found_keys) == set(fully_trained_keys)
assert set(found_keys) == set(adapted_keys)
@pytest.mark.parametrize("policy_type", ["smolvla"])
def test_peft_training_params_are_fewer(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
def dummy_update_policy(
train_metrics,
@@ -123,8 +139,10 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy):
lerobot_train(
[
f"--policy.type={policy_type}",
f"--policy.path={model_id}",
"--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",
@@ -139,6 +157,7 @@ class DummyRobot:
cameras = []
action_features = {"foo": 1.0, "bar": 2.0}
observation_features = {"obs1": 1.0, "obs2": 2.0}
is_connected = True
def connect(self, *args):
pass
@@ -157,11 +176,14 @@ def test_peft_record_loads_policy(policy_type, tmp_path):
from peft import PeftModel
output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
lerobot_train(
[
f"--policy.type={policy_type}",
f"--policy.path={model_id}",
"--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",