mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
pre download dataset in tests
This commit is contained in:
@@ -35,6 +35,8 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
def get_num_available_gpus():
|
def get_num_available_gpus():
|
||||||
"""Returns the number of available GPUs."""
|
"""Returns the number of available GPUs."""
|
||||||
@@ -43,6 +45,19 @@ def get_num_available_gpus():
|
|||||||
return torch.cuda.device_count()
|
return torch.cuda.device_count()
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset(repo_id, episodes):
|
||||||
|
"""
|
||||||
|
Pre-download dataset to avoid race conditions in multi-GPU training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: HuggingFace dataset repository ID
|
||||||
|
episodes: List of episode indices to download
|
||||||
|
"""
|
||||||
|
# Simply instantiating the dataset will download it
|
||||||
|
_ = LeRobotDataset(repo_id, episodes=episodes)
|
||||||
|
print(f"Dataset {repo_id} downloaded successfully")
|
||||||
|
|
||||||
|
|
||||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||||
"""
|
"""
|
||||||
Helper function to run training with accelerate launch.
|
Helper function to run training with accelerate launch.
|
||||||
@@ -113,6 +128,9 @@ class TestMultiGPUTraining:
|
|||||||
Test that basic multi-GPU training runs successfully.
|
Test that basic multi-GPU training runs successfully.
|
||||||
Verifies that the training completes without errors.
|
Verifies that the training completes without errors.
|
||||||
"""
|
"""
|
||||||
|
# Pre-download dataset to avoid race conditions
|
||||||
|
download_dataset("lerobot/pusht", episodes=[0])
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
output_dir = Path(temp_dir) / "outputs"
|
output_dir = Path(temp_dir) / "outputs"
|
||||||
|
|
||||||
@@ -129,6 +147,7 @@ class TestMultiGPUTraining:
|
|||||||
"--log_freq=5",
|
"--log_freq=5",
|
||||||
"--save_freq=10",
|
"--save_freq=10",
|
||||||
"--seed=42",
|
"--seed=42",
|
||||||
|
"--num_workers=0",
|
||||||
]
|
]
|
||||||
|
|
||||||
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
||||||
@@ -152,6 +171,9 @@ class TestMultiGPUTraining:
|
|||||||
Test that checkpoints are correctly saved during multi-GPU training.
|
Test that checkpoints are correctly saved during multi-GPU training.
|
||||||
Only the main process (rank 0) should save checkpoints.
|
Only the main process (rank 0) should save checkpoints.
|
||||||
"""
|
"""
|
||||||
|
# Pre-download dataset to avoid race conditions
|
||||||
|
download_dataset("lerobot/pusht", episodes=[0])
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
output_dir = Path(temp_dir) / "outputs"
|
output_dir = Path(temp_dir) / "outputs"
|
||||||
|
|
||||||
@@ -168,6 +190,7 @@ class TestMultiGPUTraining:
|
|||||||
"--log_freq=5",
|
"--log_freq=5",
|
||||||
"--save_freq=10",
|
"--save_freq=10",
|
||||||
"--seed=42",
|
"--seed=42",
|
||||||
|
"--num_workers=0",
|
||||||
]
|
]
|
||||||
|
|
||||||
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
||||||
|
|||||||
Reference in New Issue
Block a user