diff --git a/tests/test_cli.py b/tests/test_cli.py index 1bd4e3ec0..e949e42d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -99,6 +99,41 @@ def test_peft_training_works(policy_type, tmp_path): assert set(found_keys) == set(fully_trained_keys) +@pytest.mark.parametrize("policy_type", ["smolvla"]) +def test_peft_training_params_are_fewer(policy_type, tmp_path): + """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" + output_dir = tmp_path / f"output_{policy_type}" + + def dummy_update_policy( + train_metrics, + policy, + batch, + optimizer, + grad_clip_norm: float, + accelerator, + **kwargs + ): + params_total = sum(p.numel() for p in policy.parameters()) + params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad) + + assert params_total > params_trainable + + return train_metrics, {} + + with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy): + lerobot_train( + [ + f"--policy.type={policy_type}", + "--policy.push_to_hub=false", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + class DummyRobot: name = "dummy" cameras = []