diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 1eed16963..77da7899f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os import time from collections.abc import Callable from contextlib import nullcontext @@ -164,8 +165,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): """ cfg.validate() - # Check if this is the main process (process_index == 0 or no accelerator) - is_main_process = not accelerator or (hasattr(accelerator, 'process_index') and accelerator.process_index == 0) or accelerator.is_main_process + # Check if this is the main process + # Use LOCAL_RANK environment variable (set by accelerate) for reliable detection + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if local_rank == -1: + # No LOCAL_RANK, check accelerator object or assume main process + is_main_process = not accelerator or (hasattr(accelerator, 'is_main_process') and accelerator.is_main_process) + else: + # LOCAL_RANK is set, main process is rank 0 + is_main_process = local_rank == 0 if accelerator and not is_main_process: # Disable WandB and logging on non-main processes. diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 77d41a499..b32decbb9 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -138,8 +138,13 @@ def init_logging( logger.removeHandler(handler) # Check if this is a non-main process in multi-GPU training - # Use process_index to be more explicit (main process is index 0) - is_non_main_process = accelerator is not None and hasattr(accelerator, 'process_index') and accelerator.process_index != 0 + # Check environment variables set by accelerate (more reliable than checking accelerator object) + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + is_non_main_process = local_rank > 0 + + # Fallback to accelerator object check if LOCAL_RANK not set + if local_rank == -1 and accelerator is not None: + is_non_main_process = hasattr(accelerator, 'process_index') and accelerator.process_index != 0 # Write logs to console (only for main process in multi-GPU training) if not is_non_main_process: