refactor(datasets): fold deterministic mode into EpisodeAwareSampler

Instead of a parallel DeterministicEpisodeAwareSampler class, extend the
existing EpisodeAwareSampler with a deterministic=True mode (seeded
Feistel permutation, epoch auto-advance, state_dict/load_state_dict).

The default mode is behavior-identical: same torch.randperm consumption
and the same generator contract accelerate synchronizes; the O(N) Python
index list is replaced by O(num_episodes) boundary arrays in both modes,
with `indices` kept as a back-compat property. Passing a generator
together with deterministic=True is rejected, and the state/seek methods
raise outside deterministic mode.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 11:37:44 +02:00
parent 7416b714c0
commit 32b0d7d1ef
4 changed files with 98 additions and 121 deletions
+1 -2
View File
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import DeterministicEpisodeAwareSampler, EpisodeAwareSampler, compute_sampler_state
from .sampler import EpisodeAwareSampler, compute_sampler_state
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -64,7 +64,6 @@ __all__ = [
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"DeterministicEpisodeAwareSampler",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
+61 -94
View File
@@ -38,97 +38,27 @@ def _mix64(x: int) -> int:
class EpisodeAwareSampler:
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
"""Sampler that incorporates episode boundary information.
Args:
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
Frame indices are never materialized: only per-episode boundaries are stored (a few numpy
int64 per episode) and the mapping from a logical position to a frame index is computed on
the fly via `searchsorted`, so memory does not grow with the number of frames.
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
Two shuffling modes are supported:
if not indices:
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
else:
for i in self.indices:
yield i
def __len__(self) -> int:
return len(self.indices)
class DeterministicEpisodeAwareSampler:
"""Episode-aware sampler with O(num_episodes) memory and a deterministic, seekable shuffle.
Unlike `EpisodeAwareSampler`, frame indices are never materialized: only per-episode
boundaries are stored (a few numpy int64 per episode) and the mapping from a logical
position to a frame index is computed on the fly via `searchsorted`. Shuffling applies a
seeded Feistel permutation over `[0, num_frames)` (cycle-walking to the exact domain size)
instead of `torch.randperm`, so the data order is a pure function of `(seed, epoch)`:
- every distributed rank derives the identical order with no RNG state to synchronize,
- any position can be sought in O(1), enabling sample-exact resume of interrupted runs
(see `state_dict` / `load_state_dict`),
- memory and epoch-boundary cost do not grow with the number of frames.
The trade-off is that the shuffle is pseudo-random rather than a true uniform permutation,
which is the standard compromise in large-scale training loaders.
Each completed `__iter__` advances the internal epoch, so consecutive dataloader passes
yield different permutations without requiring `set_epoch` calls. During an epoch resumed
via `load_state_dict`, `__len__` still reports the full epoch length; the first pass simply
yields fewer samples.
- Default (`deterministic=False`): `torch.randperm` over positions, optionally driven by a
dedicated `generator`. Exposing the `generator` attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so every rank
draws the same permutation and batch shards stay disjoint.
- `deterministic=True`: a seeded Feistel permutation over `[0, num_frames)` (cycle-walking
to the exact domain size). The data order becomes a pure function of `(seed, epoch)`:
nothing to synchronize across ranks, O(1) seek to any position (enabling sample-exact
resume via `state_dict` / `load_state_dict`), and zero epoch-boundary cost at any dataset
size. The shuffle is pseudo-random rather than a true uniform permutation — the standard
trade-off in large-scale training loaders. Each completed `__iter__` advances the internal
epoch, so consecutive dataloader passes yield different permutations without `set_epoch`
calls. During an epoch resumed via `load_state_dict`, `__len__` still reports the full
epoch length; the first pass simply yields fewer samples.
"""
def __init__(
@@ -138,7 +68,9 @@ class DeterministicEpisodeAwareSampler:
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = True,
shuffle: bool = False,
generator: torch.Generator | None = None,
deterministic: bool = False,
seed: int = 0,
):
"""
@@ -149,13 +81,18 @@ class DeterministicEpisodeAwareSampler:
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle with the seeded Feistel permutation.
seed: Seed the permutation is derived from (together with the epoch).
shuffle: Whether to shuffle the indices.
generator: Generator used for default-mode shuffling. When None, shuffling falls
back to the global torch RNG. Incompatible with `deterministic=True`.
deterministic: Use the seeded Feistel permutation instead of `torch.randperm`.
seed: Seed the deterministic permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
if deterministic and generator is not None:
raise ValueError("generator is unused in deterministic mode; pass seed instead.")
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
@@ -192,6 +129,8 @@ class DeterministicEpisodeAwareSampler:
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.shuffle = shuffle
self.generator = generator
self.deterministic = deterministic
self.seed = seed
self._epoch = 0
self._start_index = 0
@@ -202,17 +141,35 @@ class DeterministicEpisodeAwareSampler:
self._half_bits = (bits + 1) // 2
self._half_mask = (1 << self._half_bits) - 1
@property
def indices(self) -> list[int]:
"""Materialized frame indices, in unshuffled order (back-compat / introspection only).
This builds an O(num_frames) list — avoid on very large datasets; iteration never uses it.
"""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
"""Set the epoch the next `__iter__` will use (DistributedSampler convention)."""
self._require_deterministic("set_epoch")
self._epoch = epoch
def state_dict(self) -> dict:
self._require_deterministic("state_dict")
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._require_deterministic("load_state_dict")
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _require_deterministic(self, method: str) -> None:
if not self.deterministic:
raise RuntimeError(
f"{method} is only meaningful with deterministic=True: in default mode the "
"order is drawn from the (generator) RNG and cannot be sought."
)
def _round_keys(self, epoch: int) -> list[int]:
state = _mix64(_mix64(self.seed) ^ _mix64(epoch))
keys = []
@@ -238,14 +195,24 @@ class DeterministicEpisodeAwareSampler:
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
if not self.deterministic:
return self._iter_default()
# Capture and advance state eagerly so epoch bookkeeping is not deferred until the
# returned generator is first consumed.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_epoch(epoch, start)
return self._iter_deterministic_epoch(epoch, start)
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
def _iter_default(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(self._num_frames, generator=self.generator):
yield self._frame_index(int(i))
else:
for k in range(self._num_frames):
yield self._frame_index(k)
def _iter_deterministic_epoch(self, epoch: int, start: int) -> Iterator[int]:
keys = self._round_keys(epoch) if self.shuffle else None
for k in range(start, self._num_frames):
yield self._frame_index(self._permute(k, keys) if self.shuffle else k)
@@ -255,7 +222,7 @@ class DeterministicEpisodeAwareSampler:
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map a global optimization step to a `DeterministicEpisodeAwareSampler` state for resume.
"""Map a global optimization step to an `EpisodeAwareSampler` state for resume.
Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and
keeps every `num_processes`-th batch, so one optimization step consumes
+3 -7
View File
@@ -43,12 +43,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import (
DeterministicEpisodeAwareSampler,
EpisodeAwareSampler,
compute_sampler_state,
make_dataset,
)
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -397,12 +392,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# O(1) memory in dataset size, and a resumed run continues at the exact sample where the
# checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).
shuffle = False
sampler = DeterministicEpisodeAwareSampler(
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
deterministic=True,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
+33 -18
View File
@@ -163,17 +163,19 @@ def test_partial_episode_drop_warns(caplog):
assert "Episode 0" in caplog.text
# --- DeterministicEpisodeAwareSampler ---
# --- deterministic mode (seeded Feistel permutation) ---
from functools import partial # noqa: E402
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
deterministic_sampler = partial(EpisodeAwareSampler, deterministic=True)
from lerobot.datasets.sampler import ( # noqa: E402
DeterministicEpisodeAwareSampler,
compute_sampler_state,
)
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
def test_deterministic_sampler_unshuffled_matches_episode_aware():
def test_deterministic_mode_unshuffled_matches_default_mode():
for kwargs in (
{},
{"drop_n_first_frames": 1},
@@ -181,21 +183,34 @@ def test_deterministic_sampler_unshuffled_matches_episode_aware():
{"episode_indices_to_use": [0, 2]},
):
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
sampler = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
sampler = deterministic_sampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
assert list(sampler) == list(reference), kwargs
assert len(sampler) == len(reference), kwargs
def test_deterministic_mode_rejects_generator():
with pytest.raises(ValueError, match="generator is unused in deterministic mode"):
deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, generator=torch.Generator())
def test_state_methods_require_deterministic_mode():
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True)
with pytest.raises(RuntimeError, match="deterministic=True"):
sampler.set_epoch(1)
with pytest.raises(RuntimeError, match="deterministic=True"):
sampler.state_dict()
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
for seed in (0, 1, 1234):
sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=seed)
assert sorted(sampler) == list(range(num_frames))
def test_deterministic_sampler_epochs_reproduce_and_differ():
sampler_a = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42)
sampler_b = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42)
sampler_a = deterministic_sampler([0], [100], shuffle=True, seed=42)
sampler_b = deterministic_sampler([0], [100], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
@@ -203,15 +218,15 @@ def test_deterministic_sampler_epochs_reproduce_and_differ():
assert sorted(epoch_1) == sorted(epoch_0)
sampler_a.set_epoch(0)
assert list(sampler_a) == epoch_0
assert list(DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
assert list(deterministic_sampler([0], [100], shuffle=True, seed=7)) != epoch_0
def test_deterministic_sampler_resume_mid_epoch():
reference = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
reference = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
epoch_0 = list(reference)
epoch_1 = list(reference)
for start in (0, 1, 4, len(epoch_0)):
resumed = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
@@ -222,7 +237,7 @@ def test_deterministic_sampler_constant_memory():
# A trillion-frame dataset must instantiate instantly and seek anywhere in O(1):
# only per-episode boundaries are stored, never per-frame indices.
num_frames = 10**12
sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3})
tail = list(sampler)
@@ -232,16 +247,16 @@ def test_deterministic_sampler_constant_memory():
def test_deterministic_sampler_validation_matches_episode_aware():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
DeterministicEpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
deterministic_sampler([0], [10], drop_n_first_frames=-1)
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
DeterministicEpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
deterministic_sampler([0], [10], drop_n_last_frames=-1)
with pytest.raises(ValueError, match="No valid frames remain"):
DeterministicEpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
deterministic_sampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_deterministic_sampler_partial_episode_drop_warns(caplog):
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = DeterministicEpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False)
sampler = deterministic_sampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False)
assert list(sampler) == [2, 3, 4, 5]
assert "Episode 0" in caplog.text