mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
feat(datasets): default EpisodeAwareSampler to deterministic mode and trim comments
deterministic=True is now the class default as well as the training default; the legacy RNG path requires an explicit deterministic=False (the train script's non-deterministic branch passes it). Docstrings and inline comments slimmed down across the changed files. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -99,12 +99,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
batch_size: int = 8
|
||||
prefetch_factor: int = 4
|
||||
persistent_workers: bool = True
|
||||
# Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of RNG-based
|
||||
# index shuffling. The data order becomes a pure function of (seed, epoch), which makes
|
||||
# distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size,
|
||||
# and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather
|
||||
# than a true uniform permutation. Set to false to restore the legacy RNG-based shuffle
|
||||
# order. Ignored when dataset.streaming is enabled.
|
||||
# Deterministic data order (pure function of seed and epoch): immune to cross-rank RNG
|
||||
# desync and enables sample-exact resume. Set to false for the legacy RNG-based shuffle.
|
||||
# Ignored when dataset.streaming is enabled.
|
||||
deterministic_sampler: bool = True
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
|
||||
@@ -27,7 +27,7 @@ _FEISTEL_ROUNDS = 4
|
||||
|
||||
|
||||
def _mix64(x: int) -> int:
|
||||
"""SplitMix64 finalizer: a high-quality, cheap 64-bit integer hash."""
|
||||
"""SplitMix64 finalizer (64-bit integer hash)."""
|
||||
x = (x + 0x9E3779B97F4A7C15) & _MASK_64
|
||||
x ^= x >> 30
|
||||
x = (x * 0xBF58476D1CE4E5B9) & _MASK_64
|
||||
@@ -38,27 +38,20 @@ def _mix64(x: int) -> int:
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
"""Sampler that incorporates episode boundary information.
|
||||
"""Sampler over episode frames with O(num_episodes) memory.
|
||||
|
||||
Frame indices are never materialized: only per-episode boundaries are stored (a few numpy
|
||||
int64 per episode) and the mapping from a logical position to a frame index is computed on
|
||||
the fly via `searchsorted`, so memory does not grow with the number of frames.
|
||||
Only episode boundaries are stored; logical positions map to frame indices on the fly, so
|
||||
memory does not grow with the number of frames.
|
||||
|
||||
Two shuffling modes are supported:
|
||||
By default (`deterministic=True`) shuffling uses a seeded Feistel permutation over
|
||||
`[0, num_frames)`: the data order is a pure function of `(seed, epoch)`, needs no RNG
|
||||
synchronization across distributed ranks, and any position can be sought in O(1), enabling
|
||||
sample-exact resume via `state_dict` / `load_state_dict`. Each completed `__iter__`
|
||||
advances the epoch. The shuffle is pseudo-random rather than truly uniform — the standard
|
||||
large-scale trade-off. During a resumed epoch, `__len__` still reports the full length.
|
||||
|
||||
- Default (`deterministic=False`): `torch.randperm` over positions, optionally driven by a
|
||||
dedicated `generator`. Exposing the `generator` attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so every rank
|
||||
draws the same permutation and batch shards stay disjoint.
|
||||
- `deterministic=True`: a seeded Feistel permutation over `[0, num_frames)` (cycle-walking
|
||||
to the exact domain size). The data order becomes a pure function of `(seed, epoch)`:
|
||||
nothing to synchronize across ranks, O(1) seek to any position (enabling sample-exact
|
||||
resume via `state_dict` / `load_state_dict`), and zero epoch-boundary cost at any dataset
|
||||
size. The shuffle is pseudo-random rather than a true uniform permutation — the standard
|
||||
trade-off in large-scale training loaders. Each completed `__iter__` advances the internal
|
||||
epoch, so consecutive dataloader passes yield different permutations without `set_epoch`
|
||||
calls. During an epoch resumed via `load_state_dict`, `__len__` still reports the full
|
||||
epoch length; the first pass simply yields fewer samples.
|
||||
With `deterministic=False`, shuffling falls back to `torch.randperm` driven by `generator`
|
||||
(accelerate synchronizes the generator across ranks when preparing the dataloader).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -70,20 +63,18 @@ class EpisodeAwareSampler:
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
deterministic: bool = False,
|
||||
deterministic: bool = True,
|
||||
seed: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
dataset_from_indices: Start index of each episode in the dataset.
|
||||
dataset_to_indices: End index of each episode in the dataset.
|
||||
episode_indices_to_use: Episode indices to use; None means all.
|
||||
drop_n_first_frames: Frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for default-mode shuffling. When None, shuffling falls
|
||||
back to the global torch RNG. Incompatible with `deterministic=True`.
|
||||
generator: Generator for non-deterministic shuffling (global torch RNG when None).
|
||||
deterministic: Use the seeded Feistel permutation instead of `torch.randperm`.
|
||||
seed: Seed the deterministic permutation is derived from (together with the epoch).
|
||||
"""
|
||||
@@ -135,22 +126,18 @@ class EpisodeAwareSampler:
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
|
||||
# Feistel cipher domain: the smallest even-bit-width power of two >= num_frames,
|
||||
# so both halves have equal width and cycle-walking converges in <4 expected steps.
|
||||
# Smallest even-bit-width power-of-two domain >= num_frames: equal Feistel halves,
|
||||
# cycle-walking converges in <4 expected steps.
|
||||
bits = max((self._num_frames - 1).bit_length(), 2)
|
||||
self._half_bits = (bits + 1) // 2
|
||||
self._half_mask = (1 << self._half_bits) - 1
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
"""Materialized frame indices, in unshuffled order (back-compat / introspection only).
|
||||
|
||||
This builds an O(num_frames) list — avoid on very large datasets; iteration never uses it.
|
||||
"""
|
||||
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
|
||||
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Set the epoch the next `__iter__` will use (DistributedSampler convention)."""
|
||||
self._require_deterministic("set_epoch")
|
||||
self._epoch = epoch
|
||||
|
||||
@@ -165,10 +152,7 @@ class EpisodeAwareSampler:
|
||||
|
||||
def _require_deterministic(self, method: str) -> None:
|
||||
if not self.deterministic:
|
||||
raise RuntimeError(
|
||||
f"{method} is only meaningful with deterministic=True: in default mode the "
|
||||
"order is drawn from the (generator) RNG and cannot be sought."
|
||||
)
|
||||
raise RuntimeError(f"{method} requires deterministic=True: an RNG order cannot be sought.")
|
||||
|
||||
def _round_keys(self, epoch: int) -> list[int]:
|
||||
state = _mix64(_mix64(self.seed) ^ _mix64(epoch))
|
||||
@@ -197,8 +181,7 @@ class EpisodeAwareSampler:
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if not self.deterministic:
|
||||
return self._iter_default()
|
||||
# Capture and advance state eagerly so epoch bookkeeping is not deferred until the
|
||||
# returned generator is first consumed.
|
||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||
epoch, start = self._epoch, self._start_index
|
||||
self._epoch += 1
|
||||
self._start_index = 0
|
||||
@@ -222,17 +205,12 @@ class EpisodeAwareSampler:
|
||||
|
||||
|
||||
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||
"""Map a global optimization step to an `EpisodeAwareSampler` state for resume.
|
||||
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
|
||||
|
||||
Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and
|
||||
keeps every `num_processes`-th batch, so one optimization step consumes
|
||||
`batch_size * num_processes` consecutive sampler positions, and (with `even_batches` padding)
|
||||
each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches per epoch.
|
||||
|
||||
`batches_into_epoch * num_processes <= ceil(num_frames / batch_size) - 1` always holds, so the
|
||||
start index stays strictly below `num_frames`; the `min` is purely defensive. Resume is
|
||||
sample-exact up to the `even_batches` padding accelerate appends at epoch boundaries (at most
|
||||
`num_processes - 1` duplicated batches per epoch, the same duplication non-resumed runs get).
|
||||
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
|
||||
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||
per epoch (`even_batches` padding included). The start index provably stays below
|
||||
`num_frames`; the `min` is defensive.
|
||||
"""
|
||||
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||
|
||||
@@ -388,9 +388,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
# create dataloader for offline training
|
||||
if cfg.deterministic_sampler and not cfg.dataset.streaming:
|
||||
# Data order is a pure function of (seed, epoch): nothing to synchronize across ranks,
|
||||
# O(1) memory in dataset size, and a resumed run continues at the exact sample where the
|
||||
# checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).
|
||||
# Deterministic data order: no cross-rank RNG sync needed, sample-exact resume.
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
@@ -398,7 +396,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||
shuffle=True,
|
||||
deterministic=True,
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
@@ -413,9 +410,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
)
|
||||
elif hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
# Legacy RNG shuffle: a dedicated generator lets accelerate synchronize it across ranks.
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
@@ -425,6 +420,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
deterministic=False,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -118,12 +118,18 @@ def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_a = EpisodeAwareSampler(
|
||||
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
sampler_b = EpisodeAwareSampler(
|
||||
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_c = EpisodeAwareSampler(
|
||||
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
|
||||
)
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
@@ -133,7 +139,7 @@ def test_shuffle_with_generator_is_deterministic():
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True, deterministic=False)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
@@ -194,7 +200,7 @@ def test_deterministic_mode_rejects_generator():
|
||||
|
||||
|
||||
def test_state_methods_require_deterministic_mode():
|
||||
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True)
|
||||
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, deterministic=False)
|
||||
with pytest.raises(RuntimeError, match="deterministic=True"):
|
||||
sampler.set_epoch(1)
|
||||
with pytest.raises(RuntimeError, match="deterministic=True"):
|
||||
|
||||
Reference in New Issue
Block a user