mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
cleanup logging
This commit is contained in:
@@ -10,7 +10,7 @@ First, ensure you have accelerate installed:
|
|||||||
pip install accelerate
|
pip install accelerate
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training with Multiple GPUss
|
## Training with Multiple GPUs
|
||||||
|
|
||||||
You can launch training in two ways:
|
You can launch training in two ways:
|
||||||
|
|
||||||
|
|||||||
+28
-27
@@ -21,6 +21,7 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from accelerate import Accelerator
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -113,52 +114,52 @@ def init_logging(
|
|||||||
display_pid: bool = False,
|
display_pid: bool = False,
|
||||||
console_level: str = "INFO",
|
console_level: str = "INFO",
|
||||||
file_level: str = "DEBUG",
|
file_level: str = "DEBUG",
|
||||||
accelerator: Callable | None = None,
|
accelerator: Accelerator | None = None,
|
||||||
):
|
):
|
||||||
|
"""Initialize logging configuration for LeRobot.
|
||||||
|
|
||||||
|
In multi-GPU training, only the main process logs to console to avoid duplicate output.
|
||||||
|
Non-main processes have console logging suppressed but can still log to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file: Optional file path to write logs to
|
||||||
|
display_pid: Include process ID in log messages (useful for debugging multi-process)
|
||||||
|
console_level: Logging level for console output
|
||||||
|
file_level: Logging level for file output
|
||||||
|
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
||||||
|
"""
|
||||||
def custom_format(record: logging.LogRecord) -> str:
|
def custom_format(record: logging.LogRecord) -> str:
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
fnameline = f"{record.pathname}:{record.lineno}"
|
fnameline = f"{record.pathname}:{record.lineno}"
|
||||||
|
pid_str = f"[PID: {os.getpid()}] " if display_pid else ""
|
||||||
# NOTE: Display PID is useful for multi-process logging.
|
return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
||||||
if display_pid:
|
|
||||||
pid_str = f"[PID: {os.getpid()}]"
|
|
||||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
|
||||||
else:
|
|
||||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
|
||||||
return message
|
|
||||||
|
|
||||||
formatter = logging.Formatter()
|
formatter = logging.Formatter()
|
||||||
formatter.format = custom_format
|
formatter.format = custom_format
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
# Remove unused default handlers
|
|
||||||
for handler in logger.handlers[:]:
|
|
||||||
logger.removeHandler(handler)
|
|
||||||
|
|
||||||
# Check if this is a non-main process in multi-GPU training
|
|
||||||
# Check environment variables set by accelerate (more reliable than checking accelerator object)
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
||||||
is_non_main_process = local_rank > 0
|
|
||||||
|
|
||||||
# Fallback to accelerator object check if LOCAL_RANK not set
|
# Clear any existing handlers
|
||||||
if local_rank == -1 and accelerator is not None:
|
logger.handlers.clear()
|
||||||
is_non_main_process = hasattr(accelerator, 'process_index') and accelerator.process_index != 0
|
|
||||||
|
|
||||||
# Write logs to console (only for main process in multi-GPU training)
|
# Determine if this is a non-main process in distributed training
|
||||||
if not is_non_main_process:
|
# Check LOCAL_RANK env var (set by accelerate/torchrun)
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||||
|
is_main_process = local_rank <= 0 # -1 means not distributed, 0 means main process
|
||||||
|
|
||||||
|
# Console logging (main process only)
|
||||||
|
if is_main_process:
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
console_handler.setLevel(console_level.upper())
|
console_handler.setLevel(console_level.upper())
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
else:
|
else:
|
||||||
# For non-main processes, add a NullHandler to completely suppress output
|
# Suppress console output for non-main processes
|
||||||
# and set level to ERROR to minimize any logging
|
|
||||||
logger.addHandler(logging.NullHandler())
|
logger.addHandler(logging.NullHandler())
|
||||||
logger.setLevel(logging.ERROR)
|
logger.setLevel(logging.ERROR)
|
||||||
|
|
||||||
# Additionally write logs to file
|
# File logging (optional, all processes)
|
||||||
if log_file is not None:
|
if log_file is not None:
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
|||||||
Reference in New Issue
Block a user