Compare commits

..

2 Commits

Author SHA1 Message Date
github-actions[bot] 058378e82f chore(dependencies): update uv.lock 2026-06-10 10:43:18 +00:00
Steven Palma d947983a78 chore(dependecies): update mujoco transitives 2026-06-10 12:11:10 +02:00
7 changed files with 926 additions and 1161 deletions
+4 -37
View File
@@ -49,19 +49,8 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(
step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None
) -> None:
state: dict = {"step": step}
# num_processes and batch_size are recorded so a resumed run can detect a changed world size or
# batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that
# produced `step`, since both scale how many sampler positions a step consumes (see
# compute_sampler_state).
if num_processes is not None:
state["num_processes"] = num_processes
if batch_size is not None:
state["batch_size"] = batch_size
write_json(state, save_dir / TRAINING_STEP)
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
@@ -69,16 +58,6 @@ def load_training_step(save_dir: Path) -> int:
return training_step["step"]
def load_training_num_processes(checkpoint_dir: Path) -> int | None:
"""World size recorded at checkpoint time, or None for checkpoints written before it was stored."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes")
def load_training_batch_size(checkpoint_dir: Path) -> int | None:
"""Per-process batch size recorded at checkpoint time, or None for older checkpoints."""
return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size")
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
@@ -96,8 +75,6 @@ def save_checkpoint(
scheduler: LRScheduler | None = None,
preprocessor: PolicyProcessorPipeline | None = None,
postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""This function creates the following directory structure:
@@ -123,10 +100,6 @@ def save_checkpoint(
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
num_processes (int | None, optional): Distributed world size to record for sample-exact
resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded).
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
@@ -139,9 +112,7 @@ def save_checkpoint(
preprocessor.save_pretrained(pretrained_dir)
if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir)
save_training_state(
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
def save_training_state(
@@ -149,8 +120,6 @@ def save_training_state(
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
num_processes: int | None = None,
batch_size: int | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
@@ -162,12 +131,10 @@ def save_training_state(
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_training_step(train_step, save_dir)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
+1 -2
View File
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler, compute_sampler_state
from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
@@ -82,7 +82,6 @@ __all__ = [
"aggregate_stats",
"convert_image_to_video_dataset",
"create_initial_features",
"compute_sampler_state",
"create_lerobot_dataset_card",
"column_for_style",
"delete_episodes",
+32 -122
View File
@@ -14,36 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from collections.abc import Iterator
import numpy as np
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
"""Sampler over episode frames that stores only per-episode boundaries.
Logical positions map to frame indices on the fly (O(num_episodes) construction memory)
instead of materializing a Python list of every frame index.
Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order
is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the
global RNG (no `generator` to sync across distributed ranks), and `state_dict` /
`load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and
continuing from the saved offset. Each call to `__iter__` advances the epoch. During a
resumed epoch, `__len__` still reports the full length.
Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict`
set it explicitly. Within a single run callers should rely on exactly one of these mechanisms,
not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same
iterations would skip or repeat epochs. The training loop drives it purely through `__iter__`
(via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration
starts (e.g. on resume or in tests).
"""
def __init__(
self,
dataset_from_indices: list[int],
@@ -52,125 +30,57 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
seed: int = 0,
):
"""
"""Sampler that optionally incorporates episode boundary information.
Args:
dataset_from_indices: Start index of each episode in the dataset.
dataset_to_indices: End index of each episode in the dataset.
episode_indices_to_use: Episode indices to use; None means all.
drop_n_first_frames: Frames to drop from the start of each episode.
drop_n_last_frames: Frames to drop from the end of each episode.
dataset_from_indices: List of indices containing the start of each episode in the dataset.
dataset_to_indices: List of indices containing the end of each episode in the dataset.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
seed: Seed the permutation is derived from (together with the epoch).
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
if from_indices.shape != to_indices.shape:
raise ValueError(
f"dataset_from_indices and dataset_to_indices must have the same length, "
f"got {len(from_indices)} and {len(to_indices)}"
)
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
used = np.ones(len(from_indices), dtype=bool)
if episode_indices_to_use is not None:
used = np.zeros(len(from_indices), dtype=bool)
used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True
starts = from_indices + drop_n_first_frames
lengths = to_indices - drop_n_last_frames - starts
for episode_idx in np.flatnonzero(used & (lengths <= 0)):
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
to_indices[episode_idx] - from_indices[episode_idx],
drop_n_first_frames,
drop_n_last_frames,
)
used &= lengths > 0
if not used.any():
if not indices:
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self._starts = starts[used]
self._cum_lengths = np.cumsum(lengths[used])
self._num_frames = int(self._cum_lengths[-1])
self.indices = indices
self.shuffle = shuffle
self.seed = seed
self._epoch = 0
self._start_index = 0
@property
def indices(self) -> list[int]:
"""Materialized frame indices in unshuffled order; O(num_frames), introspection only."""
return [self._frame_index(k) for k in range(self._num_frames)]
def set_epoch(self, epoch: int) -> None:
self._epoch = epoch
def state_dict(self) -> dict:
return {"epoch": self._epoch, "start_index": self._start_index}
def load_state_dict(self, state: dict) -> None:
self._epoch = state["epoch"]
self._start_index = state["start_index"]
def _epoch_generator(self, epoch: int) -> torch.Generator:
# Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both
# and reproduces identically on every rank without touching the global RNG.
epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0])
return torch.Generator().manual_seed(epoch_seed)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode
def __iter__(self) -> Iterator[int]:
# Advance epoch state eagerly, not on first consumption of the generator.
epoch, start = self._epoch, self._start_index
self._epoch += 1
self._start_index = 0
return self._iter_epoch(epoch, start)
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
if self.shuffle:
order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch))
for k in range(start, self._num_frames):
yield self._frame_index(int(order[k]))
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for k in range(start, self._num_frames):
yield self._frame_index(k)
for i in self.indices:
yield i
def __len__(self) -> int:
return self._num_frames
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
"""Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume.
Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler
positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches
per epoch (`even_batches` padding included). The start index provably stays below
`num_frames`; the `min` is defensive.
Assumptions (resume is only sample-exact when they hold):
- `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how
many positions a step consumes, so the epoch/offset are wrong if either changed. The
caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch.
- accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term
mirrors that padding; with `even_batches=False` the per-epoch batch count differs and
the boundary is off.
"""
batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes)
epoch, batches_into_epoch = divmod(step, batches_per_epoch)
start_index = min(batches_into_epoch * batch_size * num_processes, num_frames)
return {"epoch": epoch, "start_index": start_index}
return len(self.indices)
+5 -43
View File
@@ -36,8 +36,6 @@ from tqdm import tqdm
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
save_checkpoint,
update_last_checkpoint,
@@ -45,7 +43,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -234,16 +232,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: the global main process downloads once to the shared
# dataset root, then a barrier lets every other rank read the already-populated copy.
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Other ranks read from the shared copy populated by the main process.
# Now all other processes can safely load the dataset
if not is_main_process:
dataset = make_dataset(cfg)
@@ -388,47 +384,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if not cfg.dataset.streaming:
# All non-streaming (map-style) datasets use EpisodeAwareSampler.
# The order is a pure function of (seed, epoch), so every rank independently produces the
# same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard
# without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact.
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
seed=cfg.seed if cfg.seed is not None else 0,
)
if cfg.resume and step > 0:
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
# use the values recorded in the checkpoint (falling back to the current ones for older
# ckpts that did not store them).
saved_num_processes = load_training_num_processes(cfg.checkpoint_path)
saved_batch_size = load_training_batch_size(cfg.checkpoint_path)
ckpt_num_processes = saved_num_processes or accelerator.num_processes
ckpt_batch_size = saved_batch_size or cfg.batch_size
if is_main_process and saved_num_processes not in (None, accelerator.num_processes):
logging.warning(
f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was "
f"written with num_processes={saved_num_processes}. The data order resumes at the "
"right epoch/offset, but per-rank sample-exactness requires the same world size."
)
if is_main_process and saved_batch_size not in (None, cfg.batch_size):
logging.warning(
f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with "
f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, "
"but per-rank sample-exactness requires the same batch size."
)
sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes)
sampler.load_state_dict(sampler_state)
if is_main_process:
logging.info(
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
else:
shuffle = True
sampler = None
@@ -547,8 +511,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
num_processes=accelerator.num_processes,
batch_size=cfg.batch_size,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
-97
View File
@@ -114,19 +114,6 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_is_reproducible_across_instances():
# The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks)
# produce the same permutation without any generator synchronization.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == epoch_0
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
@@ -150,87 +137,3 @@ def test_partial_episode_drop_warns(caplog):
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
# --- seeded (seed, epoch) shuffling, resume, and state ---
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
for seed in (0, 1, 1234):
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
assert sorted(sampler) == list(range(num_frames))
def test_deterministic_sampler_epochs_reproduce_and_differ():
sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42)
epoch_0 = list(sampler_a)
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
assert epoch_1 != epoch_0
assert sorted(epoch_1) == sorted(epoch_0)
sampler_a.set_epoch(0)
assert list(sampler_a) == epoch_0
assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
def test_deterministic_sampler_resume_mid_epoch():
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
epoch_0 = list(reference)
epoch_1 = list(reference)
for start in (0, 1, 4, len(epoch_0)):
resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
assert list(resumed) == epoch_1
def test_deterministic_sampler_construction_stores_only_boundaries():
# Construction is O(num_episodes), not O(num_frames): a million-frame single episode
# instantiates from just its boundaries without materializing a per-frame index list.
num_frames = 1_000_000
sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,)
def test_deterministic_sampler_resume_is_exact_at_scale():
# Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's
# permutation and slicing from the saved offset reproduces the remaining order exactly.
num_frames = 100_000
reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
epoch_0 = list(reference)
assert sorted(epoch_0) == list(range(num_frames))
start = num_frames - 5
resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
resumed.load_state_dict({"epoch": 0, "start_index": start})
assert list(resumed) == epoch_0[start:]
def test_compute_sampler_state():
# 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch.
assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 0,
"start_index": 0,
}
# step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in
assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 40,
}
# uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank
assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == {
"epoch": 2,
"start_index": 40,
}
# uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads)
assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == {
"epoch": 1,
"start_index": 100,
}
-24
View File
@@ -20,8 +20,6 @@ from unittest.mock import Mock, patch
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_batch_size,
load_training_num_processes,
load_training_state,
load_training_step,
save_checkpoint,
@@ -65,28 +63,6 @@ def test_load_training_step(tmp_path):
assert loaded_step == step
def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4)
assert load_training_num_processes(tmp_path) == 4
def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the world size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_num_processes(tmp_path) is None
def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32)
assert load_training_batch_size(tmp_path) == 32
def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler):
# Checkpoints written before the batch size was recorded must still load (back-compat).
save_training_state(tmp_path, 10, optimizer, scheduler)
assert load_training_batch_size(tmp_path) is None
def test_update_last_checkpoint(tmp_path):
checkpoint = tmp_path / "0005"
checkpoint.mkdir()
Generated
+884 -836
View File
File diff suppressed because it is too large Load Diff