From 6fa495c6b0861824932a5e3b9977363a4723b74f Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 11 Jun 2026 10:33:52 +0200 Subject: [PATCH] feat(datasets): add DeterministicEpisodeAwareSampler with O(1) memory and sample-exact resume Add a sampler that never materializes frame indices: it stores only per-episode boundaries (numpy, a few bytes per episode) and maps logical positions to frame indices on the fly with searchsorted. Shuffling uses a seeded Feistel permutation over [0, num_frames) (cycle-walking to the exact domain), so the data order is a pure function of (seed, epoch): - no RNG state to synchronize across distributed ranks, - constant memory and zero epoch-boundary cost at any dataset size, - O(1) seek to any position, enabling sample-exact resume. Opt in with --deterministic_sampler=true. On resume, lerobot-train maps the checkpointed step back to (epoch, start_index) via compute_sampler_state and continues at the exact sample where the run left off (up to accelerate's even_batches padding at epoch boundaries). The shuffle is pseudo-random rather than a true uniform permutation, the standard trade-off in large-scale training loaders. Co-Authored-By: Claude Fable 5 --- src/lerobot/configs/train.py | 12 ++ src/lerobot/datasets/__init__.py | 4 +- src/lerobot/datasets/sampler.py | 181 +++++++++++++++++++++++++++ src/lerobot/scripts/lerobot_train.py | 32 ++++- tests/datasets/test_sampler.py | 106 ++++++++++++++++ 5 files changed, 332 insertions(+), 3 deletions(-) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index bac1a946b..ab34d43c7 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -99,6 +99,12 @@ class TrainPipelineConfig(HubMixin): batch_size: int = 8 prefetch_factor: int = 4 persistent_workers: bool = True + # Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of materialized + # index shuffling. The data order becomes a pure function of (seed, epoch), which makes + # distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size, + # and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather + # than a true uniform permutation. Not compatible with dataset.streaming. + deterministic_sampler: bool = False steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 @@ -133,6 +139,12 @@ class TrainPipelineConfig(HubMixin): return self.policy # type: ignore[return-value] def validate(self) -> None: + if self.deterministic_sampler and self.dataset.streaming: + raise ValueError( + "deterministic_sampler requires a map-style dataset and is not compatible with " + "dataset.streaming=true." + ) + # HACK: We parse again the cli args here to get the pretrained paths if there was some. policy_path = parser.get_path_arg("policy") reward_model_path = parser.get_path_arg("reward_model") diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index 2a67858d2..f7d0ff889 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset from .multi_dataset import MultiLeRobotDataset from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav -from .sampler import EpisodeAwareSampler +from .sampler import DeterministicEpisodeAwareSampler, EpisodeAwareSampler, compute_sampler_state from .streaming_dataset import StreamingLeRobotDataset from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card from .video_utils import VideoEncodingManager @@ -64,6 +64,7 @@ __all__ = [ "DEFAULT_EPISODES_PATH", "DEFAULT_QUANTILES", "EVENT_ONLY_STYLES", + "DeterministicEpisodeAwareSampler", "EpisodeAwareSampler", "LANGUAGE_EVENTS", "LANGUAGE_PERSISTENT", @@ -82,6 +83,7 @@ __all__ = [ "aggregate_stats", "convert_image_to_video_dataset", "create_initial_features", + "compute_sampler_state", "create_lerobot_dataset_card", "column_for_style", "delete_episodes", diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 64d871907..851bdbada 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -14,12 +14,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math from collections.abc import Iterator +import numpy as np import torch logger = logging.getLogger(__name__) +_MASK_64 = (1 << 64) - 1 +_FEISTEL_ROUNDS = 4 + + +def _mix64(x: int) -> int: + """SplitMix64 finalizer: a high-quality, cheap 64-bit integer hash.""" + x = (x + 0x9E3779B97F4A7C15) & _MASK_64 + x ^= x >> 30 + x = (x * 0xBF58476D1CE4E5B9) & _MASK_64 + x ^= x >> 27 + x = (x * 0x94D049BB133111EB) & _MASK_64 + x ^= x >> 31 + return x + class EpisodeAwareSampler: def __init__( @@ -90,3 +106,168 @@ class EpisodeAwareSampler: def __len__(self) -> int: return len(self.indices) + + +class DeterministicEpisodeAwareSampler: + """Episode-aware sampler with O(num_episodes) memory and a deterministic, seekable shuffle. + + Unlike `EpisodeAwareSampler`, frame indices are never materialized: only per-episode + boundaries are stored (a few numpy int64 per episode) and the mapping from a logical + position to a frame index is computed on the fly via `searchsorted`. Shuffling applies a + seeded Feistel permutation over `[0, num_frames)` (cycle-walking to the exact domain size) + instead of `torch.randperm`, so the data order is a pure function of `(seed, epoch)`: + + - every distributed rank derives the identical order with no RNG state to synchronize, + - any position can be sought in O(1), enabling sample-exact resume of interrupted runs + (see `state_dict` / `load_state_dict`), + - memory and epoch-boundary cost do not grow with the number of frames. + + The trade-off is that the shuffle is pseudo-random rather than a true uniform permutation, + which is the standard compromise in large-scale training loaders. + + Each completed `__iter__` advances the internal epoch, so consecutive dataloader passes + yield different permutations without requiring `set_epoch` calls. During an epoch resumed + via `load_state_dict`, `__len__` still reports the full epoch length; the first pass simply + yields fewer samples. + """ + + def __init__( + self, + dataset_from_indices: list[int], + dataset_to_indices: list[int], + episode_indices_to_use: list | None = None, + drop_n_first_frames: int = 0, + drop_n_last_frames: int = 0, + shuffle: bool = True, + seed: int = 0, + ): + """ + Args: + dataset_from_indices: List of indices containing the start of each episode in the dataset. + dataset_to_indices: List of indices containing the end of each episode in the dataset. + episode_indices_to_use: List of episode indices to use. If None, all episodes are used. + Assumes that episodes are indexed from 0 to N-1. + drop_n_first_frames: Number of frames to drop from the start of each episode. + drop_n_last_frames: Number of frames to drop from the end of each episode. + shuffle: Whether to shuffle with the seeded Feistel permutation. + seed: Seed the permutation is derived from (together with the epoch). + """ + if drop_n_first_frames < 0: + raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") + if drop_n_last_frames < 0: + raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") + + from_indices = np.asarray(dataset_from_indices, dtype=np.int64) + to_indices = np.asarray(dataset_to_indices, dtype=np.int64) + if from_indices.shape != to_indices.shape: + raise ValueError( + f"dataset_from_indices and dataset_to_indices must have the same length, " + f"got {len(from_indices)} and {len(to_indices)}" + ) + + used = np.ones(len(from_indices), dtype=bool) + if episode_indices_to_use is not None: + used = np.zeros(len(from_indices), dtype=bool) + used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True + + starts = from_indices + drop_n_first_frames + lengths = to_indices - drop_n_last_frames - starts + for episode_idx in np.flatnonzero(used & (lengths <= 0)): + logger.warning( + "Episode %d has %d frames but drop_n_first_frames=%d and " + "drop_n_last_frames=%d removes all frames. Skipping.", + episode_idx, + to_indices[episode_idx] - from_indices[episode_idx], + drop_n_first_frames, + drop_n_last_frames, + ) + used &= lengths > 0 + if not used.any(): + raise ValueError( + "No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. " + "All episodes were either filtered out or had too few frames." + ) + + self._starts = starts[used] + self._cum_lengths = np.cumsum(lengths[used]) + self._num_frames = int(self._cum_lengths[-1]) + self.shuffle = shuffle + self.seed = seed + self._epoch = 0 + self._start_index = 0 + + # Feistel cipher domain: the smallest even-bit-width power of two >= num_frames, + # so both halves have equal width and cycle-walking converges in <4 expected steps. + bits = max((self._num_frames - 1).bit_length(), 2) + self._half_bits = (bits + 1) // 2 + self._half_mask = (1 << self._half_bits) - 1 + + def set_epoch(self, epoch: int) -> None: + """Set the epoch the next `__iter__` will use (DistributedSampler convention).""" + self._epoch = epoch + + def state_dict(self) -> dict: + return {"epoch": self._epoch, "start_index": self._start_index} + + def load_state_dict(self, state: dict) -> None: + self._epoch = state["epoch"] + self._start_index = state["start_index"] + + def _round_keys(self, epoch: int) -> list[int]: + state = _mix64(_mix64(self.seed) ^ _mix64(epoch)) + keys = [] + for _ in range(_FEISTEL_ROUNDS): + state = _mix64(state) + keys.append(state) + return keys + + def _permute(self, index: int, keys: list[int]) -> int: + # Feistel network with cycle-walking: a bijection on [0, num_frames). + half_bits, half_mask = self._half_bits, self._half_mask + while True: + left, right = index >> half_bits, index & half_mask + for key in keys: + left, right = right, left ^ (_mix64(right ^ key) & half_mask) + index = (left << half_bits) | right + if index < self._num_frames: + return index + + def _frame_index(self, position: int) -> int: + episode = int(np.searchsorted(self._cum_lengths, position, side="right")) + position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0) + return int(self._starts[episode]) + position_in_episode + + def __iter__(self) -> Iterator[int]: + # Capture and advance state eagerly so epoch bookkeeping is not deferred until the + # returned generator is first consumed. + epoch, start = self._epoch, self._start_index + self._epoch += 1 + self._start_index = 0 + return self._iter_epoch(epoch, start) + + def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]: + keys = self._round_keys(epoch) if self.shuffle else None + for k in range(start, self._num_frames): + yield self._frame_index(self._permute(k, keys) if self.shuffle else k) + + def __len__(self) -> int: + return self._num_frames + + +def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict: + """Map a global optimization step to a `DeterministicEpisodeAwareSampler` state for resume. + + Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and + keeps every `num_processes`-th batch, so one optimization step consumes + `batch_size * num_processes` consecutive sampler positions, and (with `even_batches` padding) + each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches per epoch. + + `batches_into_epoch * num_processes <= ceil(num_frames / batch_size) - 1` always holds, so the + start index stays strictly below `num_frames`; the `min` is purely defensive. Resume is + sample-exact up to the `even_batches` padding accelerate appends at epoch boundaries (at most + `num_processes - 1` duplicated batches per epoch, the same duplication non-resumed runs get). + """ + batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes) + epoch, batches_into_epoch = divmod(step, batches_per_epoch) + start_index = min(batches_into_epoch * batch_size * num_processes, num_frames) + return {"epoch": epoch, "start_index": start_index} diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 3d210f00b..a772db2ba 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -43,7 +43,12 @@ from lerobot.common.train_utils import ( from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets import EpisodeAwareSampler, make_dataset +from lerobot.datasets import ( + DeterministicEpisodeAwareSampler, + EpisodeAwareSampler, + compute_sampler_state, + make_dataset, +) from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors @@ -387,7 +392,30 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training - if hasattr(active_cfg, "drop_n_last_frames"): + if cfg.deterministic_sampler: + # Data order is a pure function of (seed, epoch): nothing to synchronize across ranks, + # O(1) memory in dataset size, and a resumed run continues at the exact sample where the + # checkpoint left off (up to accelerate's even_batches padding at epoch boundaries). + shuffle = False + sampler = DeterministicEpisodeAwareSampler( + dataset.meta.episodes["dataset_from_index"], + dataset.meta.episodes["dataset_to_index"], + episode_indices_to_use=dataset.episodes, + drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0), + shuffle=True, + seed=cfg.seed if cfg.seed is not None else 0, + ) + if cfg.resume and step > 0: + sampler_state = compute_sampler_state( + step, len(sampler), cfg.batch_size, accelerator.num_processes + ) + sampler.load_state_dict(sampler_state) + if is_main_process: + logging.info( + f"Resuming data order at epoch {sampler_state['epoch']}, " + f"sample {sampler_state['start_index']}" + ) + elif hasattr(active_cfg, "drop_n_last_frames"): shuffle = False # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 95429c7ec..0c37cae0e 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -161,3 +161,109 @@ def test_partial_episode_drop_warns(caplog): # Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5 assert sampler.indices == [2, 3, 4, 5] assert "Episode 0" in caplog.text + + +# --- DeterministicEpisodeAwareSampler --- + +from lerobot.datasets.sampler import ( # noqa: E402 + DeterministicEpisodeAwareSampler, + compute_sampler_state, +) + +EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames + + +def test_deterministic_sampler_unshuffled_matches_episode_aware(): + for kwargs in ( + {}, + {"drop_n_first_frames": 1}, + {"drop_n_last_frames": 1}, + {"episode_indices_to_use": [0, 2]}, + ): + reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs) + sampler = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs) + assert list(sampler) == list(reference), kwargs + assert len(sampler) == len(reference), kwargs + + +@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100]) +def test_deterministic_sampler_shuffle_is_permutation(num_frames): + for seed in (0, 1, 1234): + sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed) + assert sorted(sampler) == list(range(num_frames)) + + +def test_deterministic_sampler_epochs_reproduce_and_differ(): + sampler_a = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42) + sampler_b = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42) + epoch_0 = list(sampler_a) + assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process + epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch + assert epoch_1 != epoch_0 + assert sorted(epoch_1) == sorted(epoch_0) + sampler_a.set_epoch(0) + assert list(sampler_a) == epoch_0 + assert list(DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0 + + +def test_deterministic_sampler_resume_mid_epoch(): + reference = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + epoch_0 = list(reference) + epoch_1 = list(reference) + for start in (0, 1, 4, len(epoch_0)): + resumed = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + resumed.load_state_dict({"epoch": 0, "start_index": start}) + assert list(resumed) == epoch_0[start:] + # the resumed sampler continues into the same epoch 1 as the uninterrupted one + assert list(resumed) == epoch_1 + + +def test_deterministic_sampler_constant_memory(): + # A trillion-frame dataset must instantiate instantly and seek anywhere in O(1): + # only per-episode boundaries are stored, never per-frame indices. + num_frames = 10**12 + sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + assert len(sampler) == num_frames + sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3}) + tail = list(sampler) + assert len(tail) == 3 + assert all(0 <= idx < num_frames for idx in tail) + + +def test_deterministic_sampler_validation_matches_episode_aware(): + with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): + DeterministicEpisodeAwareSampler([0], [10], drop_n_first_frames=-1) + with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"): + DeterministicEpisodeAwareSampler([0], [10], drop_n_last_frames=-1) + with pytest.raises(ValueError, match="No valid frames remain"): + DeterministicEpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1) + + +def test_deterministic_sampler_partial_episode_drop_warns(caplog): + with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"): + sampler = DeterministicEpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False) + assert list(sampler) == [2, 3, 4, 5] + assert "Episode 0" in caplog.text + + +def test_compute_sampler_state(): + # 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch. + assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == { + "epoch": 0, + "start_index": 0, + } + # step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in + assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == { + "epoch": 1, + "start_index": 40, + } + # uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank + assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == { + "epoch": 2, + "start_index": 40, + } + # uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads) + assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == { + "epoch": 1, + "start_index": 100, + }