diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 01cef9487..1537b3783 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -99,7 +99,7 @@ class WandBLogger: cfg.wandb.run_id = run_id # Handle custom step key for rl asynchronous training. self._wandb_custom_step_key: set[str] | None = None - print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 1d1f0adc8..1eed16963 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -164,18 +164,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): """ cfg.validate() - if accelerator and not accelerator.is_main_process: - # Disable logging on non-main processes. + # Check if this is the main process (process_index == 0 or no accelerator) + is_main_process = not accelerator or (hasattr(accelerator, 'process_index') and accelerator.process_index == 0) or accelerator.is_main_process + + if accelerator and not is_main_process: + # Disable WandB and logging on non-main processes. cfg.wandb.enable = False - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info(pformat(cfg.to_dict())) - if cfg.wandb.enable and cfg.wandb.project: + # Only create WandB logger on main process + if cfg.wandb.enable and cfg.wandb.project and is_main_process: wandb_logger = WandBLogger(cfg) else: wandb_logger = None - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: @@ -186,7 +190,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info("Creating dataset") dataset = make_dataset(cfg) @@ -195,11 +199,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None if cfg.eval_freq > 0 and cfg.env is not None: - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info("Creating policy") policy = make_policy( cfg=cfg.policy, @@ -238,7 +242,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): **postprocessor_kwargs, ) - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) @@ -251,7 +255,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if cfg.env is not None: logging.info(f"{cfg.env.task=}") @@ -309,7 +313,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): accelerator=accelerator, ) - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info("Start offline training on a fixed dataset") for _ in range(step, cfg.steps): @@ -334,19 +338,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): # increment `step` here. step += 1 train_tracker.step() - is_log_step = ( - cfg.log_freq > 0 and step % cfg.log_freq == 0 and (not accelerator or accelerator.is_main_process) - ) - is_saving_step = ( - step % cfg.save_freq == 0 - or step == cfg.steps - and (not accelerator or accelerator.is_main_process) - ) - is_eval_step = ( - cfg.eval_freq > 0 - and step % cfg.eval_freq == 0 - and (not accelerator or accelerator.is_main_process) - ) + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process + is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and is_main_process + is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and is_main_process if is_log_step: logging.info(train_tracker) @@ -426,7 +420,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None): if eval_env: close_envs(eval_env) - if not accelerator or accelerator.is_main_process: + if is_main_process: logging.info("End of training") if cfg.policy.push_to_hub: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index afc16f4ea..77d41a499 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -138,7 +138,8 @@ def init_logging( logger.removeHandler(handler) # Check if this is a non-main process in multi-GPU training - is_non_main_process = accelerator is not None and not accelerator.is_main_process + # Use process_index to be more explicit (main process is index 0) + is_non_main_process = accelerator is not None and hasattr(accelerator, 'process_index') and accelerator.process_index != 0 # Write logs to console (only for main process in multi-GPU training) if not is_non_main_process: @@ -147,8 +148,10 @@ def init_logging( console_handler.setLevel(console_level.upper()) logger.addHandler(console_handler) else: - # For non-main processes, set logger level to WARNING to suppress most logs - logger.setLevel(logging.WARNING) + # For non-main processes, add a NullHandler to completely suppress output + # and set level to ERROR to minimize any logging + logger.addHandler(logging.NullHandler()) + logger.setLevel(logging.ERROR) # Additionally write logs to file if log_file is not None: