Disallow PEFT training on non-pretrained policies

At first I thought it would make sense to have this feature
in case you want to fine-tune a pre-trained section but in the
end it makes more trouble than it's worth.

It's still possible to allow this in the future when a concrete
need arises.
This commit is contained in:
nemo
2025-12-19 11:50:29 +01:00
parent e01927f641
commit 54c25a4400
2 changed files with 33 additions and 10 deletions
+4 -3
View File
@@ -153,9 +153,9 @@ def wrap_policy_in_peft_model(cfg, policy):
p.requires_grad_(False) p.requires_grad_(False)
if not cfg.policy.pretrained_path: if not cfg.policy.pretrained_path:
logging.warning( raise ValueError(
"Training from scratch using PEFT. This is unlikely to yield good results. " "Training from scratch using PEFT. This is unlikely to yield good results. "
"Consider supplying a `policy.path` to fine-tune an existing model." "Supply a `policy.path` to fine-tune an existing model."
) )
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
@@ -195,7 +195,8 @@ def wrap_policy_in_peft_model(cfg, policy):
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
# adapter, not the base model. # adapter, not the base model.
policy.name_or_path = str(policy.config.pretrained_path) if policy.config.pretrained_path:
policy.name_or_path = str(policy.config.pretrained_path)
# Finally wrap the policy in a PEFT model # Finally wrap the policy in a PEFT model
policy = get_peft_model( policy = get_peft_model(
+29 -7
View File
@@ -20,12 +20,22 @@ def lerobot_record(args):
return run_command(cmd="lerobot-record", module="lerobot_record", args=args) return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
def resolve_model_id_for_peft_training(policy_type):
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
if policy_type == "smolvla":
return "lerobot/smolvla_base"
raise ValueError(f"No pretrained model known for {policy_type}. PEFT training will not work.")
@pytest.mark.parametrize("policy_type", ["smolvla"]) @pytest.mark.parametrize("policy_type", ["smolvla"])
def test_peft_training_push_to_hub_works(policy_type, tmp_path): def test_peft_training_push_to_hub_works(policy_type, tmp_path):
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights.""" """Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
output_dir = tmp_path / f"output_{policy_type}" output_dir = tmp_path / f"output_{policy_type}"
upload_folder_contents = set() upload_folder_contents = set()
model_id = resolve_model_id_for_peft_training(policy_type)
def mock_upload_folder(*args, **kwargs): def mock_upload_folder(*args, **kwargs):
folder_path = kwargs["folder_path"] folder_path = kwargs["folder_path"]
# we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders() # we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
@@ -38,9 +48,11 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path):
): ):
lerobot_train( lerobot_train(
[ [
f"--policy.type={policy_type}", f"--policy.path={model_id}",
"--policy.push_to_hub=true", "--policy.push_to_hub=true",
"--policy.repo_id=foo/bar", "--policy.repo_id=foo/bar",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA", "--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht", "--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]", "--dataset.episodes=[0, 1]",
@@ -58,11 +70,14 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path):
def test_peft_training_works(policy_type, tmp_path): def test_peft_training_works(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
output_dir = tmp_path / f"output_{policy_type}" output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
lerobot_train( lerobot_train(
[ [
f"--policy.type={policy_type}", f"--policy.path={model_id}",
"--policy.push_to_hub=false", "--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA", "--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht", "--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]", "--dataset.episodes=[0, 1]",
@@ -81,7 +96,7 @@ def test_peft_training_works(policy_type, tmp_path):
# so these must be part of the trained state dict. # so these must be part of the trained state dict.
state_dict = load_file(policy_dir / "adapter_model.safetensors") state_dict = load_file(policy_dir / "adapter_model.safetensors")
fully_trained_keys = [ adapted_keys = [
"state_proj", "state_proj",
"action_in_proj", "action_in_proj",
"action_out_proj", "action_out_proj",
@@ -91,18 +106,19 @@ def test_peft_training_works(policy_type, tmp_path):
found_keys = [ found_keys = [
module_key module_key
for module_key in fully_trained_keys for module_key in adapted_keys
for state_dict_key in state_dict for state_dict_key in state_dict
if f".{module_key}." in state_dict_key if f".{module_key}." in state_dict_key
] ]
assert set(found_keys) == set(fully_trained_keys) assert set(found_keys) == set(adapted_keys)
@pytest.mark.parametrize("policy_type", ["smolvla"]) @pytest.mark.parametrize("policy_type", ["smolvla"])
def test_peft_training_params_are_fewer(policy_type, tmp_path): def test_peft_training_params_are_fewer(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
output_dir = tmp_path / f"output_{policy_type}" output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
def dummy_update_policy( def dummy_update_policy(
train_metrics, train_metrics,
@@ -123,8 +139,10 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy): with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy):
lerobot_train( lerobot_train(
[ [
f"--policy.type={policy_type}", f"--policy.path={model_id}",
"--policy.push_to_hub=false", "--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA", "--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht", "--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]", "--dataset.episodes=[0, 1]",
@@ -139,6 +157,7 @@ class DummyRobot:
cameras = [] cameras = []
action_features = {"foo": 1.0, "bar": 2.0} action_features = {"foo": 1.0, "bar": 2.0}
observation_features = {"obs1": 1.0, "obs2": 2.0} observation_features = {"obs1": 1.0, "obs2": 2.0}
is_connected = True
def connect(self, *args): def connect(self, *args):
pass pass
@@ -157,11 +176,14 @@ def test_peft_record_loads_policy(policy_type, tmp_path):
from peft import PeftModel from peft import PeftModel
output_dir = tmp_path / f"output_{policy_type}" output_dir = tmp_path / f"output_{policy_type}"
model_id = resolve_model_id_for_peft_training(policy_type)
lerobot_train( lerobot_train(
[ [
f"--policy.type={policy_type}", f"--policy.path={model_id}",
"--policy.push_to_hub=false", "--policy.push_to_hub=false",
"--policy.input_features=null",
"--policy.output_features=null",
"--peft.method=LORA", "--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht", "--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]", "--dataset.episodes=[0, 1]",