main logging

This commit is contained in:
Pepijn
2025-10-10 15:01:27 +02:00
parent 8ebda30d1a
commit 63fcebd5a7
3 changed files with 26 additions and 29 deletions
+1 -1
View File
@@ -99,7 +99,7 @@ class WandBLogger:
cfg.wandb.run_id = run_id
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
+19 -25
View File
@@ -164,18 +164,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
"""
cfg.validate()
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
# Check if this is the main process (process_index == 0 or no accelerator)
is_main_process = not accelerator or (hasattr(accelerator, 'process_index') and accelerator.process_index == 0) or accelerator.is_main_process
if accelerator and not is_main_process:
# Disable WandB and logging on non-main processes.
cfg.wandb.enable = False
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info(pformat(cfg.to_dict()))
if cfg.wandb.enable and cfg.wandb.project:
# Only create WandB logger on main process
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
@@ -186,7 +190,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
@@ -195,11 +199,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
@@ -238,7 +242,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
**postprocessor_kwargs,
)
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
@@ -251,7 +255,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
@@ -309,7 +313,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
accelerator=accelerator,
)
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
@@ -334,19 +338,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
# increment `step` here.
step += 1
train_tracker.step()
is_log_step = (
cfg.log_freq > 0 and step % cfg.log_freq == 0 and (not accelerator or accelerator.is_main_process)
)
is_saving_step = (
step % cfg.save_freq == 0
or step == cfg.steps
and (not accelerator or accelerator.is_main_process)
)
is_eval_step = (
cfg.eval_freq > 0
and step % cfg.eval_freq == 0
and (not accelerator or accelerator.is_main_process)
)
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and is_main_process
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and is_main_process
if is_log_step:
logging.info(train_tracker)
@@ -426,7 +420,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
if eval_env:
close_envs(eval_env)
if not accelerator or accelerator.is_main_process:
if is_main_process:
logging.info("End of training")
if cfg.policy.push_to_hub:
+6 -3
View File
@@ -138,7 +138,8 @@ def init_logging(
logger.removeHandler(handler)
# Check if this is a non-main process in multi-GPU training
is_non_main_process = accelerator is not None and not accelerator.is_main_process
# Use process_index to be more explicit (main process is index 0)
is_non_main_process = accelerator is not None and hasattr(accelerator, 'process_index') and accelerator.process_index != 0
# Write logs to console (only for main process in multi-GPU training)
if not is_non_main_process:
@@ -147,8 +148,10 @@ def init_logging(
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
else:
# For non-main processes, set logger level to WARNING to suppress most logs
logger.setLevel(logging.WARNING)
# For non-main processes, add a NullHandler to completely suppress output
# and set level to ERROR to minimize any logging
logger.addHandler(logging.NullHandler())
logger.setLevel(logging.ERROR)
# Additionally write logs to file
if log_file is not None: