From 1aa937aad2eb2f1a5af18ea0c8596feb6be93897 Mon Sep 17 00:00:00 2001 From: pepijn Date: Thu, 11 Jun 2026 14:37:19 +0000 Subject: [PATCH] fix(datasets): make deterministic-sampler resume robust to world-size changes compute_sampler_state mapped a checkpointed step back to (epoch, start_index) using the *current* num_processes, but the number of sampler positions a step consumes scales with the world size that produced it. Resuming on a different GPU count therefore landed on the wrong epoch/offset, silently re-seeing or skipping data. Record num_processes in training_step.json at checkpoint time and feed the checkpoint's value into compute_sampler_state on resume, so the data order resumes at the right position regardless of the new world size. Warn when the world size changed (the global offset is correct, but per-rank sample-exactness needs the same topology). Old checkpoints without the field fall back to the current world size. Also document compute_sampler_state's assumptions explicitly: num_processes / batch_size must match the checkpointing run, and accelerate's even_batches=True padding is mirrored by the ceil(... / num_processes) term. Co-Authored-By: Claude Fable 5 Co-authored-by: Cursor --- src/lerobot/common/train_utils.py | 23 +++++++++++++++++++---- src/lerobot/datasets/sampler.py | 8 ++++++++ src/lerobot/scripts/lerobot_train.py | 14 +++++++++++++- tests/utils/test_train_utils.py | 12 ++++++++++++ 4 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 21ee514de..1b069c63a 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -49,8 +49,13 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa return output_dir / CHECKPOINTS_DIR / step_identifier -def save_training_step(step: int, save_dir: Path) -> None: - write_json({"step": step}, save_dir / TRAINING_STEP) +def save_training_step(step: int, save_dir: Path, num_processes: int | None = None) -> None: + state: dict = {"step": step} + if num_processes is not None: + # Recorded so a resumed run can detect a changed world size: the deterministic sampler's + # resume offset is computed from the world size that produced `step` (see compute_sampler_state). + state["num_processes"] = num_processes + write_json(state, save_dir / TRAINING_STEP) def load_training_step(save_dir: Path) -> int: @@ -58,6 +63,11 @@ def load_training_step(save_dir: Path) -> int: return training_step["step"] +def load_training_num_processes(checkpoint_dir: Path) -> int | None: + """World size recorded at checkpoint time, or None for checkpoints written before it was stored.""" + return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes") + + def update_last_checkpoint(checkpoint_dir: Path) -> Path: last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK if last_checkpoint_dir.is_symlink(): @@ -75,6 +85,7 @@ def save_checkpoint( scheduler: LRScheduler | None = None, preprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None, + num_processes: int | None = None, ) -> None: """This function creates the following directory structure: @@ -100,6 +111,8 @@ def save_checkpoint( scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. preprocessor: The preprocessor/pipeline to save. Defaults to None. postprocessor: The postprocessor/pipeline to save. Defaults to None. + num_processes (int | None, optional): Distributed world size to record for sample-exact + resume. Defaults to None (not recorded). """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) @@ -112,7 +125,7 @@ def save_checkpoint( preprocessor.save_pretrained(pretrained_dir) if postprocessor is not None: postprocessor.save_pretrained(pretrained_dir) - save_training_state(checkpoint_dir, step, optimizer, scheduler) + save_training_state(checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes) def save_training_state( @@ -120,6 +133,7 @@ def save_training_state( train_step: int, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, + num_processes: int | None = None, ) -> None: """ Saves the training step, optimizer state, scheduler state, and rng state. @@ -131,10 +145,11 @@ def save_training_state( Defaults to None. scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. Defaults to None. + num_processes (int | None, optional): Distributed world size to record. Defaults to None. """ save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir.mkdir(parents=True, exist_ok=True) - save_training_step(train_step, save_dir) + save_training_step(train_step, save_dir, num_processes=num_processes) save_rng_state(save_dir) if optimizer is not None: save_optimizer_state(optimizer, save_dir) diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 5af24b740..d23551218 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -218,6 +218,14 @@ def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_proce positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches per epoch (`even_batches` padding included). The start index provably stays below `num_frames`; the `min` is defensive. + + Assumptions (resume is only sample-exact when they hold): + - `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how + many positions a step consumes, so the epoch/offset are wrong if either changed. The + caller passes the checkpoint's `num_processes` and warns on a mismatch. + - accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term + mirrors that padding; with `even_batches=False` the per-epoch batch count differs and + the boundary is off. """ batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes) epoch, batches_into_epoch = divmod(step, batches_per_epoch) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 134a28eec..469d38b62 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -36,6 +36,7 @@ from tqdm import tqdm from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_num_processes, load_training_state, save_checkpoint, update_last_checkpoint, @@ -399,8 +400,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): seed=cfg.seed if cfg.seed is not None else 0, ) if cfg.resume and step > 0: + # The resume offset depends on the world size that produced `step`, so use the world + # size recorded in the checkpoint (falling back to the current one for older ckpts). + saved_num_processes = load_training_num_processes(cfg.checkpoint_path) + ckpt_num_processes = saved_num_processes or accelerator.num_processes + if is_main_process and saved_num_processes not in (None, accelerator.num_processes): + logging.warning( + f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was " + f"written with num_processes={saved_num_processes}. The data order resumes at the " + "right epoch/offset, but per-rank sample-exactness requires the same world size." + ) sampler_state = compute_sampler_state( - step, len(sampler), cfg.batch_size, accelerator.num_processes + step, len(sampler), cfg.batch_size, ckpt_num_processes ) sampler.load_state_dict(sampler_state) if is_main_process: @@ -541,6 +552,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): scheduler=lr_scheduler, preprocessor=preprocessor, postprocessor=postprocessor, + num_processes=accelerator.num_processes, ) update_last_checkpoint(checkpoint_dir) if wandb_logger: diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 8e5b3f167..ee07747d6 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -20,6 +20,7 @@ from unittest.mock import Mock, patch from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_num_processes, load_training_state, load_training_step, save_checkpoint, @@ -63,6 +64,17 @@ def test_load_training_step(tmp_path): assert loaded_step == step +def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4) + assert load_training_num_processes(tmp_path) == 4 + + +def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler): + # Checkpoints written before the world size was recorded must still load (back-compat). + save_training_state(tmp_path, 10, optimizer, scheduler) + assert load_training_num_processes(tmp_path) is None + + def test_update_last_checkpoint(tmp_path): checkpoint = tmp_path / "0005" checkpoint.mkdir()