Compare commits

...

1 Commits

Author SHA1 Message Date
danaaubakirova d148279921 Support accelerate training and add test configs for SmolVLA
- 2-GPU SLURM job (distributed training)
- 1-GPU local accelerate and direct training scripts
- Accelerate configs for 1-GPU and 2-GPU setups
2025-09-04 13:07:25 +00:00
8 changed files with 321 additions and 34 deletions
+11
View File
@@ -0,0 +1,11 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: NO
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 1
use_cpu: false
+18
View File
@@ -0,0 +1,18 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
dynamo_backend: "no"
+5 -1
View File
@@ -243,7 +243,11 @@ def eval_policy(
if max_episodes_rendered > 0 and not videos_dir: if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
if not isinstance(policy, PreTrainedPolicy): # Handle accelerate-wrapped models by unwrapping them
if hasattr(policy, 'module') and isinstance(policy.module, PreTrainedPolicy):
# This is likely an accelerate-wrapped model (DistributedDataParallel)
policy = policy.module
elif not isinstance(policy, PreTrainedPolicy):
raise ValueError( raise ValueError(
f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided." f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided."
) )
+99 -30
View File
@@ -16,6 +16,7 @@
import logging import logging
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial
from pprint import pformat from pprint import pformat
from typing import Any from typing import Any
@@ -23,6 +24,8 @@ import torch
from termcolor import colored from termcolor import colored
from torch.amp import GradScaler from torch.amp import GradScaler
from torch.optim import Optimizer from torch.optim import Optimizer
import os
from datetime import timedelta
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
@@ -52,6 +55,8 @@ from lerobot.utils.utils import (
) )
from lerobot.utils.wandb_utils import WandBLogger from lerobot.utils.wandb_utils import WandBLogger
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def update_policy( def update_policy(
train_metrics: MetricsTracker, train_metrics: MetricsTracker,
@@ -59,36 +64,65 @@ def update_policy(
batch: Any, batch: Any,
optimizer: Optimizer, optimizer: Optimizer,
grad_clip_norm: float, grad_clip_norm: float,
grad_scaler: GradScaler, grad_scaler: GradScaler | None,
lr_scheduler=None, lr_scheduler=None,
use_amp: bool = False, use_amp: bool = False,
lock=None, lock=None,
accelerator=None,
) -> tuple[MetricsTracker, dict]: ) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter() start_time = time.perf_counter()
device = get_device_from_parameters(policy) device = get_device_from_parameters(policy)
policy.train() policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch) grad_norm = 0.0 # Initialize grad_norm to avoid undefined variable
if accelerator:
with accelerator.accumulate(policy):
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
optimizer.zero_grad()
else:
# Standard training loop without accelerate
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
if grad_scaler is not None:
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
else:
# Without GradScaler (fallback)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
with lock if lock is not None else nullcontext():
optimizer.step()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. optimizer.zero_grad()
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch # Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None: if lr_scheduler is not None:
@@ -99,7 +133,7 @@ def update_policy(
policy.update() policy.update()
train_metrics.loss = loss.item() train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item() train_metrics.grad_norm = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
train_metrics.lr = optimizer.param_groups[0]["lr"] train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict return train_metrics, output_dict
@@ -108,8 +142,33 @@ def update_policy(
@parser.wrap() @parser.wrap()
def train(cfg: TrainPipelineConfig): def train(cfg: TrainPipelineConfig):
cfg.validate() cfg.validate()
accelerator = None
if is_launched_with_accelerate():
import accelerate
# For example pi0 has unused params (last llm block)
from accelerate import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
from accelerate import InitProcessGroupKwargs
# Set NCCL timeout (default 30 minutes = 1800 seconds)
nccl_timeout = getattr(cfg, 'nccl_timeout', 1800)
ddp_init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=nccl_timeout)) # FIXME(mshukor): allow user to set timeout. This should be longer than the evaluation time
# Set gradient accumulation steps (default 1)
gradient_accumulation_steps = getattr(cfg, 'gradient_accumulation_steps', 1)
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[ddp_init_kwargs, ddp_kwargs])
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
logging.info(pformat(cfg.to_dict())) logging.info(pformat(cfg.to_dict()))
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project: if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg) wandb_logger = WandBLogger(cfg)
else: else:
@@ -143,7 +202,8 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating optimizer and scheduler") logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) # Only use GradScaler when not using accelerate (accelerate handles mixed precision internally)
grad_scaler = None if accelerator else GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
@@ -185,6 +245,11 @@ def train(cfg: TrainPipelineConfig):
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
# Prepare models for accelerate if using multi-GPU
if accelerator:
policy, optimizer, dataloader = accelerator.prepare(policy, optimizer, dataloader)
dl_iter = cycle(dataloader)
policy.train() policy.train()
train_metrics = { train_metrics = {
@@ -205,9 +270,10 @@ def train(cfg: TrainPipelineConfig):
batch = next(dl_iter) batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch: if not accelerator:
if isinstance(batch[key], torch.Tensor): for key in batch:
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy( train_tracker, output_dict = update_policy(
train_tracker, train_tracker,
@@ -218,6 +284,7 @@ def train(cfg: TrainPipelineConfig):
grad_scaler=grad_scaler, grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp, use_amp=cfg.policy.use_amp,
accelerator=accelerator,
) )
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -237,15 +304,17 @@ def train(cfg: TrainPipelineConfig):
wandb_logger.log_dict(wandb_log_dict, step) wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages() train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step and (not accelerator or accelerator.is_main_process):
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) # Unwrap model for accelerate
policy_to_save = accelerator.unwrap_model(policy) if accelerator else policy
save_checkpoint(checkpoint_dir, step, cfg, policy_to_save, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
wandb_logger.log_policy(checkpoint_dir) wandb_logger.log_policy(checkpoint_dir)
if cfg.env and is_eval_step: if cfg.env and is_eval_step and (not accelerator or accelerator.is_main_process):
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with ( with (
@@ -254,7 +323,7 @@ def train(cfg: TrainPipelineConfig):
): ):
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,
policy, accelerator.unwrap_model(policy) if accelerator else policy,
cfg.eval.n_episodes, cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4, max_episodes_rendered=4,
+31 -3
View File
@@ -60,11 +60,39 @@ def load_training_step(save_dir: Path) -> int:
def update_last_checkpoint(checkpoint_dir: Path) -> Path: def update_last_checkpoint(checkpoint_dir: Path) -> Path:
import fcntl
import tempfile
import os
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
last_checkpoint_dir.unlink()
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent) relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
last_checkpoint_dir.symlink_to(relative_target)
# Use file locking to prevent race conditions in multi-GPU training
lock_file = checkpoint_dir.parent / ".symlink_lock"
try:
with open(lock_file, 'w') as f:
# Get exclusive lock
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
# Update symlink atomically
if last_checkpoint_dir.exists() or last_checkpoint_dir.is_symlink():
last_checkpoint_dir.unlink()
last_checkpoint_dir.symlink_to(relative_target)
except (OSError, FileExistsError) as e:
# Handle race conditions gracefully - another process may have already updated
if not last_checkpoint_dir.exists():
try:
last_checkpoint_dir.symlink_to(relative_target)
except FileExistsError:
pass # Another process created it, that's fine
finally:
# Clean up lock file
try:
lock_file.unlink()
except FileNotFoundError:
pass
def save_checkpoint( def save_checkpoint(
+45
View File
@@ -0,0 +1,45 @@
#!/bin/bash
echo "=== Local 1-GPU Accelerate Training Test with SmolVLA ==="
echo "Environment: multi"
echo "GPU: 1"
echo "Steps: 50 (quick local test)"
echo ""
# Activate conda environment
source /fsx/dana_aubakirova/miniconda3/etc/profile.d/conda.sh
conda activate multi
# Set CUDA environment for 1 GPU
export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,expandable_segments:True
export TORCH_DISTRIBUTED_DEBUG=OFF
export CUDA_LAUNCH_BLOCKING=0
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
# Change to working directory
cd /fsx/dana_aubakirova/vla/pr/lerobot
# Set output directory with timestamp
export OUTPUT_DIR="outputs/test_accelerate_1gpu_local_$(date +%Y%m%d_%H%M%S)"
echo "Output directory: $OUTPUT_DIR"
echo ""
# Test accelerate training with 1 GPU
accelerate launch --config_file accelerate_configs/1gpu_config.yaml -m lerobot.scripts.train \
--policy.path=lerobot/smolvla_base \
--policy.push_to_hub=false \
--dataset.repo_id=lerobot/svla_so100_sorting \
--dataset.video_backend=pyav \
--steps=50 \
--save_freq=25 \
--log_freq=5 \
--batch_size=1 \
--num_workers=0 \
--output_dir=$OUTPUT_DIR \
--wandb.enable=false
echo ""
echo "=== Training completed! ==="
echo "Check outputs in: $OUTPUT_DIR"
+67
View File
@@ -0,0 +1,67 @@
#!/bin/bash
#SBATCH --job-name=test_accelerate
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=16
#SBATCH --gres=gpu:2
#SBATCH --time=1:00:00
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/dana_aubakirova/vla/logs/test_accelerate_%j.out
#SBATCH --error=/fsx/dana_aubakirova/vla/logs/test_accelerate_%j.err
# Create logs directory if it doesn't exist
mkdir -p /fsx/dana_aubakirova/vla/pr/lerobot/logs
# Activate conda environment
source /fsx/dana_aubakirova/miniconda3/etc/profile.d/conda.sh
conda activate multi
# 2-GPU Test CUDA environment
export CUDA_VISIBLE_DEVICES=0,1
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,expandable_segments:True
export TORCH_DISTRIBUTED_DEBUG=OFF
export NCCL_DEBUG=INFO
export CUDA_LAUNCH_BLOCKING=0
export ACCELERATE_USE_FSDP=false
export ACCELERATE_USE_DEEPSPEED=false
export HF_ACCELERATE_DEVICE_MAP=false
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
export SAFETENSORS_FAST_GPU=1
export HF_HUB_ENABLE_HF_TRANSFER=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ACCELERATE_TORCH_DEVICE_MAP_AUTO=false
# Change to working directory
cd /fsx/dana_aubakirova/vla/pr/lerobot
echo "=== Testing Accelerate Multi-GPU Training with SmolVLA ==="
echo "Dataset: lerobot/svla_so100_sorting"
echo "GPUs: 2"
echo "Steps: 100 (for quick test)"
echo "Job ID: $SLURM_JOB_ID"
echo ""
# Set output directory with job ID
export OUTPUT_DIR="outputs/test_accelerate_2gpu_job_${SLURM_JOB_ID}"
echo "Output directory: $OUTPUT_DIR"
echo ""
# Test accelerate training
accelerate launch --config_file accelerate_configs/2gpu_config_safe.yaml -m lerobot.scripts.train \
--policy.type=smolvla \
--policy.push_to_hub=false \
--dataset.repo_id=lerobot/svla_so100_sorting \
--dataset.video_backend=pyav \
--steps=100 \
--save_freq=50 \
--log_freq=5 \
--batch_size=2 \
--num_workers=0 \
--output_dir=$OUTPUT_DIR \
--wandb.enable=false
echo ""
echo "=== Training completed! ==="
echo "Check logs and outputs in: $OUTPUT_DIR"
echo "Job ID: $SLURM_JOB_ID"
+45
View File
@@ -0,0 +1,45 @@
#!/bin/bash
echo "=== Direct 1-GPU Training Test with SmolVLA (no accelerate) ==="
echo "Environment: multi"
echo "GPU: 1"
echo "Steps: 50 (quick local test)"
echo ""
# Activate conda environment
source /fsx/dana_aubakirova/miniconda3/etc/profile.d/conda.sh
conda activate multi
# Set CUDA environment for 1 GPU
export CUDA_VISIBLE_DEVICES=0
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,expandable_segments:True
export TORCH_DISTRIBUTED_DEBUG=OFF
export CUDA_LAUNCH_BLOCKING=0
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
# Change to working directory
cd /fsx/dana_aubakirova/vla/pr/lerobot
# Set output directory with timestamp
export OUTPUT_DIR="outputs/test_direct_1gpu_local_$(date +%Y%m%d_%H%M%S)"
echo "Output directory: $OUTPUT_DIR"
echo ""
# Test direct training with 1 GPU (no accelerate)
python -m lerobot.scripts.train \
--policy.path=lerobot/smolvla_base \
--policy.push_to_hub=false \
--dataset.repo_id=lerobot/svla_so100_sorting \
--dataset.video_backend=pyav \
--steps=50 \
--save_freq=25 \
--log_freq=5 \
--batch_size=1 \
--num_workers=0 \
--output_dir=$OUTPUT_DIR \
--wandb.enable=false
echo ""
echo "=== Training completed! ==="
echo "Check outputs in: $OUTPUT_DIR"