Compare commits

...

7 Commits

Author SHA1 Message Date
pepijn 69e8ab38bd chore(datasets): trim sampler comment and drop duplicate tests
Remove the verbose dataloader-guard comment and the two EpisodeAwareSampler tests
that duplicated existing validation/warning coverage (no coverage loss).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 17:31:50 +00:00
pepijn 926fb9c31e fix(train): download dataset once on the global main process
Gate the training dataset download on the global is_main_process (download once to the
shared dataset root, barrier, then every other rank reads the already-populated copy)
instead of per-node is_local_main_process. LeRobotDataset skips its snapshot_download
when try_load() succeeds, so no rank re-downloads. Assumes the dataset root / HF cache is
on storage shared across nodes.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 17:21:35 +00:00
pepijn 3f6909fb63 fix(datasets): address sampler review (batch_size resume guard + docs)
- Record batch_size in training_step.json alongside num_processes and feed
  the checkpoint's value into compute_sampler_state on resume; warn when it
  differs (per-rank sample-exactness needs the same batch size).
- Document the set_epoch vs __iter__ auto-advance coupling on EpisodeAwareSampler
  (callers should rely on exactly one mechanism per run).
- Note the broadened (reproducibility-breaking) sampler guard and the no-generator
  distributed sharding correctness in lerobot_train.py.
- Add load_training_batch_size + parallel tests.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 16:48:37 +00:00
pepijn 3d0e8681f3 refactor(datasets): make EpisodeAwareSampler always deterministic
With Feistel gone, deterministic and legacy modes were both just torch.randperm and the
deterministic path strictly dominated (reproducible across ranks via the (seed, epoch) seed,
no accelerate generator sync, resumable). Collapse to a single path and drop the redundant
flag:

- remove the `deterministic` and `generator` constructor args, `_iter_default`, and
  `_require_deterministic`; `set_epoch` / `state_dict` / `load_state_dict` are now unconditional
- remove the `deterministic_sampler` train config field and the legacy generator branch in
  lerobot_train.py (non-streaming map datasets always use the sampler)
- drop the now-obsolete generator/legacy tests

Note: removes the `generator` kwarg from EpisodeAwareSampler (back-compat break vs main); the
order is now a pure function of (seed, epoch), so no cross-rank RNG sync is needed.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 15:31:15 +00:00
pepijn 1e2057e3be refactor(datasets): use seeded torch.randperm instead of Feistel in EpisodeAwareSampler
Drop the Feistel permutation (and its SplitMix64 hash / cycle-walking) in favor of a
torch.randperm seeded from (seed, epoch). The deterministic mode keeps its key properties
- data order is a pure function of (seed, epoch), so it reproduces on every rank with no
  global-RNG synchronization, and
- state_dict / load_state_dict still resume sample-exactly, now by regenerating the epoch's
  permutation and slicing from the saved offset.

Construction stays O(num_episodes) (only episode boundaries are stored, never a per-frame
index list). The trade-off vs Feistel: the per-epoch shuffle is again O(num_frames) memory
(the randperm tensor) and no longer O(1)-seekable, in exchange for ~30 fewer LOC and a truly
uniform shuffle. Tests updated: the trillion-frame O(1) test is replaced with a
boundary-storage check and a scale resume-exactness test.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 15:15:14 +00:00
pepijn c094f40868 style: apply ruff-format to lerobot_train.py
Collapse the compute_sampler_state(...) call onto one line so the
ruff-format pre-commit hook passes (fixes the failing CI check).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 14:58:15 +00:00
pepijn 1aa937aad2 fix(datasets): make deterministic-sampler resume robust to world-size changes
compute_sampler_state mapped a checkpointed step back to (epoch, start_index)
using the *current* num_processes, but the number of sampler positions a step
consumes scales with the world size that produced it. Resuming on a different
GPU count therefore landed on the wrong epoch/offset, silently re-seeing or
skipping data.

Record num_processes in training_step.json at checkpoint time and feed the
checkpoint's value into compute_sampler_state on resume, so the data order
resumes at the right position regardless of the new world size. Warn when the
world size changed (the global offset is correct, but per-rank sample-exactness
needs the same topology). Old checkpoints without the field fall back to the
current world size.

