diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index edb88a2fb..b89ca9148 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -463,7 +463,17 @@ def main(): if __name__ == "__main__": + import os + distributed_env_vars = { + "LOCAL_RANK": os.environ.get("LOCAL_RANK", "NOT SET"), + "WORLD_SIZE": os.environ.get("WORLD_SIZE", "NOT SET"), + "RANK": os.environ.get("RANK", "NOT SET"), + "ACCELERATE_MIXED_PRECISION": os.environ.get("ACCELERATE_MIXED_PRECISION", "NOT SET"), + } + print(f"[PID {os.getpid()}] Distributed env vars: {distributed_env_vars}") + if is_launched_with_accelerate(): + print(f"[PID {os.getpid()}] Detected distributed training mode") import accelerate # We set step_scheduler_with_optimizer False to prevent accelerate from diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 1cdfed67c..3c9cf5145 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -179,19 +179,17 @@ def format_big_number(num, precision=0): def is_launched_with_accelerate() -> bool: - """Check if the script was launched using accelerate. + """Check if the script was launched in a distributed training context. - Accelerate sets several environment variables when launching distributed training. - We check for variables that are always set regardless of the configuration. + This checks for standard distributed training environment variables that are set + by accelerate launch, torchrun, or torch.distributed.launch. + + Returns: + True if running in a distributed context, False otherwise. """ - # Check for environment variables that accelerate always sets - accelerate_env_vars = [ - "ACCELERATE_MIXED_PRECISION", # Set when mixed precision is configured - "LOCAL_RANK", # Always set in distributed training - "WORLD_SIZE", # Always set in distributed training - "ACCELERATE_USE_CPU", # Set by accelerate - ] - return any(var in os.environ for var in accelerate_env_vars) + # Check for LOCAL_RANK which is the standard way to detect distributed training + # This is set by accelerate, torchrun, and torch.distributed.launch + return "LOCAL_RANK" in os.environ def say(text: str, blocking: bool = False):