diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py index f0668a8ef..5aedde374 100644 --- a/tests/training/test_multi_gpu.py +++ b/tests/training/test_multi_gpu.py @@ -119,7 +119,7 @@ class TestMultiGPUTraining: config_args = [ "--dataset.repo_id=lerobot/pusht", "--dataset.episodes=[0]", - "--policy=act", + "--policy.type=act", "--policy.device=cuda", f"--output_dir={output_dir}", "--batch_size=4", @@ -157,7 +157,7 @@ class TestMultiGPUTraining: config_args = [ "--dataset.repo_id=lerobot/pusht", "--dataset.episodes=[0]", - "--policy=act", + "--policy.type=act", "--policy.device=cuda", f"--output_dir={output_dir}", "--batch_size=4",