feat(datasets): default EpisodeAwareSampler to deterministic mode and trim comments

deterministic=True is now the class default as well as the training
default; the legacy RNG path requires an explicit deterministic=False
(the train script's non-deterministic branch passes it). Docstrings and
inline comments slimmed down across the changed files.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 11:54:22 +02:00
parent b2d5d4ccfc
commit 29ca0f53d9
4 changed files with 46 additions and 69 deletions
+3 -6
View File
@@ -99,12 +99,9 @@ class TrainPipelineConfig(HubMixin):
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
# Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of RNG-based
# index shuffling. The data order becomes a pure function of (seed, epoch), which makes
# distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size,
# and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather
# than a true uniform permutation. Set to false to restore the legacy RNG-based shuffle
# order. Ignored when dataset.streaming is enabled.
# Deterministic data order (pure function of seed and epoch): immune to cross-rank RNG
# desync and enables sample-exact resume. Set to false for the legacy RNG-based shuffle.
# Ignored when dataset.streaming is enabled.
deterministic_sampler: bool = True
steps: int = 100_000
eval_freq: int = 20_000
+29 -51
View File
@@ -27,7 +27,7 @@ _FEISTEL_ROUNDS = 4
def _mix64(x: int) -> int:
"""SplitMix64 finalizer: a high-quality, cheap 64-bit integer hash."""
"""SplitMix64 finalizer (64-bit integer hash)."""
x = (x + 0x9E3779B97F4A7C15) & _MASK_64
x ^= x >> 30
x = (x * 0xBF58476D1CE4E5B9) & _MASK_64
@@ -38,27 +38,20 @@ def _mix64(x: int) -> int:
class EpisodeAwareSampler:
"""Sampler that incorporates episode boundary information.
"""Sampler over episode frames with O(num_episodes) memory.
Frame indices are never materialized: only per-episode boundaries are stored (a few numpy
int64 per episode) and the mapping from a logical position to a frame index is computed on
the fly via `searchsorted`, so memory does not grow with the number of frames.
Only episode boundaries are stored; logical positions map to frame indices on the fly, so
memory does not grow with the number of frames.
Two shuffling modes are supported:
By default (`deterministic=True`) shuffling uses a seeded Feistel permutation over
`[0, num_frames)`: the data order is a pure function of `(seed, epoch)`, needs no RNG
synchronization across distributed ranks, and any position can be sought in O(1), enabling
sample-exact resume via `state_dict` / `load_state_dict`. Each completed `__iter__`
advances the epoch. The shuffle is pseudo-random rather than truly uniform — the standard
large-scale trade-off. During a resumed epoch, `__len__` still reports the full length.
- Default (`deterministic=False`): `torch.randperm` over positions, optionally driven by a
dedicated `generator`. Exposing the `generator` attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so every rank
draws the same permutation and batch shards stay disjoint.
- `deterministic=True`: a seeded Feistel permutation over `[0, num_frames)` (cycle-walking
to the exact domain size). The data order becomes a pure function of `(seed, epoch)`:
nothing to synchronize across ranks, O(1) seek to any position (enabling sample-exact
resume via `state_dict` / `load_state_dict`), and zero epoch-boundary cost at any dataset
size. The shuffle is pseudo-random rather than a true uniform permutation — the standard
trade-off in large-scale training loaders. Each completed `__iter__` advances the internal
epoch, so consecutive dataloader passes yield different permutations without `set_epoch`
calls. During an epoch resumed via `load_state_dict`, `__len__` still reports the full
epoch length; the first pass simply yields fewer samples.
With `deterministic=False`, shuffling falls back to `torch.randperm` driven by `generator`
(accelerate synchronizes the generator across ranks when preparing the dataloader).
"""
def __init__(
@@ -70,20 +63,18 @@ class EpisodeAwareSampler:
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
deterministic: bool = False,
deterministic: bool = True,
seed: int = 0,
):
"""
Args:
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
dataset_from_indices: Start index of each episode in the dataset.
dataset_to_indices: End index of each episode in the dataset.
episode_indices_to_use: Episode indices to use; None means all.
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for default-mode shuffling. When None, shuffling falls
back to the global torch RNG. Incompatible with `deterministic=True`.
generator: Generator for non-deterministic shuffling (global torch RNG when None).
deterministic: Use the seeded Feistel permutation instead of `torch.randperm`.
seed: Seed the deterministic permutation is derived from (together with the epoch).
"""
@@ -135,22 +126,18 @@ class EpisodeAwareSampler:
self._epoch = 0
self._start_index = 0
# Feistel cipher domain: the smallest even-bit-width power of two >= num_frames,
# so both halves have equal width and cycle-walking converges in <4 expected steps.
# Smallest even-bit-width power-of-two domain >= num_frames: equal Feistel halves,
# cycle-walking converges in <4 expected steps.
bits = max((self._num_frames - 1).bit_length(), 2)
self._half_bits = (bits + 1) // 2
self._half_mask = (1 << self._half_bits) - 1
@property
def indices(self) -> list[int]:
"""Materialized frame indices, in unshuffled order (back-compat / introspection only).
This builds an O(num_frames) list — avoid on very large datasets; iteration never uses it.
"""
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
"""Set the epoch the next `__iter__` will use (DistributedSampler convention)."""
self._require_deterministic("set_epoch")
self._epoch = epoch
@@ -165,10 +152,7 @@ class EpisodeAwareSampler:
def _require_deterministic(self, method: str) -> None:
if not self.deterministic:
raise RuntimeError(
f"{method} is only meaningful with deterministic=True: in default mode the "
"order is drawn from the (generator) RNG and cannot be sought."
)
raise RuntimeError(f"{method} requires deterministic=True: an RNG order cannot be sought.")
def _round_keys(self, epoch: int) -> list[int]:
state = _mix64(_mix64(self.seed) ^ _mix64(epoch))
@@ -197,8 +181,7 @@ class EpisodeAwareSampler:
def __iter__(self) -> Iterator[int]:
if not self.deterministic:
return self._iter_default()
# Capture and advance state eagerly so epoch bookkeeping is not deferred until the
# returned generator is first consumed.
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
@@ -222,17 +205,12 @@ class EpisodeAwareSampler:
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map a global optimization step to an `EpisodeAwareSampler` state for resume.
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and
keeps every `num_processes`-th batch, so one optimization step consumes
`batch_size * num_processes` consecutive sampler positions, and (with `even_batches` padding)
each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches per epoch.
`batches_into_epoch * num_processes <= ceil(num_frames / batch_size) - 1` always holds, so the
start index stays strictly below `num_frames`; the `min` is purely defensive. Resume is
sample-exact up to the `even_batches` padding accelerate appends at epoch boundaries (at most
`num_processes - 1` duplicated batches per epoch, the same duplication non-resumed runs get).
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
+3 -7
View File
@@ -388,9 +388,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if cfg.deterministic_sampler and not cfg.dataset.streaming:
# Data order is a pure function of (seed, epoch): nothing to synchronize across ranks,
# O(1) memory in dataset size, and a resumed run continues at the exact sample where the
# checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).
# Deterministic data order: no cross-rank RNG sync needed, sample-exact resume.
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
@@ -398,7 +396,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
deterministic=True,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
@@ -413,9 +410,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
)
elif hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
# Legacy RNG shuffle: a dedicated generator lets accelerate synchronize it across ranks.
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
@@ -425,6 +420,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
deterministic=False,
generator=sampler_generator,
)
else:
+11 -5
View File
@@ -118,12 +118,18 @@ def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_a = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
sampler_b = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_c = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
@@ -133,7 +139,7 @@ def test_shuffle_with_generator_is_deterministic():
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
sampler = EpisodeAwareSampler([0], [6], shuffle=True, deterministic=False)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
@@ -194,7 +200,7 @@ def test_deterministic_mode_rejects_generator():
def test_state_methods_require_deterministic_mode():
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True)
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, deterministic=False)
with pytest.raises(RuntimeError, match="deterministic=True"):
sampler.set_epoch(1)
with pytest.raises(RuntimeError, match="deterministic=True"):