Compare commits

...

16 Commits

Author SHA1 Message Date
pepijn 69e8ab38bd chore(datasets): trim sampler comment and drop duplicate tests
Remove the verbose dataloader-guard comment and the two EpisodeAwareSampler tests
that duplicated existing validation/warning coverage (no coverage loss).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 17:31:50 +00:00
pepijn 926fb9c31e fix(train): download dataset once on the global main process
Gate the training dataset download on the global is_main_process (download once to the
shared dataset root, barrier, then every other rank reads the already-populated copy)
instead of per-node is_local_main_process. LeRobotDataset skips its snapshot_download
when try_load() succeeds, so no rank re-downloads. Assumes the dataset root / HF cache is
on storage shared across nodes.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 17:21:35 +00:00
pepijn 3f6909fb63 fix(datasets): address sampler review (batch_size resume guard + docs)
- Record batch_size in training_step.json alongside num_processes and feed
  the checkpoint's value into compute_sampler_state on resume; warn when it
  differs (per-rank sample-exactness needs the same batch size).
- Document the set_epoch vs __iter__ auto-advance coupling on EpisodeAwareSampler
  (callers should rely on exactly one mechanism per run).
- Note the broadened (reproducibility-breaking) sampler guard and the no-generator
  distributed sharding correctness in lerobot_train.py.
- Add load_training_batch_size + parallel tests.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 16:48:37 +00:00
pepijn 3d0e8681f3 refactor(datasets): make EpisodeAwareSampler always deterministic
With Feistel gone, deterministic and legacy modes were both just torch.randperm and the
deterministic path strictly dominated (reproducible across ranks via the (seed, epoch) seed,
no accelerate generator sync, resumable). Collapse to a single path and drop the redundant
flag:

- remove the `deterministic` and `generator` constructor args, `_iter_default`, and
  `_require_deterministic`; `set_epoch` / `state_dict` / `load_state_dict` are now unconditional
- remove the `deterministic_sampler` train config field and the legacy generator branch in
  lerobot_train.py (non-streaming map datasets always use the sampler)
- drop the now-obsolete generator/legacy tests

Note: removes the `generator` kwarg from EpisodeAwareSampler (back-compat break vs main); the
order is now a pure function of (seed, epoch), so no cross-rank RNG sync is needed.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 15:31:15 +00:00
pepijn 1e2057e3be refactor(datasets): use seeded torch.randperm instead of Feistel in EpisodeAwareSampler
Drop the Feistel permutation (and its SplitMix64 hash / cycle-walking) in favor of a
torch.randperm seeded from (seed, epoch). The deterministic mode keeps its key properties
- data order is a pure function of (seed, epoch), so it reproduces on every rank with no
  global-RNG synchronization, and
- state_dict / load_state_dict still resume sample-exactly, now by regenerating the epoch's
  permutation and slicing from the saved offset.

Construction stays O(num_episodes) (only episode boundaries are stored, never a per-frame
index list). The trade-off vs Feistel: the per-epoch shuffle is again O(num_frames) memory
(the randperm tensor) and no longer O(1)-seekable, in exchange for ~30 fewer LOC and a truly
uniform shuffle. Tests updated: the trillion-frame O(1) test is replaced with a
boundary-storage check and a scale resume-exactness test.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 15:15:14 +00:00
pepijn c094f40868 style: apply ruff-format to lerobot_train.py
Collapse the compute_sampler_state(...) call onto one line so the
ruff-format pre-commit hook passes (fixes the failing CI check).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 14:58:15 +00:00
pepijn 1aa937aad2 fix(datasets): make deterministic-sampler resume robust to world-size changes
compute_sampler_state mapped a checkpointed step back to (epoch, start_index)
using the *current* num_processes, but the number of sampler positions a step
consumes scales with the world size that produced it. Resuming on a different
GPU count therefore landed on the wrong epoch/offset, silently re-seeing or
skipping data.

