mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
use accelerate to determin logging
This commit is contained in:
@@ -151,6 +151,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
init_logging(accelerator=accelerator)
|
||||
|
||||
# Determine if this is the main process (for logging and checkpointing)
|
||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||
@@ -435,7 +437,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
train()
|
||||
|
||||
|
||||
|
||||
@@ -143,9 +143,8 @@ def init_logging(
|
||||
logger.handlers.clear()
|
||||
|
||||
# Determine if this is a non-main process in distributed training
|
||||
# Check LOCAL_RANK env var (set by accelerate/torchrun)
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
is_main_process = local_rank <= 0 # -1 means not distributed, 0 means main process
|
||||
if accelerator is not None:
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# Console logging (main process only)
|
||||
if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user