add some debugging

This commit is contained in:
Pepijn
2025-10-14 14:13:50 +02:00
parent bb824f2275
commit d2687e9486
2 changed files with 19 additions and 11 deletions
+10
View File
@@ -463,7 +463,17 @@ def main():
if __name__ == "__main__":
import os
distributed_env_vars = {
"LOCAL_RANK": os.environ.get("LOCAL_RANK", "NOT SET"),
"WORLD_SIZE": os.environ.get("WORLD_SIZE", "NOT SET"),
"RANK": os.environ.get("RANK", "NOT SET"),
"ACCELERATE_MIXED_PRECISION": os.environ.get("ACCELERATE_MIXED_PRECISION", "NOT SET"),
}
print(f"[PID {os.getpid()}] Distributed env vars: {distributed_env_vars}")
if is_launched_with_accelerate():
print(f"[PID {os.getpid()}] Detected distributed training mode")
import accelerate
# We set step_scheduler_with_optimizer False to prevent accelerate from
+9 -11
View File
@@ -179,19 +179,17 @@ def format_big_number(num, precision=0):
def is_launched_with_accelerate() -> bool:
"""Check if the script was launched using accelerate.
"""Check if the script was launched in a distributed training context.
Accelerate sets several environment variables when launching distributed training.
We check for variables that are always set regardless of the configuration.
This checks for standard distributed training environment variables that are set
by accelerate launch, torchrun, or torch.distributed.launch.
Returns:
True if running in a distributed context, False otherwise.
"""
# Check for environment variables that accelerate always sets
accelerate_env_vars = [
"ACCELERATE_MIXED_PRECISION", # Set when mixed precision is configured
"LOCAL_RANK", # Always set in distributed training
"WORLD_SIZE", # Always set in distributed training
"ACCELERATE_USE_CPU", # Set by accelerate
]
return any(var in os.environ for var in accelerate_env_vars)
# Check for LOCAL_RANK which is the standard way to detect distributed training
# This is set by accelerate, torchrun, and torch.distributed.launch
return "LOCAL_RANK" in os.environ
def say(text: str, blocking: bool = False):