mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
refactor(streaming): rebuild StreamingLeRobotDataset on native datasets primitives
The custom episode pool becomes a pure `datasets` pipeline:
split_dataset_by_node -> batch(by_column="episode_index")
-> shuffle(buffer=episode_pool_size) # episode pool
-> map(explode + exact delta windows) # episode -> frames
-> shuffle(buffer=frame_shuffle_buffer_size) # frame interleave
and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.
Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.
Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.
The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.
Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -21,7 +21,7 @@ streaming features of :class:`StreamingLeRobotDataset`:
|
||||
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
||||
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
||||
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
||||
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
|
||||
- native `datasets` resume: the loader checkpoints stream state via ``state_dict()`` (``torchdata`` StatefulDataLoader when available, so ``num_workers > 0`` resumes too);
|
||||
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
||||
|
||||
Launch with Accelerate (single node, N GPUs):
|
||||
@@ -85,7 +85,16 @@ def make_dataloader(
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
loader = DataLoader(
|
||||
# torchdata's StatefulDataLoader checkpoints each worker's dataset state through the
|
||||
# dataset's native state_dict protocol, making resume work with num_workers > 0. Fall back
|
||||
# to the plain DataLoader (resume then requires num_workers=0).
|
||||
try:
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
loader_cls = StatefulDataLoader
|
||||
except ImportError:
|
||||
loader_cls = DataLoader
|
||||
loader = loader_cls(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
@@ -124,13 +133,17 @@ def main() -> None:
|
||||
# of it). Batches are moved to the device manually in the loop.
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
|
||||
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
|
||||
# worker re-derives its own skip. Same file works for every rank.
|
||||
# Resume: native datasets stream state, saved per rank. With torchdata's StatefulDataLoader
|
||||
# the state covers every worker; with the plain DataLoader it is exact for num_workers=0.
|
||||
can_checkpoint_loader = hasattr(loader, "state_dict")
|
||||
if args.resume_from is not None:
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
|
||||
state_path = Path(args.resume_from) / f"dataset_state_rank{accelerator.process_index}.pt"
|
||||
state = torch.load(state_path, weights_only=False) # plain dict of stream offsets # nosec B614
|
||||
if can_checkpoint_loader:
|
||||
loader.load_state_dict(state)
|
||||
else:
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resumed dataset stream from {state_path}")
|
||||
|
||||
step = 0
|
||||
frames_seen = 0
|
||||
@@ -157,15 +170,15 @@ def main() -> None:
|
||||
)
|
||||
window_start = time.perf_counter()
|
||||
|
||||
if step % args.save_freq == 0 and accelerator.is_main_process:
|
||||
if step % args.save_freq == 0:
|
||||
ckpt = output_dir / f"checkpoint-{step}"
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
# Save the consumed-batch counters so a restart fast-forwards to this position.
|
||||
torch.save(
|
||||
{"batches_consumed": step, "batch_size": args.batch_size},
|
||||
ckpt / "dataset_state.pt",
|
||||
)
|
||||
if model is not None:
|
||||
if accelerator.is_main_process:
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
# Every rank saves its own stream state: shard positions differ per rank.
|
||||
state = loader.state_dict() if can_checkpoint_loader else dataset.state_dict()
|
||||
torch.save(state, ckpt / f"dataset_state_rank{accelerator.process_index}.pt")
|
||||
if model is not None and accelerator.is_main_process:
|
||||
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
||||
|
||||
if step >= args.steps:
|
||||
|
||||
Reference in New Issue
Block a user