Compare commits

...

9 Commits

Author SHA1 Message Date
Claude 7a62235bac fix(datasets): guard Feistel cycle-walking loop against non-convergence
Replace the unbounded while True in EpisodeAwareSampler._permute with a
bounded for loop capped at _MAX_CYCLE_WALK_STEPS (100) and raise
RuntimeError if the cycle-walk fails to land in [0, num_frames). The
loop is expected to converge in <4 steps on the chosen power-of-two
domain, so the bound is a safety net that should never trip in practice
but prevents a pathological infinite loop.

https://claude.ai/code/session_01HQ15tFrBsHYScjGWosEv22
2026-06-11 13:20:31 +00:00
Pepijn 81f0ca9ce4 test(sampler): drain resumed trillion-frame sampler via iter() to avoid list() prealloc
list(sampler) calls PyObject_LengthHint -> __len__ (the full 10**12 epoch length) and
preallocates that many slots before iterating, OOMing even though the resumed epoch only
yields 3 frames. Collect through the iterator (no length hint) so the test exercises the
real O(1) seek/drain instead of CPython's list growth heuristic.
2026-06-11 10:39:13 +00:00
Pepijn 29ca0f53d9 feat(datasets): default EpisodeAwareSampler to deterministic mode and trim comments
deterministic=True is now the class default as well as the training
default; the legacy RNG path requires an explicit deterministic=False
(the train script's non-deterministic branch passes it). Docstrings and
inline comments slimmed down across the changed files.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 11:54:22 +02:00
Pepijn b2d5d4ccfc feat(train): enable deterministic_sampler by default
Deterministic data order (sample-exact resume, no cross-rank RNG sync,
O(1) sampler memory) is now the default for map-style training; set
deterministic_sampler=false to restore the legacy RNG-based shuffle.
Streaming datasets ignore the flag (the sampler path only applies to
map-style datasets), replacing the previous hard validation error so
streaming configs keep working with the new default.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 11:45:36 +02:00
Pepijn 32b0d7d1ef refactor(datasets): fold deterministic mode into EpisodeAwareSampler
Instead of a parallel DeterministicEpisodeAwareSampler class, extend the
existing EpisodeAwareSampler with a deterministic=True mode (seeded
Feistel permutation, epoch auto-advance, state_dict/load_state_dict).

The default mode is behavior-identical: same torch.randperm consumption
and the same generator contract accelerate synchronizes; the O(N) Python
index list is replaced by O(num_episodes) boundary arrays in both modes,
with `indices` kept as a back-compat property. Passing a generator
together with deterministic=True is rejected, and the state/seek methods
raise outside deterministic mode.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 11:37:44 +02:00
Pepijn 7416b714c0 Merge remote-tracking branch 'origin/main' into feat/deterministic-sampler 2026-06-11 11:33:44 +02:00
Pepijn 6fa495c6b0 feat(datasets): add DeterministicEpisodeAwareSampler with O(1) memory and sample-exact resume
Add a sampler that never materializes frame indices: it stores only
per-episode boundaries (numpy, a few bytes per episode) and maps logical
positions to frame indices on the fly with searchsorted. Shuffling uses a
seeded Feistel permutation over [0, num_frames) (cycle-walking to the
exact domain), so the data order is a pure function of (seed, epoch):

- no RNG state to synchronize across distributed ranks,
- constant memory and zero epoch-boundary cost at any dataset size,
- O(1) seek to any position, enabling sample-exact resume.

Opt in with --deterministic_sampler=true. On resume, lerobot-train maps
the checkpointed step back to (epoch, start_index) via
compute_sampler_state and continues at the exact sample where the run
left off (up to accelerate's even_batches padding at epoch boundaries).
The shuffle is pseudo-random rather than a true uniform permutation, the
standard trade-off in large-scale training loaders.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 10:33:52 +02:00
Pepijn 72e093dbff fix(train): seed sampler generator and gate dataset download per node
- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so
  accelerator.prepare registers it as the synchronized RNG and the
  shuffle order is reproducible.
- Gate the initial make_dataset call on is_local_main_process instead of
  is_main_process: the global main process only exists on node 0, so on
  every other node all local ranks were downloading the dataset and
  building the Arrow cache concurrently.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 10:01:43 +02:00
Pepijn 3d262a6c9e fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync
In distributed training, accelerate can only synchronize the shuffle
permutation across ranks when the sampler exposes a generator attribute.
EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch
shards relied on every rank's global CPU RNG staying in lockstep forever;
any rank-asymmetric RNG consumption (e.g. eval rollouts on the main
process only) silently desynced the permutations and ranks trained on
overlapping/missing samples.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 10:01:42 +02:00
5 changed files with 334 additions and 46 deletions
+4
View File
@@ -99,6 +99,10 @@ class TrainPipelineConfig(HubMixin):
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
# Deterministic data order (pure function of seed and epoch): immune to cross-rank RNG
# desync and enables sample-exact resume. Set to false for the legacy RNG-based shuffle.
# Ignored when dataset.streaming is enabled.
deterministic_sampler: bool = True
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
+2 -1
View File
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler
from .sampler import EpisodeAwareSampler, compute_sampler_state
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -82,6 +82,7 @@ __all__ = [
"aggregate_stats",
"convert_image_to_video_dataset",
"create_initial_features",
"compute_sampler_state",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
+169 -36
View File
@@ -14,14 +14,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections.abc import Iterator
import numpy as np
import torch
logger = logging.getLogger(__name__)
_MASK_64 = (1 << 64) - 1
_FEISTEL_ROUNDS = 4
# Cycle-walking converges in <4 expected steps on the chosen domain; this bound is a generous
# safety net that should never be hit in practice.
_MAX_CYCLE_WALK_STEPS = 100
def _mix64(x: int) -> int:
"""SplitMix64 finalizer (64-bit integer hash)."""
x = (x + 0x9E3779B97F4A7C15) & _MASK_64
x ^= x >> 30
x = (x * 0xBF58476D1CE4E5B9) & _MASK_64
x ^= x >> 27
x = (x * 0x94D049BB133111EB) & _MASK_64
x ^= x >> 31
return x
class EpisodeAwareSampler:
"""Sampler over episode frames with O(num_episodes) memory.
Only episode boundaries are stored; logical positions map to frame indices on the fly, so
memory does not grow with the number of frames.
By default (`deterministic=True`) shuffling uses a seeded Feistel permutation over
`[0, num_frames)`: the data order is a pure function of `(seed, epoch)`, needs no RNG
synchronization across distributed ranks, and any position can be sought in O(1), enabling
sample-exact resume via `state_dict` / `load_state_dict`. Each completed `__iter__`
advances the epoch. The shuffle is pseudo-random rather than truly uniform — the standard
large-scale trade-off. During a resumed epoch, `__len__` still reports the full length.
With `deterministic=False`, shuffling falls back to `torch.randperm` driven by `generator`
(accelerate synchronizes the generator across ranks when preparing the dataloader).
"""
def __init__(
self,
dataset_from_indices: list[int],
@@ -31,62 +66,160 @@ class EpisodeAwareSampler:
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
deterministic: bool = True,
seed: int = 0,
):
"""Sampler that optionally incorporates episode boundary information.
"""
Args:
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
dataset_from_indices: Start index of each episode in the dataset.
dataset_to_indices: End index of each episode in the dataset.
episode_indices_to_use: Episode indices to use; None means all.
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
generator: Generator for non-deterministic shuffling (global torch RNG when None).
deterministic: Use the seeded Feistel permutation instead of `torch.randperm`.
seed: Seed the deterministic permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
if deterministic and generator is not None:
raise ValueError("generator is unused in deterministic mode; pass seed instead.")
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
if from_indices.shape != to_indices.shape:
raise ValueError(
f"dataset_from_indices and dataset_to_indices must have the same length, "
f"got {len(from_indices)} and {len(to_indices)}"
)
if not indices:
used = np.ones(len(from_indices), dtype=bool)
if episode_indices_to_use is not None:
used = np.zeros(len(from_indices), dtype=bool)
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
starts = from_indices + drop_n_first_frames
lengths = to_indices - drop_n_last_frames - starts
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
to_indices[episode_idx] - from_indices[episode_idx],
drop_n_first_frames,
drop_n_last_frames,
)
used &= lengths > 0
if not used.any():
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self.indices = indices
self._starts = starts[used]
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.shuffle = shuffle
self.generator = generator
self.deterministic = deterministic
self.seed = seed
self._epoch = 0
self._start_index = 0
# Smallest even-bit-width power-of-two domain >= num_frames: equal Feistel halves,
# cycle-walking converges in <4 expected steps.
bits = max((self._num_frames - 1).bit_length(), 2)
self._half_bits = (bits + 1) // 2
self._half_mask = (1 << self._half_bits) - 1
@property
def indices(self) -> list[int]:
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
self._require_deterministic("set_epoch")
self._epoch = epoch
def state_dict(self) -> dict:
self._require_deterministic("state_dict")
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._require_deterministic("load_state_dict")
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _require_deterministic(self, method: str) -> None:
if not self.deterministic:
raise RuntimeError(f"{method} requires deterministic=True: an RNG order cannot be sought.")
def _round_keys(self, epoch: int) -> list[int]:
state = _mix64(_mix64(self.seed) ^ _mix64(epoch))
keys = []
for _ in range(_FEISTEL_ROUNDS):
state = _mix64(state)
keys.append(state)
return keys
def _permute(self, index: int, keys: list[int]) -> int:
# Feistel network with cycle-walking: a bijection on [0, num_frames).
half_bits, half_mask = self._half_bits, self._half_mask
for _ in range(_MAX_CYCLE_WALK_STEPS):
left, right = index >> half_bits, index & half_mask
for key in keys:
left, right = right, left ^ (_mix64(right ^ key) & half_mask)
index = (left << half_bits) | right
if index < self._num_frames:
return index
raise RuntimeError(
f"Feistel cycle-walking did not converge within {_MAX_CYCLE_WALK_STEPS} steps; "
"this should never happen for a valid domain."
)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
if not self.deterministic:
return self._iter_default()
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_deterministic_epoch(epoch, start)
def _iter_default(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
for i in torch.randperm(self._num_frames, generator=self.generator):
yield self._frame_index(int(i))
else:
for i in self.indices:
yield i
for k in range(self._num_frames):
yield self._frame_index(k)
def _iter_deterministic_epoch(self, epoch: int, start: int) -> Iterator[int]:
keys = self._round_keys(epoch) if self.shuffle else None
for k in range(start, self._num_frames):
yield self._frame_index(self._permute(k, keys) if self.shuffle else k)
def __len__(self) -> int:
return len(self.indices)
return self._num_frames
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
return {"epoch": epoch, "start_index": start_index}
+25 -5
View File
@@ -43,7 +43,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -387,11 +387,30 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
if cfg.deterministic_sampler and not cfg.dataset.streaming:
# Deterministic data order: no cross-rank RNG sync needed, sample-exact resume.
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
sampler_state = compute_sampler_state(
step, len(sampler), cfg.batch_size, accelerator.num_processes
)
sampler.load_state_dict(sampler_state)
if is_main_process:
logging.info(
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
elif hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# Legacy RNG shuffle: a dedicated generator lets accelerate synchronize it across ranks.
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
@@ -401,6 +420,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
deterministic=False,
generator=sampler_generator,
)
else:
+134 -4
View File
@@ -118,12 +118,18 @@ def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_a = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
sampler_b = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_c = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
@@ -133,7 +139,7 @@ def test_shuffle_with_generator_is_deterministic():
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
sampler = EpisodeAwareSampler([0], [6], shuffle=True, deterministic=False)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
@@ -161,3 +167,127 @@ def test_partial_episode_drop_warns(caplog):
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
# --- deterministic mode (seeded Feistel permutation) ---
from functools import partial # noqa: E402
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
deterministic_sampler = partial(EpisodeAwareSampler, deterministic=True)
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
def test_deterministic_mode_unshuffled_matches_default_mode():
for kwargs in (
{},
{"drop_n_first_frames": 1},
{"drop_n_last_frames": 1},
{"episode_indices_to_use": [0, 2]},
):
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
sampler = deterministic_sampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
assert list(sampler) == list(reference), kwargs
assert len(sampler) == len(reference), kwargs
def test_deterministic_mode_rejects_generator():
with pytest.raises(ValueError, match="generator is unused in deterministic mode"):
deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, generator=torch.Generator())
def test_state_methods_require_deterministic_mode():
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, deterministic=False)
with pytest.raises(RuntimeError, match="deterministic=True"):
sampler.set_epoch(1)
with pytest.raises(RuntimeError, match="deterministic=True"):
sampler.state_dict()
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
for seed in (0, 1, 1234):
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=seed)
assert sorted(sampler) == list(range(num_frames))
def test_deterministic_sampler_epochs_reproduce_and_differ():
sampler_a = deterministic_sampler([0], [100], shuffle=True, seed=42)
sampler_b = deterministic_sampler([0], [100], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
assert epoch_1 != epoch_0
assert sorted(epoch_1) == sorted(epoch_0)
sampler_a.set_epoch(0)
assert list(sampler_a) == epoch_0
assert list(deterministic_sampler([0], [100], shuffle=True, seed=7)) != epoch_0
def test_deterministic_sampler_resume_mid_epoch():
reference = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
epoch_0 = list(reference)
epoch_1 = list(reference)
for start in (0, 1, 4, len(epoch_0)):
resumed = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
assert list(resumed) == epoch_1
def test_deterministic_sampler_constant_memory():
# A trillion-frame dataset must instantiate instantly and seek anywhere in O(1):
# only per-episode boundaries are stored, never per-frame indices.
num_frames = 10**12
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3})
# Collect via the iterator: list(sampler) would call PyObject_LengthHint -> sampler.__len__
# (the full epoch length, here 10**12) and pre-allocate that many slots before iterating. The
# iterator itself exposes no length hint, so this stays O(1) like the resumed epoch it drains.
tail = list(iter(sampler))
assert len(tail) == 3
assert all(0 <= idx < num_frames for idx in tail)
def test_deterministic_sampler_validation_matches_episode_aware():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
deterministic_sampler([0], [10], drop_n_first_frames=-1)
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
deterministic_sampler([0], [10], drop_n_last_frames=-1)
with pytest.raises(ValueError, match="No valid frames remain"):
deterministic_sampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_deterministic_sampler_partial_episode_drop_warns(caplog):
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = deterministic_sampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False)
assert list(sampler) == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
def test_compute_sampler_state():
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 0,
"start_index": 0,
}
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 40,
}
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
"epoch": 2,
"start_index": 40,
}
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 100,
}