From 32b0d7d1ef4ee0edcdc68fd3ed14147d75794d9c Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 11 Jun 2026 11:37:44 +0200 Subject: [PATCH] refactor(datasets): fold deterministic mode into EpisodeAwareSampler Instead of a parallel DeterministicEpisodeAwareSampler class, extend the existing EpisodeAwareSampler with a deterministic=True mode (seeded Feistel permutation, epoch auto-advance, state_dict/load_state_dict). The default mode is behavior-identical: same torch.randperm consumption and the same generator contract accelerate synchronizes; the O(N) Python index list is replaced by O(num_episodes) boundary arrays in both modes, with `indices` kept as a back-compat property. Passing a generator together with deterministic=True is rejected, and the state/seek methods raise outside deterministic mode. Co-Authored-By: Claude Fable 5 --- src/lerobot/datasets/__init__.py | 3 +- src/lerobot/datasets/sampler.py | 155 +++++++++++---------------- src/lerobot/scripts/lerobot_train.py | 10 +- tests/datasets/test_sampler.py | 51 +++++---- 4 files changed, 98 insertions(+), 121 deletions(-) diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index f7d0ff889..bd12a7248 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset from .multi_dataset import MultiLeRobotDataset from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav -from .sampler import DeterministicEpisodeAwareSampler, EpisodeAwareSampler, compute_sampler_state +from .sampler import EpisodeAwareSampler, compute_sampler_state from .streaming_dataset import StreamingLeRobotDataset from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card from .video_utils import VideoEncodingManager @@ -64,7 +64,6 @@ __all__ = [ "DEFAULT_EPISODES_PATH", "DEFAULT_QUANTILES", "EVENT_ONLY_STYLES", - "DeterministicEpisodeAwareSampler", "EpisodeAwareSampler", "LANGUAGE_EVENTS", "LANGUAGE_PERSISTENT", diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 851bdbada..e88ae6f62 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -38,97 +38,27 @@ def _mix64(x: int) -> int: class EpisodeAwareSampler: - def __init__( - self, - dataset_from_indices: list[int], - dataset_to_indices: list[int], - episode_indices_to_use: list | None = None, - drop_n_first_frames: int = 0, - drop_n_last_frames: int = 0, - shuffle: bool = False, - generator: torch.Generator | None = None, - ): - """Sampler that optionally incorporates episode boundary information. + """Sampler that incorporates episode boundary information. - Args: - dataset_from_indices: List of indices containing the start of each episode in the dataset. - dataset_to_indices: List of indices containing the end of each episode in the dataset. - episode_indices_to_use: List of episode indices to use. If None, all episodes are used. - Assumes that episodes are indexed from 0 to N-1. - drop_n_first_frames: Number of frames to drop from the start of each episode. - drop_n_last_frames: Number of frames to drop from the end of each episode. - shuffle: Whether to shuffle the indices. - generator: Generator used for shuffling. Exposing this attribute (even when None) lets - `accelerate` register it as the synchronized RNG in distributed training, so - every rank draws the same permutation and batch shards stay disjoint. When - None, shuffling falls back to the global torch RNG. - """ - if drop_n_first_frames < 0: - raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") - if drop_n_last_frames < 0: - raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") + Frame indices are never materialized: only per-episode boundaries are stored (a few numpy + int64 per episode) and the mapping from a logical position to a frame index is computed on + the fly via `searchsorted`, so memory does not grow with the number of frames. - indices = [] - for episode_idx, (start_index, end_index) in enumerate( - zip(dataset_from_indices, dataset_to_indices, strict=True) - ): - if episode_indices_to_use is None or episode_idx in episode_indices_to_use: - ep_length = end_index - start_index - if drop_n_first_frames + drop_n_last_frames >= ep_length: - logger.warning( - "Episode %d has %d frames but drop_n_first_frames=%d and " - "drop_n_last_frames=%d removes all frames. Skipping.", - episode_idx, - ep_length, - drop_n_first_frames, - drop_n_last_frames, - ) - continue - indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) + Two shuffling modes are supported: - if not indices: - raise ValueError( - "No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. " - "All episodes were either filtered out or had too few frames." - ) - - self.indices = indices - self.shuffle = shuffle - self.generator = generator - - def __iter__(self) -> Iterator[int]: - if self.shuffle: - for i in torch.randperm(len(self.indices), generator=self.generator): - yield self.indices[i] - else: - for i in self.indices: - yield i - - def __len__(self) -> int: - return len(self.indices) - - -class DeterministicEpisodeAwareSampler: - """Episode-aware sampler with O(num_episodes) memory and a deterministic, seekable shuffle. - - Unlike `EpisodeAwareSampler`, frame indices are never materialized: only per-episode - boundaries are stored (a few numpy int64 per episode) and the mapping from a logical - position to a frame index is computed on the fly via `searchsorted`. Shuffling applies a - seeded Feistel permutation over `[0, num_frames)` (cycle-walking to the exact domain size) - instead of `torch.randperm`, so the data order is a pure function of `(seed, epoch)`: - - - every distributed rank derives the identical order with no RNG state to synchronize, - - any position can be sought in O(1), enabling sample-exact resume of interrupted runs - (see `state_dict` / `load_state_dict`), - - memory and epoch-boundary cost do not grow with the number of frames. - - The trade-off is that the shuffle is pseudo-random rather than a true uniform permutation, - which is the standard compromise in large-scale training loaders. - - Each completed `__iter__` advances the internal epoch, so consecutive dataloader passes - yield different permutations without requiring `set_epoch` calls. During an epoch resumed - via `load_state_dict`, `__len__` still reports the full epoch length; the first pass simply - yields fewer samples. + - Default (`deterministic=False`): `torch.randperm` over positions, optionally driven by a + dedicated `generator`. Exposing the `generator` attribute (even when None) lets + `accelerate` register it as the synchronized RNG in distributed training, so every rank + draws the same permutation and batch shards stay disjoint. + - `deterministic=True`: a seeded Feistel permutation over `[0, num_frames)` (cycle-walking + to the exact domain size). The data order becomes a pure function of `(seed, epoch)`: + nothing to synchronize across ranks, O(1) seek to any position (enabling sample-exact + resume via `state_dict` / `load_state_dict`), and zero epoch-boundary cost at any dataset + size. The shuffle is pseudo-random rather than a true uniform permutation — the standard + trade-off in large-scale training loaders. Each completed `__iter__` advances the internal + epoch, so consecutive dataloader passes yield different permutations without `set_epoch` + calls. During an epoch resumed via `load_state_dict`, `__len__` still reports the full + epoch length; the first pass simply yields fewer samples. """ def __init__( @@ -138,7 +68,9 @@ class DeterministicEpisodeAwareSampler: episode_indices_to_use: list | None = None, drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, - shuffle: bool = True, + shuffle: bool = False, + generator: torch.Generator | None = None, + deterministic: bool = False, seed: int = 0, ): """ @@ -149,13 +81,18 @@ class DeterministicEpisodeAwareSampler: Assumes that episodes are indexed from 0 to N-1. drop_n_first_frames: Number of frames to drop from the start of each episode. drop_n_last_frames: Number of frames to drop from the end of each episode. - shuffle: Whether to shuffle with the seeded Feistel permutation. - seed: Seed the permutation is derived from (together with the epoch). + shuffle: Whether to shuffle the indices. + generator: Generator used for default-mode shuffling. When None, shuffling falls + back to the global torch RNG. Incompatible with `deterministic=True`. + deterministic: Use the seeded Feistel permutation instead of `torch.randperm`. + seed: Seed the deterministic permutation is derived from (together with the epoch). """ if drop_n_first_frames < 0: raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") if drop_n_last_frames < 0: raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") + if deterministic and generator is not None: + raise ValueError("generator is unused in deterministic mode; pass seed instead.") from_indices = np.asarray(dataset_from_indices, dtype=np.int64) to_indices = np.asarray(dataset_to_indices, dtype=np.int64) @@ -192,6 +129,8 @@ class DeterministicEpisodeAwareSampler: self._cum_lengths = np.cumsum(lengths[used]) self._num_frames = int(self._cum_lengths[-1]) self.shuffle = shuffle + self.generator = generator + self.deterministic = deterministic self.seed = seed self._epoch = 0 self._start_index = 0 @@ -202,17 +141,35 @@ class DeterministicEpisodeAwareSampler: self._half_bits = (bits + 1) // 2 self._half_mask = (1 << self._half_bits) - 1 + @property + def indices(self) -> list[int]: + """Materialized frame indices, in unshuffled order (back-compat / introspection only). + + This builds an O(num_frames) list — avoid on very large datasets; iteration never uses it. + """ + return [self._frame_index(k) for k in range(self._num_frames)] + def set_epoch(self, epoch: int) -> None: """Set the epoch the next `__iter__` will use (DistributedSampler convention).""" + self._require_deterministic("set_epoch") self._epoch = epoch def state_dict(self) -> dict: + self._require_deterministic("state_dict") return {"epoch": self._epoch, "start_index": self._start_index} def load_state_dict(self, state: dict) -> None: + self._require_deterministic("load_state_dict") self._epoch = state["epoch"] self._start_index = state["start_index"] + def _require_deterministic(self, method: str) -> None: + if not self.deterministic: + raise RuntimeError( + f"{method} is only meaningful with deterministic=True: in default mode the " + "order is drawn from the (generator) RNG and cannot be sought." + ) + def _round_keys(self, epoch: int) -> list[int]: state = _mix64(_mix64(self.seed) ^ _mix64(epoch)) keys = [] @@ -238,14 +195,24 @@ class DeterministicEpisodeAwareSampler: return int(self._starts[episode]) + position_in_episode def __iter__(self) -> Iterator[int]: + if not self.deterministic: + return self._iter_default() # Capture and advance state eagerly so epoch bookkeeping is not deferred until the # returned generator is first consumed. epoch, start = self._epoch, self._start_index self._epoch += 1 self._start_index = 0 - return self._iter_epoch(epoch, start) + return self._iter_deterministic_epoch(epoch, start) - def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]: + def _iter_default(self) -> Iterator[int]: + if self.shuffle: + for i in torch.randperm(self._num_frames, generator=self.generator): + yield self._frame_index(int(i)) + else: + for k in range(self._num_frames): + yield self._frame_index(k) + + def _iter_deterministic_epoch(self, epoch: int, start: int) -> Iterator[int]: keys = self._round_keys(epoch) if self.shuffle else None for k in range(start, self._num_frames): yield self._frame_index(self._permute(k, keys) if self.shuffle else k) @@ -255,7 +222,7 @@ class DeterministicEpisodeAwareSampler: def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict: - """Map a global optimization step to a `DeterministicEpisodeAwareSampler` state for resume. + """Map a global optimization step to an `EpisodeAwareSampler` state for resume. Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and keeps every `num_processes`-th batch, so one optimization step consumes diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a772db2ba..405928d45 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -43,12 +43,7 @@ from lerobot.common.train_utils import ( from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets import ( - DeterministicEpisodeAwareSampler, - EpisodeAwareSampler, - compute_sampler_state, - make_dataset, -) +from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors @@ -397,12 +392,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # O(1) memory in dataset size, and a resumed run continues at the exact sample where the # checkpoint left off (up to accelerate's even_batches padding at epoch boundaries). shuffle = False - sampler = DeterministicEpisodeAwareSampler( + sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0), shuffle=True, + deterministic=True, seed=cfg.seed if cfg.seed is not None else 0, ) if cfg.resume and step > 0: diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 0c37cae0e..15e1df7d9 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -163,17 +163,19 @@ def test_partial_episode_drop_warns(caplog): assert "Episode 0" in caplog.text -# --- DeterministicEpisodeAwareSampler --- +# --- deterministic mode (seeded Feistel permutation) --- + +from functools import partial # noqa: E402 + +from lerobot.datasets.sampler import compute_sampler_state # noqa: E402 + +deterministic_sampler = partial(EpisodeAwareSampler, deterministic=True) -from lerobot.datasets.sampler import ( # noqa: E402 - DeterministicEpisodeAwareSampler, - compute_sampler_state, -) EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames -def test_deterministic_sampler_unshuffled_matches_episode_aware(): +def test_deterministic_mode_unshuffled_matches_default_mode(): for kwargs in ( {}, {"drop_n_first_frames": 1}, @@ -181,21 +183,34 @@ def test_deterministic_sampler_unshuffled_matches_episode_aware(): {"episode_indices_to_use": [0, 2]}, ): reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs) - sampler = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs) + sampler = deterministic_sampler(*EPISODE_BOUNDS, shuffle=False, **kwargs) assert list(sampler) == list(reference), kwargs assert len(sampler) == len(reference), kwargs +def test_deterministic_mode_rejects_generator(): + with pytest.raises(ValueError, match="generator is unused in deterministic mode"): + deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, generator=torch.Generator()) + + +def test_state_methods_require_deterministic_mode(): + sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True) + with pytest.raises(RuntimeError, match="deterministic=True"): + sampler.set_epoch(1) + with pytest.raises(RuntimeError, match="deterministic=True"): + sampler.state_dict() + + @pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100]) def test_deterministic_sampler_shuffle_is_permutation(num_frames): for seed in (0, 1, 1234): - sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed) + sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=seed) assert sorted(sampler) == list(range(num_frames)) def test_deterministic_sampler_epochs_reproduce_and_differ(): - sampler_a = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42) - sampler_b = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42) + sampler_a = deterministic_sampler([0], [100], shuffle=True, seed=42) + sampler_b = deterministic_sampler([0], [100], shuffle=True, seed=42) epoch_0 = list(sampler_a) assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch @@ -203,15 +218,15 @@ def test_deterministic_sampler_epochs_reproduce_and_differ(): assert sorted(epoch_1) == sorted(epoch_0) sampler_a.set_epoch(0) assert list(sampler_a) == epoch_0 - assert list(DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0 + assert list(deterministic_sampler([0], [100], shuffle=True, seed=7)) != epoch_0 def test_deterministic_sampler_resume_mid_epoch(): - reference = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + reference = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42) epoch_0 = list(reference) epoch_1 = list(reference) for start in (0, 1, 4, len(epoch_0)): - resumed = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + resumed = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42) resumed.load_state_dict({"epoch": 0, "start_index": start}) assert list(resumed) == epoch_0[start:] # the resumed sampler continues into the same epoch 1 as the uninterrupted one @@ -222,7 +237,7 @@ def test_deterministic_sampler_constant_memory(): # A trillion-frame dataset must instantiate instantly and seek anywhere in O(1): # only per-episode boundaries are stored, never per-frame indices. num_frames = 10**12 - sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0) assert len(sampler) == num_frames sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3}) tail = list(sampler) @@ -232,16 +247,16 @@ def test_deterministic_sampler_constant_memory(): def test_deterministic_sampler_validation_matches_episode_aware(): with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): - DeterministicEpisodeAwareSampler([0], [10], drop_n_first_frames=-1) + deterministic_sampler([0], [10], drop_n_first_frames=-1) with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"): - DeterministicEpisodeAwareSampler([0], [10], drop_n_last_frames=-1) + deterministic_sampler([0], [10], drop_n_last_frames=-1) with pytest.raises(ValueError, match="No valid frames remain"): - DeterministicEpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1) + deterministic_sampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1) def test_deterministic_sampler_partial_episode_drop_warns(caplog): with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"): - sampler = DeterministicEpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False) + sampler = deterministic_sampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False) assert list(sampler) == [2, 3, 4, 5] assert "Episode 0" in caplog.text