From 72e093dbff8daad2eaf59ad1ceaf2b4539e43632 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 11 Jun 2026 10:01:43 +0200 Subject: [PATCH] fix(train): seed sampler generator and gate dataset download per node - Pass a generator seeded with cfg.seed to EpisodeAwareSampler so accelerator.prepare registers it as the synchronized RNG and the shuffle order is reproducible. - Gate the initial make_dataset call on is_local_main_process instead of is_main_process: the global main process only exists on node 0, so on every other node all local ranks were downloading the dataset and building the Arrow cache concurrently. Co-Authored-By: Claude Fable 5 --- src/lerobot/scripts/lerobot_train.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 4ddef3105..3d210f00b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: main process downloads first to avoid race conditions - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: each node's local main process downloads first to avoid + # race conditions (the global main process only exists on node 0, so gating on it would let + # all ranks of the other nodes download and build the Arrow cache concurrently). + if accelerator.is_local_main_process: + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset - if not is_main_process: + # Now all other processes can safely load the dataset from the local cache + if not accelerator.is_local_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. @@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # create dataloader for offline training if hasattr(active_cfg, "drop_n_last_frames"): shuffle = False + # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare + # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even + # when ranks consume the global RNG asymmetrically (e.g. eval on the main process only). + sampler_generator = torch.Generator() + if cfg.seed is not None: + sampler_generator.manual_seed(cfg.seed) sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, drop_n_last_frames=active_cfg.drop_n_last_frames, shuffle=True, + generator=sampler_generator, ) else: shuffle = True