diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9d4b49714..438d7dff1 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -234,18 +234,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: each node's local main process downloads first to avoid - # race conditions (the global main process only exists on node 0, so gating on it would let - # all ranks of the other nodes download and build the Arrow cache concurrently). - if accelerator.is_local_main_process: - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: the global main process downloads once to the shared + # dataset root, then a barrier lets every other rank read the already-populated copy. + # LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset from the local cache - if not accelerator.is_local_main_process: + # Other ranks read from the shared copy populated by the main process. + if not is_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data.