scale all params better

This commit is contained in:
Pepijn
2025-09-24 09:47:05 +02:00
parent fc7998a3d5
commit 10acbe1069
2 changed files with 26 additions and 1 deletions
+1
View File
@@ -67,6 +67,7 @@ class TrainPipelineConfig(HubMixin):
use_accelerate: bool = False
gradient_accumulation_steps: int = 1
mixed_precision: str = "no" # Options: "no", "fp16", "bf16"
scale_lr_with_num_gpus: bool = True # Automatically scale learning rate with number of GPUs
def __post_init__(self):
self.checkpoint_path = None
+25 -1
View File
@@ -247,6 +247,18 @@ def train(cfg: TrainPipelineConfig):
if accelerator is None or accelerator.is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Scale learning rate for multi-GPU training
if accelerator is not None and accelerator.num_processes > 1 and cfg.scale_lr_with_num_gpus:
# Scale learning rate linearly with number of GPUs
original_lr = optimizer.param_groups[0]["lr"]
for param_group in optimizer.param_groups:
param_group["lr"] *= accelerator.num_processes
if accelerator.is_main_process:
logging.info(
f"Scaled learning rate by {accelerator.num_processes}x for multi-GPU training: {original_lr:.2e} -> {optimizer.param_groups[0]['lr']:.2e}"
)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
@@ -276,6 +288,15 @@ def train(cfg: TrainPipelineConfig):
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Log batch size and learning rate info
if accelerator is not None:
logging.info(f"Per-GPU batch size: {cfg.batch_size}")
logging.info(f"Effective batch size (total): {cfg.batch_size * accelerator.num_processes}")
logging.info(f"Learning rate (scaled): {optimizer.param_groups[0]['lr']:.2e}")
else:
logging.info(f"Batch size: {cfg.batch_size}")
logging.info(f"Learning rate: {optimizer.param_groups[0]['lr']:.2e}")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
@@ -320,8 +341,11 @@ def train(cfg: TrainPipelineConfig):
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
# Calculate effective batch size for metrics (total across all GPUs)
effective_batch_size = cfg.batch_size * (accelerator.num_processes if accelerator is not None else 1)
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
effective_batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
if accelerator is None or accelerator.is_main_process: