diff --git a/pyproject.toml b/pyproject.toml index 2b4c22f12..f72cfa6dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot topreward = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"] # Features @@ -231,9 +231,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation # NOTE: Explicitly listing scipy helps flatten the dependecy tree. -aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] +aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"] pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] +libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"] # NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution # is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 2bf7ab922..64d871907 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -30,6 +30,7 @@ class EpisodeAwareSampler: drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, shuffle: bool = False, + generator: torch.Generator | None = None, ): """Sampler that optionally incorporates episode boundary information. @@ -41,6 +42,10 @@ class EpisodeAwareSampler: drop_n_first_frames: Number of frames to drop from the start of each episode. drop_n_last_frames: Number of frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. + generator: Generator used for shuffling. Exposing this attribute (even when None) lets + `accelerate` register it as the synchronized RNG in distributed training, so + every rank draws the same permutation and batch shards stay disjoint. When + None, shuffling falls back to the global torch RNG. """ if drop_n_first_frames < 0: raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") @@ -73,10 +78,11 @@ class EpisodeAwareSampler: self.indices = indices self.shuffle = shuffle + self.generator = generator def __iter__(self) -> Iterator[int]: if self.shuffle: - for i in torch.randperm(len(self.indices)): + for i in torch.randperm(len(self.indices), generator=self.generator): yield self.indices[i] else: for i in self.indices: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 4ddef3105..3d210f00b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: main process downloads first to avoid race conditions - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: each node's local main process downloads first to avoid + # race conditions (the global main process only exists on node 0, so gating on it would let + # all ranks of the other nodes download and build the Arrow cache concurrently). + if accelerator.is_local_main_process: + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset - if not is_main_process: + # Now all other processes can safely load the dataset from the local cache + if not accelerator.is_local_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. @@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # create dataloader for offline training if hasattr(active_cfg, "drop_n_last_frames"): shuffle = False + # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare + # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even + # when ranks consume the global RNG asymmetrically (e.g. eval on the main process only). + sampler_generator = torch.Generator() + if cfg.seed is not None: + sampler_generator.manual_seed(cfg.seed) sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, drop_n_last_frames=active_cfg.drop_n_last_frames, shuffle=True, + generator=sampler_generator, ) else: shuffle = True diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 8bb3be8e9..95429c7ec 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -114,6 +114,30 @@ def test_shuffle(): assert set(sampler) == {0, 1, 2, 3, 4, 5} +def test_shuffle_with_generator_is_deterministic(): + # Two samplers shuffling with same-seed generators must yield identical permutations. + # This is what keeps batch shards disjoint across ranks in distributed training, where + # accelerate synchronizes the sampler's generator state instead of the global torch RNG. + sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + assert list(sampler_a) == list(sampler_b) + + # Desyncing the global RNG must not affect the permutation. + sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + order_before = list(sampler_c) + sampler_c.generator.manual_seed(42) + torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would + assert list(sampler_c) == order_before + + +def test_generator_attribute_defaults_to_none(): + # accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`, + # so the attribute must exist even when no generator is passed. + sampler = EpisodeAwareSampler([0], [6], shuffle=True) + assert sampler.generator is None + assert set(sampler) == {0, 1, 2, 3, 4, 5} + + def test_negative_drop_first_frames_raises(): with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): EpisodeAwareSampler([0], [10], drop_n_first_frames=-1) diff --git a/uv.lock b/uv.lock index 6acacab56..3a7129dac 100644 --- a/uv.lock +++ b/uv.lock @@ -1764,7 +1764,7 @@ wheels = [ [[package]] name = "gym-aloha" -version = "0.1.3" +version = "0.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dm-control" }, @@ -1772,14 +1772,14 @@ dependencies = [ { name = "imageio", extra = ["ffmpeg"] }, { name = "mujoco" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" }, + { url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" }, ] [[package]] name = "gym-hil" -version = "0.1.13" +version = "0.1.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "gymnasium" }, @@ -1789,9 +1789,9 @@ dependencies = [ { name = "pygame" }, { name = "pynput" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" }, + { url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" }, ] [[package]] @@ -1881,7 +1881,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9 [[package]] name = "hf-libero" -version = "0.1.3" +version = "0.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "bddl", marker = "sys_platform == 'linux'" }, @@ -1902,7 +1902,10 @@ dependencies = [ { name = "transformers", marker = "sys_platform == 'linux'" }, { name = "wandb", marker = "sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" }, +] [[package]] name = "hf-xet" @@ -3090,12 +3093,12 @@ requires-dist = [ { name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" }, { name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" }, { name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" }, - { name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" }, - { name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" }, + { name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" }, + { name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" }, { name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" }, { name = "gymnasium", specifier = ">=1.1.1,<2.0.0" }, { name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" }, - { name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" }, + { name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" }, { name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" }, { name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" }, { name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },