mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
pre download dataset in tests
This commit is contained in:
@@ -35,6 +35,8 @@ from pathlib import Path
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def get_num_available_gpus():
|
||||
"""Returns the number of available GPUs."""
|
||||
@@ -43,6 +45,19 @@ def get_num_available_gpus():
|
||||
return torch.cuda.device_count()
|
||||
|
||||
|
||||
def download_dataset(repo_id, episodes):
|
||||
"""
|
||||
Pre-download dataset to avoid race conditions in multi-GPU training.
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID
|
||||
episodes: List of episode indices to download
|
||||
"""
|
||||
# Simply instantiating the dataset will download it
|
||||
_ = LeRobotDataset(repo_id, episodes=episodes)
|
||||
print(f"Dataset {repo_id} downloaded successfully")
|
||||
|
||||
|
||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
"""
|
||||
Helper function to run training with accelerate launch.
|
||||
@@ -113,6 +128,9 @@ class TestMultiGPUTraining:
|
||||
Test that basic multi-GPU training runs successfully.
|
||||
Verifies that the training completes without errors.
|
||||
"""
|
||||
# Pre-download dataset to avoid race conditions
|
||||
download_dataset("lerobot/pusht", episodes=[0])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_dir = Path(temp_dir) / "outputs"
|
||||
|
||||
@@ -129,6 +147,7 @@ class TestMultiGPUTraining:
|
||||
"--log_freq=5",
|
||||
"--save_freq=10",
|
||||
"--seed=42",
|
||||
"--num_workers=0",
|
||||
]
|
||||
|
||||
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
||||
@@ -152,6 +171,9 @@ class TestMultiGPUTraining:
|
||||
Test that checkpoints are correctly saved during multi-GPU training.
|
||||
Only the main process (rank 0) should save checkpoints.
|
||||
"""
|
||||
# Pre-download dataset to avoid race conditions
|
||||
download_dataset("lerobot/pusht", episodes=[0])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_dir = Path(temp_dir) / "outputs"
|
||||
|
||||
@@ -168,6 +190,7 @@ class TestMultiGPUTraining:
|
||||
"--log_freq=5",
|
||||
"--save_freq=10",
|
||||
"--seed=42",
|
||||
"--num_workers=0",
|
||||
]
|
||||
|
||||
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
||||
|
||||
Reference in New Issue
Block a user