diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 21ee514de..2d23b4003 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -49,8 +49,19 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa return output_dir / CHECKPOINTS_DIR / step_identifier -def save_training_step(step: int, save_dir: Path) -> None: - write_json({"step": step}, save_dir / TRAINING_STEP) +def save_training_step( + step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None +) -> None: + state: dict = {"step": step} + # num_processes and batch_size are recorded so a resumed run can detect a changed world size or + # batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that + # produced `step`, since both scale how many sampler positions a step consumes (see + # compute_sampler_state). + if num_processes is not None: + state["num_processes"] = num_processes + if batch_size is not None: + state["batch_size"] = batch_size + write_json(state, save_dir / TRAINING_STEP) def load_training_step(save_dir: Path) -> int: @@ -58,6 +69,16 @@ def load_training_step(save_dir: Path) -> int: return training_step["step"] +def load_training_num_processes(checkpoint_dir: Path) -> int | None: + """World size recorded at checkpoint time, or None for checkpoints written before it was stored.""" + return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes") + + +def load_training_batch_size(checkpoint_dir: Path) -> int | None: + """Per-process batch size recorded at checkpoint time, or None for older checkpoints.""" + return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size") + + def update_last_checkpoint(checkpoint_dir: Path) -> Path: last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK if last_checkpoint_dir.is_symlink(): @@ -75,6 +96,8 @@ def save_checkpoint( scheduler: LRScheduler | None = None, preprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None, + num_processes: int | None = None, + batch_size: int | None = None, ) -> None: """This function creates the following directory structure: @@ -100,6 +123,10 @@ def save_checkpoint( scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. preprocessor: The preprocessor/pipeline to save. Defaults to None. postprocessor: The postprocessor/pipeline to save. Defaults to None. + num_processes (int | None, optional): Distributed world size to record for sample-exact + resume. Defaults to None (not recorded). + batch_size (int | None, optional): Per-process batch size to record for sample-exact + resume. Defaults to None (not recorded). """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) @@ -112,7 +139,9 @@ def save_checkpoint( preprocessor.save_pretrained(pretrained_dir) if postprocessor is not None: postprocessor.save_pretrained(pretrained_dir) - save_training_state(checkpoint_dir, step, optimizer, scheduler) + save_training_state( + checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size + ) def save_training_state( @@ -120,6 +149,8 @@ def save_training_state( train_step: int, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, + num_processes: int | None = None, + batch_size: int | None = None, ) -> None: """ Saves the training step, optimizer state, scheduler state, and rng state. @@ -131,10 +162,12 @@ def save_training_state( Defaults to None. scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. Defaults to None. + num_processes (int | None, optional): Distributed world size to record. Defaults to None. + batch_size (int | None, optional): Per-process batch size to record. Defaults to None. """ save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir.mkdir(parents=True, exist_ok=True) - save_training_step(train_step, save_dir) + save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size) save_rng_state(save_dir) if optimizer is not None: save_optimizer_state(optimizer, save_dir) diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index 2a67858d2..bd12a7248 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset from .multi_dataset import MultiLeRobotDataset from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav -from .sampler import EpisodeAwareSampler +from .sampler import EpisodeAwareSampler, compute_sampler_state from .streaming_dataset import StreamingLeRobotDataset from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card from .video_utils import VideoEncodingManager @@ -82,6 +82,7 @@ __all__ = [ "aggregate_stats", "convert_image_to_video_dataset", "create_initial_features", + "compute_sampler_state", "create_lerobot_dataset_card", "column_for_style", "delete_episodes", diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 64d871907..af85dff9b 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -14,14 +14,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math from collections.abc import Iterator +import numpy as np import torch logger = logging.getLogger(__name__) class EpisodeAwareSampler: + """Sampler over episode frames that stores only per-episode boundaries. + + Logical positions map to frame indices on the fly (O(num_episodes) construction memory) + instead of materializing a Python list of every frame index. + + Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order + is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the + global RNG (no `generator` to sync across distributed ranks), and `state_dict` / + `load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and + continuing from the saved offset. Each call to `__iter__` advances the epoch. During a + resumed epoch, `__len__` still reports the full length. + + Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict` + set it explicitly. Within a single run callers should rely on exactly one of these mechanisms, + not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same + iterations would skip or repeat epochs. The training loop drives it purely through `__iter__` + (via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration + starts (e.g. on resume or in tests). + """ + def __init__( self, dataset_from_indices: list[int], @@ -30,63 +52,125 @@ class EpisodeAwareSampler: drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, shuffle: bool = False, - generator: torch.Generator | None = None, + seed: int = 0, ): - """Sampler that optionally incorporates episode boundary information. - + """ Args: - dataset_from_indices: List of indices containing the start of each episode in the dataset. - dataset_to_indices: List of indices containing the end of each episode in the dataset. - episode_indices_to_use: List of episode indices to use. If None, all episodes are used. - Assumes that episodes are indexed from 0 to N-1. - drop_n_first_frames: Number of frames to drop from the start of each episode. - drop_n_last_frames: Number of frames to drop from the end of each episode. + dataset_from_indices: Start index of each episode in the dataset. + dataset_to_indices: End index of each episode in the dataset. + episode_indices_to_use: Episode indices to use; None means all. + drop_n_first_frames: Frames to drop from the start of each episode. + drop_n_last_frames: Frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. - generator: Generator used for shuffling. Exposing this attribute (even when None) lets - `accelerate` register it as the synchronized RNG in distributed training, so - every rank draws the same permutation and batch shards stay disjoint. When - None, shuffling falls back to the global torch RNG. + seed: Seed the permutation is derived from (together with the epoch). """ if drop_n_first_frames < 0: raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") if drop_n_last_frames < 0: raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") - indices = [] - for episode_idx, (start_index, end_index) in enumerate( - zip(dataset_from_indices, dataset_to_indices, strict=True) - ): - if episode_indices_to_use is None or episode_idx in episode_indices_to_use: - ep_length = end_index - start_index - if drop_n_first_frames + drop_n_last_frames >= ep_length: - logger.warning( - "Episode %d has %d frames but drop_n_first_frames=%d and " - "drop_n_last_frames=%d removes all frames. Skipping.", - episode_idx, - ep_length, - drop_n_first_frames, - drop_n_last_frames, - ) - continue - indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) + from_indices = np.asarray(dataset_from_indices, dtype=np.int64) + to_indices = np.asarray(dataset_to_indices, dtype=np.int64) + if from_indices.shape != to_indices.shape: + raise ValueError( + f"dataset_from_indices and dataset_to_indices must have the same length, " + f"got {len(from_indices)} and {len(to_indices)}" + ) - if not indices: + used = np.ones(len(from_indices), dtype=bool) + if episode_indices_to_use is not None: + used = np.zeros(len(from_indices), dtype=bool) + used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True + + starts = from_indices + drop_n_first_frames + lengths = to_indices - drop_n_last_frames - starts + for episode_idx in np.flatnonzero(used & (lengths <= 0)): + logger.warning( + "Episode %d has %d frames but drop_n_first_frames=%d and " + "drop_n_last_frames=%d removes all frames. Skipping.", + episode_idx, + to_indices[episode_idx] - from_indices[episode_idx], + drop_n_first_frames, + drop_n_last_frames, + ) + used &= lengths > 0 + if not used.any(): raise ValueError( "No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. " "All episodes were either filtered out or had too few frames." ) - self.indices = indices + self._starts = starts[used] + self._cum_lengths = np.cumsum(lengths[used]) + self._num_frames = int(self._cum_lengths[-1]) self.shuffle = shuffle - self.generator = generator + self.seed = seed + self._epoch = 0 + self._start_index = 0 + + @property + def indices(self) -> list[int]: + """Materialized frame indices in unshuffled order; O(num_frames), introspection only.""" + return [self._frame_index(k) for k in range(self._num_frames)] + + def set_epoch(self, epoch: int) -> None: + self._epoch = epoch + + def state_dict(self) -> dict: + return {"epoch": self._epoch, "start_index": self._start_index} + + def load_state_dict(self, state: dict) -> None: + self._epoch = state["epoch"] + self._start_index = state["start_index"] + + def _epoch_generator(self, epoch: int) -> torch.Generator: + # Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both + # and reproduces identically on every rank without touching the global RNG. + epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0]) + return torch.Generator().manual_seed(epoch_seed) + + def _frame_index(self, position: int) -> int: + episode = int(np.searchsorted(self._cum_lengths, position, side="right")) + position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0) + return int(self._starts[episode]) + position_in_episode def __iter__(self) -> Iterator[int]: + # Advance epoch state eagerly, not on first consumption of the generator. + epoch, start = self._epoch, self._start_index + self._epoch += 1 + self._start_index = 0 + return self._iter_epoch(epoch, start) + + def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]: if self.shuffle: - for i in torch.randperm(len(self.indices), generator=self.generator): - yield self.indices[i] + order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch)) + for k in range(start, self._num_frames): + yield self._frame_index(int(order[k])) else: - for i in self.indices: - yield i + for k in range(start, self._num_frames): + yield self._frame_index(k) def __len__(self) -> int: - return len(self.indices) + return self._num_frames + + +def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict: + """Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume. + + Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler + positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches + per epoch (`even_batches` padding included). The start index provably stays below + `num_frames`; the `min` is defensive. + + Assumptions (resume is only sample-exact when they hold): + - `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how + many positions a step consumes, so the epoch/offset are wrong if either changed. The + caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch. + - accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term + mirrors that padding; with `even_batches=False` the per-epoch batch count differs and + the boundary is off. + """ + batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes) + epoch, batches_into_epoch = divmod(step, batches_per_epoch) + start_index = min(batches_into_epoch * batch_size * num_processes, num_frames) + return {"epoch": epoch, "start_index": start_index} diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a35d4229d..70a5e9e9d 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -36,6 +36,8 @@ from tqdm import tqdm from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_batch_size, + load_training_num_processes, load_training_state, save_checkpoint, update_last_checkpoint, @@ -43,7 +45,7 @@ from lerobot.common.train_utils import ( from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets import EpisodeAwareSampler, make_dataset +from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors @@ -237,18 +239,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: each node's local main process downloads first to avoid - # race conditions (the global main process only exists on node 0, so gating on it would let - # all ranks of the other nodes download and build the Arrow cache concurrently). - if accelerator.is_local_main_process: - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: the global main process downloads once to the shared + # dataset root, then a barrier lets every other rank read the already-populated copy. + # LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset from the local cache - if not accelerator.is_local_main_process: + # Other ranks read from the shared copy populated by the main process. + if not is_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. @@ -392,22 +393,47 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training - if hasattr(active_cfg, "drop_n_last_frames"): + if not cfg.dataset.streaming: + # All non-streaming (map-style) datasets use EpisodeAwareSampler. + # The order is a pure function of (seed, epoch), so every rank independently produces the + # same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard + # without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact. shuffle = False - # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare - # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even - # when ranks consume the global RNG asymmetrically (e.g. eval on the main process only). - sampler_generator = torch.Generator() - if cfg.seed is not None: - sampler_generator.manual_seed(cfg.seed) sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, - drop_n_last_frames=active_cfg.drop_n_last_frames, + drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0), shuffle=True, - generator=sampler_generator, + seed=cfg.seed if cfg.seed is not None else 0, ) + if cfg.resume and step > 0: + # The resume offset depends on the (num_processes, batch_size) that produced `step`, so + # use the values recorded in the checkpoint (falling back to the current ones for older + # ckpts that did not store them). + saved_num_processes = load_training_num_processes(cfg.checkpoint_path) + saved_batch_size = load_training_batch_size(cfg.checkpoint_path) + ckpt_num_processes = saved_num_processes or accelerator.num_processes + ckpt_batch_size = saved_batch_size or cfg.batch_size + if is_main_process and saved_num_processes not in (None, accelerator.num_processes): + logging.warning( + f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was " + f"written with num_processes={saved_num_processes}. The data order resumes at the " + "right epoch/offset, but per-rank sample-exactness requires the same world size." + ) + if is_main_process and saved_batch_size not in (None, cfg.batch_size): + logging.warning( + f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with " + f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, " + "but per-rank sample-exactness requires the same batch size." + ) + sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes) + sampler.load_state_dict(sampler_state) + if is_main_process: + logging.info( + f"Resuming data order at epoch {sampler_state['epoch']}, " + f"sample {sampler_state['start_index']}" + ) else: shuffle = True sampler = None @@ -544,6 +570,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): scheduler=lr_scheduler, preprocessor=preprocessor, postprocessor=postprocessor, + num_processes=accelerator.num_processes, + batch_size=cfg.batch_size, ) update_last_checkpoint(checkpoint_dir) if wandb_logger: diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 95429c7ec..7a5fc0fe0 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -114,28 +114,17 @@ def test_shuffle(): assert set(sampler) == {0, 1, 2, 3, 4, 5} -def test_shuffle_with_generator_is_deterministic(): - # Two samplers shuffling with same-seed generators must yield identical permutations. - # This is what keeps batch shards disjoint across ranks in distributed training, where - # accelerate synchronizes the sampler's generator state instead of the global torch RNG. - sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) - sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) - assert list(sampler_a) == list(sampler_b) - +def test_shuffle_is_reproducible_across_instances(): + # The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks) + # produce the same permutation without any generator synchronization. + sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42) + sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42) + epoch_0 = list(sampler_a) + assert list(sampler_b) == epoch_0 # Desyncing the global RNG must not affect the permutation. - sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) - order_before = list(sampler_c) - sampler_c.generator.manual_seed(42) + sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42) torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would - assert list(sampler_c) == order_before - - -def test_generator_attribute_defaults_to_none(): - # accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`, - # so the attribute must exist even when no generator is passed. - sampler = EpisodeAwareSampler([0], [6], shuffle=True) - assert sampler.generator is None - assert set(sampler) == {0, 1, 2, 3, 4, 5} + assert list(sampler_c) == epoch_0 def test_negative_drop_first_frames_raises(): @@ -161,3 +150,87 @@ def test_partial_episode_drop_warns(caplog): # Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5 assert sampler.indices == [2, 3, 4, 5] assert "Episode 0" in caplog.text + + +# --- seeded (seed, epoch) shuffling, resume, and state --- + +from lerobot.datasets.sampler import compute_sampler_state # noqa: E402 + +EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames + + +@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100]) +def test_deterministic_sampler_shuffle_is_permutation(num_frames): + for seed in (0, 1, 1234): + sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed) + assert sorted(sampler) == list(range(num_frames)) + + +def test_deterministic_sampler_epochs_reproduce_and_differ(): + sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42) + sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42) + epoch_0 = list(sampler_a) + assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process + epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch + assert epoch_1 != epoch_0 + assert sorted(epoch_1) == sorted(epoch_0) + sampler_a.set_epoch(0) + assert list(sampler_a) == epoch_0 + assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0 + + +def test_deterministic_sampler_resume_mid_epoch(): + reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + epoch_0 = list(reference) + epoch_1 = list(reference) + for start in (0, 1, 4, len(epoch_0)): + resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + resumed.load_state_dict({"epoch": 0, "start_index": start}) + assert list(resumed) == epoch_0[start:] + # the resumed sampler continues into the same epoch 1 as the uninterrupted one + assert list(resumed) == epoch_1 + + +def test_deterministic_sampler_construction_stores_only_boundaries(): + # Construction is O(num_episodes), not O(num_frames): a million-frame single episode + # instantiates from just its boundaries without materializing a per-frame index list. + num_frames = 1_000_000 + sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + assert len(sampler) == num_frames + assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,) + + +def test_deterministic_sampler_resume_is_exact_at_scale(): + # Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's + # permutation and slicing from the saved offset reproduces the remaining order exactly. + num_frames = 100_000 + reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + epoch_0 = list(reference) + assert sorted(epoch_0) == list(range(num_frames)) + start = num_frames - 5 + resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + resumed.load_state_dict({"epoch": 0, "start_index": start}) + assert list(resumed) == epoch_0[start:] + + +def test_compute_sampler_state(): + # 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch. + assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == { + "epoch": 0, + "start_index": 0, + } + # step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in + assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == { + "epoch": 1, + "start_index": 40, + } + # uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank + assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == { + "epoch": 2, + "start_index": 40, + } + # uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads) + assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == { + "epoch": 1, + "start_index": 100, + } diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 8e5b3f167..c171763c2 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -20,6 +20,8 @@ from unittest.mock import Mock, patch from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_batch_size, + load_training_num_processes, load_training_state, load_training_step, save_checkpoint, @@ -63,6 +65,28 @@ def test_load_training_step(tmp_path): assert loaded_step == step +def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4) + assert load_training_num_processes(tmp_path) == 4 + + +def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler): + # Checkpoints written before the world size was recorded must still load (back-compat). + save_training_state(tmp_path, 10, optimizer, scheduler) + assert load_training_num_processes(tmp_path) is None + + +def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32) + assert load_training_batch_size(tmp_path) == 32 + + +def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler): + # Checkpoints written before the batch size was recorded must still load (back-compat). + save_training_state(tmp_path, 10, optimizer, scheduler) + assert load_training_batch_size(tmp_path) is None + + def test_update_last_checkpoint(tmp_path): checkpoint = tmp_path / "0005" checkpoint.mkdir()