mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
Disallow PEFT training on non-pretrained policies
At first I thought it would make sense to have this feature in case you want to fine-tune a pre-trained section but in the end it makes more trouble than it's worth. It's still possible to allow this in the future when a concrete need arises.
This commit is contained in:
@@ -153,9 +153,9 @@ def wrap_policy_in_peft_model(cfg, policy):
|
|||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
|
||||||
if not cfg.policy.pretrained_path:
|
if not cfg.policy.pretrained_path:
|
||||||
logging.warning(
|
raise ValueError(
|
||||||
"Training from scratch using PEFT. This is unlikely to yield good results. "
|
"Training from scratch using PEFT. This is unlikely to yield good results. "
|
||||||
"Consider supplying a `policy.path` to fine-tune an existing model."
|
"Supply a `policy.path` to fine-tune an existing model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
|
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
|
||||||
@@ -195,7 +195,8 @@ def wrap_policy_in_peft_model(cfg, policy):
|
|||||||
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
|
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
|
||||||
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
|
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
|
||||||
# adapter, not the base model.
|
# adapter, not the base model.
|
||||||
policy.name_or_path = str(policy.config.pretrained_path)
|
if policy.config.pretrained_path:
|
||||||
|
policy.name_or_path = str(policy.config.pretrained_path)
|
||||||
|
|
||||||
# Finally wrap the policy in a PEFT model
|
# Finally wrap the policy in a PEFT model
|
||||||
policy = get_peft_model(
|
policy = get_peft_model(
|
||||||
|
|||||||
+29
-7
@@ -20,12 +20,22 @@ def lerobot_record(args):
|
|||||||
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_id_for_peft_training(policy_type):
|
||||||
|
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
|
||||||
|
if policy_type == "smolvla":
|
||||||
|
return "lerobot/smolvla_base"
|
||||||
|
|
||||||
|
raise ValueError(f"No pretrained model known for {policy_type}. PEFT training will not work.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||||
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
|
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
|
||||||
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
|
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
|
||||||
output_dir = tmp_path / f"output_{policy_type}"
|
output_dir = tmp_path / f"output_{policy_type}"
|
||||||
upload_folder_contents = set()
|
upload_folder_contents = set()
|
||||||
|
|
||||||
|
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||||
|
|
||||||
def mock_upload_folder(*args, **kwargs):
|
def mock_upload_folder(*args, **kwargs):
|
||||||
folder_path = kwargs["folder_path"]
|
folder_path = kwargs["folder_path"]
|
||||||
# we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
|
# we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
|
||||||
@@ -38,9 +48,11 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path):
|
|||||||
):
|
):
|
||||||
lerobot_train(
|
lerobot_train(
|
||||||
[
|
[
|
||||||
f"--policy.type={policy_type}",
|
f"--policy.path={model_id}",
|
||||||
"--policy.push_to_hub=true",
|
"--policy.push_to_hub=true",
|
||||||
"--policy.repo_id=foo/bar",
|
"--policy.repo_id=foo/bar",
|
||||||
|
"--policy.input_features=null",
|
||||||
|
"--policy.output_features=null",
|
||||||
"--peft.method=LORA",
|
"--peft.method=LORA",
|
||||||
"--dataset.repo_id=lerobot/pusht",
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
"--dataset.episodes=[0, 1]",
|
"--dataset.episodes=[0, 1]",
|
||||||
@@ -58,11 +70,14 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path):
|
|||||||
def test_peft_training_works(policy_type, tmp_path):
|
def test_peft_training_works(policy_type, tmp_path):
|
||||||
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
||||||
output_dir = tmp_path / f"output_{policy_type}"
|
output_dir = tmp_path / f"output_{policy_type}"
|
||||||
|
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||||
|
|
||||||
lerobot_train(
|
lerobot_train(
|
||||||
[
|
[
|
||||||
f"--policy.type={policy_type}",
|
f"--policy.path={model_id}",
|
||||||
"--policy.push_to_hub=false",
|
"--policy.push_to_hub=false",
|
||||||
|
"--policy.input_features=null",
|
||||||
|
"--policy.output_features=null",
|
||||||
"--peft.method=LORA",
|
"--peft.method=LORA",
|
||||||
"--dataset.repo_id=lerobot/pusht",
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
"--dataset.episodes=[0, 1]",
|
"--dataset.episodes=[0, 1]",
|
||||||
@@ -81,7 +96,7 @@ def test_peft_training_works(policy_type, tmp_path):
|
|||||||
# so these must be part of the trained state dict.
|
# so these must be part of the trained state dict.
|
||||||
state_dict = load_file(policy_dir / "adapter_model.safetensors")
|
state_dict = load_file(policy_dir / "adapter_model.safetensors")
|
||||||
|
|
||||||
fully_trained_keys = [
|
adapted_keys = [
|
||||||
"state_proj",
|
"state_proj",
|
||||||
"action_in_proj",
|
"action_in_proj",
|
||||||
"action_out_proj",
|
"action_out_proj",
|
||||||
@@ -91,18 +106,19 @@ def test_peft_training_works(policy_type, tmp_path):
|
|||||||
|
|
||||||
found_keys = [
|
found_keys = [
|
||||||
module_key
|
module_key
|
||||||
for module_key in fully_trained_keys
|
for module_key in adapted_keys
|
||||||
for state_dict_key in state_dict
|
for state_dict_key in state_dict
|
||||||
if f".{module_key}." in state_dict_key
|
if f".{module_key}." in state_dict_key
|
||||||
]
|
]
|
||||||
|
|
||||||
assert set(found_keys) == set(fully_trained_keys)
|
assert set(found_keys) == set(adapted_keys)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||||
def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
||||||
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
||||||
output_dir = tmp_path / f"output_{policy_type}"
|
output_dir = tmp_path / f"output_{policy_type}"
|
||||||
|
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||||
|
|
||||||
def dummy_update_policy(
|
def dummy_update_policy(
|
||||||
train_metrics,
|
train_metrics,
|
||||||
@@ -123,8 +139,10 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
|||||||
with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy):
|
with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy):
|
||||||
lerobot_train(
|
lerobot_train(
|
||||||
[
|
[
|
||||||
f"--policy.type={policy_type}",
|
f"--policy.path={model_id}",
|
||||||
"--policy.push_to_hub=false",
|
"--policy.push_to_hub=false",
|
||||||
|
"--policy.input_features=null",
|
||||||
|
"--policy.output_features=null",
|
||||||
"--peft.method=LORA",
|
"--peft.method=LORA",
|
||||||
"--dataset.repo_id=lerobot/pusht",
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
"--dataset.episodes=[0, 1]",
|
"--dataset.episodes=[0, 1]",
|
||||||
@@ -139,6 +157,7 @@ class DummyRobot:
|
|||||||
cameras = []
|
cameras = []
|
||||||
action_features = {"foo": 1.0, "bar": 2.0}
|
action_features = {"foo": 1.0, "bar": 2.0}
|
||||||
observation_features = {"obs1": 1.0, "obs2": 2.0}
|
observation_features = {"obs1": 1.0, "obs2": 2.0}
|
||||||
|
is_connected = True
|
||||||
|
|
||||||
def connect(self, *args):
|
def connect(self, *args):
|
||||||
pass
|
pass
|
||||||
@@ -157,11 +176,14 @@ def test_peft_record_loads_policy(policy_type, tmp_path):
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
output_dir = tmp_path / f"output_{policy_type}"
|
output_dir = tmp_path / f"output_{policy_type}"
|
||||||
|
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||||
|
|
||||||
lerobot_train(
|
lerobot_train(
|
||||||
[
|
[
|
||||||
f"--policy.type={policy_type}",
|
f"--policy.path={model_id}",
|
||||||
"--policy.push_to_hub=false",
|
"--policy.push_to_hub=false",
|
||||||
|
"--policy.input_features=null",
|
||||||
|
"--policy.output_features=null",
|
||||||
"--peft.method=LORA",
|
"--peft.method=LORA",
|
||||||
"--dataset.repo_id=lerobot/pusht",
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
"--dataset.episodes=[0, 1]",
|
"--dataset.episodes=[0, 1]",
|
||||||
|
|||||||
Reference in New Issue
Block a user