mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
fix formatting
This commit is contained in:
@@ -186,7 +186,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
nvidia-smi
|
nvidia-smi
|
||||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||||
|
|
||||||
- name: Run multi-GPU training tests
|
- name: Run multi-GPU training tests
|
||||||
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
|
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ When you launch training with accelerate:
|
|||||||
|
|
||||||
### Why No Automatic Scaling?
|
### Why No Automatic Scaling?
|
||||||
|
|
||||||
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
|
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
|
||||||
However, LeRobot keeps the learning rate exactly as you specify it.
|
However, LeRobot keeps the learning rate exactly as you specify it.
|
||||||
|
|
||||||
### When and How to Scale
|
### When and How to Scale
|
||||||
@@ -104,8 +104,8 @@ Since the effective batch size `bs` increases with multiple GPUs (batch_size ×
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Example: 2 GPUs with effective batch size 2x larger
|
# Example: 2 GPUs with effective batch size 2x larger
|
||||||
# Original: batch_size=8, steps=100000
|
# Original: batch_size=8, steps=100000
|
||||||
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
|
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
|
||||||
accelerate launch --num_processes=2 $(which lerobot-train) \
|
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||||
--batch_size=8 \
|
--batch_size=8 \
|
||||||
--steps=50000 \
|
--steps=50000 \
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||||
"""Used by Physical Intelligence to train Pi0.
|
"""Used by Physical Intelligence to train Pi0.
|
||||||
|
|
||||||
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
|
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
|
||||||
This ensures the learning rate schedule completes properly even with shorter training runs.
|
This ensures the learning rate schedule completes properly even with shorter training runs.
|
||||||
"""
|
"""
|
||||||
@@ -95,13 +95,13 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||||||
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps
|
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps
|
||||||
actual_warmup_steps = self.num_warmup_steps
|
actual_warmup_steps = self.num_warmup_steps
|
||||||
actual_decay_steps = self.num_decay_steps
|
actual_decay_steps = self.num_decay_steps
|
||||||
|
|
||||||
if num_training_steps < self.num_decay_steps:
|
if num_training_steps < self.num_decay_steps:
|
||||||
# Calculate scaling factor to fit the schedule into the available training steps
|
# Calculate scaling factor to fit the schedule into the available training steps
|
||||||
scale_factor = num_training_steps / self.num_decay_steps
|
scale_factor = num_training_steps / self.num_decay_steps
|
||||||
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
|
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
|
||||||
actual_decay_steps = num_training_steps
|
actual_decay_steps = num_training_steps
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Auto-scaling LR scheduler: "
|
f"Auto-scaling LR scheduler: "
|
||||||
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
|
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
@@ -86,7 +85,7 @@ def update_policy(
|
|||||||
"""
|
"""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Let accelerator handle mixed precision
|
# Let accelerator handle mixed precision
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
@@ -94,17 +93,17 @@ def update_policy(
|
|||||||
|
|
||||||
# Use accelerator's backward method
|
# Use accelerator's backward method
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|
||||||
# Clip gradients if specified
|
# Clip gradients if specified
|
||||||
if grad_clip_norm > 0:
|
if grad_clip_norm > 0:
|
||||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.tensor(0.0, device=accelerator.device)
|
grad_norm = torch.tensor(0.0, device=accelerator.device)
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
with lock if lock is not None else nullcontext():
|
with lock if lock is not None else nullcontext():
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Step through pytorch scheduler at every batch instead of epoch
|
# Step through pytorch scheduler at every batch instead of epoch
|
||||||
@@ -143,16 +142,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
# Create Accelerator if not provided
|
# Create Accelerator if not provided
|
||||||
# It will automatically detect if running in distributed mode or single-process mode
|
# It will automatically detect if running in distributed mode or single-process mode
|
||||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||||
# the lr_scheduler steps based on the num_processes
|
# We set find_unused_parameters=True to handle models with conditional computation
|
||||||
# We set find_unused_parameters=True to handle models with conditional computation paths
|
|
||||||
if accelerator is None:
|
if accelerator is None:
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||||
step_scheduler_with_optimizer=False,
|
|
||||||
kwargs_handlers=[ddp_kwargs]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine if this is the main process (for logging and checkpointing)
|
# Determine if this is the main process (for logging and checkpointing)
|
||||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||||
@@ -182,10 +178,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("Creating dataset")
|
logging.info("Creating dataset")
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# Wait for main process to finish downloading/caching dataset
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# Now all other processes can safely load the dataset
|
# Now all other processes can safely load the dataset
|
||||||
if not is_main_process:
|
if not is_main_process:
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
@@ -205,7 +200,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for all processes to finish policy creation before continuing
|
# Wait for all processes to finish policy creation before continuing
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -288,7 +283,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare everything with accelerator
|
# Prepare everything with accelerator
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@@ -341,7 +336,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
step += 1
|
step += 1
|
||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||||
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps)
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||||
|
|
||||||
if is_log_step:
|
if is_log_step:
|
||||||
@@ -431,7 +426,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
unwrapped_policy.push_model_to_hub(cfg)
|
unwrapped_policy.push_model_to_hub(cfg)
|
||||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
|
|
||||||
# Properly clean up the distributed process group
|
# Properly clean up the distributed process group
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from accelerate import Accelerator
|
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -29,6 +28,7 @@ from statistics import mean
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||||
|
|
||||||
|
|
||||||
@@ -117,10 +117,10 @@ def init_logging(
|
|||||||
accelerator: Accelerator | None = None,
|
accelerator: Accelerator | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize logging configuration for LeRobot.
|
"""Initialize logging configuration for LeRobot.
|
||||||
|
|
||||||
In multi-GPU training, only the main process logs to console to avoid duplicate output.
|
In multi-GPU training, only the main process logs to console to avoid duplicate output.
|
||||||
Non-main processes have console logging suppressed but can still log to file.
|
Non-main processes have console logging suppressed but can still log to file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
log_file: Optional file path to write logs to
|
log_file: Optional file path to write logs to
|
||||||
display_pid: Include process ID in log messages (useful for debugging multi-process)
|
display_pid: Include process ID in log messages (useful for debugging multi-process)
|
||||||
@@ -128,6 +128,7 @@ def init_logging(
|
|||||||
file_level: Logging level for file output
|
file_level: Logging level for file output
|
||||||
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def custom_format(record: logging.LogRecord) -> str:
|
def custom_format(record: logging.LogRecord) -> str:
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
fnameline = f"{record.pathname}:{record.lineno}"
|
fnameline = f"{record.pathname}:{record.lineno}"
|
||||||
@@ -139,7 +140,7 @@ def init_logging(
|
|||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.NOTSET)
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
# Clear any existing handlers
|
# Clear any existing handlers
|
||||||
logger.handlers.clear()
|
logger.handlers.clear()
|
||||||
|
|
||||||
@@ -159,7 +160,6 @@ def init_logging(
|
|||||||
logger.addHandler(logging.NullHandler())
|
logger.addHandler(logging.NullHandler())
|
||||||
logger.setLevel(logging.ERROR)
|
logger.setLevel(logging.ERROR)
|
||||||
|
|
||||||
# File logging (optional, all processes)
|
|
||||||
if log_file is not None:
|
if log_file is not None:
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
@@ -178,6 +178,7 @@ def format_big_number(num, precision=0):
|
|||||||
|
|
||||||
return num
|
return num
|
||||||
|
|
||||||
|
|
||||||
def say(text: str, blocking: bool = False):
|
def say(text: str, blocking: bool = False):
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,7 @@ The tests automatically generate accelerate configs and launch training
|
|||||||
with subprocess to properly test the distributed training environment.
|
with subprocess to properly test the distributed training environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -48,7 +46,7 @@ def get_num_available_gpus():
|
|||||||
def download_dataset(repo_id, episodes):
|
def download_dataset(repo_id, episodes):
|
||||||
"""
|
"""
|
||||||
Pre-download dataset to avoid race conditions in multi-GPU training.
|
Pre-download dataset to avoid race conditions in multi-GPU training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id: HuggingFace dataset repository ID
|
repo_id: HuggingFace dataset repository ID
|
||||||
episodes: List of episode indices to download
|
episodes: List of episode indices to download
|
||||||
@@ -61,27 +59,18 @@ def download_dataset(repo_id, episodes):
|
|||||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||||
"""
|
"""
|
||||||
Helper function to run training with accelerate launch.
|
Helper function to run training with accelerate launch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_args: List of config arguments to pass to lerobot_train.py
|
config_args: List of config arguments to pass to lerobot_train.py
|
||||||
num_processes: Number of processes (GPUs) to use
|
num_processes: Number of processes (GPUs) to use
|
||||||
temp_dir: Temporary directory for outputs
|
temp_dir: Temporary directory for outputs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
subprocess.CompletedProcess result
|
subprocess.CompletedProcess result
|
||||||
"""
|
"""
|
||||||
# Create accelerate config
|
|
||||||
accelerate_config = {
|
|
||||||
"compute_environment": "LOCAL_MACHINE",
|
|
||||||
"distributed_type": "MULTI_GPU",
|
|
||||||
"mixed_precision": "no",
|
|
||||||
"num_processes": num_processes,
|
|
||||||
"use_cpu": False,
|
|
||||||
"gpu_ids": "all",
|
|
||||||
}
|
|
||||||
|
|
||||||
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
||||||
|
|
||||||
# Write YAML config
|
# Write YAML config
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||||
@@ -96,7 +85,7 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
|||||||
f.write("num_machines: 1\n")
|
f.write("num_machines: 1\n")
|
||||||
f.write("rdzv_backend: static\n")
|
f.write("rdzv_backend: static\n")
|
||||||
f.write("same_network: true\n")
|
f.write("same_network: true\n")
|
||||||
|
|
||||||
cmd = [
|
cmd = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"launch",
|
"launch",
|
||||||
@@ -105,14 +94,14 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
|||||||
"-m",
|
"-m",
|
||||||
"lerobot.scripts.lerobot_train",
|
"lerobot.scripts.lerobot_train",
|
||||||
] + config_args
|
] + config_args
|
||||||
|
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
cmd,
|
cmd,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))},
|
env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))},
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -130,10 +119,10 @@ class TestMultiGPUTraining:
|
|||||||
"""
|
"""
|
||||||
# Pre-download dataset to avoid race conditions
|
# Pre-download dataset to avoid race conditions
|
||||||
download_dataset("lerobot/pusht", episodes=[0])
|
download_dataset("lerobot/pusht", episodes=[0])
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
output_dir = Path(temp_dir) / "outputs"
|
output_dir = Path(temp_dir) / "outputs"
|
||||||
|
|
||||||
config_args = [
|
config_args = [
|
||||||
"--dataset.repo_id=lerobot/pusht",
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
"--dataset.episodes=[0]",
|
"--dataset.episodes=[0]",
|
||||||
@@ -149,20 +138,20 @@ class TestMultiGPUTraining:
|
|||||||
"--seed=42",
|
"--seed=42",
|
||||||
"--num_workers=0",
|
"--num_workers=0",
|
||||||
]
|
]
|
||||||
|
|
||||||
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
||||||
|
|
||||||
# Check that training completed successfully
|
# Check that training completed successfully
|
||||||
assert result.returncode == 0, (
|
assert result.returncode == 0, (
|
||||||
f"Multi-GPU training failed with return code {result.returncode}\n"
|
f"Multi-GPU training failed with return code {result.returncode}\n"
|
||||||
f"STDOUT:\n{result.stdout}\n"
|
f"STDOUT:\n{result.stdout}\n"
|
||||||
f"STDERR:\n{result.stderr}"
|
f"STDERR:\n{result.stderr}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify checkpoint was saved
|
# Verify checkpoint was saved
|
||||||
checkpoints_dir = output_dir / "checkpoints"
|
checkpoints_dir = output_dir / "checkpoints"
|
||||||
assert checkpoints_dir.exists(), "Checkpoints directory was not created"
|
assert checkpoints_dir.exists(), "Checkpoints directory was not created"
|
||||||
|
|
||||||
# Verify that training completed
|
# Verify that training completed
|
||||||
assert "End of training" in result.stdout or "End of training" in result.stderr
|
assert "End of training" in result.stdout or "End of training" in result.stderr
|
||||||
|
|
||||||
@@ -173,10 +162,10 @@ class TestMultiGPUTraining:
|
|||||||
"""
|
"""
|
||||||
# Pre-download dataset to avoid race conditions
|
# Pre-download dataset to avoid race conditions
|
||||||
download_dataset("lerobot/pusht", episodes=[0])
|
download_dataset("lerobot/pusht", episodes=[0])
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
output_dir = Path(temp_dir) / "outputs"
|
output_dir = Path(temp_dir) / "outputs"
|
||||||
|
|
||||||
config_args = [
|
config_args = [
|
||||||
"--dataset.repo_id=lerobot/pusht",
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
"--dataset.episodes=[0]",
|
"--dataset.episodes=[0]",
|
||||||
@@ -192,31 +181,31 @@ class TestMultiGPUTraining:
|
|||||||
"--seed=42",
|
"--seed=42",
|
||||||
"--num_workers=0",
|
"--num_workers=0",
|
||||||
]
|
]
|
||||||
|
|
||||||
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
||||||
|
|
||||||
assert result.returncode == 0, (
|
assert result.returncode == 0, (
|
||||||
f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify checkpoint directory exists
|
# Verify checkpoint directory exists
|
||||||
checkpoints_dir = output_dir / "checkpoints"
|
checkpoints_dir = output_dir / "checkpoints"
|
||||||
assert checkpoints_dir.exists(), "Checkpoints directory not created"
|
assert checkpoints_dir.exists(), "Checkpoints directory not created"
|
||||||
|
|
||||||
# Count checkpoint directories (should have checkpoint at step 10 and 20)
|
# Count checkpoint directories (should have checkpoint at step 10 and 20)
|
||||||
checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()]
|
checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()]
|
||||||
assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}"
|
assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}"
|
||||||
|
|
||||||
# Verify checkpoint contents
|
# Verify checkpoint contents
|
||||||
for checkpoint_dir in checkpoint_dirs:
|
for checkpoint_dir in checkpoint_dirs:
|
||||||
# Check for model files
|
# Check for model files
|
||||||
model_files = list(checkpoint_dir.rglob("*.safetensors"))
|
model_files = list(checkpoint_dir.rglob("*.safetensors"))
|
||||||
assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}"
|
assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}"
|
||||||
|
|
||||||
# Check for training state
|
# Check for training state
|
||||||
training_state_dir = checkpoint_dir / "training_state"
|
training_state_dir = checkpoint_dir / "training_state"
|
||||||
assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}"
|
assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}"
|
||||||
|
|
||||||
# Verify optimizer state exists
|
# Verify optimizer state exists
|
||||||
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
||||||
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
||||||
|
|||||||
Reference in New Issue
Block a user