From 8a32764c3834c53b5ab2969fdd9b48be74c20dec Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 15 Oct 2025 13:04:34 +0200 Subject: [PATCH] use accelerate to determin logging --- src/lerobot/scripts/lerobot_train.py | 3 ++- src/lerobot/utils/utils.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index c738b89f8..524e68bfd 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -151,6 +151,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs]) + + init_logging(accelerator=accelerator) # Determine if this is the main process (for logging and checkpointing) # When using accelerate, only the main process should log to avoid duplicate outputs @@ -435,7 +437,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): def main(): - init_logging() train() diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 1614a40cb..dd685654f 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -143,9 +143,8 @@ def init_logging( logger.handlers.clear() # Determine if this is a non-main process in distributed training - # Check LOCAL_RANK env var (set by accelerate/torchrun) - local_rank = int(os.environ.get("LOCAL_RANK", -1)) - is_main_process = local_rank <= 0 # -1 means not distributed, 0 means main process + if accelerator is not None: + is_main_process = accelerator.is_main_process # Console logging (main process only) if is_main_process: