mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
fix
This commit is contained in:
@@ -165,7 +165,15 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
# Initialize Accelerate if enabled
|
# Initialize Accelerate if enabled
|
||||||
accelerator = None
|
accelerator = None
|
||||||
if cfg.use_accelerate:
|
logging.info(f"DEBUG: cfg.use_accelerate = {cfg.use_accelerate}")
|
||||||
|
|
||||||
|
# Auto-detect if we're using accelerate launch (fallback)
|
||||||
|
import os
|
||||||
|
|
||||||
|
using_accelerate_launch = "ACCELERATE_LAUNCH" in os.environ or "WORLD_SIZE" in os.environ
|
||||||
|
logging.info(f"DEBUG: Auto-detected accelerate launch = {using_accelerate_launch}")
|
||||||
|
|
||||||
|
if cfg.use_accelerate or using_accelerate_launch:
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
@@ -174,8 +182,9 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
)
|
)
|
||||||
device = accelerator.device
|
device = accelerator.device
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
|
accelerate_source = "config" if cfg.use_accelerate else "auto-detected from accelerate launch"
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}"
|
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision} (source: {accelerate_source})"
|
||||||
)
|
)
|
||||||
logging.info(f"Training on {accelerator.num_processes} processes")
|
logging.info(f"Training on {accelerator.num_processes} processes")
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user