mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
main logging
This commit is contained in:
@@ -99,7 +99,7 @@ class WandBLogger:
|
||||
cfg.wandb.run_id = run_id
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
|
||||
@@ -164,18 +164,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
"""
|
||||
cfg.validate()
|
||||
|
||||
if accelerator and not accelerator.is_main_process:
|
||||
# Disable logging on non-main processes.
|
||||
# Check if this is the main process (process_index == 0 or no accelerator)
|
||||
is_main_process = not accelerator or (hasattr(accelerator, 'process_index') and accelerator.process_index == 0) or accelerator.is_main_process
|
||||
|
||||
if accelerator and not is_main_process:
|
||||
# Disable WandB and logging on non-main processes.
|
||||
cfg.wandb.enable = False
|
||||
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
# Only create WandB logger on main process
|
||||
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
if cfg.seed is not None:
|
||||
@@ -186,7 +190,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
@@ -195,11 +199,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
@@ -238,7 +242,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||
@@ -251,7 +255,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
@@ -309,7 +313,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
@@ -334,19 +338,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
train_tracker.step()
|
||||
is_log_step = (
|
||||
cfg.log_freq > 0 and step % cfg.log_freq == 0 and (not accelerator or accelerator.is_main_process)
|
||||
)
|
||||
is_saving_step = (
|
||||
step % cfg.save_freq == 0
|
||||
or step == cfg.steps
|
||||
and (not accelerator or accelerator.is_main_process)
|
||||
)
|
||||
is_eval_step = (
|
||||
cfg.eval_freq > 0
|
||||
and step % cfg.eval_freq == 0
|
||||
and (not accelerator or accelerator.is_main_process)
|
||||
)
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = (step % cfg.save_freq == 0 or step == cfg.steps) and is_main_process
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 and is_main_process
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
@@ -426,7 +420,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
if not accelerator or accelerator.is_main_process:
|
||||
if is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
|
||||
@@ -138,7 +138,8 @@ def init_logging(
|
||||
logger.removeHandler(handler)
|
||||
|
||||
# Check if this is a non-main process in multi-GPU training
|
||||
is_non_main_process = accelerator is not None and not accelerator.is_main_process
|
||||
# Use process_index to be more explicit (main process is index 0)
|
||||
is_non_main_process = accelerator is not None and hasattr(accelerator, 'process_index') and accelerator.process_index != 0
|
||||
|
||||
# Write logs to console (only for main process in multi-GPU training)
|
||||
if not is_non_main_process:
|
||||
@@ -147,8 +148,10 @@ def init_logging(
|
||||
console_handler.setLevel(console_level.upper())
|
||||
logger.addHandler(console_handler)
|
||||
else:
|
||||
# For non-main processes, set logger level to WARNING to suppress most logs
|
||||
logger.setLevel(logging.WARNING)
|
||||
# For non-main processes, add a NullHandler to completely suppress output
|
||||
# and set level to ERROR to minimize any logging
|
||||
logger.addHandler(logging.NullHandler())
|
||||
logger.setLevel(logging.ERROR)
|
||||
|
||||
# Additionally write logs to file
|
||||
if log_file is not None:
|
||||
|
||||
Reference in New Issue
Block a user