feat(datasets): add DeterministicEpisodeAwareSampler with O(1) memory and sample-exact resume

Add a sampler that never materializes frame indices: it stores only
per-episode boundaries (numpy, a few bytes per episode) and maps logical
positions to frame indices on the fly with searchsorted. Shuffling uses a
seeded Feistel permutation over [0, num_frames) (cycle-walking to the
exact domain), so the data order is a pure function of (seed, epoch):

- no RNG state to synchronize across distributed ranks,
- constant memory and zero epoch-boundary cost at any dataset size,
- O(1) seek to any position, enabling sample-exact resume.

Opt in with --deterministic_sampler=true. On resume, lerobot-train maps
the checkpointed step back to (epoch, start_index) via
compute_sampler_state and continues at the exact sample where the run
left off (up to accelerate's even_batches padding at epoch boundaries).
The shuffle is pseudo-random rather than a true uniform permutation, the
standard trade-off in large-scale training loaders.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 10:33:52 +02:00
parent 72e093dbff
commit 6fa495c6b0
5 changed files with 332 additions and 3 deletions
+12
View File
@@ -99,6 +99,12 @@ class TrainPipelineConfig(HubMixin):
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
# Use a deterministic O(1)-memory sampler (seeded Feistel permutation) instead of materialized
# index shuffling. The data order becomes a pure function of (seed, epoch), which makes
# distributed sharding immune to RNG desync, keeps sampler memory constant in dataset size,
# and enables sample-exact resume of interrupted runs. The shuffle is pseudo-random rather
# than a true uniform permutation. Not compatible with dataset.streaming.
deterministic_sampler: bool = False
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
@@ -133,6 +139,12 @@ class TrainPipelineConfig(HubMixin):
return self.policy # type: ignore[return-value]
def validate(self) -> None:
if self.deterministic_sampler and self.dataset.streaming:
raise ValueError(
"deterministic_sampler requires a map-style dataset and is not compatible with "
"dataset.streaming=true."
)
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
reward_model_path = parser.get_path_arg("reward_model")
+3 -1
View File
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler
from .sampler import DeterministicEpisodeAwareSampler, EpisodeAwareSampler, compute_sampler_state
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -64,6 +64,7 @@ __all__ = [
"DEFAULT_EPISODES_PATH",
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"DeterministicEpisodeAwareSampler",
"EpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
@@ -82,6 +83,7 @@ __all__ = [
"aggregate_stats",
"convert_image_to_video_dataset",
"create_initial_features",
"compute_sampler_state",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
+181
View File
@@ -14,12 +14,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections.abc import Iterator
import numpy as np
import torch
logger = logging.getLogger(__name__)
_MASK_64 = (1 << 64) - 1
_FEISTEL_ROUNDS = 4
def _mix64(x: int) -> int:
"""SplitMix64 finalizer: a high-quality, cheap 64-bit integer hash."""
x = (x + 0x9E3779B97F4A7C15) & _MASK_64
x ^= x >> 30
x = (x * 0xBF58476D1CE4E5B9) & _MASK_64
x ^= x >> 27
x = (x * 0x94D049BB133111EB) & _MASK_64
x ^= x >> 31
return x
class EpisodeAwareSampler:
def __init__(
@@ -90,3 +106,168 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class DeterministicEpisodeAwareSampler:
"""Episode-aware sampler with O(num_episodes) memory and a deterministic, seekable shuffle.
Unlike `EpisodeAwareSampler`, frame indices are never materialized: only per-episode
boundaries are stored (a few numpy int64 per episode) and the mapping from a logical
position to a frame index is computed on the fly via `searchsorted`. Shuffling applies a
seeded Feistel permutation over `[0, num_frames)` (cycle-walking to the exact domain size)
instead of `torch.randperm`, so the data order is a pure function of `(seed, epoch)`:
- every distributed rank derives the identical order with no RNG state to synchronize,
- any position can be sought in O(1), enabling sample-exact resume of interrupted runs
(see `state_dict` / `load_state_dict`),
- memory and epoch-boundary cost do not grow with the number of frames.
The trade-off is that the shuffle is pseudo-random rather than a true uniform permutation,
which is the standard compromise in large-scale training loaders.
Each completed `__iter__` advances the internal epoch, so consecutive dataloader passes
yield different permutations without requiring `set_epoch` calls. During an epoch resumed
via `load_state_dict`, `__len__` still reports the full epoch length; the first pass simply
yields fewer samples.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = True,
seed: int = 0,
):
"""
Args:
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle with the seeded Feistel permutation.
seed: Seed the permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
if from_indices.shape != to_indices.shape:
raise ValueError(
f"dataset_from_indices and dataset_to_indices must have the same length, "
f"got {len(from_indices)} and {len(to_indices)}"
)
used = np.ones(len(from_indices), dtype=bool)
if episode_indices_to_use is not None:
used = np.zeros(len(from_indices), dtype=bool)
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
starts = from_indices + drop_n_first_frames
lengths = to_indices - drop_n_last_frames - starts
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
to_indices[episode_idx] - from_indices[episode_idx],
drop_n_first_frames,
drop_n_last_frames,
)
used &= lengths > 0
if not used.any():
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self._starts = starts[used]
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.shuffle = shuffle
self.seed = seed
self._epoch = 0
self._start_index = 0
# Feistel cipher domain: the smallest even-bit-width power of two >= num_frames,
# so both halves have equal width and cycle-walking converges in <4 expected steps.
bits = max((self._num_frames - 1).bit_length(), 2)
self._half_bits = (bits + 1) // 2
self._half_mask = (1 << self._half_bits) - 1
def set_epoch(self, epoch: int) -> None:
"""Set the epoch the next `__iter__` will use (DistributedSampler convention)."""
self._epoch = epoch
def state_dict(self) -> dict:
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _round_keys(self, epoch: int) -> list[int]:
state = _mix64(_mix64(self.seed) ^ _mix64(epoch))
keys = []
for _ in range(_FEISTEL_ROUNDS):
state = _mix64(state)
keys.append(state)
return keys
def _permute(self, index: int, keys: list[int]) -> int:
# Feistel network with cycle-walking: a bijection on [0, num_frames).
half_bits, half_mask = self._half_bits, self._half_mask
while True:
left, right = index >> half_bits, index & half_mask
for key in keys:
left, right = right, left ^ (_mix64(right ^ key) & half_mask)
index = (left << half_bits) | right
if index < self._num_frames:
return index
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
# Capture and advance state eagerly so epoch bookkeeping is not deferred until the
# returned generator is first consumed.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_epoch(epoch, start)
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
keys = self._round_keys(epoch) if self.shuffle else None
for k in range(start, self._num_frames):
yield self._frame_index(self._permute(k, keys) if self.shuffle else k)
def __len__(self) -> int:
return self._num_frames
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map a global optimization step to a `DeterministicEpisodeAwareSampler` state for resume.
Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and
keeps every `num_processes`-th batch, so one optimization step consumes
`batch_size * num_processes` consecutive sampler positions, and (with `even_batches` padding)
each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches per epoch.
`batches_into_epoch * num_processes <= ceil(num_frames / batch_size) - 1` always holds, so the
start index stays strictly below `num_frames`; the `min` is purely defensive. Resume is
sample-exact up to the `even_batches` padding accelerate appends at epoch boundaries (at most
`num_processes - 1` duplicated batches per epoch, the same duplication non-resumed runs get).
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
return {"epoch": epoch, "start_index": start_index}
+30 -2
View File
@@ -43,7 +43,12 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.datasets import (
DeterministicEpisodeAwareSampler,
EpisodeAwareSampler,
compute_sampler_state,
make_dataset,
)
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -387,7 +392,30 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
if cfg.deterministic_sampler:
# Data order is a pure function of (seed, epoch): nothing to synchronize across ranks,
# O(1) memory in dataset size, and a resumed run continues at the exact sample where the
# checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).
shuffle = False
sampler = DeterministicEpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
sampler_state = compute_sampler_state(
step, len(sampler), cfg.batch_size, accelerator.num_processes
)
sampler.load_state_dict(sampler_state)
if is_main_process:
logging.info(
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
elif hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
+106
View File
@@ -161,3 +161,109 @@ def test_partial_episode_drop_warns(caplog):
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
# --- DeterministicEpisodeAwareSampler ---
from lerobot.datasets.sampler import ( # noqa: E402
DeterministicEpisodeAwareSampler,
compute_sampler_state,
)
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
def test_deterministic_sampler_unshuffled_matches_episode_aware():
for kwargs in (
{},
{"drop_n_first_frames": 1},
{"drop_n_last_frames": 1},
{"episode_indices_to_use": [0, 2]},
):
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
sampler = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
assert list(sampler) == list(reference), kwargs
assert len(sampler) == len(reference), kwargs
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
for seed in (0, 1, 1234):
sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
assert sorted(sampler) == list(range(num_frames))
def test_deterministic_sampler_epochs_reproduce_and_differ():
sampler_a = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42)
sampler_b = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
assert epoch_1 != epoch_0
assert sorted(epoch_1) == sorted(epoch_0)
sampler_a.set_epoch(0)
assert list(sampler_a) == epoch_0
assert list(DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
def test_deterministic_sampler_resume_mid_epoch():
reference = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
epoch_0 = list(reference)
epoch_1 = list(reference)
for start in (0, 1, 4, len(epoch_0)):
resumed = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
assert list(resumed) == epoch_1
def test_deterministic_sampler_constant_memory():
# A trillion-frame dataset must instantiate instantly and seek anywhere in O(1):
# only per-episode boundaries are stored, never per-frame indices.
num_frames = 10**12
sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3})
tail = list(sampler)
assert len(tail) == 3
assert all(0 <= idx < num_frames for idx in tail)
def test_deterministic_sampler_validation_matches_episode_aware():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
DeterministicEpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
DeterministicEpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
with pytest.raises(ValueError, match="No valid frames remain"):
DeterministicEpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_deterministic_sampler_partial_episode_drop_warns(caplog):
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = DeterministicEpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False)
assert list(sampler) == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
def test_compute_sampler_state():
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 0,
"start_index": 0,
}
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 40,
}
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
"epoch": 2,
"start_index": 40,
}
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 100,
}