diff --git a/docs/source/multi_gpu_training.mdx b/docs/source/multi_gpu_training.mdx index 1e26e8806..03afc6a3a 100644 --- a/docs/source/multi_gpu_training.mdx +++ b/docs/source/multi_gpu_training.mdx @@ -92,9 +92,10 @@ For faster training, you can enable mixed precision (fp16 or bf16). This is conf ## Notes - The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration. -- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. +- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output. - The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32. - Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes. - When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility. +- WandB integration automatically initializes only on the main process, preventing multiple runs from being created. For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index c01906bc9..18bb395b1 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -136,15 +136,18 @@ def init_logging( for handler in logger.handlers[:]: logger.removeHandler(handler) - # Write logs to console - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - console_handler.setLevel(console_level.upper()) - logger.addHandler(console_handler) - if accelerator is not None and not accelerator.is_main_process: - # Disable duplicate logging on non-main processes - logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.") - logging.getLogger().setLevel(logging.WARNING) + # Check if this is a non-main process in multi-GPU training + is_non_main_process = accelerator is not None and not accelerator.is_main_process + + # Write logs to console (only for main process in multi-GPU training) + if not is_non_main_process: + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + console_handler.setLevel(console_level.upper()) + logger.addHandler(console_handler) + else: + # For non-main processes, set logger level to WARNING to suppress most logs + logger.setLevel(logging.WARNING) # Additionally write logs to file if log_file is not None: