use accelerate to determin logging

This commit is contained in:
Pepijn
2025-10-15 13:04:34 +02:00
parent c775d8d07a
commit 8a32764c38
2 changed files with 4 additions and 4 deletions
+2 -1
View File
@@ -151,6 +151,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
init_logging(accelerator=accelerator)
# Determine if this is the main process (for logging and checkpointing)
# When using accelerate, only the main process should log to avoid duplicate outputs
@@ -435,7 +437,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
def main():
init_logging()
train()
+2 -3
View File
@@ -143,9 +143,8 @@ def init_logging(
logger.handlers.clear()
# Determine if this is a non-main process in distributed training
# Check LOCAL_RANK env var (set by accelerate/torchrun)
local_rank = int(os.environ.get("LOCAL_RANK", -1))
is_main_process = local_rank <= 0 # -1 means not distributed, 0 means main process
if accelerator is not None:
is_main_process = accelerator.is_main_process
# Console logging (main process only)
if is_main_process: