mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup (#3773)
* feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup * chore: address feedback
This commit is contained in:
+3
-2
@@ -115,7 +115,7 @@ dataset = [
|
||||
]
|
||||
training = [
|
||||
"lerobot[dataset]",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
"lerobot[accelerate-dep]",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
]
|
||||
hardware = [
|
||||
@@ -142,6 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -199,7 +200,7 @@ wallx = [
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
|
||||
@@ -99,6 +99,9 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
@@ -158,6 +161,8 @@ def update_policy(
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
if torch.cuda.is_available():
|
||||
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@@ -434,12 +439,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
|
||||
# is actually optimizing. grad_norm and lr are already identical on every rank (post
|
||||
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
|
||||
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
|
||||
# true straggler instead of rank 0's view.
|
||||
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
|
||||
# Derived from the post-reduce max step time; set once per log window on the main rank.
|
||||
"samples_per_s": AverageMeter("smp/s", ":.0f"),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
# max() because headroom is gated by the worst-case rank.
|
||||
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
|
||||
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
@@ -491,21 +506,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
# Collective reduce must run on every rank, before the main-process gate below.
|
||||
train_tracker.reduce_across_ranks()
|
||||
if is_main_process:
|
||||
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
|
||||
# reflects the slowest rank — which is what actually gates the next iteration.
|
||||
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
|
||||
if step_time > 0:
|
||||
train_tracker.samples_per_s = effective_batch_size / step_time
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
|
||||
@@ -13,21 +13,39 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import format_big_number
|
||||
|
||||
_VALID_REDUCTIONS = ("none", "max", "mean", "sum")
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
||||
|
||||
Args:
|
||||
name: Display name of the metric.
|
||||
fmt: Format string used when rendering the metric.
|
||||
reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks`
|
||||
before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``,
|
||||
or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or
|
||||
update wall time) so multi-GPU runs report the slowest rank rather than rank 0.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"):
|
||||
if reduction not in _VALID_REDUCTIONS:
|
||||
raise ValueError(
|
||||
f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}."
|
||||
)
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reduction = reduction
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -138,6 +156,37 @@ class MetricsTracker:
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def reduce_across_ranks(self) -> None:
|
||||
"""
|
||||
Synchronises the running averages of every metric whose ``reduction`` is not ``"none"``
|
||||
across all distributed processes (in-place).
|
||||
|
||||
This is a collective operation and MUST be invoked on every rank — typically just before
|
||||
logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics
|
||||
reported by the main process only reflect rank 0; for bottleneck-style timings
|
||||
(``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible.
|
||||
"""
|
||||
if self.accelerator is None or self.accelerator.num_processes <= 1:
|
||||
return
|
||||
|
||||
buckets: dict[str, list[str]] = defaultdict(list)
|
||||
for name, meter in self.metrics.items():
|
||||
if meter.reduction != "none":
|
||||
buckets[meter.reduction].append(name)
|
||||
if not buckets:
|
||||
return
|
||||
|
||||
device = self.accelerator.device
|
||||
for reduction, names in buckets.items():
|
||||
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
|
||||
reduced = self.accelerator.reduce(tensor, reduction=reduction)
|
||||
for name, value in zip(names, reduced.tolist(), strict=True):
|
||||
meter = self.metrics[name]
|
||||
# Preserve avg == sum / count so a later .update() on this meter accumulates
|
||||
# against the cluster view, not the stale per-rank history.
|
||||
meter.avg = value
|
||||
meter.sum = value * meter.count
|
||||
|
||||
def __str__(self) -> str:
|
||||
display_list = [
|
||||
f"step:{format_big_number(self.steps)}",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
@@ -25,8 +26,16 @@ def mock_metrics():
|
||||
|
||||
|
||||
class MockAccelerator:
|
||||
def __init__(self, num_processes: int):
|
||||
def __init__(self, num_processes: int, reduce_fn=None):
|
||||
self.num_processes = num_processes
|
||||
self.device = torch.device("cpu")
|
||||
self._reduce_fn = reduce_fn
|
||||
|
||||
def reduce(self, tensor, reduction="mean"):
|
||||
# In single-process tests we just want a deterministic stand-in for accelerate's reduce.
|
||||
if self._reduce_fn is not None:
|
||||
return self._reduce_fn(tensor, reduction)
|
||||
return tensor
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
@@ -157,3 +166,70 @@ def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker.reset_averages()
|
||||
assert tracker.loss.avg == 0.0
|
||||
assert tracker.accuracy.avg == 0.0
|
||||
|
||||
|
||||
def test_average_meter_invalid_reduction():
|
||||
with pytest.raises(ValueError):
|
||||
AverageMeter("loss", reduction="median")
|
||||
|
||||
|
||||
def test_average_meter_reduction_stored():
|
||||
meter = AverageMeter("updt_s", reduction="max")
|
||||
assert meter.reduction == "max"
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_no_accelerator():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op without accelerator
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_single_process():
|
||||
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=1),
|
||||
)
|
||||
tracker.update_s = 0.5
|
||||
tracker.reduce_across_ranks() # no-op when world size is 1
|
||||
assert tracker.update_s.avg == 0.5
|
||||
|
||||
|
||||
def test_metrics_tracker_reduce_across_ranks_invokes_reduce():
|
||||
captured = {}
|
||||
|
||||
def fake_reduce(tensor, reduction):
|
||||
captured["reduction"] = reduction
|
||||
captured["values"] = tensor.clone()
|
||||
# Pretend the slowest rank reported 0.9 instead of this rank's 0.4.
|
||||
return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
metrics = {
|
||||
"loss": AverageMeter("loss"), # reduction="none" -> not touched
|
||||
"update_s": AverageMeter("update_s", reduction="max"),
|
||||
}
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=metrics,
|
||||
accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce),
|
||||
)
|
||||
tracker.loss = 1.0
|
||||
tracker.update_s = 0.4
|
||||
tracker.reduce_across_ranks()
|
||||
|
||||
assert captured["reduction"] == "max"
|
||||
assert torch.allclose(captured["values"], torch.tensor([0.4]))
|
||||
assert tracker.update_s.avg == pytest.approx(0.9)
|
||||
# Metrics without a reduction stay untouched.
|
||||
assert tracker.loss.avg == 1.0
|
||||
# Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls
|
||||
# accumulate against the cluster view rather than the stale per-rank sum.
|
||||
meter = tracker.update_s
|
||||
assert meter.sum / meter.count == pytest.approx(meter.avg)
|
||||
|
||||
@@ -59,7 +59,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "1.13.0"
|
||||
version = "1.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
@@ -71,9 +71,9 @@ dependencies = [
|
||||
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" },
|
||||
{ name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/75/94cd5d389649578aca399e5aa822637eec18319a1dadc400ffe2f9a7493f/accelerate-1.14.0.tar.gz", hash = "sha256:41b9c4377a54e0b460a959b0defa1b736e4ca0a2373252d9a539964c2afe3c8d", size = 412167, upload-time = "2026-06-11T13:45:52.326Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/db/253133d7e7cb40d3af384bb2f5c0b4a2b7fdcffbc95c688cc67a20a3c103/accelerate-1.14.0-py3-none-any.whl", hash = "sha256:e94390c2863b873be18f623f9df48a0d8fe5eff13ea7f1a00092b0a7904888c6", size = 389246, upload-time = "2026-06-11T13:45:50.477Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2687,6 +2687,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
accelerate-dep = [
|
||||
{ name = "accelerate" },
|
||||
]
|
||||
all = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "av" },
|
||||
@@ -3073,8 +3076,7 @@ xvla = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
|
||||
{ name = "accelerate", marker = "extra == 'accelerate-dep'", specifier = ">=1.14.0,<2.0.0" },
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
@@ -3104,6 +3106,8 @@ requires-dist = [
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'smolvla'" },
|
||||
{ name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'training'" },
|
||||
{ name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["async"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" },
|
||||
@@ -3279,7 +3283,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "accelerate-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user