mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(datasets): make deterministic-sampler resume robust to world-size changes
compute_sampler_state mapped a checkpointed step back to (epoch, start_index) using the *current* num_processes, but the number of sampler positions a step consumes scales with the world size that produced it. Resuming on a different GPU count therefore landed on the wrong epoch/offset, silently re-seeing or skipping data. Record num_processes in training_step.json at checkpoint time and feed the checkpoint's value into compute_sampler_state on resume, so the data order resumes at the right position regardless of the new world size. Warn when the world size changed (the global offset is correct, but per-rank sample-exactness needs the same topology). Old checkpoints without the field fall back to the current world size. Also document compute_sampler_state's assumptions explicitly: num_processes / batch_size must match the checkpointing run, and accelerate's even_batches=True padding is mirrored by the ceil(... / num_processes) term. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -49,8 +49,13 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
def save_training_step(step: int, save_dir: Path, num_processes: int | None = None) -> None:
|
||||
state: dict = {"step": step}
|
||||
if num_processes is not None:
|
||||
# Recorded so a resumed run can detect a changed world size: the deterministic sampler's
|
||||
# resume offset is computed from the world size that produced `step` (see compute_sampler_state).
|
||||
state["num_processes"] = num_processes
|
||||
write_json(state, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
@@ -58,6 +63,11 @@ def load_training_step(save_dir: Path) -> int:
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
|
||||
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
|
||||
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
@@ -75,6 +85,7 @@ def save_checkpoint(
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -100,6 +111,8 @@ def save_checkpoint(
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
@@ -112,7 +125,7 @@ def save_checkpoint(
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
@@ -120,6 +133,7 @@ def save_training_state(
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
num_processes: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -131,10 +145,11 @@ def save_training_state(
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_training_step(train_step, save_dir, num_processes=num_processes)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
|
||||
@@ -218,6 +218,14 @@ def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_proce
|
||||
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
|
||||
per epoch (`even_batches` padding included). The start index provably stays below
|
||||
`num_frames`; the `min` is defensive.
|
||||
|
||||
Assumptions (resume is only sample-exact when they hold):
|
||||
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
|
||||
many positions a step consumes, so the epoch/offset are wrong if either changed. The
|
||||
caller passes the checkpoint's `num_processes` and warns on a mismatch.
|
||||
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
|
||||
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
|
||||
the boundary is off.
|
||||
"""
|
||||
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
|
||||
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
|
||||
|
||||
@@ -36,6 +36,7 @@ from tqdm import tqdm
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
@@ -399,8 +400,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
# The resume offset depends on the world size that produced `step`, so use the world
|
||||
# size recorded in the checkpoint (falling back to the current one for older ckpts).
|
||||
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
|
||||
ckpt_num_processes = saved_num_processes or accelerator.num_processes
|
||||
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
|
||||
logging.warning(
|
||||
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
|
||||
f"written with num_processes={saved_num_processes}. The data order resumes at the "
|
||||
"right epoch/offset, but per-rank sample-exactness requires the same world size."
|
||||
)
|
||||
sampler_state = compute_sampler_state(
|
||||
step, len(sampler), cfg.batch_size, accelerator.num_processes
|
||||
step, len(sampler), cfg.batch_size, ckpt_num_processes
|
||||
)
|
||||
sampler.load_state_dict(sampler_state)
|
||||
if is_main_process:
|
||||
@@ -541,6 +552,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
@@ -20,6 +20,7 @@ from unittest.mock import Mock, patch
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
@@ -63,6 +64,17 @@ def test_load_training_step(tmp_path):
|
||||
assert loaded_step == step
|
||||
|
||||
|
||||
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
|
||||
assert load_training_num_processes(tmp_path) == 4
|
||||
|
||||
|
||||
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
|
||||
# Checkpoints written before the world size was recorded must still load (back-compat).
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert load_training_num_processes(tmp_path) is None
|
||||
|
||||
|
||||
def test_update_last_checkpoint(tmp_path):
|
||||
checkpoint = tmp_path / "0005"
|
||||
checkpoint.mkdir()
|
||||
|
||||
Reference in New Issue
Block a user