diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index cb915b00a..9a33bccab 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -153,9 +153,9 @@ def wrap_policy_in_peft_model(cfg, policy): p.requires_grad_(False) if not cfg.policy.pretrained_path: - logging.warning( + raise ValueError( "Training from scratch using PEFT. This is unlikely to yield good results. " - "Consider supplying a `policy.path` to fine-tune an existing model." + "Supply a `policy.path` to fine-tune an existing model." ) if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: @@ -195,7 +195,8 @@ def wrap_policy_in_peft_model(cfg, policy): # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the # adapter, not the base model. - policy.name_or_path = str(policy.config.pretrained_path) + if policy.config.pretrained_path: + policy.name_or_path = str(policy.config.pretrained_path) # Finally wrap the policy in a PEFT model policy = get_peft_model( diff --git a/tests/test_cli.py b/tests/test_cli.py index e949e42d4..8e49bbfce 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -20,12 +20,22 @@ def lerobot_record(args): return run_command(cmd="lerobot-record", module="lerobot_record", args=args) +def resolve_model_id_for_peft_training(policy_type): + """PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training.""" + if policy_type == "smolvla": + return "lerobot/smolvla_base" + + raise ValueError(f"No pretrained model known for {policy_type}. PEFT training will not work.") + + @pytest.mark.parametrize("policy_type", ["smolvla"]) def test_peft_training_push_to_hub_works(policy_type, tmp_path): """Ensure that push to hub stores PEFT only the adapter, not the full model weights.""" output_dir = tmp_path / f"output_{policy_type}" upload_folder_contents = set() + model_id = resolve_model_id_for_peft_training(policy_type) + def mock_upload_folder(*args, **kwargs): folder_path = kwargs["folder_path"] # we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders() @@ -38,9 +48,11 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path): ): lerobot_train( [ - f"--policy.type={policy_type}", + f"--policy.path={model_id}", "--policy.push_to_hub=true", "--policy.repo_id=foo/bar", + "--policy.input_features=null", + "--policy.output_features=null", "--peft.method=LORA", "--dataset.repo_id=lerobot/pusht", "--dataset.episodes=[0, 1]", @@ -58,11 +70,14 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path): def test_peft_training_works(policy_type, tmp_path): """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) lerobot_train( [ - f"--policy.type={policy_type}", + f"--policy.path={model_id}", "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", "--peft.method=LORA", "--dataset.repo_id=lerobot/pusht", "--dataset.episodes=[0, 1]", @@ -81,7 +96,7 @@ def test_peft_training_works(policy_type, tmp_path): # so these must be part of the trained state dict. state_dict = load_file(policy_dir / "adapter_model.safetensors") - fully_trained_keys = [ + adapted_keys = [ "state_proj", "action_in_proj", "action_out_proj", @@ -91,18 +106,19 @@ def test_peft_training_works(policy_type, tmp_path): found_keys = [ module_key - for module_key in fully_trained_keys + for module_key in adapted_keys for state_dict_key in state_dict if f".{module_key}." in state_dict_key ] - assert set(found_keys) == set(fully_trained_keys) + assert set(found_keys) == set(adapted_keys) @pytest.mark.parametrize("policy_type", ["smolvla"]) def test_peft_training_params_are_fewer(policy_type, tmp_path): """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) def dummy_update_policy( train_metrics, @@ -123,8 +139,10 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path): with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy): lerobot_train( [ - f"--policy.type={policy_type}", + f"--policy.path={model_id}", "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", "--peft.method=LORA", "--dataset.repo_id=lerobot/pusht", "--dataset.episodes=[0, 1]", @@ -139,6 +157,7 @@ class DummyRobot: cameras = [] action_features = {"foo": 1.0, "bar": 2.0} observation_features = {"obs1": 1.0, "obs2": 2.0} + is_connected = True def connect(self, *args): pass @@ -157,11 +176,14 @@ def test_peft_record_loads_policy(policy_type, tmp_path): from peft import PeftModel output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) lerobot_train( [ - f"--policy.type={policy_type}", + f"--policy.path={model_id}", "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", "--peft.method=LORA", "--dataset.repo_id=lerobot/pusht", "--dataset.episodes=[0, 1]",