diff --git a/pyproject.toml b/pyproject.toml index f72cfa6dd..89200d1ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ dataset = [ ] training = [ "lerobot[dataset]", - "accelerate>=1.10.0,<2.0.0", + "lerobot[accelerate-dep]", "wandb>=0.24.0,<0.25.0", ] hardware = [ @@ -142,6 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"] # (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available. placo-dep = ["placo>=0.9.6,<0.9.16"] transformers-dep = ["transformers>=5.4.0,<5.6.0"] +accelerate-dep = ["accelerate>=1.14.0,<2.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] @@ -199,7 +200,7 @@ wallx = [ ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"] -smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"] +smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"] multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"] groot = [ "lerobot[transformers-dep]", diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 3d210f00b..a35d4229d 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -99,6 +99,9 @@ def update_policy( start_time = time.perf_counter() policy.train() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + # Compute sample weights if a weighter is provided sample_weights = None weight_stats = None @@ -158,6 +161,8 @@ def update_policy( train_metrics.grad_norm = grad_norm.item() train_metrics.lr = optimizer.param_groups[0]["lr"] train_metrics.update_s = time.perf_counter() - start_time + if torch.cuda.is_available(): + train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3) return train_metrics, output_dict @@ -434,12 +439,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): policy.train() train_metrics = { - "loss": AverageMeter("loss", ":.3f"), + # Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP + # is actually optimizing. grad_norm and lr are already identical on every rank (post + # gradient sync / deterministic scheduler) so reducing them would be a no-op collective. + "loss": AverageMeter("loss", ":.3f", reduction="mean"), "grad_norm": AverageMeter("grdn", ":.3f"), "lr": AverageMeter("lr", ":0.1e"), - "update_s": AverageMeter("updt_s", ":.3f"), - "dataloading_s": AverageMeter("data_s", ":.3f"), + # Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the + # true straggler instead of rank 0's view. + "update_s": AverageMeter("updt_s", ":.3f", reduction="max"), + "dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"), + # Derived from the post-reduce max step time; set once per log window on the main rank. + "samples_per_s": AverageMeter("smp/s", ":.0f"), } + if torch.cuda.is_available(): + # max() because headroom is gated by the worst-case rank. + train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max") # Keep global batch size for logging; MetricsTracker handles world size internally. effective_batch_size = cfg.batch_size * accelerator.num_processes @@ -491,21 +506,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if is_main_process: progbar.update(1) train_tracker.step() - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 if is_log_step: - logging.info(train_tracker) - if wandb_logger: - wandb_log_dict = train_tracker.to_dict() - if output_dict: - wandb_log_dict.update(output_dict) - # Log sample weighting statistics if enabled - if sample_weighter is not None: - weighter_stats = sample_weighter.get_stats() - wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()}) - wandb_logger.log_dict(wandb_log_dict, step) + # Collective reduce must run on every rank, before the main-process gate below. + train_tracker.reduce_across_ranks() + if is_main_process: + # Cluster-wide throughput, derived from the already-reduced (max) step time so it + # reflects the slowest rank — which is what actually gates the next iteration. + step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg + if step_time > 0: + train_tracker.samples_per_s = effective_batch_size / step_time + logging.info(train_tracker) + if wandb_logger: + wandb_log_dict = train_tracker.to_dict() + if output_dict: + wandb_log_dict.update(output_dict) + # Log sample weighting statistics if enabled + if sample_weighter is not None: + weighter_stats = sample_weighter.get_stats() + wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()}) + wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index 0ce596f55..20673fc30 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -13,21 +13,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from collections.abc import Callable from typing import Any +import torch + from .utils import format_big_number +_VALID_REDUCTIONS = ("none", "max", "mean", "sum") + class AverageMeter: """ Computes and stores the average and current value Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py + + Args: + name: Display name of the metric. + fmt: Format string used when rendering the metric. + reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks` + before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``, + or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or + update wall time) so multi-GPU runs report the slowest rank rather than rank 0. """ - def __init__(self, name: str, fmt: str = ":f"): + def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"): + if reduction not in _VALID_REDUCTIONS: + raise ValueError( + f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}." + ) self.name = name self.fmt = fmt + self.reduction = reduction self.reset() def reset(self) -> None: @@ -138,6 +156,37 @@ class MetricsTracker: self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames + def reduce_across_ranks(self) -> None: + """ + Synchronises the running averages of every metric whose ``reduction`` is not ``"none"`` + across all distributed processes (in-place). + + This is a collective operation and MUST be invoked on every rank — typically just before + logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics + reported by the main process only reflect rank 0; for bottleneck-style timings + (``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible. + """ + if self.accelerator is None or self.accelerator.num_processes <= 1: + return + + buckets: dict[str, list[str]] = defaultdict(list) + for name, meter in self.metrics.items(): + if meter.reduction != "none": + buckets[meter.reduction].append(name) + if not buckets: + return + + device = self.accelerator.device + for reduction, names in buckets.items(): + tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device) + reduced = self.accelerator.reduce(tensor, reduction=reduction) + for name, value in zip(names, reduced.tolist(), strict=True): + meter = self.metrics[name] + # Preserve avg == sum / count so a later .update() on this meter accumulates + # against the cluster view, not the stale per-rank history. + meter.avg = value + meter.sum = value * meter.count + def __str__(self) -> str: display_list = [ f"step:{format_big_number(self.steps)}", diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index 1207534c0..aa851bd2a 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import pytest +import torch from lerobot.utils.logging_utils import AverageMeter, MetricsTracker @@ -25,8 +26,16 @@ def mock_metrics(): class MockAccelerator: - def __init__(self, num_processes: int): + def __init__(self, num_processes: int, reduce_fn=None): self.num_processes = num_processes + self.device = torch.device("cpu") + self._reduce_fn = reduce_fn + + def reduce(self, tensor, reduction="mean"): + # In single-process tests we just want a deterministic stand-in for accelerate's reduce. + if self._reduce_fn is not None: + return self._reduce_fn(tensor, reduction) + return tensor def test_average_meter_initialization(): @@ -157,3 +166,70 @@ def test_metrics_tracker_reset_averages(mock_metrics): tracker.reset_averages() assert tracker.loss.avg == 0.0 assert tracker.accuracy.avg == 0.0 + + +def test_average_meter_invalid_reduction(): + with pytest.raises(ValueError): + AverageMeter("loss", reduction="median") + + +def test_average_meter_reduction_stored(): + meter = AverageMeter("updt_s", reduction="max") + assert meter.reduction == "max" + + +def test_metrics_tracker_reduce_across_ranks_no_accelerator(): + metrics = {"update_s": AverageMeter("update_s", reduction="max")} + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics) + tracker.update_s = 0.5 + tracker.reduce_across_ranks() # no-op without accelerator + assert tracker.update_s.avg == 0.5 + + +def test_metrics_tracker_reduce_across_ranks_single_process(): + metrics = {"update_s": AverageMeter("update_s", reduction="max")} + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=metrics, + accelerator=MockAccelerator(num_processes=1), + ) + tracker.update_s = 0.5 + tracker.reduce_across_ranks() # no-op when world size is 1 + assert tracker.update_s.avg == 0.5 + + +def test_metrics_tracker_reduce_across_ranks_invokes_reduce(): + captured = {} + + def fake_reduce(tensor, reduction): + captured["reduction"] = reduction + captured["values"] = tensor.clone() + # Pretend the slowest rank reported 0.9 instead of this rank's 0.4. + return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device) + + metrics = { + "loss": AverageMeter("loss"), # reduction="none" -> not touched + "update_s": AverageMeter("update_s", reduction="max"), + } + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=metrics, + accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce), + ) + tracker.loss = 1.0 + tracker.update_s = 0.4 + tracker.reduce_across_ranks() + + assert captured["reduction"] == "max" + assert torch.allclose(captured["values"], torch.tensor([0.4])) + assert tracker.update_s.avg == pytest.approx(0.9) + # Metrics without a reduction stay untouched. + assert tracker.loss.avg == 1.0 + # Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls + # accumulate against the cluster view rather than the stale per-rank sum. + meter = tracker.update_s + assert meter.sum / meter.count == pytest.approx(meter.avg) diff --git a/uv.lock b/uv.lock index 3a7129dac..f4f854b62 100644 --- a/uv.lock +++ b/uv.lock @@ -59,7 +59,7 @@ wheels = [ [[package]] name = "accelerate" -version = "1.13.0" +version = "1.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, @@ -71,9 +71,9 @@ dependencies = [ { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" }, { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/75/94cd5d389649578aca399e5aa822637eec18319a1dadc400ffe2f9a7493f/accelerate-1.14.0.tar.gz", hash = "sha256:41b9c4377a54e0b460a959b0defa1b736e4ca0a2373252d9a539964c2afe3c8d", size = 412167, upload-time = "2026-06-11T13:45:52.326Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" }, + { url = "https://files.pythonhosted.org/packages/a8/db/253133d7e7cb40d3af384bb2f5c0b4a2b7fdcffbc95c688cc67a20a3c103/accelerate-1.14.0-py3-none-any.whl", hash = "sha256:e94390c2863b873be18f623f9df48a0d8fe5eff13ea7f1a00092b0a7904888c6", size = 389246, upload-time = "2026-06-11T13:45:50.477Z" }, ] [[package]] @@ -2687,6 +2687,9 @@ dependencies = [ ] [package.optional-dependencies] +accelerate-dep = [ + { name = "accelerate" }, +] all = [ { name = "accelerate" }, { name = "av" }, @@ -3073,8 +3076,7 @@ xvla = [ [package.metadata] requires-dist = [ - { name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" }, - { name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" }, + { name = "accelerate", marker = "extra == 'accelerate-dep'", specifier = ">=1.14.0,<2.0.0" }, { name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" }, { name = "cmake", specifier = ">=3.29.0.1,<4.2.0" }, { name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" }, @@ -3104,6 +3106,8 @@ requires-dist = [ { name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" }, { name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" }, { name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" }, + { name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'smolvla'" }, + { name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'training'" }, { name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["async"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" }, @@ -3279,7 +3283,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "accelerate-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt"