From 51876e7f55c5fc2ef0d933a16cee36535d0d7e32 Mon Sep 17 00:00:00 2001 From: nemo Date: Tue, 16 Dec 2025 17:58:27 +0100 Subject: [PATCH] Add test checking that PEFT actually reduces params --- tests/test_cli.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 1bd4e3ec0..e949e42d4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -99,6 +99,41 @@ def test_peft_training_works(policy_type, tmp_path): assert set(found_keys) == set(fully_trained_keys) +@pytest.mark.parametrize("policy_type", ["smolvla"]) +def test_peft_training_params_are_fewer(policy_type, tmp_path): + """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" + output_dir = tmp_path / f"output_{policy_type}" + + def dummy_update_policy( + train_metrics, + policy, + batch, + optimizer, + grad_clip_norm: float, + accelerator, + **kwargs + ): + params_total = sum(p.numel() for p in policy.parameters()) + params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad) + + assert params_total > params_trainable + + return train_metrics, {} + + with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy): + lerobot_train( + [ + f"--policy.type={policy_type}", + "--policy.push_to_hub=false", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + class DummyRobot: name = "dummy" cameras = []