mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-27 13:17:10 +00:00
add some debugging
This commit is contained in:
@@ -463,7 +463,17 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
distributed_env_vars = {
|
||||
"LOCAL_RANK": os.environ.get("LOCAL_RANK", "NOT SET"),
|
||||
"WORLD_SIZE": os.environ.get("WORLD_SIZE", "NOT SET"),
|
||||
"RANK": os.environ.get("RANK", "NOT SET"),
|
||||
"ACCELERATE_MIXED_PRECISION": os.environ.get("ACCELERATE_MIXED_PRECISION", "NOT SET"),
|
||||
}
|
||||
print(f"[PID {os.getpid()}] Distributed env vars: {distributed_env_vars}")
|
||||
|
||||
if is_launched_with_accelerate():
|
||||
print(f"[PID {os.getpid()}] Detected distributed training mode")
|
||||
import accelerate
|
||||
|
||||
# We set step_scheduler_with_optimizer False to prevent accelerate from
|
||||
|
||||
@@ -179,19 +179,17 @@ def format_big_number(num, precision=0):
|
||||
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
"""Check if the script was launched using accelerate.
|
||||
"""Check if the script was launched in a distributed training context.
|
||||
|
||||
Accelerate sets several environment variables when launching distributed training.
|
||||
We check for variables that are always set regardless of the configuration.
|
||||
This checks for standard distributed training environment variables that are set
|
||||
by accelerate launch, torchrun, or torch.distributed.launch.
|
||||
|
||||
Returns:
|
||||
True if running in a distributed context, False otherwise.
|
||||
"""
|
||||
# Check for environment variables that accelerate always sets
|
||||
accelerate_env_vars = [
|
||||
"ACCELERATE_MIXED_PRECISION", # Set when mixed precision is configured
|
||||
"LOCAL_RANK", # Always set in distributed training
|
||||
"WORLD_SIZE", # Always set in distributed training
|
||||
"ACCELERATE_USE_CPU", # Set by accelerate
|
||||
]
|
||||
return any(var in os.environ for var in accelerate_env_vars)
|
||||
# Check for LOCAL_RANK which is the standard way to detect distributed training
|
||||
# This is set by accelerate, torchrun, and torch.distributed.launch
|
||||
return "LOCAL_RANK" in os.environ
|
||||
|
||||
|
||||
def say(text: str, blocking: bool = False):
|
||||
|
||||
Reference in New Issue
Block a user