mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
Add test checking that PEFT actually reduces params
This commit is contained in:
@@ -99,6 +99,41 @@ def test_peft_training_works(policy_type, tmp_path):
|
|||||||
assert set(found_keys) == set(fully_trained_keys)
|
assert set(found_keys) == set(fully_trained_keys)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||||
|
def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
||||||
|
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
|
||||||
|
output_dir = tmp_path / f"output_{policy_type}"
|
||||||
|
|
||||||
|
def dummy_update_policy(
|
||||||
|
train_metrics,
|
||||||
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
grad_clip_norm: float,
|
||||||
|
accelerator,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
params_total = sum(p.numel() for p in policy.parameters())
|
||||||
|
params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
assert params_total > params_trainable
|
||||||
|
|
||||||
|
return train_metrics, {}
|
||||||
|
|
||||||
|
with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy):
|
||||||
|
lerobot_train(
|
||||||
|
[
|
||||||
|
f"--policy.type={policy_type}",
|
||||||
|
"--policy.push_to_hub=false",
|
||||||
|
"--peft.method=LORA",
|
||||||
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
|
"--dataset.episodes=[0, 1]",
|
||||||
|
"--steps=1",
|
||||||
|
f"--output_dir={output_dir}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DummyRobot:
|
class DummyRobot:
|
||||||
name = "dummy"
|
name = "dummy"
|
||||||
cameras = []
|
cameras = []
|
||||||
|
|||||||
Reference in New Issue
Block a user