This commit is contained in:
Pepijn
2025-09-23 22:17:37 +02:00
parent 199f3b927b
commit b794fc3c70
+57 -26
View File
@@ -167,6 +167,10 @@ def train(cfg: TrainPipelineConfig):
cfg: A `TrainPipelineConfig` object containing all training configurations.
"""
cfg.validate()
# Only log config on main process when using accelerate
# For now we don't know if we're using accelerate yet, so we'll log this always
# and fix the duplicate later if needed
logging.info(pformat(cfg.to_dict()))
# Initialize Accelerate if requested
@@ -177,16 +181,25 @@ def train(cfg: TrainPipelineConfig):
mixed_precision=cfg.mixed_precision,
)
device = accelerator.device
logging.info(f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}")
if accelerator.is_main_process:
logging.info(
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}"
)
logging.info(f"Training on {accelerator.num_processes} processes")
else:
# Check device is available (original behavior)
device = get_safe_torch_device(cfg.policy.device, log=True)
# Only create wandb logger on main process
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
if accelerator is None or accelerator.is_main_process:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if accelerator is None or accelerator.is_main_process:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_seed(cfg.seed)
@@ -194,7 +207,8 @@ def train(cfg: TrainPipelineConfig):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
if accelerator is None or accelerator.is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -202,10 +216,12 @@ def train(cfg: TrainPipelineConfig):
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
if accelerator is None or accelerator.is_main_process:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy")
if accelerator is None or accelerator.is_main_process:
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
@@ -224,7 +240,8 @@ def train(cfg: TrainPipelineConfig):
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
)
logging.info("Creating optimizer and scheduler")
if accelerator is None or accelerator.is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
@@ -236,21 +253,24 @@ def train(cfg: TrainPipelineConfig):
accelerate_state_path = cfg.checkpoint_path / "accelerate_state"
if accelerate_state_path.exists():
accelerator.load_state(str(accelerate_state_path))
logging.info("Loaded Accelerate state from checkpoint")
if accelerator.is_main_process:
logging.info("Loaded Accelerate state from checkpoint")
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Only log setup info on main process
if accelerator is None or accelerator.is_main_process:
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
@@ -281,7 +301,8 @@ def train(cfg: TrainPipelineConfig):
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
logging.info("Policy, optimizer, dataloader, and scheduler prepared with Accelerate")
if accelerator.is_main_process:
logging.info("Policy, optimizer, dataloader, and scheduler prepared with Accelerate")
dl_iter = cycle(dataloader)
@@ -299,7 +320,8 @@ def train(cfg: TrainPipelineConfig):
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
logging.info("Start offline training on a fixed dataset")
if accelerator is None or accelerator.is_main_process:
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
# Handle gradient accumulation
if accelerator is not None:
@@ -347,16 +369,19 @@ def train(cfg: TrainPipelineConfig):
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
# Only log training metrics on main process
if accelerator is None or accelerator.is_main_process:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
if accelerator is None or accelerator.is_main_process:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
if accelerator is not None:
@@ -443,7 +468,13 @@ def train(cfg: TrainPipelineConfig):
if eval_env:
close_envs(eval_env)
logging.info("End of training")
if accelerator is None or accelerator.is_main_process:
logging.info("End of training")
# Synchronize all processes before finishing
if accelerator is not None:
accelerator.wait_for_everyone()
if cfg.policy.push_to_hub:
# Only push to hub from main process when using accelerate