Also document compute_sampler_state's assumptions explicitly: num_processes /
batch_size must match the checkpointing run, and accelerate's even_batches=True
padding is mirrored by the ceil(... / num_processes) term.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 14:37:26 +00:00
6 changed files with 165 additions and 210 deletions
+37 -4
View File
@@ -49,8 +49,19 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def save_training_step(
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
) -> None:
state: dict = {"step": step}
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
# produced `step`, since both scale how many sampler positions a step consumes (see
# compute_sampler_state).
if num_processes is not None:
state["num_processes"] = num_processes
if batch_size is not None:
state["batch_size"] = batch_size
write_json(state, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
@@ -58,6 +69,16 @@ def load_training_step(save_dir: Path) -> int:
return training_step["step"]
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
@@ -75,6 +96,8 @@ def save_checkpoint(
scheduler: LRScheduler | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -100,6 +123,10 @@ def save_checkpoint(
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
num_processes (int | None, optional): Distributed world size to record for sample-exact
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
@@ -112,7 +139,9 @@ def save_checkpoint(
preprocessor.save_pretrained(pretrained_dir)
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
)
def save_training_state(
@@ -120,6 +149,8 @@ def save_training_state(
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
@@ -131,10 +162,12 @@ def save_training_state(
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
-4
View File
@@ -99,10 +99,6 @@ class TrainPipelineConfig(HubMixin):
batch_size: int = 8
prefetch_factor: int = 4
persistent_workers: bool = True
# Deterministic data order (pure function of seed and epoch): immune to cross-rank RNG
# desync and enables sample-exact resume. Set to false for the legacy RNG-based shuffle.
# Ignored when dataset.streaming is enabled.
deterministic_sampler: bool = True
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
+35 -84
View File
@@ -22,39 +22,26 @@ import torch
logger = logging.getLogger(__name__)
_MASK_64 = (1 << 64) - 1
_FEISTEL_ROUNDS = 4
# Cycle-walking converges in <4 expected steps on the chosen domain; this bound is a generous
# safety net that should never be hit in practice.
_MAX_CYCLE_WALK_STEPS = 100
def _mix64(x: int) -> int:
"""SplitMix64 finalizer (64-bit integer hash)."""
x = (x + 0x9E3779B97F4A7C15) & _MASK_64
x ^= x >> 30
x = (x * 0xBF58476D1CE4E5B9) & _MASK_64
x ^= x >> 27
x = (x * 0x94D049BB133111EB) & _MASK_64
x ^= x >> 31
return x
class EpisodeAwareSampler:
"""Sampler over episode frames with O(num_episodes) memory.
"""Sampler over episode frames that stores only per-episode boundaries.
Only episode boundaries are stored; logical positions map to frame indices on the fly, so
memory does not grow with the number of frames.
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
instead of materializing a Python list of every frame index.
By default (`deterministic=True`) shuffling uses a seeded Feistel permutation over
`[0, num_frames)`: the data order is a pure function of `(seed, epoch)`, needs no RNG
synchronization across distributed ranks, and any position can be sought in O(1), enabling
sample-exact resume via `state_dict` / `load_state_dict`. Each completed `__iter__`
advances the epoch. The shuffle is pseudo-random rather than truly uniform — the standard
large-scale trade-off. During a resumed epoch, `__len__` still reports the full length.
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
resumed epoch, `__len__` still reports the full length.
With `deterministic=False`, shuffling falls back to `torch.randperm` driven by `generator`
(accelerate synchronizes the generator across ranks when preparing the dataloader).
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
starts (e.g. on resume or in tests).
"""
def __init__(
@@ -65,8 +52,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
deterministic: bool = True,
seed: int = 0,
):
"""
@@ -77,16 +62,12 @@ class EpisodeAwareSampler:
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator for non-deterministic shuffling (global torch RNG when None).
deterministic: Use the seeded Feistel permutation instead of `torch.randperm`.
seed: Seed the deterministic permutation is derived from (together with the epoch).
seed: Seed the permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
if deterministic and generator is not None:
raise ValueError("generator is unused in deterministic mode; pass seed instead.")
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
@@ -123,62 +104,30 @@ class EpisodeAwareSampler:
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.shuffle = shuffle
self.generator = generator
self.deterministic = deterministic
self.seed = seed
self._epoch = 0
self._start_index = 0
# Smallest even-bit-width power-of-two domain >= num_frames: equal Feistel halves,
# cycle-walking converges in <4 expected steps.
bits = max((self._num_frames - 1).bit_length(), 2)
self._half_bits = (bits + 1) // 2
self._half_mask = (1 << self._half_bits) - 1
@property
def indices(self) -> list[int]:
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
self._require_deterministic("set_epoch")
self._epoch = epoch
def state_dict(self) -> dict:
self._require_deterministic("state_dict")
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._require_deterministic("load_state_dict")
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _require_deterministic(self, method: str) -> None:
if not self.deterministic:
raise RuntimeError(f"{method} requires deterministic=True: an RNG order cannot be sought.")
def _round_keys(self, epoch: int) -> list[int]:
state = _mix64(_mix64(self.seed) ^ _mix64(epoch))
keys = []
for _ in range(_FEISTEL_ROUNDS):
state = _mix64(state)
keys.append(state)
return keys
def _permute(self, index: int, keys: list[int]) -> int:
# Feistel network with cycle-walking: a bijection on [0, num_frames).
half_bits, half_mask = self._half_bits, self._half_mask
for _ in range(_MAX_CYCLE_WALK_STEPS):
left, right = index >> half_bits, index & half_mask
for key in keys:
left, right = right, left ^ (_mix64(right ^ key) & half_mask)
index = (left << half_bits) | right
if index < self._num_frames:
return index
raise RuntimeError(
f"Feistel cycle-walking did not converge within {_MAX_CYCLE_WALK_STEPS} steps; "
"this should never happen for a valid domain."
)
def _epoch_generator(self, epoch: int) -> torch.Generator:
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
# and reproduces identically on every rank without touching the global RNG.
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
return torch.Generator().manual_seed(epoch_seed)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
@@ -186,27 +135,21 @@ class EpisodeAwareSampler:
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
if not self.deterministic:
return self._iter_default()
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_deterministic_epoch(epoch, start)
return self._iter_epoch(epoch, start)
def _iter_default(self) -> Iterator[int]:
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(self._num_frames, generator=self.generator):
yield self._frame_index(int(i))
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
for k in range(start, self._num_frames):
yield self._frame_index(int(order[k]))
else:
for k in range(self._num_frames):
for k in range(start, self._num_frames):
yield self._frame_index(k)
def _iter_deterministic_epoch(self, epoch: int, start: int) -> Iterator[int]:
keys = self._round_keys(epoch) if self.shuffle else None
for k in range(start, self._num_frames):
yield self._frame_index(self._permute(k, keys) if self.shuffle else k)
def __len__(self) -> int:
return self._num_frames
@@ -218,6 +161,14 @@ def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_proce
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
Assumptions (resume is only sample-exact when they hold):
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
many positions a step consumes, so the epoch/offset are wrong if either changed. The
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
the boundary is off.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
+36 -28
View File
@@ -36,6 +36,8 @@ from tqdm import tqdm
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
save_checkpoint,
update_last_checkpoint,
@@ -232,18 +234,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: the global main process downloads once to the shared
# dataset root, then a barrier lets every other rank read the already-populated copy.
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
# Other ranks read from the shared copy populated by the main process.
if not is_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -387,8 +388,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if cfg.deterministic_sampler and not cfg.dataset.streaming:
# Deterministic data order: no cross-rank RNG sync needed, sample-exact resume.
if not cfg.dataset.streaming:
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
# The order is a pure function of (seed, epoch), so every rank independently produces the
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
@@ -399,30 +403,32 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
sampler_state = compute_sampler_state(
step, len(sampler), cfg.batch_size, accelerator.num_processes
)
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
# use the values recorded in the checkpoint (falling back to the current ones for older
# ckpts that did not store them).
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
ckpt_num_processes = saved_num_processes or accelerator.num_processes
ckpt_batch_size = saved_batch_size or cfg.batch_size
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
logging.warning(
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
f"written with num_processes={saved_num_processes}. The data order resumes at the "
"right epoch/offset, but per-rank sample-exactness requires the same world size."
)
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
logging.warning(
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
"but per-rank sample-exactness requires the same batch size."
)
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
sampler.load_state_dict(sampler_state)
if is_main_process:
logging.info(
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
elif hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# Legacy RNG shuffle: a dedicated generator lets accelerate synchronize it across ranks.
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
deterministic=False,
generator=sampler_generator,
)
else:
shuffle = True
sampler = None
@@ -541,6 +547,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
num_processes=accelerator.num_processes,
batch_size=cfg.batch_size,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
+33 -90
View File
@@ -114,34 +114,17 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
sampler_b = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
assert list(sampler_a) == list(sampler_b)
def test_shuffle_is_reproducible_across_instances():
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
# produce the same permutation without any generator synchronization.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler(
[0], [6], shuffle=True, deterministic=False, generator=torch.Generator().manual_seed(42)
)
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True, deterministic=False)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
assert list(sampler_c) == epoch_0
def test_negative_drop_first_frames_raises():
@@ -169,54 +152,23 @@ def test_partial_episode_drop_warns(caplog):
assert "Episode 0" in caplog.text
# --- deterministic mode (seeded Feistel permutation) ---
from functools import partial # noqa: E402
# --- seeded (seed, epoch) shuffling, resume, and state ---
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
deterministic_sampler = partial(EpisodeAwareSampler, deterministic=True)
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
def test_deterministic_mode_unshuffled_matches_default_mode():
for kwargs in (
{},
{"drop_n_first_frames": 1},
{"drop_n_last_frames": 1},
{"episode_indices_to_use": [0, 2]},
):
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
sampler = deterministic_sampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
assert list(sampler) == list(reference), kwargs
assert len(sampler) == len(reference), kwargs
def test_deterministic_mode_rejects_generator():
with pytest.raises(ValueError, match="generator is unused in deterministic mode"):
deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, generator=torch.Generator())
def test_state_methods_require_deterministic_mode():
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, deterministic=False)
with pytest.raises(RuntimeError, match="deterministic=True"):
sampler.set_epoch(1)
with pytest.raises(RuntimeError, match="deterministic=True"):
sampler.state_dict()
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
for seed in (0, 1, 1234):
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=seed)
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
assert sorted(sampler) == list(range(num_frames))
def test_deterministic_sampler_epochs_reproduce_and_differ():
sampler_a = deterministic_sampler([0], [100], shuffle=True, seed=42)
sampler_b = deterministic_sampler([0], [100], shuffle=True, seed=42)
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
@@ -224,50 +176,41 @@ def test_deterministic_sampler_epochs_reproduce_and_differ():
assert sorted(epoch_1) == sorted(epoch_0)
sampler_a.set_epoch(0)
assert list(sampler_a) == epoch_0
assert list(deterministic_sampler([0], [100], shuffle=True, seed=7)) != epoch_0
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
def test_deterministic_sampler_resume_mid_epoch():
reference = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
epoch_0 = list(reference)
epoch_1 = list(reference)
for start in (0, 1, 4, len(epoch_0)):
resumed = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
assert list(resumed) == epoch_1
def test_deterministic_sampler_constant_memory():
# A trillion-frame dataset must instantiate instantly and seek anywhere in O(1):
# only per-episode boundaries are stored, never per-frame indices.
num_frames = 10**12
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0)
def test_deterministic_sampler_construction_stores_only_boundaries():
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
# instantiates from just its boundaries without materializing a per-frame index list.
num_frames = 1_000_000
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3})
# Collect via the iterator: list(sampler) would call PyObject_LengthHint -> sampler.__len__
# (the full epoch length, here 10**12) and pre-allocate that many slots before iterating. The
# iterator itself exposes no length hint, so this stays O(1) like the resumed epoch it drains.
tail = list(iter(sampler))
assert len(tail) == 3
assert all(0 <= idx < num_frames for idx in tail)
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
def test_deterministic_sampler_validation_matches_episode_aware():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
deterministic_sampler([0], [10], drop_n_first_frames=-1)
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
deterministic_sampler([0], [10], drop_n_last_frames=-1)
with pytest.raises(ValueError, match="No valid frames remain"):
deterministic_sampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_deterministic_sampler_partial_episode_drop_warns(caplog):
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = deterministic_sampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False)
assert list(sampler) == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
def test_deterministic_sampler_resume_is_exact_at_scale():
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
# permutation and slicing from the saved offset reproduces the remaining order exactly.
num_frames = 100_000
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
epoch_0 = list(reference)
assert sorted(epoch_0) == list(range(num_frames))
start = num_frames - 5
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
def test_compute_sampler_state():
+24
View File
@@ -20,6 +20,8 @@ from unittest.mock import Mock, patch
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
load_training_step,
save_checkpoint,
@@ -63,6 +65,28 @@ def test_load_training_step(tmp_path):
assert loaded_step == step
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
assert load_training_num_processes(tmp_path) == 4
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the world size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_num_processes(tmp_path) is None
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
assert load_training_batch_size(tmp_path) == 32
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the batch size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_batch_size(tmp_path) is None
def test_update_last_checkpoint(tmp_path):
checkpoint = tmp_path / "0005"
checkpoint.mkdir()