Record num_processes in training_step.json at checkpoint time and feed the
checkpoint's value into compute_sampler_state on resume, so the data order
resumes at the right position regardless of the new world size. Warn when the
world size changed (the global offset is correct, but per-rank sample-exactness
needs the same topology). Old checkpoints without the field fall back to the
current world size.

Also document compute_sampler_state's assumptions explicitly: num_processes /
batch_size must match the checkpointing run, and accelerate's even_batches=True
padding is mirrored by the ceil(... / num_processes) term.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 14:37:26 +00:00
Claude 7a62235bac fix(datasets): guard Feistel cycle-walking loop against non-convergence
Replace the unbounded while True in EpisodeAwareSampler._permute with a
bounded for loop capped at _MAX_CYCLE_WALK_STEPS (100) and raise
RuntimeError if the cycle-walk fails to land in [0, num_frames). The
loop is expected to converge in <4 steps on the chosen power-of-two
domain, so the bound is a safety net that should never trip in practice
but prevents a pathological infinite loop.

https://claude.ai/code/session_01HQ15tFrBsHYScjGWosEv22
2026-06-11 13:20:31 +00:00
Pepijn 81f0ca9ce4 test(sampler): drain resumed trillion-frame sampler via iter() to avoid list() prealloc
list(sampler) calls PyObject_LengthHint -> __len__ (the full 10**12 epoch length) and
preallocates that many slots before iterating, OOMing even though the resumed epoch only
yields 3 frames. Collect through the iterator (no length hint) so the test exercises the
real O(1) seek/drain instead of CPython's list growth heuristic.
2026-06-11 10:39:13 +00:00
Pepijn 29ca0f53d9 feat(datasets): default EpisodeAwareSampler to deterministic mode and trim comments
deterministic=True is now the class default as well as the training
default; the legacy RNG path requires an explicit deterministic=False
(the train script's non-deterministic branch passes it). Docstrings and
inline comments slimmed down across the changed files.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 11:54:22 +02:00
Pepijn b2d5d4ccfc feat(train): enable deterministic_sampler by default
Deterministic data order (sample-exact resume, no cross-rank RNG sync,
O(1) sampler memory) is now the default for map-style training; set
deterministic_sampler=false to restore the legacy RNG-based shuffle.
Streaming datasets ignore the flag (the sampler path only applies to
map-style datasets), replacing the previous hard validation error so
streaming configs keep working with the new default.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 11:45:36 +02:00
Pepijn 32b0d7d1ef refactor(datasets): fold deterministic mode into EpisodeAwareSampler
Instead of a parallel DeterministicEpisodeAwareSampler class, extend the
existing EpisodeAwareSampler with a deterministic=True mode (seeded
Feistel permutation, epoch auto-advance, state_dict/load_state_dict).

The default mode is behavior-identical: same torch.randperm consumption
and the same generator contract accelerate synchronizes; the O(N) Python
index list is replaced by O(num_episodes) boundary arrays in both modes,
with `indices` kept as a back-compat property. Passing a generator
together with deterministic=True is rejected, and the state/seek methods
raise outside deterministic mode.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 11:37:44 +02:00
Pepijn 7416b714c0 Merge remote-tracking branch 'origin/main' into feat/deterministic-sampler 2026-06-11 11:33:44 +02:00
Pepijn 6fa495c6b0 feat(datasets): add DeterministicEpisodeAwareSampler with O(1) memory and sample-exact resume
Add a sampler that never materializes frame indices: it stores only
per-episode boundaries (numpy, a few bytes per episode) and maps logical
positions to frame indices on the fly with searchsorted. Shuffling uses a
seeded Feistel permutation over [0, num_frames) (cycle-walking to the
exact domain), so the data order is a pure function of (seed, epoch):

- no RNG state to synchronize across distributed ranks,
- constant memory and zero epoch-boundary cost at any dataset size,
- O(1) seek to any position, enabling sample-exact resume.

Opt in with --deterministic_sampler=true. On resume, lerobot-train maps
the checkpointed step back to (epoch, start_index) via
compute_sampler_state and continues at the exact sample where the run
left off (up to accelerate's even_batches padding at epoch boundaries).
The shuffle is pseudo-random rather than a true uniform permutation, the
standard trade-off in large-scale training loaders.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 10:33:52 +02:00
Pepijn 72e093dbff fix(train): seed sampler generator and gate dataset download per node
- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so
  accelerator.prepare registers it as the synchronized RNG and the
  shuffle order is reproducible.
- Gate the initial make_dataset call on is_local_main_process instead of
  is_main_process: the global main process only exists on node 0, so on
  every other node all local ranks were downloading the dataset and
  building the Arrow cache concurrently.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 10:01:43 +02:00
Pepijn 3d262a6c9e fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync
In distributed training, accelerate can only synchronize the shuffle
permutation across ranks when the sampler exposes a generator attribute.
EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch
shards relied on every rank's global CPU RNG staying in lockstep forever;
any rank-asymmetric RNG consumption (e.g. eval rollouts on the main
process only) silently desynced the permutations and ranks trained on
overlapping/missing samples.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 10:01:42 +02:00
6 changed files with 324 additions and 81 deletions
+37 -4
View File
@@ -49,8 +49,19 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def save_training_step(
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
) -> None:
state: dict = {"step": step}
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
# produced `step`, since both scale how many sampler positions a step consumes (see
# compute_sampler_state).
if num_processes is not None:
state["num_processes"] = num_processes
if batch_size is not None:
state["batch_size"] = batch_size
write_json(state, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
@@ -58,6 +69,16 @@ def load_training_step(save_dir: Path) -> int:
return training_step["step"]
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
@@ -75,6 +96,8 @@ def save_checkpoint(
scheduler: LRScheduler | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -100,6 +123,10 @@ def save_checkpoint(
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
num_processes (int | None, optional): Distributed world size to record for sample-exact
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
@@ -112,7 +139,9 @@ def save_checkpoint(
preprocessor.save_pretrained(pretrained_dir)
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
)
def save_training_state(
@@ -120,6 +149,8 @@ def save_training_state(
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
@@ -131,10 +162,12 @@ def save_training_state(
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
+2 -1
View File
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler
from .sampler import EpisodeAwareSampler, compute_sampler_state
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -82,6 +82,7 @@ __all__ = [
"aggregate_stats",
"convert_image_to_video_dataset",
"create_initial_features",
"compute_sampler_state",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
+122 -38
View File
@@ -14,14 +14,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections.abc import Iterator
import numpy as np
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
"""Sampler over episode frames that stores only per-episode boundaries.
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
instead of materializing a Python list of every frame index.
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
resumed epoch, `__len__` still reports the full length.
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
starts (e.g. on resume or in tests).
"""
def __init__(
self,
dataset_from_indices: list[int],
@@ -30,63 +52,125 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
seed: int = 0,
):
"""Sampler that optionally incorporates episode boundary information.
"""
Args:
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
dataset_from_indices: Start index of each episode in the dataset.
dataset_to_indices: End index of each episode in the dataset.
episode_indices_to_use: Episode indices to use; None means all.
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
seed: Seed the permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
if from_indices.shape != to_indices.shape:
raise ValueError(
f"dataset_from_indices and dataset_to_indices must have the same length, "
f"got {len(from_indices)} and {len(to_indices)}"
)
if not indices:
used = np.ones(len(from_indices), dtype=bool)
if episode_indices_to_use is not None:
used = np.zeros(len(from_indices), dtype=bool)
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
starts = from_indices + drop_n_first_frames
lengths = to_indices - drop_n_last_frames - starts
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
to_indices[episode_idx] - from_indices[episode_idx],
drop_n_first_frames,
drop_n_last_frames,
)
used &= lengths > 0
if not used.any():
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self.indices = indices
self._starts = starts[used]
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.shuffle = shuffle
self.generator = generator
self.seed = seed
self._epoch = 0
self._start_index = 0
@property
def indices(self) -> list[int]:
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
self._epoch = epoch
def state_dict(self) -> dict:
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _epoch_generator(self, epoch: int) -> torch.Generator:
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
# and reproduces identically on every rank without touching the global RNG.
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
return torch.Generator().manual_seed(epoch_seed)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_epoch(epoch, start)
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
for k in range(start, self._num_frames):
yield self._frame_index(int(order[k]))
else:
for i in self.indices:
yield i
for k in range(start, self._num_frames):
yield self._frame_index(k)
def __len__(self) -> int:
return len(self.indices)
return self._num_frames
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
Assumptions (resume is only sample-exact when they hold):
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
many positions a step consumes, so the epoch/offset are wrong if either changed. The
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
the boundary is off.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
return {"epoch": epoch, "start_index": start_index}
+46 -18
View File
@@ -36,6 +36,8 @@ from tqdm import tqdm
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
save_checkpoint,
update_last_checkpoint,
@@ -43,7 +45,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -232,18 +234,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: the global main process downloads once to the shared
# dataset root, then a barrier lets every other rank read the already-populated copy.
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
# Other ranks read from the shared copy populated by the main process.
if not is_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -387,22 +388,47 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
if not cfg.dataset.streaming:
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
# The order is a pure function of (seed, epoch), so every rank independently produces the
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True,
generator=sampler_generator,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
# use the values recorded in the checkpoint (falling back to the current ones for older
# ckpts that did not store them).
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
ckpt_num_processes = saved_num_processes or accelerator.num_processes
ckpt_batch_size = saved_batch_size or cfg.batch_size
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
logging.warning(
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
f"written with num_processes={saved_num_processes}. The data order resumes at the "
"right epoch/offset, but per-rank sample-exactness requires the same world size."
)
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
logging.warning(
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
"but per-rank sample-exactness requires the same batch size."
)
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
sampler.load_state_dict(sampler_state)
if is_main_process:
logging.info(
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
else:
shuffle = True
sampler = None
@@ -521,6 +547,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
num_processes=accelerator.num_processes,
batch_size=cfg.batch_size,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
+93 -20
View File
@@ -114,28 +114,17 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
def test_shuffle_is_reproducible_across_instances():
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
# produce the same permutation without any generator synchronization.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
assert list(sampler_c) == epoch_0
def test_negative_drop_first_frames_raises():
@@ -161,3 +150,87 @@ def test_partial_episode_drop_warns(caplog):
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
# --- seeded (seed, epoch) shuffling, resume, and state ---
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
for seed in (0, 1, 1234):
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
assert sorted(sampler) == list(range(num_frames))
def test_deterministic_sampler_epochs_reproduce_and_differ():
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
assert epoch_1 != epoch_0
assert sorted(epoch_1) == sorted(epoch_0)
sampler_a.set_epoch(0)
assert list(sampler_a) == epoch_0
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
def test_deterministic_sampler_resume_mid_epoch():
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
epoch_0 = list(reference)
epoch_1 = list(reference)
for start in (0, 1, 4, len(epoch_0)):
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
assert list(resumed) == epoch_1
def test_deterministic_sampler_construction_stores_only_boundaries():
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
# instantiates from just its boundaries without materializing a per-frame index list.
num_frames = 1_000_000
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
def test_deterministic_sampler_resume_is_exact_at_scale():
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
# permutation and slicing from the saved offset reproduces the remaining order exactly.
num_frames = 100_000
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
epoch_0 = list(reference)
assert sorted(epoch_0) == list(range(num_frames))
start = num_frames - 5
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
def test_compute_sampler_state():
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 0,
"start_index": 0,
}
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 40,
}
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
"epoch": 2,
"start_index": 40,
}
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 100,
}
+24
View File
@@ -20,6 +20,8 @@ from unittest.mock import Mock, patch
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
load_training_step,
save_checkpoint,
@@ -63,6 +65,28 @@ def test_load_training_step(tmp_path):
assert loaded_step == step
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
assert load_training_num_processes(tmp_path) == 4
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the world size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_num_processes(tmp_path) is None
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
assert load_training_batch_size(tmp_path) == 32
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the batch size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_batch_size(tmp_path) is None
def test_update_last_checkpoint(tmp_path):
checkpoint = tmp_path / "0005"
checkpoint.mkdir()