feat(streaming): episode-pool iteration with decode-on-exit, video prefetch, and exact resume

Replace the shard/Backtrackable/decoded-shuffle-buffer internals with an
episode pool: each (rank x worker) consumer keeps episode_pool_size whole
episodes' tabular rows in RAM and emits uniformly random frames across
them. delta_timestamps windows become exact in-RAM slices with correct
boundary padding (the Backtrackable machinery and its lookback/lookahead
ceilings are gone), and video is decoded only when a sample is emitted,
so pool memory stays tabular-sized instead of buffer_size decoded
samples.

- Prefetch-on-admit: when streaming from a remote source, each pooled
  episode's video files download to a local cache in the background
  (refcounted, since v3 packs several episodes per file; deleted on
  eviction), so decode-on-exit reads local bytes instead of paying
  network seek latency.
- Per-consumer RNG derived from (seed, epoch, rank, worker): consumers
  decorrelated, runs reproducible, epochs reshuffle automatically.
- Deterministic fast-forward resume: load_state_dict takes the trainer's
  {batches_consumed, batch_size}; each worker re-derives its own skip
  from the DataLoader's round-robin batch assignment and replays
  tabular-only (no decode). Exact within an epoch, works with
  num_workers > 0, and the same state file serves every rank. Replaces
  the per-shard HF state_dict approach, which lived in worker processes
  and could not be captured from the trainer.
- Shard-cap default removed (max_num_shards=None uses every parquet
  shard); runtime warnings for non-divisible world sizes (datasets
  degrades to read-everything splitting) and workers left without
  shards.
- episode_pool_size replaces buffer_size (deprecated, ignored with a
  warning); decoder cache sized to the pool working set, capped at 128.

Legacy order-replication tests asserted the old buffer algorithm
step-by-step and are rewritten as behavior contracts (exactly-once
coverage, per-seed determinism, epoch reshuffle). Value-level parity
tests against the map-style dataset pass unchanged.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 15:02:15 +02:00
parent 66ac901632
commit 1050c2fb6c
11 changed files with 521 additions and 650 deletions
+16 -10
View File
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
- DataLoader-worker shard splitting (no duplicate frames within a rank);
- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint;
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
Launch with Accelerate (single node, N GPUs):
@@ -57,7 +57,10 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument(
"--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames."
"--episode_pool_size",
type=int,
default=64,
help="Whole episodes open per consumer (randomness knob).",
)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
@@ -78,7 +81,7 @@ def make_dataloader(
args.repo_id,
root=args.root,
delta_timestamps=delta_timestamps,
buffer_size=args.buffer_size,
episode_pool_size=args.episode_pool_size,
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
)
@@ -121,13 +124,13 @@ def main() -> None:
# of it). Batches are moved to the device manually in the loop.
model, optimizer = accelerator.prepare(model, optimizer)
# Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds
# plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a
# checkpoint this script wrote itself.
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
# worker re-derives its own skip. Same file works for every rank.
if args.resume_from is not None:
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
dataset.load_state_dict(state)
accelerator.print(f"Resumed dataset stream from {args.resume_from}")
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
step = 0
frames_seen = 0
@@ -157,8 +160,11 @@ def main() -> None:
if step % args.save_freq == 0 and accelerator.is_main_process:
ckpt = output_dir / f"checkpoint-{step}"
ckpt.mkdir(parents=True, exist_ok=True)
# Save the dataset stream position alongside the model so a restart resumes mid-stream.
torch.save(dataset.state_dict(), ckpt / "dataset_state.pt")
# Save the consumed-batch counters so a restart fast-forwards to this position.
torch.save(
{"batches_consumed": step, "batch_size": args.batch_size},
ckpt / "dataset_state.pt",
)
if model is not None:
accelerator.unwrap_model(model).save_pretrained(ckpt)