diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 21ee514de..1b069c63a 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -49,8 +49,13 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa return output_dir / CHECKPOINTS_DIR / step_identifier -def save_training_step(step: int, save_dir: Path) -> None: - write_json({"step": step}, save_dir / TRAINING_STEP) +def save_training_step(step: int, save_dir: Path, num_processes: int | None = None) -> None: + state: dict = {"step": step} + if num_processes is not None: + # Recorded so a resumed run can detect a changed world size: the deterministic sampler's + # resume offset is computed from the world size that produced `step` (see compute_sampler_state). + state["num_processes"] = num_processes + write_json(state, save_dir / TRAINING_STEP) def load_training_step(save_dir: Path) -> int: @@ -58,6 +63,11 @@ def load_training_step(save_dir: Path) -> int: return training_step["step"] +def load_training_num_processes(checkpoint_dir: Path) -> int | None: + """World size recorded at checkpoint time, or None for checkpoints written before it was stored.""" + return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes") + + def update_last_checkpoint(checkpoint_dir: Path) -> Path: last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK if last_checkpoint_dir.is_symlink(): @@ -75,6 +85,7 @@ def save_checkpoint( scheduler: LRScheduler | None = None, preprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None, + num_processes: int | None = None, ) -> None: """This function creates the following directory structure: @@ -100,6 +111,8 @@ def save_checkpoint( scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. preprocessor: The preprocessor/pipeline to save. Defaults to None. postprocessor: The postprocessor/pipeline to save. Defaults to None. + num_processes (int | None, optional): Distributed world size to record for sample-exact + resume. Defaults to None (not recorded). """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) @@ -112,7 +125,7 @@ def save_checkpoint( preprocessor.save_pretrained(pretrained_dir) if postprocessor is not None: postprocessor.save_pretrained(pretrained_dir) - save_training_state(checkpoint_dir, step, optimizer, scheduler) + save_training_state(checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes) def save_training_state( @@ -120,6 +133,7 @@ def save_training_state( train_step: int, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, + num_processes: int | None = None, ) -> None: """ Saves the training step, optimizer state, scheduler state, and rng state. @@ -131,10 +145,11 @@ def save_training_state( Defaults to None. scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. Defaults to None. + num_processes (int | None, optional): Distributed world size to record. Defaults to None. """ save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir.mkdir(parents=True, exist_ok=True) - save_training_step(train_step, save_dir) + save_training_step(train_step, save_dir, num_processes=num_processes) save_rng_state(save_dir) if optimizer is not None: save_optimizer_state(optimizer, save_dir) diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 5af24b740..d23551218 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -218,6 +218,14 @@ def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_proce positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches per epoch (`even_batches` padding included). The start index provably stays below `num_frames`; the `min` is defensive. + + Assumptions (resume is only sample-exact when they hold): + - `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how + many positions a step consumes, so the epoch/offset are wrong if either changed. The + caller passes the checkpoint's `num_processes` and warns on a mismatch. + - accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term + mirrors that padding; with `even_batches=False` the per-epoch batch count differs and + the boundary is off. """ batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes) epoch, batches_into_epoch = divmod(step, batches_per_epoch) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 134a28eec..469d38b62 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -36,6 +36,7 @@ from tqdm import tqdm from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_num_processes, load_training_state, save_checkpoint, update_last_checkpoint, @@ -399,8 +400,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): seed=cfg.seed if cfg.seed is not None else 0, ) if cfg.resume and step > 0: + # The resume offset depends on the world size that produced `step`, so use the world + # size recorded in the checkpoint (falling back to the current one for older ckpts). + saved_num_processes = load_training_num_processes(cfg.checkpoint_path) + ckpt_num_processes = saved_num_processes or accelerator.num_processes + if is_main_process and saved_num_processes not in (None, accelerator.num_processes): + logging.warning( + f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was " + f"written with num_processes={saved_num_processes}. The data order resumes at the " + "right epoch/offset, but per-rank sample-exactness requires the same world size." + ) sampler_state = compute_sampler_state( - step, len(sampler), cfg.batch_size, accelerator.num_processes + step, len(sampler), cfg.batch_size, ckpt_num_processes ) sampler.load_state_dict(sampler_state) if is_main_process: @@ -541,6 +552,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): scheduler=lr_scheduler, preprocessor=preprocessor, postprocessor=postprocessor, + num_processes=accelerator.num_processes, ) update_last_checkpoint(checkpoint_dir) if wandb_logger: diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 8e5b3f167..ee07747d6 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -20,6 +20,7 @@ from unittest.mock import Mock, patch from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_num_processes, load_training_state, load_training_step, save_checkpoint, @@ -63,6 +64,17 @@ def test_load_training_step(tmp_path): assert loaded_step == step +def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4) + assert load_training_num_processes(tmp_path) == 4 + + +def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler): + # Checkpoints written before the world size was recorded must still load (back-compat). + save_training_state(tmp_path, 10, optimizer, scheduler) + assert load_training_num_processes(tmp_path) is None + + def test_update_last_checkpoint(tmp_path): checkpoint = tmp_path / "0005" checkpoint.mkdir()