mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
fix
This commit is contained in:
@@ -165,7 +165,15 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
# Initialize Accelerate if enabled
|
||||
accelerator = None
|
||||
if cfg.use_accelerate:
|
||||
logging.info(f"DEBUG: cfg.use_accelerate = {cfg.use_accelerate}")
|
||||
|
||||
# Auto-detect if we're using accelerate launch (fallback)
|
||||
import os
|
||||
|
||||
using_accelerate_launch = "ACCELERATE_LAUNCH" in os.environ or "WORLD_SIZE" in os.environ
|
||||
logging.info(f"DEBUG: Auto-detected accelerate launch = {using_accelerate_launch}")
|
||||
|
||||
if cfg.use_accelerate or using_accelerate_launch:
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
@@ -174,8 +182,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
)
|
||||
device = accelerator.device
|
||||
if accelerator.is_main_process:
|
||||
accelerate_source = "config" if cfg.use_accelerate else "auto-detected from accelerate launch"
|
||||
logging.info(
|
||||
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}"
|
||||
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision} (source: {accelerate_source})"
|
||||
)
|
||||
logging.info(f"Training on {accelerator.num_processes} processes")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user