This commit is contained in:
Pepijn
2025-09-24 12:29:43 +02:00
parent 40e8aaa05b
commit 5bc06126c0
+11 -2
View File
@@ -165,7 +165,15 @@ def train(cfg: TrainPipelineConfig):
# Initialize Accelerate if enabled
accelerator = None
if cfg.use_accelerate:
logging.info(f"DEBUG: cfg.use_accelerate = {cfg.use_accelerate}")
# Auto-detect if we're using accelerate launch (fallback)
import os
using_accelerate_launch = "ACCELERATE_LAUNCH" in os.environ or "WORLD_SIZE" in os.environ
logging.info(f"DEBUG: Auto-detected accelerate launch = {using_accelerate_launch}")
if cfg.use_accelerate or using_accelerate_launch:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
@@ -174,8 +182,9 @@ def train(cfg: TrainPipelineConfig):
)
device = accelerator.device
if accelerator.is_main_process:
accelerate_source = "config" if cfg.use_accelerate else "auto-detected from accelerate launch"
logging.info(
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}"
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision} (source: {accelerate_source})"
)
logging.info(f"Training on {accelerator.num_processes} processes")
else: