fix(train): prepare eval dataloader with accelerator for multi-GPU

Prepare eval_dataloader through accelerator.prepare() so eval data is
sharded across ranks instead of duplicated. Reduce eval_loss across
ranks with mean reduction for consistent logging.
This commit is contained in:
Khalil Meftah
2026-06-22 22:26:57 +02:00
parent 84961c5591
commit d99d4eda0f
+10 -3
View File
@@ -493,9 +493,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
if eval_dataloader is not None:
policy, optimizer, dataloader, lr_scheduler, eval_dataloader = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler, eval_dataloader
)
else:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
# FSDP optimizer state is sharded across ranks, so it can only be loaded once the optimizer and
# model are FSDP-wrapped (i.e. after `prepare`). Collective: every rank must participate.
@@ -614,6 +619,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
eval_loss_sum += loss.item()
n_eval_batches += 1
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
eval_loss = torch.tensor(eval_loss, device=device)
eval_loss = accelerator.reduce(eval_loss, reduction="mean").item()
policy.train()
if is_main_process: