cleanup logging

This commit is contained in:
Pepijn
2025-10-14 17:05:47 +02:00
parent a66b50d372
commit f8a185f753
2 changed files with 29 additions and 28 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ First, ensure you have accelerate installed:
pip install accelerate pip install accelerate
``` ```
## Training with Multiple GPUss ## Training with Multiple GPUs
You can launch training in two ways: You can launch training in two ways:
+28 -27
View File
@@ -21,6 +21,7 @@ import subprocess
import sys import sys
import time import time
from collections.abc import Callable from collections.abc import Callable
from accelerate import Accelerator
from copy import copy, deepcopy from copy import copy, deepcopy
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -113,52 +114,52 @@ def init_logging(
display_pid: bool = False, display_pid: bool = False,
console_level: str = "INFO", console_level: str = "INFO",
file_level: str = "DEBUG", file_level: str = "DEBUG",
accelerator: Callable | None = None, accelerator: Accelerator | None = None,
): ):
"""Initialize logging configuration for LeRobot.
In multi-GPU training, only the main process logs to console to avoid duplicate output.
Non-main processes have console logging suppressed but can still log to file.
Args:
log_file: Optional file path to write logs to
display_pid: Include process ID in log messages (useful for debugging multi-process)
console_level: Logging level for console output
file_level: Logging level for file output
accelerator: Optional Accelerator instance (for multi-GPU detection)
"""
def custom_format(record: logging.LogRecord) -> str: def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}" fnameline = f"{record.pathname}:{record.lineno}"
pid_str = f"[PID: {os.getpid()}] " if display_pid else ""
# NOTE: Display PID is useful for multi-process logging. return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}"
if display_pid:
pid_str = f"[PID: {os.getpid()}]"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
else:
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
return message
formatter = logging.Formatter() formatter = logging.Formatter()
formatter.format = custom_format formatter.format = custom_format
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages logger.setLevel(logging.NOTSET)
# Remove unused default handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Check if this is a non-main process in multi-GPU training
# Check environment variables set by accelerate (more reliable than checking accelerator object)
local_rank = int(os.environ.get("LOCAL_RANK", -1))
is_non_main_process = local_rank > 0
# Fallback to accelerator object check if LOCAL_RANK not set # Clear any existing handlers
if local_rank == -1 and accelerator is not None: logger.handlers.clear()
is_non_main_process = hasattr(accelerator, 'process_index') and accelerator.process_index != 0
# Write logs to console (only for main process in multi-GPU training) # Determine if this is a non-main process in distributed training
if not is_non_main_process: # Check LOCAL_RANK env var (set by accelerate/torchrun)
local_rank = int(os.environ.get("LOCAL_RANK", -1))
is_main_process = local_rank <= 0 # -1 means not distributed, 0 means main process
# Console logging (main process only)
if is_main_process:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper()) console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler) logger.addHandler(console_handler)
else: else:
# For non-main processes, add a NullHandler to completely suppress output # Suppress console output for non-main processes
# and set level to ERROR to minimize any logging
logger.addHandler(logging.NullHandler()) logger.addHandler(logging.NullHandler())
logger.setLevel(logging.ERROR) logger.setLevel(logging.ERROR)
# Additionally write logs to file # File logging (optional, all processes)
if log_file is not None: if log_file is not None:
file_handler = logging.FileHandler(log_file) file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)