diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index aed53a53c..f58d13f4b 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -167,6 +167,10 @@ def train(cfg: TrainPipelineConfig): cfg: A `TrainPipelineConfig` object containing all training configurations. """ cfg.validate() + + # Only log config on main process when using accelerate + # For now we don't know if we're using accelerate yet, so we'll log this always + # and fix the duplicate later if needed logging.info(pformat(cfg.to_dict())) # Initialize Accelerate if requested @@ -177,16 +181,25 @@ def train(cfg: TrainPipelineConfig): mixed_precision=cfg.mixed_precision, ) device = accelerator.device - logging.info(f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}") + if accelerator.is_main_process: + logging.info( + f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}" + ) + logging.info(f"Training on {accelerator.num_processes} processes") else: # Check device is available (original behavior) device = get_safe_torch_device(cfg.policy.device, log=True) + # Only create wandb logger on main process if cfg.wandb.enable and cfg.wandb.project: - wandb_logger = WandBLogger(cfg) + if accelerator is None or accelerator.is_main_process: + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None else: wandb_logger = None - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + if accelerator is None or accelerator.is_main_process: + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: set_seed(cfg.seed) @@ -194,7 +207,8 @@ def train(cfg: TrainPipelineConfig): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - logging.info("Creating dataset") + if accelerator is None or accelerator.is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. @@ -202,10 +216,12 @@ def train(cfg: TrainPipelineConfig): # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None if cfg.eval_freq > 0 and cfg.env is not None: - logging.info("Creating env") + if accelerator is None or accelerator.is_main_process: + logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - logging.info("Creating policy") + if accelerator is None or accelerator.is_main_process: + logging.info("Creating policy") policy = make_policy( cfg=cfg.policy, ds_meta=dataset.meta, @@ -224,7 +240,8 @@ def train(cfg: TrainPipelineConfig): policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs ) - logging.info("Creating optimizer and scheduler") + if accelerator is None or accelerator.is_main_process: + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) @@ -236,21 +253,24 @@ def train(cfg: TrainPipelineConfig): accelerate_state_path = cfg.checkpoint_path / "accelerate_state" if accelerate_state_path.exists(): accelerator.load_state(str(accelerate_state_path)) - logging.info("Loaded Accelerate state from checkpoint") + if accelerator.is_main_process: + logging.info("Loaded Accelerate state from checkpoint") step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") - if cfg.env is not None: - logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") - logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") - logging.info(f"{dataset.num_episodes=}") - logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") - logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + # Only log setup info on main process + if accelerator is None or accelerator.is_main_process: + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + if cfg.env is not None: + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") + logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") + logging.info(f"{dataset.num_episodes=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): @@ -281,7 +301,8 @@ def train(cfg: TrainPipelineConfig): policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( policy, optimizer, dataloader, lr_scheduler ) - logging.info("Policy, optimizer, dataloader, and scheduler prepared with Accelerate") + if accelerator.is_main_process: + logging.info("Policy, optimizer, dataloader, and scheduler prepared with Accelerate") dl_iter = cycle(dataloader) @@ -299,7 +320,8 @@ def train(cfg: TrainPipelineConfig): cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step ) - logging.info("Start offline training on a fixed dataset") + if accelerator is None or accelerator.is_main_process: + logging.info("Start offline training on a fixed dataset") for _ in range(step, cfg.steps): # Handle gradient accumulation if accelerator is not None: @@ -347,16 +369,19 @@ def train(cfg: TrainPipelineConfig): is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 if is_log_step: - logging.info(train_tracker) - if wandb_logger: - wandb_log_dict = train_tracker.to_dict() - if output_dict: - wandb_log_dict.update(output_dict) - wandb_logger.log_dict(wandb_log_dict, step) + # Only log training metrics on main process + if accelerator is None or accelerator.is_main_process: + logging.info(train_tracker) + if wandb_logger: + wandb_log_dict = train_tracker.to_dict() + if output_dict: + wandb_log_dict.update(output_dict) + wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: - logging.info(f"Checkpoint policy after step {step}") + if accelerator is None or accelerator.is_main_process: + logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) if accelerator is not None: @@ -443,7 +468,13 @@ def train(cfg: TrainPipelineConfig): if eval_env: close_envs(eval_env) - logging.info("End of training") + + if accelerator is None or accelerator.is_main_process: + logging.info("End of training") + + # Synchronize all processes before finishing + if accelerator is not None: + accelerator.wait_for_everyone() if cfg.policy.push_to_hub: # Only push to hub from main process when using accelerate