Revert "fix"

This reverts commit 1ea65730ac.
This commit is contained in:
Pepijn
2025-09-24 11:55:44 +02:00
parent 1ea65730ac
commit 439d79fa11
2 changed files with 23 additions and 110 deletions
+5 -10
View File
@@ -92,18 +92,13 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 0.1 # Start at 10% instead of 0.1% of peak LR
if current_step >= self.num_warmup_steps:
return 1.0 # Reach peak at end of warmup
# Linear interpolation from 10% to 100% of peak LR
return 0.1 + 0.9 * (current_step / self.num_warmup_steps)
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
def cosine_decay_schedule(current_step):
# CRITICAL FIX: Decay should count from END of warmup, not from step 0!
decay_step = current_step - self.num_warmup_steps
decay_step = max(0, min(decay_step, self.num_decay_steps))
cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_step / self.num_decay_steps))
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
+18 -100
View File
@@ -20,8 +20,6 @@ from pprint import pformat
from typing import Any
import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
@@ -149,40 +147,17 @@ def train(cfg: TrainPipelineConfig):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
# Initialize Accelerate if requested
accelerator = None
if cfg.use_accelerate:
# Configure DDP to handle unused parameters
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
mixed_precision=cfg.mixed_precision,
kwargs_handlers=[ddp_kwargs],
)
device = accelerator.device
if accelerator.is_main_process:
logging.info(
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}"
)
logging.info(f"Training on {accelerator.num_processes} processes")
else:
# Check device is available (original behavior)
device = get_safe_torch_device(cfg.policy.device, log=True)
# Only create wandb logger on main process when using accelerate
if cfg.wandb.enable and cfg.wandb.project:
if accelerator is None or accelerator.is_main_process:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
if accelerator is None or accelerator.is_main_process:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -225,12 +200,6 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
# Prepare objects with Accelerate if enabled
if accelerator is not None:
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
if accelerator.is_main_process:
logging.info("Policy, optimizer, and scheduler prepared with Accelerate")
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -266,11 +235,6 @@ def train(cfg: TrainPipelineConfig):
drop_last=False,
prefetch_factor=2,
)
# Prepare dataloader with Accelerate if enabled
if accelerator is not None:
dataloader = accelerator.prepare(dataloader)
dl_iter = cycle(dataloader)
policy.train()
@@ -288,68 +252,22 @@ def train(cfg: TrainPipelineConfig):
)
logging.info("Start offline training on a fixed dataset")
# Calculate gradient accumulation steps for multi-GPU training
# This ensures effective batch size matches single-GPU training
gradient_accumulation_steps = accelerator.num_processes if accelerator is not None else 1
if accelerator and accelerator.is_main_process:
logging.info(f"Using gradient accumulation: {gradient_accumulation_steps} steps")
logging.info(f"Effective batch size: {cfg.batch_size} (same as single-GPU)")
for _ in range(step, cfg.steps):
policy.train()
optimizer.zero_grad()
start_time = time.perf_counter()
batch = next(dl_iter)
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
# Accumulate gradients over multiple mini-batches to match single-GPU effective batch size
accumulated_loss = 0
accumulated_output_dict = {}
for accum_step in range(gradient_accumulation_steps):
start_time = time.perf_counter()
batch = next(dl_iter)
batch = preprocessor(batch)
if accum_step == 0: # Only track data loading time once per step
train_tracker.dataloading_s = time.perf_counter() - start_time
# Forward pass
start_time = time.perf_counter()
with accelerator.autocast() if accelerator else nullcontext():
loss, output_dict = policy.forward(batch)
# Scale loss by accumulation steps to get proper average
loss = loss / gradient_accumulation_steps
# Backward pass
if accelerator:
accelerator.backward(loss)
else:
grad_scaler.scale(loss).backward()
# Accumulate metrics
accumulated_loss += loss.item()
if accum_step == 0:
accumulated_output_dict = output_dict
# Gradient clipping and optimizer step
if accelerator:
accelerator.clip_grad_norm_(policy.parameters(), cfg.optimizer.grad_clip_norm)
optimizer.step()
else:
grad_scaler.unscale_(optimizer)
_ = torch.nn.utils.clip_grad_norm_(
policy.parameters(), cfg.optimizer.grad_clip_norm, error_if_nonfinite=False
)
grad_scaler.step(optimizer)
grad_scaler.update()
# Update learning rate scheduler
if lr_scheduler is not None:
lr_scheduler.step()
# Update metrics with accumulated values
train_tracker.loss = accumulated_loss
train_tracker.lr = optimizer.param_groups[0]["lr"]
train_tracker.update_s = time.perf_counter() - start_time
output_dict = accumulated_output_dict
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.