fix(logging): correct multi-rank "max" metric reduction

accelerate.reduce only implements sum/mean (max silently returned the
SUM across ranks, inflating max-reduced metrics by num_processes). Gather
per-rank values and reduce explicitly for max/sum/mean.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-06-29 14:14:24 +00:00
parent ec5df4db7a
commit e1cf646e84
+13 -2
View File
@@ -176,10 +176,21 @@ class MetricsTracker:
if not buckets:
return
# NB: don't use ``accelerator.reduce(..., reduction="max")`` — accelerate only implements
# "sum"/"mean" (it always all-reduces with SUM and divides for "mean"), so "max" silently
# returns the SUM across ranks, inflating every "max" metric by ``num_processes`` (e.g. a
# 3.5s step reported as 28s on 8 GPUs). Gather per-rank values and reduce them explicitly.
device = self.accelerator.device
num_processes = self.accelerator.num_processes
for reduction, names in buckets.items():
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
reduced = self.accelerator.reduce(tensor, reduction=reduction)
local = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
gathered = self.accelerator.gather(local).view(num_processes, len(names))
if reduction == "max":
reduced = gathered.amax(dim=0)
elif reduction == "sum":
reduced = gathered.sum(dim=0)
else: # "mean"
reduced = gathered.mean(dim=0)
for name, value in zip(names, reduced.tolist(), strict=True):
meter = self.metrics[name]
# Preserve avg == sum / count so a later .update() on this meter accumulates