refactor(streaming): rebuild StreamingLeRobotDataset on native datasets primitives

The custom episode pool becomes a pure `datasets` pipeline:

  split_dataset_by_node -> batch(by_column="episode_index")
    -> shuffle(buffer=episode_pool_size)            # episode pool
    -> map(explode + exact delta windows)           # episode -> frames
    -> shuffle(buffer=frame_shuffle_buffer_size)    # frame interleave

and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.

Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.

Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.

The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.

Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 21:03:09 +02:00
parent 984b400e5c
commit 894fc6bfb5
4 changed files with 258 additions and 567 deletions
+29 -16
View File
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
- DataLoader-worker shard splitting (no duplicate frames within a rank);
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
- native `datasets` resume: the loader checkpoints stream state via ``state_dict()`` (``torchdata`` StatefulDataLoader when available, so ``num_workers > 0`` resumes too);
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
Launch with Accelerate (single node, N GPUs):
@@ -85,7 +85,16 @@ def make_dataloader(
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
)
loader = DataLoader(
# torchdata's StatefulDataLoader checkpoints each worker's dataset state through the
# dataset's native state_dict protocol, making resume work with num_workers > 0. Fall back
# to the plain DataLoader (resume then requires num_workers=0).
try:
from torchdata.stateful_dataloader import StatefulDataLoader
loader_cls = StatefulDataLoader
except ImportError:
loader_cls = DataLoader
loader = loader_cls(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
@@ -124,13 +133,17 @@ def main() -> None:
# of it). Batches are moved to the device manually in the loop.
model, optimizer = accelerator.prepare(model, optimizer)
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
# worker re-derives its own skip. Same file works for every rank.
# Resume: native datasets stream state, saved per rank. With torchdata's StatefulDataLoader
# the state covers every worker; with the plain DataLoader it is exact for num_workers=0.
can_checkpoint_loader = hasattr(loader, "state_dict")
if args.resume_from is not None:
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
dataset.load_state_dict(state)
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
state_path = Path(args.resume_from) / f"dataset_state_rank{accelerator.process_index}.pt"
state = torch.load(state_path, weights_only=False) # plain dict of stream offsets # nosec B614
if can_checkpoint_loader:
loader.load_state_dict(state)
else:
dataset.load_state_dict(state)
accelerator.print(f"Resumed dataset stream from {state_path}")
step = 0
frames_seen = 0
@@ -157,15 +170,15 @@ def main() -> None:
)
window_start = time.perf_counter()
if step % args.save_freq == 0 and accelerator.is_main_process:
if step % args.save_freq == 0:
ckpt = output_dir / f"checkpoint-{step}"
ckpt.mkdir(parents=True, exist_ok=True)
# Save the consumed-batch counters so a restart fast-forwards to this position.
torch.save(
{"batches_consumed": step, "batch_size": args.batch_size},
ckpt / "dataset_state.pt",
)
if model is not None:
if accelerator.is_main_process:
ckpt.mkdir(parents=True, exist_ok=True)
accelerator.wait_for_everyone()
# Every rank saves its own stream state: shard positions differ per rank.
state = loader.state_dict() if can_checkpoint_loader else dataset.state_dict()
torch.save(state, ckpt / f"dataset_state_rank{accelerator.process_index}.pt")
if model is not None and accelerator.is_main_process:
accelerator.unwrap_model(model).save_pretrained(ckpt)
if step >= args.steps: