feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup (#3773)

* feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup

* chore: address feedback
This commit is contained in:
Steven Palma
2026-06-11 19:07:28 +02:00
committed by GitHub
parent 6fbcf67249
commit 1edc83a0ef
5 changed files with 177 additions and 24 deletions
+3 -2
View File
@@ -115,7 +115,7 @@ dataset = [
]
training = [
"lerobot[dataset]",
"accelerate>=1.10.0,<2.0.0",
"lerobot[accelerate-dep]",
"wandb>=0.24.0,<0.25.0",
]
hardware = [
@@ -142,6 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
@@ -199,7 +200,7 @@ wallx = [
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
"lerobot[transformers-dep]",
+37 -14
View File
@@ -99,6 +99,9 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Compute sample weights if a weighter is provided
sample_weights = None
weight_stats = None
@@ -158,6 +161,8 @@ def update_policy(
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
if torch.cuda.is_available():
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
return train_metrics, output_dict
@@ -434,12 +439,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
# is actually optimizing. grad_norm and lr are already identical on every rank (post
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
# true straggler instead of rank 0's view.
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
# Derived from the post-reduce max step time; set once per log window on the main rank.
"samples_per_s": AverageMeter("smp/s", ":.0f"),
}
if torch.cuda.is_available():
# max() because headroom is gated by the worst-case rank.
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
# Keep global batch size for logging; MetricsTracker handles world size internally.
effective_batch_size = cfg.batch_size * accelerator.num_processes
@@ -491,21 +506,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
# Collective reduce must run on every rank, before the main-process gate below.
train_tracker.reduce_across_ranks()
if is_main_process:
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
# reflects the slowest rank — which is what actually gates the next iteration.
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
if step_time > 0:
train_tracker.samples_per_s = effective_batch_size / step_time
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
+50 -1
View File
@@ -13,21 +13,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Callable
from typing import Any
import torch
from .utils import format_big_number
_VALID_REDUCTIONS = ("none", "max", "mean", "sum")
class AverageMeter:
"""
Computes and stores the average and current value
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
Args:
name: Display name of the metric.
fmt: Format string used when rendering the metric.
reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks`
before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``,
or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or
update wall time) so multi-GPU runs report the slowest rank rather than rank 0.
"""
def __init__(self, name: str, fmt: str = ":f"):
def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"):
if reduction not in _VALID_REDUCTIONS:
raise ValueError(
f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}."
)
self.name = name
self.fmt = fmt
self.reduction = reduction
self.reset()
def reset(self) -> None:
@@ -138,6 +156,37 @@ class MetricsTracker:
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def reduce_across_ranks(self) -> None:
"""
Synchronises the running averages of every metric whose ``reduction`` is not ``"none"``
across all distributed processes (in-place).
This is a collective operation and MUST be invoked on every rank — typically just before
logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics
reported by the main process only reflect rank 0; for bottleneck-style timings
(``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible.
"""
if self.accelerator is None or self.accelerator.num_processes <= 1:
return
buckets: dict[str, list[str]] = defaultdict(list)
for name, meter in self.metrics.items():
if meter.reduction != "none":
buckets[meter.reduction].append(name)
if not buckets:
return
device = self.accelerator.device
for reduction, names in buckets.items():
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
reduced = self.accelerator.reduce(tensor, reduction=reduction)
for name, value in zip(names, reduced.tolist(), strict=True):
meter = self.metrics[name]
# Preserve avg == sum / count so a later .update() on this meter accumulates
# against the cluster view, not the stale per-rank history.
meter.avg = value
meter.sum = value * meter.count
def __str__(self) -> str:
display_list = [
f"step:{format_big_number(self.steps)}",
+77 -1
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import pytest
import torch
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
@@ -25,8 +26,16 @@ def mock_metrics():
class MockAccelerator:
def __init__(self, num_processes: int):
def __init__(self, num_processes: int, reduce_fn=None):
self.num_processes = num_processes
self.device = torch.device("cpu")
self._reduce_fn = reduce_fn
def reduce(self, tensor, reduction="mean"):
# In single-process tests we just want a deterministic stand-in for accelerate's reduce.
if self._reduce_fn is not None:
return self._reduce_fn(tensor, reduction)
return tensor
def test_average_meter_initialization():
@@ -157,3 +166,70 @@ def test_metrics_tracker_reset_averages(mock_metrics):
tracker.reset_averages()
assert tracker.loss.avg == 0.0
assert tracker.accuracy.avg == 0.0
def test_average_meter_invalid_reduction():
with pytest.raises(ValueError):
AverageMeter("loss", reduction="median")
def test_average_meter_reduction_stored():
meter = AverageMeter("updt_s", reduction="max")
assert meter.reduction == "max"
def test_metrics_tracker_reduce_across_ranks_no_accelerator():
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics)
tracker.update_s = 0.5
tracker.reduce_across_ranks() # no-op without accelerator
assert tracker.update_s.avg == 0.5
def test_metrics_tracker_reduce_across_ranks_single_process():
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=metrics,
accelerator=MockAccelerator(num_processes=1),
)
tracker.update_s = 0.5
tracker.reduce_across_ranks() # no-op when world size is 1
assert tracker.update_s.avg == 0.5
def test_metrics_tracker_reduce_across_ranks_invokes_reduce():
captured = {}
def fake_reduce(tensor, reduction):
captured["reduction"] = reduction
captured["values"] = tensor.clone()
# Pretend the slowest rank reported 0.9 instead of this rank's 0.4.
return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device)
metrics = {
"loss": AverageMeter("loss"), # reduction="none" -> not touched
"update_s": AverageMeter("update_s", reduction="max"),
}
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=metrics,
accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce),
)
tracker.loss = 1.0
tracker.update_s = 0.4
tracker.reduce_across_ranks()
assert captured["reduction"] == "max"
assert torch.allclose(captured["values"], torch.tensor([0.4]))
assert tracker.update_s.avg == pytest.approx(0.9)
# Metrics without a reduction stay untouched.
assert tracker.loss.avg == 1.0
# Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls
# accumulate against the cluster view rather than the stale per-rank sum.
meter = tracker.update_s
assert meter.sum / meter.count == pytest.approx(meter.avg)
Generated
+10 -6
View File
@@ -59,7 +59,7 @@ wheels = [
[[package]]
name = "accelerate"
version = "1.13.0"
version = "1.14.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
@@ -71,9 +71,9 @@ dependencies = [
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" },
{ name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" }
sdist = { url = "https://files.pythonhosted.org/packages/8d/75/94cd5d389649578aca399e5aa822637eec18319a1dadc400ffe2f9a7493f/accelerate-1.14.0.tar.gz", hash = "sha256:41b9c4377a54e0b460a959b0defa1b736e4ca0a2373252d9a539964c2afe3c8d", size = 412167, upload-time = "2026-06-11T13:45:52.326Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" },
{ url = "https://files.pythonhosted.org/packages/a8/db/253133d7e7cb40d3af384bb2f5c0b4a2b7fdcffbc95c688cc67a20a3c103/accelerate-1.14.0-py3-none-any.whl", hash = "sha256:e94390c2863b873be18f623f9df48a0d8fe5eff13ea7f1a00092b0a7904888c6", size = 389246, upload-time = "2026-06-11T13:45:50.477Z" },
]
[[package]]
@@ -2687,6 +2687,9 @@ dependencies = [
]
[package.optional-dependencies]
accelerate-dep = [
{ name = "accelerate" },
]
all = [
{ name = "accelerate" },
{ name = "av" },
@@ -3073,8 +3076,7 @@ xvla = [
[package.metadata]
requires-dist = [
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
{ name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
{ name = "accelerate", marker = "extra == 'accelerate-dep'", specifier = ">=1.14.0,<2.0.0" },
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
@@ -3104,6 +3106,8 @@ requires-dist = [
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
{ name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'smolvla'" },
{ name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'training'" },
{ name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["async"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" },
@@ -3279,7 +3283,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "accelerate-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"