From 3f6909fb630b629f2b99915a98971f5299087461 Mon Sep 17 00:00:00 2001 From: pepijn Date: Thu, 11 Jun 2026 16:48:37 +0000 Subject: [PATCH] fix(datasets): address sampler review (batch_size resume guard + docs) - Record batch_size in training_step.json alongside num_processes and feed the checkpoint's value into compute_sampler_state on resume; warn when it differs (per-rank sample-exactness needs the same batch size). - Document the set_epoch vs __iter__ auto-advance coupling on EpisodeAwareSampler (callers should rely on exactly one mechanism per run). - Note the broadened (reproducibility-breaking) sampler guard and the no-generator distributed sharding correctness in lerobot_train.py. - Add load_training_batch_size + parallel tests. Co-authored-by: Cursor --- src/lerobot/common/train_utils.py | 28 +++++++++++++++++++++++----- src/lerobot/datasets/sampler.py | 9 ++++++++- src/lerobot/scripts/lerobot_train.py | 27 ++++++++++++++++++++++----- tests/utils/test_train_utils.py | 12 ++++++++++++ 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 1b069c63a..2d23b4003 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -49,12 +49,18 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa return output_dir / CHECKPOINTS_DIR / step_identifier -def save_training_step(step: int, save_dir: Path, num_processes: int | None = None) -> None: +def save_training_step( + step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None +) -> None: state: dict = {"step": step} + # num_processes and batch_size are recorded so a resumed run can detect a changed world size or + # batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that + # produced `step`, since both scale how many sampler positions a step consumes (see + # compute_sampler_state). if num_processes is not None: - # Recorded so a resumed run can detect a changed world size: the deterministic sampler's - # resume offset is computed from the world size that produced `step` (see compute_sampler_state). state["num_processes"] = num_processes + if batch_size is not None: + state["batch_size"] = batch_size write_json(state, save_dir / TRAINING_STEP) @@ -68,6 +74,11 @@ def load_training_num_processes(checkpoint_dir: Path) -> int | None: return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes") +def load_training_batch_size(checkpoint_dir: Path) -> int | None: + """Per-process batch size recorded at checkpoint time, or None for older checkpoints.""" + return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size") + + def update_last_checkpoint(checkpoint_dir: Path) -> Path: last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK if last_checkpoint_dir.is_symlink(): @@ -86,6 +97,7 @@ def save_checkpoint( preprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None, num_processes: int | None = None, + batch_size: int | None = None, ) -> None: """This function creates the following directory structure: @@ -113,6 +125,8 @@ def save_checkpoint( postprocessor: The postprocessor/pipeline to save. Defaults to None. num_processes (int | None, optional): Distributed world size to record for sample-exact resume. Defaults to None (not recorded). + batch_size (int | None, optional): Per-process batch size to record for sample-exact + resume. Defaults to None (not recorded). """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) @@ -125,7 +139,9 @@ def save_checkpoint( preprocessor.save_pretrained(pretrained_dir) if postprocessor is not None: postprocessor.save_pretrained(pretrained_dir) - save_training_state(checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes) + save_training_state( + checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size + ) def save_training_state( @@ -134,6 +150,7 @@ def save_training_state( optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, num_processes: int | None = None, + batch_size: int | None = None, ) -> None: """ Saves the training step, optimizer state, scheduler state, and rng state. @@ -146,10 +163,11 @@ def save_training_state( scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. Defaults to None. num_processes (int | None, optional): Distributed world size to record. Defaults to None. + batch_size (int | None, optional): Per-process batch size to record. Defaults to None. """ save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir.mkdir(parents=True, exist_ok=True) - save_training_step(train_step, save_dir, num_processes=num_processes) + save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size) save_rng_state(save_dir) if optimizer is not None: save_optimizer_state(optimizer, save_dir) diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 9d102e58f..af85dff9b 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -35,6 +35,13 @@ class EpisodeAwareSampler: `load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and continuing from the saved offset. Each call to `__iter__` advances the epoch. During a resumed epoch, `__len__` still reports the full length. + + Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict` + set it explicitly. Within a single run callers should rely on exactly one of these mechanisms, + not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same + iterations would skip or repeat epochs. The training loop drives it purely through `__iter__` + (via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration + starts (e.g. on resume or in tests). """ def __init__( @@ -158,7 +165,7 @@ def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_proce Assumptions (resume is only sample-exact when they hold): - `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how many positions a step consumes, so the epoch/offset are wrong if either changed. The - caller passes the checkpoint's `num_processes` and warns on a mismatch. + caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch. - accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term mirrors that padding; with `even_batches=False` the per-epoch batch count differs and the boundary is off. diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 6787d1c4a..9d4b49714 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -36,6 +36,7 @@ from tqdm import tqdm from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_batch_size, load_training_num_processes, load_training_state, save_checkpoint, @@ -389,8 +390,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # create dataloader for offline training if not cfg.dataset.streaming: - # Deterministic data order (pure function of seed and epoch): no cross-rank RNG sync - # needed and sample-exact resume. + # All non-streaming (map-style) datasets use EpisodeAwareSampler. This is broader than the + # historical `hasattr(active_cfg, "drop_n_last_frames")` guard: configs that previously fell + # back to DataLoader's default random shuffle now get this sampler instead, so their data + # order changes for a given seed (a deliberate, reproducibility-breaking improvement). + # + # The order is a pure function of (seed, epoch), so every rank independently produces the + # same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard + # without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact. shuffle = False sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], @@ -401,17 +408,26 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): seed=cfg.seed if cfg.seed is not None else 0, ) if cfg.resume and step > 0: - # The resume offset depends on the world size that produced `step`, so use the world - # size recorded in the checkpoint (falling back to the current one for older ckpts). + # The resume offset depends on the (num_processes, batch_size) that produced `step`, so + # use the values recorded in the checkpoint (falling back to the current ones for older + # ckpts that did not store them). saved_num_processes = load_training_num_processes(cfg.checkpoint_path) + saved_batch_size = load_training_batch_size(cfg.checkpoint_path) ckpt_num_processes = saved_num_processes or accelerator.num_processes + ckpt_batch_size = saved_batch_size or cfg.batch_size if is_main_process and saved_num_processes not in (None, accelerator.num_processes): logging.warning( f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was " f"written with num_processes={saved_num_processes}. The data order resumes at the " "right epoch/offset, but per-rank sample-exactness requires the same world size." ) - sampler_state = compute_sampler_state(step, len(sampler), cfg.batch_size, ckpt_num_processes) + if is_main_process and saved_batch_size not in (None, cfg.batch_size): + logging.warning( + f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with " + f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, " + "but per-rank sample-exactness requires the same batch size." + ) + sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes) sampler.load_state_dict(sampler_state) if is_main_process: logging.info( @@ -537,6 +553,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): preprocessor=preprocessor, postprocessor=postprocessor, num_processes=accelerator.num_processes, + batch_size=cfg.batch_size, ) update_last_checkpoint(checkpoint_dir) if wandb_logger: diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index ee07747d6..c171763c2 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -20,6 +20,7 @@ from unittest.mock import Mock, patch from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_batch_size, load_training_num_processes, load_training_state, load_training_step, @@ -75,6 +76,17 @@ def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, sc assert load_training_num_processes(tmp_path) is None +def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32) + assert load_training_batch_size(tmp_path) == 32 + + +def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler): + # Checkpoints written before the batch size was recorded must still load (back-compat). + save_training_state(tmp_path, 10, optimizer, scheduler) + assert load_training_batch_size(tmp_path) is None + + def test_update_last_checkpoint(tmp_path): checkpoint = tmp_path / "0005" checkpoint.mkdir()