mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
refactor(datasets): fold deterministic mode into EpisodeAwareSampler
Instead of a parallel DeterministicEpisodeAwareSampler class, extend the existing EpisodeAwareSampler with a deterministic=True mode (seeded Feistel permutation, epoch auto-advance, state_dict/load_state_dict). The default mode is behavior-identical: same torch.randperm consumption and the same generator contract accelerate synchronizes; the O(N) Python index list is replaced by O(num_episodes) boundary arrays in both modes, with `indices` kept as a back-compat property. Passing a generator together with deterministic=True is rejected, and the state/seek methods raise outside deterministic mode. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import DeterministicEpisodeAwareSampler, EpisodeAwareSampler, compute_sampler_state
|
||||
from .sampler import EpisodeAwareSampler, compute_sampler_state
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
@@ -64,7 +64,6 @@ __all__ = [
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"DeterministicEpisodeAwareSampler",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
|
||||
@@ -38,97 +38,27 @@ def _mix64(x: int) -> int:
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
"""Sampler that incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
Frame indices are never materialized: only per-episode boundaries are stored (a few numpy
|
||||
int64 per episode) and the mapping from a logical position to a frame index is computed on
|
||||
the fly via `searchsorted`, so memory does not grow with the number of frames.
|
||||
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
Two shuffling modes are supported:
|
||||
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class DeterministicEpisodeAwareSampler:
|
||||
"""Episode-aware sampler with O(num_episodes) memory and a deterministic, seekable shuffle.
|
||||
|
||||
Unlike `EpisodeAwareSampler`, frame indices are never materialized: only per-episode
|
||||
boundaries are stored (a few numpy int64 per episode) and the mapping from a logical
|
||||
position to a frame index is computed on the fly via `searchsorted`. Shuffling applies a
|
||||
seeded Feistel permutation over `[0, num_frames)` (cycle-walking to the exact domain size)
|
||||
instead of `torch.randperm`, so the data order is a pure function of `(seed, epoch)`:
|
||||
|
||||
- every distributed rank derives the identical order with no RNG state to synchronize,
|
||||
- any position can be sought in O(1), enabling sample-exact resume of interrupted runs
|
||||
(see `state_dict` / `load_state_dict`),
|
||||
- memory and epoch-boundary cost do not grow with the number of frames.
|
||||
|
||||
The trade-off is that the shuffle is pseudo-random rather than a true uniform permutation,
|
||||
which is the standard compromise in large-scale training loaders.
|
||||
|
||||
Each completed `__iter__` advances the internal epoch, so consecutive dataloader passes
|
||||
yield different permutations without requiring `set_epoch` calls. During an epoch resumed
|
||||
via `load_state_dict`, `__len__` still reports the full epoch length; the first pass simply
|
||||
yields fewer samples.
|
||||
- Default (`deterministic=False`): `torch.randperm` over positions, optionally driven by a
|
||||
dedicated `generator`. Exposing the `generator` attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so every rank
|
||||
draws the same permutation and batch shards stay disjoint.
|
||||
- `deterministic=True`: a seeded Feistel permutation over `[0, num_frames)` (cycle-walking
|
||||
to the exact domain size). The data order becomes a pure function of `(seed, epoch)`:
|
||||
nothing to synchronize across ranks, O(1) seek to any position (enabling sample-exact
|
||||
resume via `state_dict` / `load_state_dict`), and zero epoch-boundary cost at any dataset
|
||||
size. The shuffle is pseudo-random rather than a true uniform permutation — the standard
|
||||
trade-off in large-scale training loaders. Each completed `__iter__` advances the internal
|
||||
epoch, so consecutive dataloader passes yield different permutations without `set_epoch`
|
||||
calls. During an epoch resumed via `load_state_dict`, `__len__` still reports the full
|
||||
epoch length; the first pass simply yields fewer samples.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -138,7 +68,9 @@ class DeterministicEpisodeAwareSampler:
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = True,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
deterministic: bool = False,
|
||||
seed: int = 0,
|
||||
):
|
||||
"""
|
||||
@@ -149,13 +81,18 @@ class DeterministicEpisodeAwareSampler:
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle with the seeded Feistel permutation.
|
||||
seed: Seed the permutation is derived from (together with the epoch).
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for default-mode shuffling. When None, shuffling falls
|
||||
back to the global torch RNG. Incompatible with `deterministic=True`.
|
||||
deterministic: Use the seeded Feistel permutation instead of `torch.randperm`.
|
||||
seed: Seed the deterministic permutation is derived from (together with the epoch).
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
if deterministic and generator is not None:
|
||||
raise ValueError("generator is unused in deterministic mode; pass seed instead.")
|
||||
|
||||
from_indices = np.asarray(dataset_from_indices, dtype=np.int64)
|
||||
to_indices = np.asarray(dataset_to_indices, dtype=np.int64)
|
||||
@@ -192,6 +129,8 @@ class DeterministicEpisodeAwareSampler:
|
||||
self._cum_lengths = np.cumsum(lengths[used])
|
||||
self._num_frames = int(self._cum_lengths[-1])
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
self.deterministic = deterministic
|
||||
self.seed = seed
|
||||
self._epoch = 0
|
||||
self._start_index = 0
|
||||
@@ -202,17 +141,35 @@ class DeterministicEpisodeAwareSampler:
|
||||
self._half_bits = (bits + 1) // 2
|
||||
self._half_mask = (1 << self._half_bits) - 1
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
"""Materialized frame indices, in unshuffled order (back-compat / introspection only).
|
||||
|
||||
This builds an O(num_frames) list — avoid on very large datasets; iteration never uses it.
|
||||
"""
|
||||
return [self._frame_index(k) for k in range(self._num_frames)]
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Set the epoch the next `__iter__` will use (DistributedSampler convention)."""
|
||||
self._require_deterministic("set_epoch")
|
||||
self._epoch = epoch
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
self._require_deterministic("state_dict")
|
||||
return {"epoch": self._epoch, "start_index": self._start_index}
|
||||
|
||||
def load_state_dict(self, state: dict) -> None:
|
||||
self._require_deterministic("load_state_dict")
|
||||
self._epoch = state["epoch"]
|
||||
self._start_index = state["start_index"]
|
||||
|
||||
def _require_deterministic(self, method: str) -> None:
|
||||
if not self.deterministic:
|
||||
raise RuntimeError(
|
||||
f"{method} is only meaningful with deterministic=True: in default mode the "
|
||||
"order is drawn from the (generator) RNG and cannot be sought."
|
||||
)
|
||||
|
||||
def _round_keys(self, epoch: int) -> list[int]:
|
||||
state = _mix64(_mix64(self.seed) ^ _mix64(epoch))
|
||||
keys = []
|
||||
@@ -238,14 +195,24 @@ class DeterministicEpisodeAwareSampler:
|
||||
return int(self._starts[episode]) + position_in_episode
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if not self.deterministic:
|
||||
return self._iter_default()
|
||||
# Capture and advance state eagerly so epoch bookkeeping is not deferred until the
|
||||
# returned generator is first consumed.
|
||||
epoch, start = self._epoch, self._start_index
|
||||
self._epoch += 1
|
||||
self._start_index = 0
|
||||
return self._iter_epoch(epoch, start)
|
||||
return self._iter_deterministic_epoch(epoch, start)
|
||||
|
||||
def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||
def _iter_default(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(self._num_frames, generator=self.generator):
|
||||
yield self._frame_index(int(i))
|
||||
else:
|
||||
for k in range(self._num_frames):
|
||||
yield self._frame_index(k)
|
||||
|
||||
def _iter_deterministic_epoch(self, epoch: int, start: int) -> Iterator[int]:
|
||||
keys = self._round_keys(epoch) if self.shuffle else None
|
||||
for k in range(start, self._num_frames):
|
||||
yield self._frame_index(self._permute(k, keys) if self.shuffle else k)
|
||||
@@ -255,7 +222,7 @@ class DeterministicEpisodeAwareSampler:
|
||||
|
||||
|
||||
def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict:
|
||||
"""Map a global optimization step to a `DeterministicEpisodeAwareSampler` state for resume.
|
||||
"""Map a global optimization step to an `EpisodeAwareSampler` state for resume.
|
||||
|
||||
Under accelerate's batch-level sharding, every rank iterates the same underlying sampler and
|
||||
keeps every `num_processes`-th batch, so one optimization step consumes
|
||||
|
||||
@@ -43,12 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import (
|
||||
DeterministicEpisodeAwareSampler,
|
||||
EpisodeAwareSampler,
|
||||
compute_sampler_state,
|
||||
make_dataset,
|
||||
)
|
||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -397,12 +392,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# O(1) memory in dataset size, and a resumed run continues at the exact sample where the
|
||||
# checkpoint left off (up to accelerate's even_batches padding at epoch boundaries).
|
||||
shuffle = False
|
||||
sampler = DeterministicEpisodeAwareSampler(
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||
shuffle=True,
|
||||
deterministic=True,
|
||||
seed=cfg.seed if cfg.seed is not None else 0,
|
||||
)
|
||||
if cfg.resume and step > 0:
|
||||
|
||||
@@ -163,17 +163,19 @@ def test_partial_episode_drop_warns(caplog):
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
# --- DeterministicEpisodeAwareSampler ---
|
||||
# --- deterministic mode (seeded Feistel permutation) ---
|
||||
|
||||
from functools import partial # noqa: E402
|
||||
|
||||
from lerobot.datasets.sampler import compute_sampler_state # noqa: E402
|
||||
|
||||
deterministic_sampler = partial(EpisodeAwareSampler, deterministic=True)
|
||||
|
||||
from lerobot.datasets.sampler import ( # noqa: E402
|
||||
DeterministicEpisodeAwareSampler,
|
||||
compute_sampler_state,
|
||||
)
|
||||
|
||||
EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames
|
||||
|
||||
|
||||
def test_deterministic_sampler_unshuffled_matches_episode_aware():
|
||||
def test_deterministic_mode_unshuffled_matches_default_mode():
|
||||
for kwargs in (
|
||||
{},
|
||||
{"drop_n_first_frames": 1},
|
||||
@@ -181,21 +183,34 @@ def test_deterministic_sampler_unshuffled_matches_episode_aware():
|
||||
{"episode_indices_to_use": [0, 2]},
|
||||
):
|
||||
reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
|
||||
sampler = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
|
||||
sampler = deterministic_sampler(*EPISODE_BOUNDS, shuffle=False, **kwargs)
|
||||
assert list(sampler) == list(reference), kwargs
|
||||
assert len(sampler) == len(reference), kwargs
|
||||
|
||||
|
||||
def test_deterministic_mode_rejects_generator():
|
||||
with pytest.raises(ValueError, match="generator is unused in deterministic mode"):
|
||||
deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, generator=torch.Generator())
|
||||
|
||||
|
||||
def test_state_methods_require_deterministic_mode():
|
||||
sampler = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True)
|
||||
with pytest.raises(RuntimeError, match="deterministic=True"):
|
||||
sampler.set_epoch(1)
|
||||
with pytest.raises(RuntimeError, match="deterministic=True"):
|
||||
sampler.state_dict()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100])
|
||||
def test_deterministic_sampler_shuffle_is_permutation(num_frames):
|
||||
for seed in (0, 1, 1234):
|
||||
sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed)
|
||||
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=seed)
|
||||
assert sorted(sampler) == list(range(num_frames))
|
||||
|
||||
|
||||
def test_deterministic_sampler_epochs_reproduce_and_differ():
|
||||
sampler_a = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
sampler_b = DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=42)
|
||||
sampler_a = deterministic_sampler([0], [100], shuffle=True, seed=42)
|
||||
sampler_b = deterministic_sampler([0], [100], shuffle=True, seed=42)
|
||||
epoch_0 = list(sampler_a)
|
||||
assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process
|
||||
epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch
|
||||
@@ -203,15 +218,15 @@ def test_deterministic_sampler_epochs_reproduce_and_differ():
|
||||
assert sorted(epoch_1) == sorted(epoch_0)
|
||||
sampler_a.set_epoch(0)
|
||||
assert list(sampler_a) == epoch_0
|
||||
assert list(DeterministicEpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0
|
||||
assert list(deterministic_sampler([0], [100], shuffle=True, seed=7)) != epoch_0
|
||||
|
||||
|
||||
def test_deterministic_sampler_resume_mid_epoch():
|
||||
reference = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
reference = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
epoch_0 = list(reference)
|
||||
epoch_1 = list(reference)
|
||||
for start in (0, 1, 4, len(epoch_0)):
|
||||
resumed = DeterministicEpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
resumed = deterministic_sampler(*EPISODE_BOUNDS, shuffle=True, seed=42)
|
||||
resumed.load_state_dict({"epoch": 0, "start_index": start})
|
||||
assert list(resumed) == epoch_0[start:]
|
||||
# the resumed sampler continues into the same epoch 1 as the uninterrupted one
|
||||
@@ -222,7 +237,7 @@ def test_deterministic_sampler_constant_memory():
|
||||
# A trillion-frame dataset must instantiate instantly and seek anywhere in O(1):
|
||||
# only per-episode boundaries are stored, never per-frame indices.
|
||||
num_frames = 10**12
|
||||
sampler = DeterministicEpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0)
|
||||
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0)
|
||||
assert len(sampler) == num_frames
|
||||
sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3})
|
||||
tail = list(sampler)
|
||||
@@ -232,16 +247,16 @@ def test_deterministic_sampler_constant_memory():
|
||||
|
||||
def test_deterministic_sampler_validation_matches_episode_aware():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
DeterministicEpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
deterministic_sampler([0], [10], drop_n_first_frames=-1)
|
||||
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
|
||||
DeterministicEpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
|
||||
deterministic_sampler([0], [10], drop_n_last_frames=-1)
|
||||
with pytest.raises(ValueError, match="No valid frames remain"):
|
||||
DeterministicEpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
||||
deterministic_sampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
||||
|
||||
|
||||
def test_deterministic_sampler_partial_episode_drop_warns(caplog):
|
||||
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
|
||||
sampler = DeterministicEpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False)
|
||||
sampler = deterministic_sampler([0, 1], [1, 6], drop_n_first_frames=1, shuffle=False)
|
||||
assert list(sampler) == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
Reference in New Issue
Block a user