fix(logging): avoid double-counting samples across processes (#3045)

This commit is contained in:
Pepijn
2026-02-27 17:45:19 +01:00
committed by GitHub
parent baf9b50365
commit 04de496547
3 changed files with 42 additions and 4 deletions
+2 -2
View File
@@ -380,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
# Use effective batch size for proper epoch calculation in distributed training
# Keep global batch size for logging; MetricsTracker handles world size internally.
effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
effective_batch_size,
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
train_metrics,
+4 -2
View File
@@ -104,9 +104,10 @@ class MetricsTracker:
self.metrics = metrics
self.steps = initial_step
world_size = accelerator.num_processes if accelerator else 1
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
self.samples = self.steps * self._batch_size
self.samples = self.steps * self._batch_size * world_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
self.accelerator = accelerator
@@ -132,7 +133,8 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
world_size = self.accelerator.num_processes if self.accelerator else 1
self.samples += self._batch_size * world_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames