only log in main process

This commit is contained in:
Pepijn
2025-10-10 14:05:53 +02:00
parent 771b03c30d
commit b65172f819
2 changed files with 14 additions and 10 deletions
+2 -1
View File
@@ -92,9 +92,10 @@ For faster training, you can enable mixed precision (fp16 or bf16). This is conf
## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts.
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate).
+12 -9
View File
@@ -136,15 +136,18 @@ def init_logging(
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Write logs to console
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
# Check if this is a non-main process in multi-GPU training
is_non_main_process = accelerator is not None and not accelerator.is_main_process
# Write logs to console (only for main process in multi-GPU training)
if not is_non_main_process:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
else:
# For non-main processes, set logger level to WARNING to suppress most logs
logger.setLevel(logging.WARNING)
# Additionally write logs to file
if log_file is not None: