Add test checking that PEFT actually reduces params

This commit is contained in:
nemo
2025-12-16 17:58:27 +01:00
parent 9649f140ca
commit 51876e7f55
+35
View File
@@ -99,6 +99,41 @@ def test_peft_training_works(policy_type, tmp_path):
assert set(found_keys) == set(fully_trained_keys)
@pytest.mark.parametrize("policy_type", ["smolvla"])
def test_peft_training_params_are_fewer(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
output_dir = tmp_path / f"output_{policy_type}"
def dummy_update_policy(
train_metrics,
policy,
batch,
optimizer,
grad_clip_norm: float,
accelerator,
**kwargs
):
params_total = sum(p.numel() for p in policy.parameters())
params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)
assert params_total > params_trainable
return train_metrics, {}
with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy):
lerobot_train(
[
f"--policy.type={policy_type}",
"--policy.push_to_hub=false",
"--peft.method=LORA",
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0, 1]",
"--steps=1",
f"--output_dir={output_dir}",
]
)
class DummyRobot:
name = "dummy"
cameras = []