From 926fb9c31e9be4858012243146d5b7fc0e65fffa Mon Sep 17 00:00:00 2001 From: pepijn Date: Thu, 11 Jun 2026 17:21:35 +0000 Subject: [PATCH] fix(train): download dataset once on the global main process Gate the training dataset download on the global is_main_process (download once to the shared dataset root, barrier, then every other rank reads the already-populated copy) instead of per-node is_local_main_process. LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. Assumes the dataset root / HF cache is on storage shared across nodes. Co-authored-by: Cursor --- src/lerobot/scripts/lerobot_train.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9d4b49714..438d7dff1 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -234,18 +234,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: each node's local main process downloads first to avoid - # race conditions (the global main process only exists on node 0, so gating on it would let - # all ranks of the other nodes download and build the Arrow cache concurrently). - if accelerator.is_local_main_process: - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: the global main process downloads once to the shared + # dataset root, then a barrier lets every other rank read the already-populated copy. + # LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset from the local cache - if not accelerator.is_local_main_process: + # Other ranks read from the shared copy populated by the main process. + if not is_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data.