From 0d791307297f20d115e5ba3e62c3a3f5d955d131 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 14 Oct 2025 10:24:46 +0200 Subject: [PATCH] pre download dataset in tests --- tests/training/test_multi_gpu.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py index 73c0d39dd..c24af6138 100644 --- a/tests/training/test_multi_gpu.py +++ b/tests/training/test_multi_gpu.py @@ -35,6 +35,8 @@ from pathlib import Path import pytest import torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset + def get_num_available_gpus(): """Returns the number of available GPUs.""" @@ -43,6 +45,19 @@ def get_num_available_gpus(): return torch.cuda.device_count() +def download_dataset(repo_id, episodes): + """ + Pre-download dataset to avoid race conditions in multi-GPU training. + + Args: + repo_id: HuggingFace dataset repository ID + episodes: List of episode indices to download + """ + # Simply instantiating the dataset will download it + _ = LeRobotDataset(repo_id, episodes=episodes) + print(f"Dataset {repo_id} downloaded successfully") + + def run_accelerate_training(config_args, num_processes=4, temp_dir=None): """ Helper function to run training with accelerate launch. @@ -113,6 +128,9 @@ class TestMultiGPUTraining: Test that basic multi-GPU training runs successfully. Verifies that the training completes without errors. """ + # Pre-download dataset to avoid race conditions + download_dataset("lerobot/pusht", episodes=[0]) + with tempfile.TemporaryDirectory() as temp_dir: output_dir = Path(temp_dir) / "outputs" @@ -129,6 +147,7 @@ class TestMultiGPUTraining: "--log_freq=5", "--save_freq=10", "--seed=42", + "--num_workers=0", ] result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir) @@ -152,6 +171,9 @@ class TestMultiGPUTraining: Test that checkpoints are correctly saved during multi-GPU training. Only the main process (rank 0) should save checkpoints. """ + # Pre-download dataset to avoid race conditions + download_dataset("lerobot/pusht", episodes=[0]) + with tempfile.TemporaryDirectory() as temp_dir: output_dir = Path(temp_dir) / "outputs" @@ -168,6 +190,7 @@ class TestMultiGPUTraining: "--log_freq=5", "--save_freq=10", "--seed=42", + "--num_workers=0", ] result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)