From f8a185f7534f24e68c99901887c15b7015270efc Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 14 Oct 2025 17:05:47 +0200 Subject: [PATCH] cleanup logging --- docs/source/multi_gpu_training.mdx | 2 +- src/lerobot/utils/utils.py | 55 +++++++++++++++--------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/docs/source/multi_gpu_training.mdx b/docs/source/multi_gpu_training.mdx index 89ab2a6dc..7c78d3c6f 100644 --- a/docs/source/multi_gpu_training.mdx +++ b/docs/source/multi_gpu_training.mdx @@ -10,7 +10,7 @@ First, ensure you have accelerate installed: pip install accelerate ``` -## Training with Multiple GPUss +## Training with Multiple GPUs You can launch training in two ways: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 79a116a45..e0ddc50d4 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -21,6 +21,7 @@ import subprocess import sys import time from collections.abc import Callable +from accelerate import Accelerator from copy import copy, deepcopy from datetime import datetime from pathlib import Path @@ -113,52 +114,52 @@ def init_logging( display_pid: bool = False, console_level: str = "INFO", file_level: str = "DEBUG", - accelerator: Callable | None = None, + accelerator: Accelerator | None = None, ): + """Initialize logging configuration for LeRobot. + + In multi-GPU training, only the main process logs to console to avoid duplicate output. + Non-main processes have console logging suppressed but can still log to file. + + Args: + log_file: Optional file path to write logs to + display_pid: Include process ID in log messages (useful for debugging multi-process) + console_level: Logging level for console output + file_level: Logging level for file output + accelerator: Optional Accelerator instance (for multi-GPU detection) + """ def custom_format(record: logging.LogRecord) -> str: dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - - # NOTE: Display PID is useful for multi-process logging. - if display_pid: - pid_str = f"[PID: {os.getpid()}]" - message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}" - else: - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}" - return message + pid_str = f"[PID: {os.getpid()}] " if display_pid else "" + return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}" formatter = logging.Formatter() formatter.format = custom_format logger = logging.getLogger() - logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages - - # Remove unused default handlers - for handler in logger.handlers[:]: - logger.removeHandler(handler) - - # Check if this is a non-main process in multi-GPU training - # Check environment variables set by accelerate (more reliable than checking accelerator object) - local_rank = int(os.environ.get("LOCAL_RANK", -1)) - is_non_main_process = local_rank > 0 + logger.setLevel(logging.NOTSET) - # Fallback to accelerator object check if LOCAL_RANK not set - if local_rank == -1 and accelerator is not None: - is_non_main_process = hasattr(accelerator, 'process_index') and accelerator.process_index != 0 + # Clear any existing handlers + logger.handlers.clear() - # Write logs to console (only for main process in multi-GPU training) - if not is_non_main_process: + # Determine if this is a non-main process in distributed training + # Check LOCAL_RANK env var (set by accelerate/torchrun) + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + is_main_process = local_rank <= 0 # -1 means not distributed, 0 means main process + + # Console logging (main process only) + if is_main_process: console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) console_handler.setLevel(console_level.upper()) logger.addHandler(console_handler) else: - # For non-main processes, add a NullHandler to completely suppress output - # and set level to ERROR to minimize any logging + # Suppress console output for non-main processes logger.addHandler(logging.NullHandler()) logger.setLevel(logging.ERROR) - # Additionally write logs to file + # File logging (optional, all processes) if log_file is not None: file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter)