mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
feat(streaming): episode-pool iteration with decode-on-exit, video prefetch, and exact resume
Replace the shard/Backtrackable/decoded-shuffle-buffer internals with an
episode pool: each (rank x worker) consumer keeps episode_pool_size whole
episodes' tabular rows in RAM and emits uniformly random frames across
them. delta_timestamps windows become exact in-RAM slices with correct
boundary padding (the Backtrackable machinery and its lookback/lookahead
ceilings are gone), and video is decoded only when a sample is emitted,
so pool memory stays tabular-sized instead of buffer_size decoded
samples.
- Prefetch-on-admit: when streaming from a remote source, each pooled
episode's video files download to a local cache in the background
(refcounted, since v3 packs several episodes per file; deleted on
eviction), so decode-on-exit reads local bytes instead of paying
network seek latency.
- Per-consumer RNG derived from (seed, epoch, rank, worker): consumers
decorrelated, runs reproducible, epochs reshuffle automatically.
- Deterministic fast-forward resume: load_state_dict takes the trainer's
{batches_consumed, batch_size}; each worker re-derives its own skip
from the DataLoader's round-robin batch assignment and replays
tabular-only (no decode). Exact within an epoch, works with
num_workers > 0, and the same state file serves every rank. Replaces
the per-shard HF state_dict approach, which lived in worker processes
and could not be captured from the trainer.
- Shard-cap default removed (max_num_shards=None uses every parquet
shard); runtime warnings for non-divisible world sizes (datasets
degrades to read-everything splitting) and workers left without
shards.
- episode_pool_size replaces buffer_size (deprecated, ignored with a
warning); decoder cache sized to the pool working set, capped at 128.
Legacy order-replication tests asserted the old buffer algorithm
step-by-step and are rewritten as behavior contracts (exactly-once
coverage, per-seed determinism, epoch reshuffle). Value-level parity
tests against the map-style dataset pass unchanged.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
|
||||
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
||||
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
||||
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
||||
- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint;
|
||||
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
|
||||
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
||||
|
||||
Launch with Accelerate (single node, N GPUs):
|
||||
@@ -57,7 +57,10 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames."
|
||||
"--episode_pool_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Whole episodes open per consumer (randomness knob).",
|
||||
)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
|
||||
@@ -78,7 +81,7 @@ def make_dataloader(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
@@ -121,13 +124,13 @@ def main() -> None:
|
||||
# of it). Batches are moved to the device manually in the loop.
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
# Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds
|
||||
# plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a
|
||||
# checkpoint this script wrote itself.
|
||||
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
|
||||
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
|
||||
# worker re-derives its own skip. Same file works for every rank.
|
||||
if args.resume_from is not None:
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resumed dataset stream from {args.resume_from}")
|
||||
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
|
||||
|
||||
step = 0
|
||||
frames_seen = 0
|
||||
@@ -157,8 +160,11 @@ def main() -> None:
|
||||
if step % args.save_freq == 0 and accelerator.is_main_process:
|
||||
ckpt = output_dir / f"checkpoint-{step}"
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
# Save the dataset stream position alongside the model so a restart resumes mid-stream.
|
||||
torch.save(dataset.state_dict(), ckpt / "dataset_state.pt")
|
||||
# Save the consumed-batch counters so a restart fast-forwards to this position.
|
||||
torch.save(
|
||||
{"batches_consumed": step, "batch_size": args.batch_size},
|
||||
ckpt / "dataset_state.pt",
|
||||
)
|
||||
if model is not None:
|
||||
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user