diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 11bdb480a..d074bc8a9 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -99,12 +99,9 @@ class TrainPipelineConfig(HubMixin): batch_size: int = 8 prefetch_factor: int = 4 persistent_workers: bool = True - # Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of RNG-based - # index shuffling. The data order becomes a pure function of (seed, epoch), which makes - # distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size, - # and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather - # than a true uniform permutation. Set to false to restore the legacy RNG-based shuffle - # order. Ignored when dataset.streaming is enabled. + # Deterministic data order (pure function of seed and epoch): immune to cross-rank RNG + # desync and enables sample-exact resume. Set to false for the legacy RNG-based shuffle. + # Ignored when dataset.streaming is enabled. deterministic_sampler: bool = True steps: int = 100_000 eval_freq: int = 20_000 diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index e88ae6f62..c73da7d0a 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -27,7 +27,7 @@ _FEISTEL_ROUNDS = 4 def _mix64(x: int) -> int: - """SplitMix64 finalizer: a high-quality, cheap 64-bit integer hash.""" + """SplitMix64 finalizer (64-bit integer hash).""" x = (x + 0x9E3779B97F4A7C15) & _MASK_64 x ^= x >> 30 x = (x * 0xBF58476D1CE4E5B9) & _MASK_64 @@ -38,27 +38,20 @@ def _mix64(x: int) -> int: class EpisodeAwareSampler: - """Sampler that incorporates episode boundary information. + """Sampler over episode frames with O(num_episodes) memory. - Frame indices are never materialized: only per-episode boundaries are stored (a few numpy - int64 per episode) and the mapping from a logical position to a frame index is computed on - the fly via `searchsorted`, so memory does not grow with the number of frames. + Only episode boundaries are stored; logical positions map to frame indices on the fly, so + memory does not grow with the number of frames. - Two shuffling modes are supported: + By default (`deterministic=True`) shuffling uses a seeded Feistel permutation over + `[0, num_frames)`: the data order is a pure function of `(seed, epoch)`, needs no RNG + synchronization across distributed ranks, and any position can be sought in O(1), enabling + sample-exact resume via `state_dict` / `load_state_dict`. Each completed `__iter__` + advances the epoch. The shuffle is pseudo-random rather than truly uniform — the standard + large-scale trade-off. During a resumed epoch, `__len__` still reports the full length. - - Default (`deterministic=False`): `torch.randperm` over positions, optionally driven by a - dedicated `generator`. Exposing the `generator` attribute (even when None) lets - `accelerate` register it as the synchronized RNG in distributed training, so every rank - draws the same permutation and batch shards stay disjoint. - - `deterministic=True`: a seeded Feistel permutation over `[0, num_frames)` (cycle-walking - to the exact domain size). The data order becomes a pure function of `(seed, epoch)`: - nothing to synchronize across ranks, O(1) seek to any position (enabling sample-exact - resume via `state_dict` / `load_state_dict`), and zero epoch-boundary cost at any dataset - size. The shuffle is pseudo-random rather than a true uniform permutation — the standard - trade-off in large-scale training loaders. Each completed `__iter__` advances the internal - epoch, so consecutive dataloader passes yield different permutations without `set_epoch` - calls. During an epoch resumed via `load_state_dict`, `__len__` still reports the full - epoch length; the first pass simply yields fewer samples. + With `deterministic=False`, shuffling falls back to `torch.randperm` driven by `generator` + (accelerate synchronizes the generator across ranks when preparing the dataloader). """ def __init__( @@ -70,20 +63,18 @@ class EpisodeAwareSampler: drop_n_last_frames: int = 0, shuffle: bool = False, generator: torch.Generator | None = None, - deterministic: bool = False, + deterministic: bool = True, seed: int = 0, ): """ Args: - dataset_from_indices: List of indices containing the start of each episode in the dataset. - dataset_to_indices: List of indices containing the end of each episode in the dataset. - episode_indices_to_use: List of episode indices to use. If None, all episodes are used. - Assumes that episodes are indexed from 0 to N-1. - drop_n_first_frames: Number of frames to drop from the start of each episode. - drop_n_last_frames: Number of frames to drop from the end of each episode. + dataset_from_indices: Start index of each episode in the dataset. + dataset_to_indices: End index of each episode in the dataset. + episode_indices_to_use: Episode indices to use; None means all. + drop_n_first_frames: Frames to drop from the start of each episode. + drop_n_last_frames: Frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. - generator: Generator used for default-mode shuffling. When None, shuffling falls - back to the global torch RNG. Incompatible with `deterministic=True`. + generator: Generator for non-deterministic shuffling (global torch RNG when None). deterministic: Use the seeded Feistel permutation instead of `torch.randperm`. seed: Seed the deterministic permutation is derived from (together with the epoch). """ @@ -135,22 +126,18 @@ class EpisodeAwareSampler: self._epoch = 0 self._start_index = 0 - # Feistel cipher domain: the smallest even-bit-width power of two >= num_frames, - # so both halves have equal width and cycle-walking converges in <4 expected steps. + # Smallest even-bit-width power-of-two domain >= num_frames: equal Feistel halves, + # cycle-walking converges in <4 expected steps. bits = max((self._num_frames - 1).bit_length(), 2) self._half_bits = (bits + 1) // 2 self._half_mask = (1 << self._half_bits) - 1 @property def indices(self) -> list[int]: - """Materialized frame indices, in unshuffled order (back-compat / introspection only). - - This builds an O(num_frames) list — avoid on very large datasets; iteration never uses it. - """ + """Materialized frame indices in unshuffled order; O(num_frames), introspection only.""" return [self._frame_index(k) for k in range(self._num_frames)] def set_epoch(self, epoch: int) -> None: - """Set the epoch the next `__iter__` will use (DistributedSampler convention).""" self._require_deterministic("set_epoch") self._epoch = epoch @@ -165,10 +152,7 @@ class EpisodeAwareSampler: def _require_deterministic(self, method: str) -> None: if not self.deterministic: - raise RuntimeError( - f"{method} is only meaningful with deterministic=True: in default mode the " - "order is drawn from the (generator) RNG and cannot be sought." - ) + raise RuntimeError(f"{method} requires deterministic=True: an RNG order cannot be sought.") def _round_keys(self, epoch: int) -> list[int]: state = _mix64(_mix64(self.seed) ^ _mix64(epoch)) @@ -197,8 +181,7 @@ class EpisodeAwareSampler: def __iter__(self) -> Iterator[int]: if not self.deterministic: return self._iter_default() - # Capture and advance state eagerly so epoch bookkeeping is not deferred until the - # returned generator is first consumed. + # Advance epoch state eagerly, not on first consumption of the generator. epoch, start = self._epoch, self._start_index self._epoch += 1 self._start_index = 0 @@ -222,17 +205,12 @@ class EpisodeAwareSampler: def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict: - """Map a global optimization step to an `EpisodeAwareSampler` state for resume. + """Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume. - Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and - keeps every `num_processes`-th batch, so one optimization step consumes - `batch_size * num_processes` consecutive sampler positions, and (with `even_batches` padding) - each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches per epoch. - - `batches_into_epoch * num_processes <= ceil(num_frames / batch_size) - 1` always holds, so the - start index stays strictly below `num_frames`; the `min` is purely defensive. Resume is - sample-exact up to the `even_batches` padding accelerate appends at epoch boundaries (at most - `num_processes - 1` duplicated batches per epoch, the same duplication non-resumed runs get). + Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler + positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches + per epoch (`even_batches` padding included). The start index provably stays below + `num_frames`; the `min` is defensive. """ batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes) epoch, batches_into_epoch = divmod(step, batches_per_epoch) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 186431c54..134a28eec 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -388,9 +388,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # create dataloader for offline training if cfg.deterministic_sampler and not cfg.dataset.streaming: - # Data order is a pure function of (seed, epoch): nothing to synchronize across ranks, - # O(1) memory in dataset size, and a resumed run continues at the exact sample where the - # checkpoint left off (up to accelerate's even_batches padding at epoch boundaries). + # Deterministic data order: no cross-rank RNG sync needed, sample-exact resume. shuffle = False sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], @@ -398,7 +396,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): episode_indices_to_use=dataset.episodes, drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0), shuffle=True, - deterministic=True, seed=cfg.seed if cfg.seed is not None else 0, ) if cfg.resume and step > 0: @@ -413,9 +410,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): ) elif hasattr(active_cfg, "drop_n_last_frames"): shuffle = False - # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare - # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even - # when ranks consume the global RNG asymmetrically (e.g. eval on the main process only). + # Legacy RNG shuffle: a dedicated generator lets accelerate synchronize it across ranks. sampler_generator = torch.Generator() if cfg.seed is not None: sampler_generator.manual_seed(cfg.seed) @@ -425,6 +420,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): episode_indices_to_use=dataset.episodes, drop_n_last_frames=active_cfg.drop_n_last_frames, shuffle=True, + deterministic=False, generator=sampler_generator, ) else: diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 15e1df7d9..7614c7dd8 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -118,12 +118,18 @@ def test_shuffle_with_generator_is_deterministic(): # Two samplers shuffling with same-seed generators must yield identical permutations. # This is what keeps batch shards disjoint across ranks in distributed training, where # accelerate synchronizes the sampler's generator state instead of the global torch RNG. - sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) - sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + sampler_a = EpisodeAwareSampler( + [0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42) + ) + sampler_b = EpisodeAwareSampler( + [0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42) + ) assert list(sampler_a) == list(sampler_b) # Desyncing the global RNG must not affect the permutation. - sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + sampler_c = EpisodeAwareSampler( + [0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42) + ) order_before = list(sampler_c) sampler_c.generator.manual_seed(42) torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would @@ -133,7 +139,7 @@ def test_shuffle_with_generator_is_deterministic(): def test_generator_attribute_defaults_to_none(): # accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`, # so the attribute must exist even when no generator is passed. - sampler = EpisodeAwareSampler([0], [6], shuffle=True) + sampler = EpisodeAwareSampler([0], [6], shuffle=True, deterministic=False) assert sampler.generator is None assert set(sampler) == {0, 1, 2, 3, 4, 5} @@ -194,7 +200,7 @@ def test_deterministic_mode_rejects_generator(): def test_state_methods_require_deterministic_mode(): - sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True) + sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, deterministic=False) with pytest.raises(RuntimeError, match="deterministic=True"): sampler.set_epoch(1) with pytest.raises(RuntimeError, match="deterministic=True"):