Compare commits

...

5 Commits

Author SHA1 Message Date
github-actions[bot] e48f3bc96d chore(dependencies): update uv.lock 2026-06-12 05:01:38 +00:00
Steven Palma 87242cfced chore(dependecies): relax grpc-related bounds (#3777)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2026-06-11 19:13:14 +02:00
Steven Palma 1edc83a0ef feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup (#3773)
* feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup

* chore: address feedback
2026-06-11 19:07:28 +02:00
Steven Palma 6fbcf67249 chore: update readme (#3774)
* chore: update readme

* chore: update authors in project readme
2026-06-11 18:17:26 +02:00
Pepijn 41166b39fb fix(train): synchronize EpisodeAwareSampler shuffling across ranks and gate dataset download per node (#3768)
* fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync

In distributed training, accelerate can only synchronize the shuffle
permutation across ranks when the sampler exposes a generator attribute.
EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch
shards relied on every rank's global CPU RNG staying in lockstep forever;
any rank-asymmetric RNG consumption (e.g. eval rollouts on the main
process only) silently desynced the permutations and ranks trained on
overlapping/missing samples.

* fix(train): seed sampler generator and gate dataset download per node

- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so
  accelerator.prepare registers it as the synchronized RNG and the
  shuffle order is reproducible.
- Gate the initial make_dataset call on is_local_main_process instead of
  is_main_process: the global main process only exists on node 0, so on
  every other node all local ranks were downloading the dataset and
  building the Arrow cache concurrently.
2026-06-11 11:07:42 +02:00
8 changed files with 1199 additions and 945 deletions
+10 -7
View File
@@ -58,7 +58,7 @@ action = model.select_action(obs)
robot.send_action(action)
```
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1.
**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601.
While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot.
@@ -101,11 +101,13 @@ lerobot-train \
--dataset.repo_id=lerobot/aloha_mobile_cabinet
```
| Category | Models |
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| Category | Models |
| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) |
| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) |
| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
@@ -133,6 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
## Citation
@@ -140,7 +143,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl
```bibtex
@misc{cadene2024lerobot,
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas},
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
howpublished = "\url{https://github.com/huggingface/lerobot}",
year = {2024}
+12 -6
View File
@@ -115,8 +115,8 @@ dataset = [
]
training = [
"lerobot[dataset]",
"accelerate>=1.10.0,<2.0.0",
"wandb>=0.24.0,<0.25.0",
"wandb>=0.24.0,<0.28.0",
"lerobot[accelerate-dep]",
]
hardware = [
"lerobot[pynput-dep]",
@@ -142,7 +142,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"]
accelerate-dep = ["accelerate>=1.14.0,<2.0.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
@@ -177,7 +178,12 @@ unitree_g1 = [
"lerobot[matplotlib-dep]",
"lerobot[pygame-dep]",
]
reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions.
reachy2 = [
"reachy2_sdk>=1.0.15,<1.1.0",
"grpcio<=1.73.1",
"protobuf<=6.32.0",
]
# Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102
# leader (motorbridge-smart-servo / FashionStar UART servos).
rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"]
@@ -199,7 +205,7 @@ wallx = [
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
"lerobot[transformers-dep]",
@@ -224,7 +230,7 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
+7 -1
View File
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
else:
for i in self.indices:
+52 -19
View File
@@ -99,6 +99,9 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Compute sample weights if a weighter is provided
sample_weights = None
weight_stats = None
@@ -158,6 +161,8 @@ def update_policy(
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
if torch.cuda.is_available():
train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
return train_metrics, output_dict
@@ -232,15 +237,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset
if not is_main_process:
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -386,12 +394,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
generator=sampler_generator,
)
else:
shuffle = True
@@ -424,12 +439,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
# Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP
# is actually optimizing. grad_norm and lr are already identical on every rank (post
# gradient sync / deterministic scheduler) so reducing them would be a no-op collective.
"loss": AverageMeter("loss", ":.3f", reduction="mean"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
# Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the
# true straggler instead of rank 0's view.
"update_s": AverageMeter("updt_s", ":.3f", reduction="max"),
"dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"),
# Derived from the post-reduce max step time; set once per log window on the main rank.
"samples_per_s": AverageMeter("smp/s", ":.0f"),
}
if torch.cuda.is_available():
# max() because headroom is gated by the worst-case rank.
train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max")
# Keep global batch size for logging; MetricsTracker handles world size internally.
effective_batch_size = cfg.batch_size * accelerator.num_processes
@@ -481,21 +506,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
# Collective reduce must run on every rank, before the main-process gate below.
train_tracker.reduce_across_ranks()
if is_main_process:
# Cluster-wide throughput, derived from the already-reduced (max) step time so it
# reflects the slowest rank — which is what actually gates the next iteration.
step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg
if step_time > 0:
train_tracker.samples_per_s = effective_batch_size / step_time
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
+50 -1
View File
@@ -13,21 +13,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Callable
from typing import Any
import torch
from .utils import format_big_number
_VALID_REDUCTIONS = ("none", "max", "mean", "sum")
class AverageMeter:
"""
Computes and stores the average and current value
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
Args:
name: Display name of the metric.
fmt: Format string used when rendering the metric.
reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks`
before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``,
or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or
update wall time) so multi-GPU runs report the slowest rank rather than rank 0.
"""
def __init__(self, name: str, fmt: str = ":f"):
def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"):
if reduction not in _VALID_REDUCTIONS:
raise ValueError(
f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}."
)
self.name = name
self.fmt = fmt
self.reduction = reduction
self.reset()
def reset(self) -> None:
@@ -138,6 +156,37 @@ class MetricsTracker:
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def reduce_across_ranks(self) -> None:
"""
Synchronises the running averages of every metric whose ``reduction`` is not ``"none"``
across all distributed processes (in-place).
This is a collective operation and MUST be invoked on every rank — typically just before
logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics
reported by the main process only reflect rank 0; for bottleneck-style timings
(``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible.
"""
if self.accelerator is None or self.accelerator.num_processes <= 1:
return
buckets: dict[str, list[str]] = defaultdict(list)
for name, meter in self.metrics.items():
if meter.reduction != "none":
buckets[meter.reduction].append(name)
if not buckets:
return
device = self.accelerator.device
for reduction, names in buckets.items():
tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device)
reduced = self.accelerator.reduce(tensor, reduction=reduction)
for name, value in zip(names, reduced.tolist(), strict=True):
meter = self.metrics[name]
# Preserve avg == sum / count so a later .update() on this meter accumulates
# against the cluster view, not the stale per-rank history.
meter.avg = value
meter.sum = value * meter.count
def __str__(self) -> str:
display_list = [
f"step:{format_big_number(self.steps)}",
+24
View File
@@ -114,6 +114,30 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
+77 -1
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import pytest
import torch
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
@@ -25,8 +26,16 @@ def mock_metrics():
class MockAccelerator:
def __init__(self, num_processes: int):
def __init__(self, num_processes: int, reduce_fn=None):
self.num_processes = num_processes
self.device = torch.device("cpu")
self._reduce_fn = reduce_fn
def reduce(self, tensor, reduction="mean"):
# In single-process tests we just want a deterministic stand-in for accelerate's reduce.
if self._reduce_fn is not None:
return self._reduce_fn(tensor, reduction)
return tensor
def test_average_meter_initialization():
@@ -157,3 +166,70 @@ def test_metrics_tracker_reset_averages(mock_metrics):
tracker.reset_averages()
assert tracker.loss.avg == 0.0
assert tracker.accuracy.avg == 0.0
def test_average_meter_invalid_reduction():
with pytest.raises(ValueError):
AverageMeter("loss", reduction="median")
def test_average_meter_reduction_stored():
meter = AverageMeter("updt_s", reduction="max")
assert meter.reduction == "max"
def test_metrics_tracker_reduce_across_ranks_no_accelerator():
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics)
tracker.update_s = 0.5
tracker.reduce_across_ranks() # no-op without accelerator
assert tracker.update_s.avg == 0.5
def test_metrics_tracker_reduce_across_ranks_single_process():
metrics = {"update_s": AverageMeter("update_s", reduction="max")}
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=metrics,
accelerator=MockAccelerator(num_processes=1),
)
tracker.update_s = 0.5
tracker.reduce_across_ranks() # no-op when world size is 1
assert tracker.update_s.avg == 0.5
def test_metrics_tracker_reduce_across_ranks_invokes_reduce():
captured = {}
def fake_reduce(tensor, reduction):
captured["reduction"] = reduction
captured["values"] = tensor.clone()
# Pretend the slowest rank reported 0.9 instead of this rank's 0.4.
return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device)
metrics = {
"loss": AverageMeter("loss"), # reduction="none" -> not touched
"update_s": AverageMeter("update_s", reduction="max"),
}
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=metrics,
accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce),
)
tracker.loss = 1.0
tracker.update_s = 0.4
tracker.reduce_across_ranks()
assert captured["reduction"] == "max"
assert torch.allclose(captured["values"], torch.tensor([0.4]))
assert tracker.update_s.avg == pytest.approx(0.9)
# Metrics without a reduction stay untouched.
assert tracker.loss.avg == 1.0
# Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls
# accumulate against the cluster view rather than the stale per-rank sum.
meter = tracker.update_s
assert meter.sum / meter.count == pytest.approx(meter.avg)
Generated
+967 -910
View File
File diff suppressed because it is too large Load Diff