simplify accelerate main process detection

This commit is contained in:
Pepijn
2025-10-14 13:38:36 +02:00
parent 50ff388bf6
commit cabc47c5ad
+5 -14
View File
@@ -163,24 +163,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
"""
cfg.validate()
# Check if this is the main process
# Use LOCAL_RANK environment variable (set by accelerate) for reliable detection
local_rank = int(os.environ.get("LOCAL_RANK", -1))
if local_rank == -1:
# No LOCAL_RANK, check accelerator object or assume main process
is_main_process = not accelerator or (hasattr(accelerator, 'is_main_process') and accelerator.is_main_process)
else:
# LOCAL_RANK is set, main process is rank 0
is_main_process = local_rank == 0
if accelerator and not is_main_process:
# Disable WandB and logging on non-main processes.
cfg.wandb.enable = False
# Determine if this is the main process (for logging and checkpointing)
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process if accelerator else True
# Only log on main process
if is_main_process:
logging.info(pformat(cfg.to_dict()))
# Only create WandB logger on main process
# Initialize wandb only on main process
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
wandb_logger = WandBLogger(cfg)
else: