diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 3ad4dd993..91c7be854 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -165,7 +165,15 @@ def train(cfg: TrainPipelineConfig): # Initialize Accelerate if enabled accelerator = None - if cfg.use_accelerate: + logging.info(f"DEBUG: cfg.use_accelerate = {cfg.use_accelerate}") + + # Auto-detect if we're using accelerate launch (fallback) + import os + + using_accelerate_launch = "ACCELERATE_LAUNCH" in os.environ or "WORLD_SIZE" in os.environ + logging.info(f"DEBUG: Auto-detected accelerate launch = {using_accelerate_launch}") + + if cfg.use_accelerate or using_accelerate_launch: ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=cfg.gradient_accumulation_steps, @@ -174,8 +182,9 @@ def train(cfg: TrainPipelineConfig): ) device = accelerator.device if accelerator.is_main_process: + accelerate_source = "config" if cfg.use_accelerate else "auto-detected from accelerate launch" logging.info( - f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}" + f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision} (source: {accelerate_source})" ) logging.info(f"Training on {accelerator.num_processes} processes") else: