From 41166b39fb8bacdd8f916d700064c5f64892bc0a Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:07:42 +0200 Subject: [PATCH] fix(train): synchronize EpisodeAwareSampler shuffling across ranks and gate dataset download per node (#3768) * fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync In distributed training, accelerate can only synchronize the shuffle permutation across ranks when the sampler exposes a generator attribute. EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch shards relied on every rank's global CPU RNG staying in lockstep forever; any rank-asymmetric RNG consumption (e.g. eval rollouts on the main process only) silently desynced the permutations and ranks trained on overlapping/missing samples. * fix(train): seed sampler generator and gate dataset download per node - Pass a generator seeded with cfg.seed to EpisodeAwareSampler so accelerator.prepare registers it as the synchronized RNG and the shuffle order is reproducible. - Gate the initial make_dataset call on is_local_main_process instead of is_main_process: the global main process only exists on node 0, so on every other node all local ranks were downloading the dataset and building the Arrow cache concurrently. --- src/lerobot/datasets/sampler.py | 8 +++++++- src/lerobot/scripts/lerobot_train.py | 20 +++++++++++++++----- tests/datasets/test_sampler.py | 24 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 2bf7ab922..64d871907 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -30,6 +30,7 @@ class EpisodeAwareSampler: drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, shuffle: bool = False, + generator: torch.Generator | None = None, ): """Sampler that optionally incorporates episode boundary information. @@ -41,6 +42,10 @@ class EpisodeAwareSampler: drop_n_first_frames: Number of frames to drop from the start of each episode. drop_n_last_frames: Number of frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. + generator: Generator used for shuffling. Exposing this attribute (even when None) lets + `accelerate` register it as the synchronized RNG in distributed training, so + every rank draws the same permutation and batch shards stay disjoint. When + None, shuffling falls back to the global torch RNG. """ if drop_n_first_frames < 0: raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") @@ -73,10 +78,11 @@ class EpisodeAwareSampler: self.indices = indices self.shuffle = shuffle + self.generator = generator def __iter__(self) -> Iterator[int]: if self.shuffle: - for i in torch.randperm(len(self.indices)): + for i in torch.randperm(len(self.indices), generator=self.generator): yield self.indices[i] else: for i in self.indices: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 4ddef3105..3d210f00b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: main process downloads first to avoid race conditions - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: each node's local main process downloads first to avoid + # race conditions (the global main process only exists on node 0, so gating on it would let + # all ranks of the other nodes download and build the Arrow cache concurrently). + if accelerator.is_local_main_process: + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset - if not is_main_process: + # Now all other processes can safely load the dataset from the local cache + if not accelerator.is_local_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. @@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # create dataloader for offline training if hasattr(active_cfg, "drop_n_last_frames"): shuffle = False + # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare + # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even + # when ranks consume the global RNG asymmetrically (e.g. eval on the main process only). + sampler_generator = torch.Generator() + if cfg.seed is not None: + sampler_generator.manual_seed(cfg.seed) sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, drop_n_last_frames=active_cfg.drop_n_last_frames, shuffle=True, + generator=sampler_generator, ) else: shuffle = True diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 8bb3be8e9..95429c7ec 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -114,6 +114,30 @@ def test_shuffle(): assert set(sampler) == {0, 1, 2, 3, 4, 5} +def test_shuffle_with_generator_is_deterministic(): + # Two samplers shuffling with same-seed generators must yield identical permutations. + # This is what keeps batch shards disjoint across ranks in distributed training, where + # accelerate synchronizes the sampler's generator state instead of the global torch RNG. + sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + assert list(sampler_a) == list(sampler_b) + + # Desyncing the global RNG must not affect the permutation. + sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + order_before = list(sampler_c) + sampler_c.generator.manual_seed(42) + torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would + assert list(sampler_c) == order_before + + +def test_generator_attribute_defaults_to_none(): + # accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`, + # so the attribute must exist even when no generator is passed. + sampler = EpisodeAwareSampler([0], [6], shuffle=True) + assert sampler.generator is None + assert set(sampler) == {0, 1, 2, 3, 4, 5} + + def test_negative_drop_first_frames_raises(): with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)