Revert "fix"

This reverts commit 1ea65730ac.
This commit is contained in:
Pepijn
2025-09-24 11:55:44 +02:00
parent 1ea65730ac
commit 439d79fa11
2 changed files with 23 additions and 110 deletions
+5 -10
View File
@@ -92,18 +92,13 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
def lr_lambda(current_step): def lr_lambda(current_step):
def linear_warmup_schedule(current_step): def linear_warmup_schedule(current_step):
if current_step <= 0: if current_step <= 0:
return 0.1 # Start at 10% instead of 0.1% of peak LR return 1 / (self.num_warmup_steps + 1)
if current_step >= self.num_warmup_steps: frac = 1 - current_step / self.num_warmup_steps
return 1.0 # Reach peak at end of warmup return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
# Linear interpolation from 10% to 100% of peak LR
return 0.1 + 0.9 * (current_step / self.num_warmup_steps)
def cosine_decay_schedule(current_step): def cosine_decay_schedule(current_step):
# CRITICAL FIX: Decay should count from END of warmup, not from step 0! step = min(current_step, self.num_decay_steps)
decay_step = current_step - self.num_warmup_steps cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
decay_step = max(0, min(decay_step, self.num_decay_steps))
cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha decayed = (1 - alpha) * cosine_decay + alpha
return decayed return decayed
+18 -100
View File
@@ -20,8 +20,6 @@ from pprint import pformat
from typing import Any from typing import Any
import torch import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from termcolor import colored from termcolor import colored
from torch.amp import GradScaler from torch.amp import GradScaler
from torch.optim import Optimizer from torch.optim import Optimizer
@@ -149,40 +147,17 @@ def train(cfg: TrainPipelineConfig):
cfg.validate() cfg.validate()
logging.info(pformat(cfg.to_dict())) logging.info(pformat(cfg.to_dict()))
# Initialize Accelerate if requested
accelerator = None
if cfg.use_accelerate:
# Configure DDP to handle unused parameters
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
mixed_precision=cfg.mixed_precision,
kwargs_handlers=[ddp_kwargs],
)
device = accelerator.device
if accelerator.is_main_process:
logging.info(
f"Accelerate initialized with device: {device}, mixed_precision: {cfg.mixed_precision}"
)
logging.info(f"Training on {accelerator.num_processes} processes")
else:
# Check device is available (original behavior)
device = get_safe_torch_device(cfg.policy.device, log=True)
# Only create wandb logger on main process when using accelerate
if cfg.wandb.enable and cfg.wandb.project: if cfg.wandb.enable and cfg.wandb.project:
if accelerator is None or accelerator.is_main_process: wandb_logger = WandBLogger(cfg)
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
else: else:
wandb_logger = None wandb_logger = None
if accelerator is None or accelerator.is_main_process: logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None: if cfg.seed is not None:
set_seed(cfg.seed) set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@@ -225,12 +200,6 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume: if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
# Prepare objects with Accelerate if enabled
if accelerator is not None:
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
if accelerator.is_main_process:
logging.info("Policy, optimizer, and scheduler prepared with Accelerate")
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@@ -266,11 +235,6 @@ def train(cfg: TrainPipelineConfig):
drop_last=False, drop_last=False,
prefetch_factor=2, prefetch_factor=2,
) )
# Prepare dataloader with Accelerate if enabled
if accelerator is not None:
dataloader = accelerator.prepare(dataloader)
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
@@ -288,68 +252,22 @@ def train(cfg: TrainPipelineConfig):
) )
logging.info("Start offline training on a fixed dataset") logging.info("Start offline training on a fixed dataset")
# Calculate gradient accumulation steps for multi-GPU training
# This ensures effective batch size matches single-GPU training
gradient_accumulation_steps = accelerator.num_processes if accelerator is not None else 1
if accelerator and accelerator.is_main_process:
logging.info(f"Using gradient accumulation: {gradient_accumulation_steps} steps")
logging.info(f"Effective batch size: {cfg.batch_size} (same as single-GPU)")
for _ in range(step, cfg.steps): for _ in range(step, cfg.steps):
policy.train() start_time = time.perf_counter()
optimizer.zero_grad() batch = next(dl_iter)
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
# Accumulate gradients over multiple mini-batches to match single-GPU effective batch size train_tracker, output_dict = update_policy(
accumulated_loss = 0 train_tracker,
accumulated_output_dict = {} policy,
batch,
for accum_step in range(gradient_accumulation_steps): optimizer,
start_time = time.perf_counter() cfg.optimizer.grad_clip_norm,
batch = next(dl_iter) grad_scaler=grad_scaler,
batch = preprocessor(batch) lr_scheduler=lr_scheduler,
if accum_step == 0: # Only track data loading time once per step use_amp=cfg.policy.use_amp,
train_tracker.dataloading_s = time.perf_counter() - start_time )
# Forward pass
start_time = time.perf_counter()
with accelerator.autocast() if accelerator else nullcontext():
loss, output_dict = policy.forward(batch)
# Scale loss by accumulation steps to get proper average
loss = loss / gradient_accumulation_steps
# Backward pass
if accelerator:
accelerator.backward(loss)
else:
grad_scaler.scale(loss).backward()
# Accumulate metrics
accumulated_loss += loss.item()
if accum_step == 0:
accumulated_output_dict = output_dict
# Gradient clipping and optimizer step
if accelerator:
accelerator.clip_grad_norm_(policy.parameters(), cfg.optimizer.grad_clip_norm)
optimizer.step()
else:
grad_scaler.unscale_(optimizer)
_ = torch.nn.utils.clip_grad_norm_(
policy.parameters(), cfg.optimizer.grad_clip_norm, error_if_nonfinite=False
)
grad_scaler.step(optimizer)
grad_scaler.update()
# Update learning rate scheduler
if lr_scheduler is not None:
lr_scheduler.step()
# Update metrics with accumulated values
train_tracker.loss = accumulated_loss
train_tracker.lr = optimizer.param_groups[0]["lr"]
train_tracker.update_s = time.perf_counter() - start_time
output_dict = accumulated_output_dict
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here. # increment `step` here.