mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
fix formatting
This commit is contained in:
@@ -25,9 +25,7 @@ The tests automatically generate accelerate configs and launch training
|
||||
with subprocess to properly test the distributed training environment.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -48,7 +46,7 @@ def get_num_available_gpus():
|
||||
def download_dataset(repo_id, episodes):
|
||||
"""
|
||||
Pre-download dataset to avoid race conditions in multi-GPU training.
|
||||
|
||||
|
||||
Args:
|
||||
repo_id: HuggingFace dataset repository ID
|
||||
episodes: List of episode indices to download
|
||||
@@ -61,27 +59,18 @@ def download_dataset(repo_id, episodes):
|
||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
"""
|
||||
Helper function to run training with accelerate launch.
|
||||
|
||||
|
||||
Args:
|
||||
config_args: List of config arguments to pass to lerobot_train.py
|
||||
num_processes: Number of processes (GPUs) to use
|
||||
temp_dir: Temporary directory for outputs
|
||||
|
||||
|
||||
Returns:
|
||||
subprocess.CompletedProcess result
|
||||
"""
|
||||
# Create accelerate config
|
||||
accelerate_config = {
|
||||
"compute_environment": "LOCAL_MACHINE",
|
||||
"distributed_type": "MULTI_GPU",
|
||||
"mixed_precision": "no",
|
||||
"num_processes": num_processes,
|
||||
"use_cpu": False,
|
||||
"gpu_ids": "all",
|
||||
}
|
||||
|
||||
|
||||
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
||||
|
||||
|
||||
# Write YAML config
|
||||
with open(config_path, "w") as f:
|
||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||
@@ -96,7 +85,7 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
f.write("num_machines: 1\n")
|
||||
f.write("rdzv_backend: static\n")
|
||||
f.write("same_network: true\n")
|
||||
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
@@ -105,14 +94,14 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
"-m",
|
||||
"lerobot.scripts.lerobot_train",
|
||||
] + config_args
|
||||
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))},
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -130,10 +119,10 @@ class TestMultiGPUTraining:
|
||||
"""
|
||||
# Pre-download dataset to avoid race conditions
|
||||
download_dataset("lerobot/pusht", episodes=[0])
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_dir = Path(temp_dir) / "outputs"
|
||||
|
||||
|
||||
config_args = [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
@@ -149,20 +138,20 @@ class TestMultiGPUTraining:
|
||||
"--seed=42",
|
||||
"--num_workers=0",
|
||||
]
|
||||
|
||||
|
||||
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
||||
|
||||
|
||||
# Check that training completed successfully
|
||||
assert result.returncode == 0, (
|
||||
f"Multi-GPU training failed with return code {result.returncode}\n"
|
||||
f"STDOUT:\n{result.stdout}\n"
|
||||
f"STDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
|
||||
# Verify checkpoint was saved
|
||||
checkpoints_dir = output_dir / "checkpoints"
|
||||
assert checkpoints_dir.exists(), "Checkpoints directory was not created"
|
||||
|
||||
|
||||
# Verify that training completed
|
||||
assert "End of training" in result.stdout or "End of training" in result.stderr
|
||||
|
||||
@@ -173,10 +162,10 @@ class TestMultiGPUTraining:
|
||||
"""
|
||||
# Pre-download dataset to avoid race conditions
|
||||
download_dataset("lerobot/pusht", episodes=[0])
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_dir = Path(temp_dir) / "outputs"
|
||||
|
||||
|
||||
config_args = [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
@@ -192,31 +181,31 @@ class TestMultiGPUTraining:
|
||||
"--seed=42",
|
||||
"--num_workers=0",
|
||||
]
|
||||
|
||||
|
||||
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
||||
|
||||
|
||||
assert result.returncode == 0, (
|
||||
f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
|
||||
# Verify checkpoint directory exists
|
||||
checkpoints_dir = output_dir / "checkpoints"
|
||||
assert checkpoints_dir.exists(), "Checkpoints directory not created"
|
||||
|
||||
|
||||
# Count checkpoint directories (should have checkpoint at step 10 and 20)
|
||||
checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()]
|
||||
assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}"
|
||||
|
||||
|
||||
# Verify checkpoint contents
|
||||
for checkpoint_dir in checkpoint_dirs:
|
||||
# Check for model files
|
||||
model_files = list(checkpoint_dir.rglob("*.safetensors"))
|
||||
assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}"
|
||||
|
||||
|
||||
# Check for training state
|
||||
training_state_dir = checkpoint_dir / "training_state"
|
||||
assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}"
|
||||
|
||||
|
||||
# Verify optimizer state exists
|
||||
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
||||
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
||||
|
||||
Reference in New Issue
Block a user