Compare commits

...

29 Commits

Author SHA1 Message Date
pepijn 7b6f4f2b11 Add in-memory byte index and manifest-driven episode MP4 cache.
Build moov-derived byte ranges in RAM or from sidecar parquet, fetch tight mdat slices over the network, and decode via TorchCodec custom_frame_mappings to skip full-file metadata scans.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-16 15:03:17 +00:00
pepijn 4940281120 feat(streaming): random-episode admission via reshard() + multi-input-shard shuffle
Reshard parquet per row group (1 shard == 1 row group == 1 episode) and feed the
episode-pool shuffle with max_buffer_input_shards so the pool is a uniform random
sample of the corpus, independent of episodes-per-file. Add validate_row_groups
guardrails (collapsed-row-group + distributed divisibility), require datasets>=5.0.0,
make the test fixture write one row group per episode, and plumb max_buffer_input_shards
through the dataloading benchmark.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-15 13:33:27 +00:00
pepijn 3ec60da82b feat(streaming): add cluster dataloading benchmark example
Single-file SLURM-oriented benchmark comparing the map-style and native
streaming loaders on single-image samples: a self-submitting serial chain
that measures peak RSS, samples/s (and decoded frames/s), fetch-vs-decode
split, shuffle randomness, and p50/p95/p99 sample latency over a fixed
wall-clock window, including a 2-node split_dataset_by_node leg.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-12 14:23:15 +00:00
pepijn 7bcd5a1502 refactor(streaming): trim video_utils to the minimal readahead cap
Drop the transient-IO retry layer and the decoder-cache observability counters from
video_utils.py, keeping only the fsspec readahead cache that bounds per-handle RAM for
remote (hf://) decoders. Remove the now-orphaned instrumentation from StreamingLeRobotDataset
(video_decode_device/NVDEC, shared cache-counter tensor, video_decoder_cache_stats(),
timing_stats()). Retry is deferred to a separate, focused PR.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-12 09:50:43 +00:00
pepijn 674c990a39 feat(streaming): default episode pool 1024 and wire streaming into lerobot-train
Raise the default episode_pool_size to 1024 (DatasetConfig + StreamingLeRobotDataset)
for better default shuffle quality at scale.

Streaming is now a first-class option of the main train script: when cfg.dataset.streaming
is set, the dataloader is not handed to accelerate (the dataset is already rank-disjoint via
split_dataset_by_node, so IterableDatasetShard would drop (N-1)/N of each rank's stream),
batches are moved to device manually, and the episode-aware sampler is skipped. Remove the
standalone examples/scaling/train_streaming_multinode.py example in favor of this wiring.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-12 09:24:32 +00:00
Pepijn 38106ea6b4 chore(streaming): drop benchmark and SLURM scaffolding from the PR
The benchmarks/streaming harness (matrix submitter, summarizer, decode
diagnostic) and the robocasa SLURM scripts are cluster-specific tooling,
not part of the streaming feature. The example's --dummy mode covers
throughput measurement for reviewers. Recoverable from git history
(894fc6bfb) for cluster runs. Example docstring de-personalized.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 21:46:43 +02:00
Pepijn 894fc6bfb5 refactor(streaming): rebuild StreamingLeRobotDataset on native datasets primitives
The custom episode pool becomes a pure `datasets` pipeline:

  split_dataset_by_node -> batch(by_column="episode_index")
    -> shuffle(buffer=episode_pool_size)            # episode pool
    -> map(explode + exact delta windows)           # episode -> frames
    -> shuffle(buffer=frame_shuffle_buffer_size)    # frame interleave

and the torch IterableDataset wrapper keeps only per-sample video decode
(decode-on-exit), image transforms, task lookup, and decode/fetch timing.

Replaced by native machinery and deleted: the pooled-episode admission
loop, the refcounted video prefetcher, manual worker shard striding plus
the worker-split suppression patch, the per-(epoch, rank) shard-order
permutation, the per-consumer SplitMix64 RNG, and fast-forward resume.
DataLoader workers are split by `datasets` itself; .shuffle() permutes
shard order per epoch natively; resume delegates to the native
state_dict/load_state_dict (exact with num_workers=0; with workers use
torchdata's StatefulDataLoader, which checkpoints per-worker state
through the same protocol). An in-flight epoch counter ensures a
mid-iteration state_dict records the epoch the stream position belongs
to. Buffer contents are skipped on resume (documented datasets
behavior): never repeats data, drops at most ~pool + frame-buffer frames.

Randomness is unchanged: a batch still mixes up to episode_pool_size
episodes; delta windows are still exact in-episode slices with correct
boundary padding (value-verified against the map-style dataset). The
known trade accepted with this rewrite: no video prefetch-on-admit, so
remote decode pays per-frame range reads at yield time - use a colocated
bucket (data_files_root) at large scale.

The delta-consistency tests gained a scalar-comparison branch: they
silently skipped python-scalar keys before (stale `check` variable),
exposed by the new pipeline's key ordering.

Requires datasets with #8259 (pinned to the merge commit on this
branch). Example updated to per-rank native resume via torchdata's
StatefulDataLoader when available.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 21:03:09 +02:00
Pepijn 984b400e5c build(deps): pin datasets to the datasets#8259 merge commit
The native streaming pipeline calls .shuffle() on top of batch(by_column=...),
which crashes on released datasets 5.0.0 (batch-accumulator flag dropped on
shard/shuffle re-creation). The fix (datasets#8259) is merged but unreleased,
so pin datasets to the merge commit 2c45eab on this branch via [tool.uv.sources].
Drop this pin and bump the floor in `dependencies` once the next datasets
release ships the fix.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 18:28:41 +02:00
Pepijn 4e056081cb feat(streaming): seeded shard-order permutation per (seed, epoch, rank)
Shards were assigned to consumers in file-index order, so a sub-epoch
run over a corpus consolidated source-by-source trains on whatever the
first N% of files contains and drifts curriculum-style as sources change
under it. Permute the rank's shard list with a seeded RNG before worker
striding: a 30%-of-epoch run now sees a uniform 30% sample of files.

The permutation is seeded by (seed, epoch, rank) only - every DataLoader
worker of a rank must derive the identical list, since workers stride it
and disagreement would create overlapping shard assignments. It re-draws
each epoch, is the identity when shuffle=False, and stays deterministic
for fast-forward resume.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 17:08:26 +02:00
Pepijn a164bb97bd feat(streaming): native datasets-5 episode batching and worker-split suppression
Allow datasets 5.x (pin >=4.7,<6; lockfile moves to 5.0.0) and use its
Arrow-native batch(by_column="episode_index") (huggingface/datasets#8194
sibling, #8172) for episode admission when available - one Arrow
accumulation per episode instead of one Python dict per row - with the
existing row loop as the 4.x fallback. A parity test asserts both paths
group identically.

Also fixes a latent worker bug this surfaced: `datasets` detects torch
DataLoader workers and re-splits its shards internally (_iter_pytorch),
on top of our explicit per-worker shard assignment. That second split
silently drops data whenever a per-worker stream has fewer internal
shards than there are workers (masked so far by single-file test
fixtures), and on datasets 5.0 it crashes by_column batching outright.
The worker context is now hidden from `datasets` while draining streams
we already partitioned (process-local patch, restored on exit).

The multi-shard shuffle buffer (huggingface/datasets#8194) is
intentionally NOT used: frame-level shuffling upstream of episode
grouping would fragment episodes and break delta windows. Its threaded
multi-source prefetch idea remains a follow-up for episode admission if
fetch timings warrant it.

Verified on both datasets 4.8.5 (fallback) and 5.0.0 (native): 27/27
streaming tests each; full datasets suite 469 passed under 5.0.0.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 16:10:53 +02:00
Pepijn 79b547de32 Merge remote episode-pool work into the full pool rewrite
The remote commit (2ab71231c) added an opt-in episode pool, deferred
decode in the legacy buffer path, decode/fetch timing instrumentation,
remote-IO retries (video_utils), and 32MB row-group writing
(dataset_tools). The pool rewrite on this side makes the episode pool
the only iteration path (with prefetch-on-admit, per-consumer seeding,
worker-exact fast-forward resume), so streaming_dataset.py resolves to
the rewrite with the remote instrumentation ported into it:

- 5-slot shared counters + timing_stats() (decode_s_total/fetch_s_total)
- fetch timed around episode admission, decode timed around emission
- benchmark/slurm keep the remote updates, with episode_pool_size as the
  knob (buffer_size deprecated and ignored)

video_utils retries and dataset_tools row groups are taken unchanged.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 15:17:04 +02:00
Pepijn a7b7f4964e fix(streaming): worker-exact resume arithmetic and multi-worker resume test
The fast-forward skip assumed every DataLoader worker delivers batches;
workers that own no shards yield nothing and are stopped, so the batch
round-robin runs over min(num_workers, num_shards) active workers. Use
that effective count (shard-less workers skip nothing). Adds a resume
test under num_workers=2 asserting exact continuation.

Note: the test fixtures write a single parquet file regardless of
data_files_size_in_mb, so worker-splitting tests exercise the degenerate
single-shard layout; multi-shard behavior is covered by the rank-level
split_dataset_by_node tests.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 15:11:00 +02:00
Pepijn 1050c2fb6c feat(streaming): episode-pool iteration with decode-on-exit, video prefetch, and exact resume
Replace the shard/Backtrackable/decoded-shuffle-buffer internals with an
episode pool: each (rank x worker) consumer keeps episode_pool_size whole
episodes' tabular rows in RAM and emits uniformly random frames across
them. delta_timestamps windows become exact in-RAM slices with correct
boundary padding (the Backtrackable machinery and its lookback/lookahead
ceilings are gone), and video is decoded only when a sample is emitted,
so pool memory stays tabular-sized instead of buffer_size decoded
samples.

- Prefetch-on-admit: when streaming from a remote source, each pooled
  episode's video files download to a local cache in the background
  (refcounted, since v3 packs several episodes per file; deleted on
  eviction), so decode-on-exit reads local bytes instead of paying
  network seek latency.
- Per-consumer RNG derived from (seed, epoch, rank, worker): consumers
  decorrelated, runs reproducible, epochs reshuffle automatically.
- Deterministic fast-forward resume: load_state_dict takes the trainer's
  {batches_consumed, batch_size}; each worker re-derives its own skip
  from the DataLoader's round-robin batch assignment and replays
  tabular-only (no decode). Exact within an epoch, works with
  num_workers > 0, and the same state file serves every rank. Replaces
  the per-shard HF state_dict approach, which lived in worker processes
  and could not be captured from the trainer.
- Shard-cap default removed (max_num_shards=None uses every parquet
  shard); runtime warnings for non-divisible world sizes (datasets
  degrades to read-everything splitting) and workers left without
  shards.
- episode_pool_size replaces buffer_size (deprecated, ignored with a
  warning); decoder cache sized to the pool working set, capped at 128.

Legacy order-replication tests asserted the old buffer algorithm
step-by-step and are rewritten as behavior contracts (exactly-once
coverage, per-seed determinism, epoch reshuffle). Value-level parity
tests against the map-style dataset pass unchanged.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 15:02:15 +02:00
Pepijn 66ac901632 fix(streaming): do not prepare the dataloader with accelerate
The dataset is already rank-disjoint via split_dataset_by_node;
accelerate's IterableDatasetShard wrapper kept only every Nth batch of
each rank's stream, silently training on 1/N of the data per pass while
decoding all of it. The --dummy benchmark path never prepared the
loader, so benchmarks were unaffected.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 12:21:20 +02:00
Pepijn ce326207e6 Merge remote-tracking branch 'origin/main' into feat/streaming-hf-native 2026-06-11 12:19:32 +02:00
pepijn 2ab71231cd feat(streaming): defer video decode, episode-pool shuffle, and remote-IO retries
- streaming_dataset: defer torchcodec decode until a sample leaves the shuffle
  buffer (buffer now holds ~KB tabular rows, not MB of pixels) and add an opt-in
  episode-pool shuffle (episode_pool_size) with exact in-episode delta lookups;
  expose decode/fetch timing_stats.
- video_utils: retry transient hf:///fsspec/httpx transport errors during
  streaming decode (LEROBOT_REMOTE_IO_MAX_RETRIES).
- dataset_tools: write multiple ~32MB row groups with a page index to bound
  per-shard streaming memory.
- benchmarks/slurm: streaming benchmark + matrix submitter updates.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-06-11 10:08:28 +00:00
Pepijn 42d4788e4a fix(streaming): drop undeclared parquet columns that break batch collation
The data_files_root/bucket path reads an unversioned source (e.g. `main`), which can
carry extra annotation columns not in the dataset's feature contract — notably
`language_events`, a variable-length list (length 0..N per frame). Passed through to the
sample, these break default DataLoader collation ("each element in list of batch should
be of equal size"), which is why bucket jobs failed while the hub path (pinned to the
clean v3.0 revision) succeeded.

Drop any hf_dataset column not in meta.features after load. No-op on a clean revision;
removes language_events/language_persistent on main. Verified by reproducing the bucket
code path locally via --data_files_root hf://datasets/<repo> (parquet builder + main
columns): now decodes and collates instead of raising.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 17:24:30 +02:00
Pepijn 2d1c17d971 docs(streaming): note AV1 is LeRobot's default codec (vcodec=libsvtav1)
So the A100/H100 no-AV1-NVDEC limitation applies to most LeRobot v3 datasets, not just
RoboCasa — GPU decode needs an Ada GPU, an hevc/h264-encoded dataset, or a re-encode.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 17:10:18 +02:00
Pepijn 7241f029c6 docs(streaming): A100/H100 NVDEC cannot decode AV1 — correct guidance
NVIDIA's decode support matrix: the compute GPUs A100 (GA100) and H100 (GH100) have no
AV1 NVDEC decoder; only Ada (L4/L40/RTX40) and some Ampere (A10/A40/A16) do. So on
A100/H100 nodes, AV1 datasets must be decoded on CPU or re-encoded to H.265/H.264 — no
torchcodec build enables cuda AV1 decode there. Also distinguish that error from
"Unsupported device: cuda (variant: ffmpeg)", which is a torchcodec-built-without-CUDA
issue. Update diagnose_decode.py message + benchmark README accordingly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 17:08:54 +02:00
Pepijn 06ddc59913 feat(streaming): CONDA_ENV knob for the matrix submitter
Add CONDA_ENV=<name> to run each matrix job via `conda run --no-capture-output -n
<name>` — works inside the dash `sbatch --wrap` without sourcing conda.sh / activating,
and streams logs live. Point it at a conda env with a modern torchcodec (>=0.11) +
datasets (>=4.7); the default cluster `base` env is often too old to decode AV1.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 16:55:42 +02:00
Pepijn 23c58f5f9e feat(streaming): decode diagnostic + fail benchmark on 0 frames
- benchmark: raise SystemExit if 0 frames were measured, so a run that produces no
  batches (swallowed decode error, all batches dropped) fails loudly instead of being
  reported green with NaN/zero numbers (the misleading "COMPLETED" CUDA jobs).
- add benchmarks/streaming/diagnose_decode.py: isolates the streaming decode path
  (resolve path -> fsspec.open -> torchcodec VideoDecoder -> get one frame) and prints
  package versions + the first bytes of the handle. Pinpoints decode failures: bad/
  placeholder bytes vs ffmpeg/torchcodec build issue. RoboCasa videos are AV1; the
  failure message calls out AV1 decoder + NVDEC-on-Ada requirements explicitly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 16:40:24 +02:00
Pepijn b0ab57cedc fix(streaming): make matrix sbatch --wrap body POSIX-sh safe
`sbatch --wrap` runs the wrapped body under /bin/sh (dash), which has no
`set -o pipefail`, so every matrix job died on line 1 ("Illegal option -o pipefail")
before reaching the benchmark. The command has no pipes, so drop the bashism and chain
with `&&` (cd-guards the run) — fully POSIX-sh compatible. Runtime env expansion
(${HF_HOME:-$SCRATCH/hf_home}) is preserved.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 16:16:54 +02:00
Pepijn afdc084677 feat(streaming): serial-by-default matrix submitter (afterany dependency chain)
For a bandwidth-sensitive benchmark, concurrent jobs would share the network to the
Hub/bucket and corrupt throughput numbers. Chain the matrix jobs with
--dependency=afterany (captured via `sbatch --parsable`) so SLURM runs exactly one at a
time while keeping each config an isolated job (own log + per-job OOM reporting).
afterany keeps the chain going if one job fails/OOMs. SERIAL=0 restores parallel
submission for OOM-isolation-only testing.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:55:58 +02:00
Pepijn a32a2c647b feat(streaming): full-matrix SLURM submitter + results summarizer
slurm/run_streaming_matrix.sh fans the benchmark matrix (sources {hub,bucket,
warmed_bucket} x modes {single,sarm} x decode {cpu,cuda}) out as isolated single-GPU
SLURM jobs, so an OOM in one config is contained and reported per-job by SLURM. Worker
count and shuffle buffer are bounded (lower for cuda, which holds a CUDA context + NVDEC
session per worker) to avoid host/VRAM OOM. Source/mode/decode/workers/buffer/account/
partition are env-overridable; SOURCES/MODES/DECODES select subsets.

benchmarks/streaming/summarize_results.py collapses the per-run JSONs into one comparison
table + summary.csv (frames/s/node, first-batch + p50/p95/p99 latency, cache hit-rate).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:51:36 +02:00
Pepijn 343ecd7980 feat(streaming): optional GPU (NVDEC) video decode device
Add `video_decode_device` to StreamingLeRobotDataset and a `device` arg to
VideoDecoderCache, passed to torchcodec's VideoDecoder. "cuda" offloads H.264/H.265
decode to the GPU's dedicated NVDEC engine (independent of the training SMs); requires
a CUDA-enabled torchcodec build.

benchmark: `--video_decode_device` flag. With cuda + num_workers>0 it forces the
`spawn` start method (CUDA cannot init in forked workers) and disables CPU pin_memory
(frames are already on-GPU). Decode device is recorded in results and the output
filename. README documents the NVDEC option and its concurrency/IPC caveats.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:47:11 +02:00
Pepijn f7c8a526e8 feat(streaming): wallclock benchmark throughput, cross-worker cache stats, bucket source
- benchmark: frames_per_s_node now measures sustained wall-clock throughput over the
  post-warmup window. The previous metric summed inter-batch gaps, which collapse to ~0
  under async prefetch (consumer drains a pre-filled queue) and overstated throughput ~100x.
- VideoDecoderCache gains an optional shared [hits, misses, evictions] counter tensor;
  StreamingLeRobotDataset.video_decoder_cache_stats() aggregates it across DataLoader
  workers (lock-free, approximate; hit_rate preserved). Fixes empty cache stats with workers.
- StreamingLeRobotDataset.data_files_root: read bulk data/ + videos/ from an fsspec root
  (e.g. hf://buckets/<owner>/<name>) while metadata still loads from repo_id. Enables
  bucket / prewarmed-bucket benchmark sources without copying metadata. Exposed as
  benchmark --data_files_root.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:25:44 +02:00
Pepijn 77af66a29c fix(streaming): decode video at episode-local timestamp + from_timestamp offset
make_frame used `item["index"] / fps` (a dataset-global value) as the in-file
video timestamp. That only matches the file timeline when the whole dataset is a
single video (as in the test fixtures); on multi-file v3 datasets it decodes
out-of-range frames and crashes (e.g. RoboCasa: "Invalid frame index=23314614 ...
must be less than 41021").

Mirror the map-style reader: use the episode-local `timestamp` column as the base,
clamp delta query timestamps to per-camera episode-local bounds [0, duration], and
shift by the episode's `from_timestamp` per camera at decode time. For single-file
datasets `from_timestamp + timestamp == index / fps`, so existing parity tests are
unaffected; multi-file streaming is now correct.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 14:54:10 +02:00
Pepijn 68fa5d80b0 feat(streaming): multinode example, dataloading benchmark, distributed smoke test
- examples/scaling/train_streaming_multinode.py: Accelerate-based distributed/
  resumable streaming training (no DistributedSampler; rank/world_size auto-resolved),
  checkpoints the dataset stream state, and supports a --dummy pure-dataloading path
  with throughput logging. SLURM launcher in slurm/train_streaming_robocasa.sh.
- benchmarks/streaming/benchmark_streaming.py: dummy-consumer dataloading benchmark
  (single / sarm frame modes) emitting frames/s/node, p50/p95/p99 sample latency,
  first-batch latency, and VideoDecoderCache reuse stats as JSON + CSV. SLURM launcher
  + README documenting the source/node/mode matrix and manual bucket prewarming.
- VideoDecoderCache: add hit/miss/eviction counters and a stats() method so the
  benchmark can surface decoder thrash (no new cache, no eviction-policy change).
- tests/datasets/test_streaming_distributed.py: accelerate-launch smoke test asserting
  per-rank disjointness; skips (does not false-pass) when <2 processes spawn.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 13:48:23 +02:00
Pepijn d1fc8e298c feat(streaming): distributed + resumable HF-native StreamingLeRobotDataset
Add the large-scale streaming pieces that were missing from the frame-streaming
internals, keeping the existing Backtrackable + output-reservoir frame-shuffle:

- split_dataset_by_node(rank, world_size) before the per-shard loop so each rank
  streams a disjoint set of shards (fixes duplicate data across GPUs). rank and
  world_size auto-resolve from Accelerate state / RANK,WORLD_SIZE env / (0, 1).
- get_worker_info() shard splitting so DataLoader workers within a rank don't
  yield duplicate frames.
- Dynamic Backtrackable window (dynamic_bounds=True) sized to the requested
  delta_timestamps, removing the fixed 100-frame ceiling so long horizons (e.g. a
  SARM window ~160 frames) reach real frames instead of silently padding. Fix the
  peek_back off-by-one: history = lookback + 1.
- video_decoder_cache_size knob; default (active_shards + 1) x num_cameras so the
  live decoder working set does not thrash the VideoDecoderCache LRU.
- state_dict()/load_state_dict() for resume (per-shard HF stream state + exhausted
  set + RNG). Reservoir is re-warmed, so resumption is not bit-exact (documented).
- factory.py wires buffer_size from a new DatasetConfig.streaming_buffer_size field
  instead of repurposing max_num_shards as the worker count.

Tests: tests/datasets/test_streaming_native.py covers distributed disjointness,
worker de-duplication, the SARM-length window, resume, schema parity vs map-style,
local video path resolution, and shuffle decorrelation. 21 passed (13 existing + 8).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 13:37:30 +02:00
20 changed files with 3273 additions and 566 deletions
+547
View File
@@ -0,0 +1,547 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Single-image dataloading benchmark across the LeRobot loaders, MADE TO RUN ON A COMPUTE CLUSTER (SLURM).
This one file is both the orchestrator and the worker:
* Run it with no ``--scenario`` (from a login node) and it submits a SERIAL sbatch chain of all
scenarios below (no two network-bound jobs overlap, so CDN numbers stay clean).
* Run it with ``--scenario <name>`` and it executes that single benchmark (this is what each sbatch
job calls). The 2-node scenario is launched with ``srun`` and reads ``RANK``/``WORLD_SIZE`` so the
streaming dataset splits shards per node.
Scenarios (all single-frame / non-SARM):
1. ``mmap_local`` map-style LeRobotDataset over a LOCAL copy (``--local_root``, no network).
2. ``mmap_local_maxworkers`` same, but workers scaled to saturate the node's cores (decode-bound).
3. ``stream_hub`` StreamingLeRobotDataset from the Hub (allenai/MolmoAct2-BimanualYAM-Dataset).
4. ``stream_bucket`` StreamingLeRobotDataset from a warmed storage bucket (1 node).
5. ``stream_bucket_2node`` same warmed bucket, 2 nodes (split_dataset_by_node, per-rank results).
Reported per run: peak process-tree RSS (max memory), parallel throughput (samples/s, where a sample
is one timestep, plus decoded_frames/s = samples/s x num_cameras),
single-process throughput, shuffle randomness fraction (distinct episodes per batch / batch size),
fetch vs decode split (% of single-process per-sample time), first-batch latency, and p50/p95/p99
sample latency. Results are written as JSON + CSV under ``--out_dir``.
Submit the whole chain (from a login node, inside the repo). Point the scheduler env vars at your own
cluster's account/partition/qos, and ``--local_root`` at a local copy of the map-style dataset:
ACCOUNT=<account> PARTITION=<partition> QOS=<qos> \\
python examples/scaling/benchmark_dataloading.py --local_root /path/to/local/dataset
"""
import argparse
import csv
import json
import os
import random
import statistics
import subprocess
import sys
import threading
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata, StreamingLeRobotDataset
from lerobot.datasets.partition import group_episodes_by_files, partition_episodes
ROBOCASA_REPO = "pepijn223/robocasa_pretrain_human300_v4"
MOLMO_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
MOLMO_BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
# MolmoAct2 is published without a codebase-version git tag, so the version-safe loader would refuse
# it; "main" pins the branch directly and skips that check.
MOLMO_REVISION = "main"
# Per-scenario sbatch shape. mem is generous for the streaming legs (32k-episode, 3-camera, 2.35 TB
# dataset keeps many AV1 decoders open); the local map-style leg is light. Optional ``num_workers`` /
# ``cpus`` override the CLI defaults for that leg.
# ``mmap_local_maxworkers``: map-style decode is CPU-bound and each worker decodes its cameras on
# parallel threads, so the saturation point is ~num_cpus / num_cameras workers (~90 concurrent decode
# threads). The 96-core H100 nodes here schedule at most 92 cpus/task, so we take 92 cpus / 30 workers.
SCENARIOS = {
"mmap_local": {"kind": "map", "nodes": 1, "mem": "64G", "time": "01:00:00"},
"mmap_local_maxworkers": {
"kind": "map",
"nodes": 1,
"mem": "128G",
"time": "01:00:00",
"num_workers": 30,
"cpus": 92,
},
"stream_hub": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
"stream_bucket": {"kind": "stream", "nodes": 1, "mem": "250G", "time": "03:00:00"},
"stream_bucket_2node": {"kind": "stream", "nodes": 2, "mem": "250G", "time": "03:00:00"},
}
def _tree_rss_bytes() -> int:
"""Sum RSS of this process and all descendants via /proc (DataLoader workers are separate procs)."""
try:
children: dict[int, list[int]] = {}
for entry in os.listdir("/proc"):
if not entry.isdigit():
continue
try:
with open(f"/proc/{entry}/stat") as f:
ppid = int(f.read().split(") ", 1)[1].split()[1])
children.setdefault(ppid, []).append(int(entry))
except (OSError, ValueError, IndexError):
pass
total, stack = 0, [os.getpid()]
while stack:
cur = stack.pop()
try:
with open(f"/proc/{cur}/statm") as f:
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
except (OSError, ValueError, IndexError):
pass
stack.extend(children.get(cur, []))
return total
except OSError:
return 0
class PeakRSSSampler:
"""Background thread tracking peak process-tree RSS for the duration of the ``with`` block."""
def __init__(self, interval_s: float = 0.5):
self.interval_s = interval_s
self.peak_bytes = 0
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
def _run(self) -> None:
while not self._stop.is_set():
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
self._stop.wait(self.interval_s)
def __enter__(self) -> "PeakRSSSampler":
self._thread.start()
return self
def __exit__(self, *exc) -> None:
self._stop.set()
self._thread.join(timeout=2)
def percentile(values: list[float], pct: float) -> float:
if not values:
return float("nan")
ordered = sorted(values)
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
return ordered[k]
class _TimedStreaming(StreamingLeRobotDataset):
"""StreamingLeRobotDataset that times the fetch stage (parquet/network row) separately from the
decode stage (video decode + torch conversion in ``_finalize_sample``), so a single-process pass
can attribute per-sample cost to fetch vs decode. Timing lives here in the benchmark, not in the
library, to keep the dataset itself instrumentation-free."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fetch_s = 0.0
self.decode_s = 0.0
def __iter__(self):
self._in_flight_epoch = self._epoch
self._pipeline.set_epoch(self._in_flight_epoch)
self._epoch += 1
self.video_decoder_cache = self._make_video_decoder_cache()
iterator = iter(self._pipeline)
while True:
t0 = time.perf_counter()
try:
row = next(iterator)
except StopIteration:
return
t1 = time.perf_counter()
sample = self._finalize_sample(row)
t2 = time.perf_counter()
self.fetch_s += t1 - t0
self.decode_s += t2 - t1
yield sample
def select_node_episodes(
meta: LeRobotDatasetMetadata, num_partitions: int, index: int, cap: int
) -> list[int]:
"""This node's episode share, mirroring lerobot_train ``--data_partition=node``: group episodes by
shared video files, LPT-balance the groups by frame count, take this node's bin (capped)."""
episodes = list(range(meta.total_episodes))
from_idx = meta.episodes["dataset_from_index"]
to_idx = meta.episodes["dataset_to_index"]
lengths = [int(to_idx[ep] - from_idx[ep]) for ep in episodes]
if meta.video_keys:
file_columns = {
key: (meta.episodes[f"videos/{key}/chunk_index"], meta.episodes[f"videos/{key}/file_index"])
for key in meta.video_keys
}
else:
file_columns = {"data": (meta.episodes["data/chunk_index"], meta.episodes["data/file_index"])}
episode_file_ids = [
[(key, chunks[ep], files[ep]) for key, (chunks, files) in file_columns.items()] for ep in episodes
]
groups = group_episodes_by_files(episode_file_ids)
if len(groups) < num_partitions:
groups = [[i] for i in range(len(episodes))]
group_lengths = [sum(lengths[i] for i in g) for g in groups]
bins = partition_episodes(group_lengths, num_partitions)
chosen = sorted(episodes[i] for g in bins[index] for i in groups[g])
return chosen[:cap] if cap and len(chosen) > cap else chosen
def build_dataset(scenario: str, args: argparse.Namespace):
"""Return (dataset, meta, is_map_style, info) for the scenario; single-frame (no delta windows)."""
if scenario.startswith("mmap_local"):
if not args.local_root:
raise SystemExit("mmap_local needs --local_root pointing at a local LeRobotDataset copy.")
meta = LeRobotDatasetMetadata(ROBOCASA_REPO, root=args.local_root)
episodes = select_node_episodes(meta, args.num_partitions, args.partition_index, args.max_episodes)
dataset = LeRobotDataset(ROBOCASA_REPO, root=args.local_root, episodes=episodes, tolerance_s=1e-3)
return dataset, meta, True, {"loaded_episodes": len(episodes)}
data_files_root = MOLMO_BUCKET if scenario.startswith("stream_bucket") else None
meta = LeRobotDatasetMetadata(MOLMO_REPO, revision=MOLMO_REVISION)
dataset = _TimedStreaming(
MOLMO_REPO,
revision=MOLMO_REVISION,
data_files_root=data_files_root,
episode_pool_size=args.episode_pool_size,
max_buffer_input_shards=args.max_buffer_input_shards,
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
# Throughput benchmark: don't gate on the one-row-group-per-episode invariant (a public
# dataset may be collapsed); reshard() still yields per-episode shards where it holds.
validate_row_groups=False,
)
return dataset, meta, False, {"num_shards": dataset.num_shards, "data_files_root": data_files_root}
def _split(fetch_s: float, decode_s: float, getitem_s: float, n_probe: int) -> dict:
stage = fetch_s + decode_s
return {
"single_proc_samples_per_s": round(n_probe / getitem_s, 2) if getitem_s else None,
"fetch_pct": round(100 * fetch_s / stage, 1) if stage else None,
"decode_pct": round(100 * decode_s / stage, 1) if stage else None,
}
def measure_fetch_decode_stream(dataset: _TimedStreaming, n_probe: int, warmup: int) -> dict:
"""Single-process pass attributing per-sample time to fetch (parquet/network row) vs decode (video)."""
it = iter(dataset)
for _ in range(warmup): # exclude the cold shuffle-buffer fill from the ratio
next(it)
dataset.fetch_s = dataset.decode_s = 0.0
t0 = time.perf_counter()
for _ in range(n_probe):
next(it)
return _split(dataset.fetch_s, dataset.decode_s, time.perf_counter() - t0, n_probe)
def measure_fetch_decode_map(dataset: LeRobotDataset, n_probe: int, warmup: int) -> dict:
"""Same split for the map-style loader: fetch = raw tabular row (``get_raw_item``), decode = the rest
of ``__getitem__`` (video decode + transforms). Local reads make fetch tiny and decode dominant.
Random frames are resampled past any that torchcodec fails to decode, so a single flaky frame can't
abort the whole benchmark (the parallel DataLoader pass draws its own fresh random frames)."""
rng = random.Random(0)
n = len(dataset)
fetch_s = getitem_s = 0.0
warmed = measured = skipped = attempts = 0
while measured < n_probe and attempts < (warmup + n_probe) * 10:
attempts += 1
i = rng.randrange(n)
try:
t0 = time.perf_counter()
dataset.get_raw_item(i)
t1 = time.perf_counter()
dataset[i]
t2 = time.perf_counter()
except Exception:
skipped += 1
continue
if warmed < warmup:
warmed += 1
continue
fetch_s += t1 - t0
getitem_s += t2 - t1
measured += 1
if skipped:
print(f"map fetch/decode probe skipped {skipped} undecodable frame(s)", flush=True)
return _split(fetch_s, max(0.0, getitem_s - fetch_s), getitem_s, measured)
def run_scenario(scenario: str, args: argparse.Namespace) -> None:
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
device = torch.device(args.device)
dataset, meta, is_map_style, info = build_dataset(scenario, args)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=is_map_style, # map-style: global random shuffle; streaming: shuffled inside the dataset
pin_memory=device.type == "cuda",
drop_last=True,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
persistent_workers=args.num_workers > 0,
)
sample_latencies_ms: list[float] = []
episodes_per_batch: list[int] = []
samples = 0
first_batch_latency_s = None
steady_start = None
t_start = time.perf_counter()
t_prev = t_start
with PeakRSSSampler() as rss:
for i, batch in enumerate(loader):
for value in batch.values():
if torch.is_tensor(value):
value.to(device, non_blocking=device.type == "cuda")
now = time.perf_counter()
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
if i == args.warmup_batches:
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
samples += args.batch_size
ep = batch.get("episode_index")
if torch.is_tensor(ep):
episodes_per_batch.append(int(torch.unique(ep).numel()))
t_prev = now
# Measure throughput over a fixed wall-clock window (after warmup) so every scenario is
# compared over the same duration regardless of its speed; num_batches is only a safety cap.
if steady_start is not None and (now - steady_start) >= args.duration_s:
break
if i + 1 >= args.num_batches:
break
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
now = time.perf_counter()
elapsed = now - t_start
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
if samples == 0:
raise SystemExit(
f"FAILED: 0 samples in {args.duration_s}s for scenario={scenario} "
"(inspect worker logs; try --num_workers 0 to surface the exception)."
)
# Single-process fetch/decode split + single-proc throughput. Run AFTER the DataLoader pass: this
# decodes video in the main process, which must stay decode-clean until the workers have forked
# (decoding before fork corrupts the workers' torchcodec state).
del loader
if is_map_style:
fetch_decode = measure_fetch_decode_map(dataset, args.probe_samples, args.probe_warmup)
else:
fetch_decode = measure_fetch_decode_stream(dataset, args.probe_samples, args.probe_warmup)
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
num_cameras = len(meta.video_keys)
results = {
"scenario": scenario,
"rank": rank,
"world_size": world_size,
"loader": "map_style" if is_map_style else "streaming",
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"episode_pool_size": None if is_map_style else args.episode_pool_size,
"max_buffer_input_shards": None
if is_map_style
else (args.max_buffer_input_shards or args.episode_pool_size),
**info,
"num_cameras": num_cameras,
"image_shape": image_shape,
"fps": meta.fps,
"peak_rss_gb": peak_rss_gb,
"samples_measured": samples,
"steady_window_s": round(steady_elapsed_s, 2),
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 3),
# Parallel throughput over the steady window (excludes warmup + the prefetch queue it filled).
# A sample is one timestep (one dataset item); it decodes num_cameras video frames.
"samples_per_s": round(samples / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
"decoded_frames_per_s": round(samples / steady_elapsed_s * num_cameras, 2)
if steady_elapsed_s
else 0.0,
**fetch_decode,
# Distinct episodes per batch / batch size: ~1.0 ≈ map-style uniform, low ≈ correlated samples.
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
if episodes_per_batch
else None,
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
if sample_latencies_ms
else None,
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
"total_time_s": round(elapsed, 2),
}
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
tag = f"{scenario}_bs{args.batch_size}_w{args.num_workers}_r{rank}of{world_size}"
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
flat = {k: (json.dumps(v) if isinstance(v, (dict, list)) else v) for k, v in results.items()}
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(flat))
writer.writeheader()
writer.writerow(flat)
print(json.dumps(results, indent=2), flush=True)
print(f"Wrote {out_dir / tag}.json and .csv", flush=True)
def submit_chain(args: argparse.Namespace) -> None:
"""Submit every scenario as a serial sbatch chain (one network-bound job at a time).
Bodies are passed to ``sbatch --wrap`` as a single argv (no outer shell), so ``$SLURM_PROCID`` /
``$SLURM_NTASKS`` stay literal and expand at job runtime, not at submit time.
"""
this_file = Path(__file__).resolve()
repo_dir = str(this_file.parents[2]) # <repo>/examples/scaling/<this file>
logs = Path(repo_dir) / "logs"
logs.mkdir(exist_ok=True)
run = f"conda run --no-capture-output -n {args.conda_env} python"
common = (
f"--batch_size {args.batch_size} "
f"--prefetch_factor {args.prefetch_factor} --episode_pool_size {args.episode_pool_size} "
f"--video_decoder_cache_size {args.video_decoder_cache_size} --duration_s {args.duration_s} "
f"--num_batches {args.num_batches} --out_dir {args.out_dir}"
)
if args.max_buffer_input_shards is not None:
common += f" --max_buffer_input_shards {args.max_buffer_input_shards}"
if args.local_root:
common += f" --local_root {args.local_root}"
env_prefix = "export TOKENIZERS_PARALLELISM=false"
sched = []
for opt, env in (("--account", "ACCOUNT"), ("--partition", "PARTITION"), ("--qos", "QOS")):
if os.environ.get(env):
sched.append(f"{opt}={os.environ[env]}")
selected = args.scenarios.split(",") if args.scenarios else list(SCENARIOS)
prev = ""
for scenario in selected:
cfg = SCENARIOS[scenario]
nw = cfg.get("num_workers", args.num_workers)
cpus = cfg.get("cpus", nw + 4)
worker = f"{run} {this_file} --scenario {scenario} --num_workers {nw} {common}"
if cfg["nodes"] > 1:
# One task per node; each exports RANK/WORLD_SIZE so the stream splits shards per node.
inner = f"export RANK=$SLURM_PROCID WORLD_SIZE=$SLURM_NTASKS && cd {repo_dir} && {env_prefix} && {worker}"
body = f"srun --export=ALL bash -c '{inner}'"
node_flags = [f"--nodes={cfg['nodes']}", "--ntasks-per-node=1", "--gpus-per-node=1"]
else:
body = f"cd {repo_dir} && {env_prefix} && {worker}"
node_flags = ["--nodes=1", "--ntasks=1", "--gpus=1"]
cmd = [
"sbatch",
"--parsable",
f"--job-name=dlbench_{scenario}",
*node_flags,
f"--cpus-per-task={cpus}",
f"--mem={cfg['mem']}",
f"--time={cfg['time']}",
f"--output={logs}/%x-%j.out",
*sched,
]
if prev:
cmd.append(f"--dependency=afterany:{prev}")
cmd += ["--wrap", body]
jid = subprocess.check_output(cmd, text=True).strip().split(";")[0]
print(f"submitted {jid} dlbench_{scenario}{f' (after {prev})' if prev else ''}", flush=True)
prev = jid
print(f"\nSubmitted {len(selected)} jobs as a serial chain. Results: {args.out_dir}/*.json", flush=True)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument(
"--scenario",
choices=list(SCENARIOS),
default=None,
help="Run ONE scenario (worker mode). Omit to submit the whole chain (orchestrator mode).",
)
p.add_argument(
"--scenarios",
type=str,
default=None,
help="Orchestrator only: comma-separated subset of scenarios to submit (default: all).",
)
p.add_argument("--local_root", type=str, default=None, help="Local LeRobotDataset copy for mmap_local.")
p.add_argument(
"--num_partitions", type=int, default=8, help="Node count for mmap_local episode partition."
)
p.add_argument("--partition_index", type=int, default=0)
p.add_argument(
"--max_episodes", type=int, default=512, help="Cap mmap_local episodes to the local share."
)
p.add_argument("--batch_size", type=int, default=64)
p.add_argument("--num_workers", type=int, default=8)
p.add_argument("--prefetch_factor", type=int, default=2)
p.add_argument(
"--episode_pool_size", type=int, default=1024, help="Streaming shuffle pool (randomness knob)."
)
p.add_argument(
"--max_buffer_input_shards",
type=int,
default=None,
help="Concurrently-live random episodes feeding the pool after reshard() "
"(default: episode_pool_size). The frac knob; set >= batch_size for frac->1.",
)
p.add_argument(
"--video_decoder_cache_size", type=int, default=32, help="Max open video decoders (bounds RAM)."
)
p.add_argument(
"--duration_s", type=float, default=60.0, help="Steady-state measurement window (seconds)."
)
p.add_argument(
"--num_batches", type=int, default=1_000_000, help="Safety cap; duration_s governs the window."
)
p.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state throughput.")
p.add_argument(
"--probe_samples", type=int, default=100, help="Single-process samples for fetch/decode split."
)
p.add_argument(
"--probe_warmup", type=int, default=10, help="Samples skipped before the fetch/decode probe."
)
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--conda_env", type=str, default="lerobot", help="Conda env the chained jobs run in.")
p.add_argument("--out_dir", type=str, default="benchmarks/streaming/results_dataloading")
return p.parse_args()
def main() -> None:
args = parse_args()
if args.scenario is None:
if torch.cuda.is_available():
print(
"NOTE: no --scenario given, submitting the SLURM chain. This benchmark is meant to run on a "
"compute cluster; run from a login node with ACCOUNT/PARTITION/QOS set.",
file=sys.stderr,
)
submit_chain(args)
else:
run_scenario(args.scenario, args)
if __name__ == "__main__":
main()
+11 -1
View File
@@ -95,7 +95,7 @@ dependencies = [
# ── Feature-scoped extras ──────────────────────────────────
dataset = [
"datasets>=4.7.0,<5.0.0",
"datasets>=5.0.0,<6.0.0", # StreamingLeRobotDataset needs reshard() + shuffle(max_buffer_input_shards=...)
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
"lerobot[av-dep]",
@@ -333,6 +333,16 @@ explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
# (#8194) — all merged, not yet in a tagged 5.0 release. Track main until the next datasets release ships
# them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
# Temporary: huggingface_hub main carries the 408-retry fix (not yet released). NOTE: main still closes the
# shared httpx.Client on every ConnectError, which races with concurrent streaming requests
# ("Cannot send a request, as the client has been closed"); we patch that out locally in
# huggingface_hub/utils/_http.py. A fresh `uv sync` re-installs main *without* that local patch.
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "main" }
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
+51
View File
@@ -0,0 +1,51 @@
#!/usr/bin/env python
"""Build mmap-able byte-index sidecars for LeRobot streaming datasets."""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from lerobot.datasets.byte_index_builder import (
build_byte_index_tables,
load_existing_file_ids,
write_byte_index,
)
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main() -> None:
parser = argparse.ArgumentParser(description="Build LeRobot video byte-index sidecar.")
parser.add_argument("--repo-id", required=True)
parser.add_argument("--revision", default=None)
parser.add_argument("--data-root", required=True, help="fsspec root for videos/ + data/")
parser.add_argument("--output", type=Path, required=True, help="Output meta/byte_index directory")
parser.add_argument("--workers", type=int, default=8)
parser.add_argument("--max-episodes", type=int, default=None, help="Limit episodes (debug/smoke)")
parser.add_argument("--no-keyframes", action="store_true")
args = parser.parse_args()
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
output = args.output
existing = load_existing_file_ids(output)
if existing:
logger.info("resuming: %s files already indexed", len(existing))
files_tbl, episodes_tbl, keyframes_tbl = build_byte_index_tables(
meta,
args.data_root,
include_keyframes=not args.no_keyframes,
workers=args.workers,
existing_files=existing,
max_episodes=args.max_episodes,
)
write_byte_index(output, files_tbl, episodes_tbl, keyframes_tbl, merge_existing=True)
logger.info("wrote byte index to %s", output)
if __name__ == "__main__":
main()
+4
View File
@@ -39,6 +39,10 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
# the pool holds tabular rows only. Ignored when streaming is False.
streaming_episode_pool_size: int = 1024
def __post_init__(self) -> None:
if self.episodes is not None:
+228
View File
@@ -0,0 +1,228 @@
"""Runtime in-memory byte index loaded from precomputed sidecar parquet."""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from .byte_index_builder import BYTE_INDEX_DIR, EPISODES_NAME, FILES_NAME, KEYFRAMES_NAME
from .mp4_episode_slice import episode_custom_frame_mappings_json
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class EpisodeSliceLookup:
global_episode_id: int
file_id: int
mdat_offset: int
mdat_length: int
frame_count: int
first_pts: float
last_pts: float
avg_fps: float
@property
def fetch_bytes(self) -> int:
return self.mdat_length
@dataclass(frozen=True)
class FileLookup:
file_id: int
file_path: str
file_size: int
moov_offset: int
moov_length: int
header_length: int
faststart: bool
avg_fps: float
codec: str
class EpisodeByteIndex:
"""Columnar byte-index resident in numpy arrays for O(1) episode lookup."""
def __init__(
self,
index_dir: str | Path | None,
*,
video_keys: list[str],
num_episodes: int,
mmap: bool = True,
files_table: pa.Table | None = None,
episodes_table: pa.Table | None = None,
mp4_by_rel: dict[str, Any] | None = None,
):
self.index_dir = Path(index_dir) if index_dir is not None else None
self.video_keys = list(video_keys)
self.num_episodes = num_episodes
self.num_cameras = len(video_keys)
self._cam_to_idx = {cam: i for i, cam in enumerate(self.video_keys)}
self._mp4_by_rel = mp4_by_rel
self._frame_mappings_by_gid: dict[int, bytes] = {}
t0 = time.perf_counter()
if files_table is not None and episodes_table is not None:
files_tbl, episodes_tbl = files_table, episodes_table
else:
if self.index_dir is None:
raise ValueError("index_dir or in-memory tables required")
files_path = self.index_dir / FILES_NAME
episodes_path = self.index_dir / EPISODES_NAME
if not files_path.exists() or not episodes_path.exists():
raise FileNotFoundError(f"byte index missing under {self.index_dir}")
files_tbl = pq.read_table(files_path, memory_map=mmap)
episodes_tbl = pq.read_table(episodes_path, memory_map=mmap)
self._load_tables(files_tbl, episodes_tbl, mmap=mmap)
self.build_time_s = time.perf_counter() - t0
self.load_time_s = self.build_time_s
def _load_tables(self, files_tbl: pa.Table, episodes_tbl: pa.Table, *, mmap: bool) -> None:
def col(tbl, name: str):
array = tbl.column(name).combine_chunks()
if pa.types.is_boolean(array.type):
return array.to_numpy(zero_copy_only=False)
return array.to_numpy()
self.file_id = col(files_tbl, "file_id")
self.file_path = files_tbl.column("file_path").to_pylist()
self.file_size = col(files_tbl, "file_size")
self.moov_offset = col(files_tbl, "moov_offset")
self.moov_length = col(files_tbl, "moov_length")
self.header_length = col(files_tbl, "header_length")
self.faststart = col(files_tbl, "faststart")
self.file_avg_fps = col(files_tbl, "avg_fps")
self.codec = files_tbl.column("codec").to_pylist()
ep = episodes_tbl
n = len(ep)
gid = col(ep, "global_episode_id")
order = np.argsort(gid)
self._global_episode_id = gid[order]
self._episode_index = col(ep, "episode_index")[order]
self._camera_index = col(ep, "camera_index")[order]
self._file_id = col(ep, "file_id")[order]
self._mdat_offset = col(ep, "mdat_offset")[order]
self._mdat_length = col(ep, "mdat_length")[order]
self._frame_count = col(ep, "frame_count")[order]
self._first_pts = col(ep, "first_pts")[order]
self._last_pts = col(ep, "last_pts")[order]
expected = self.num_episodes * self.num_cameras
if n != expected:
raise ValueError(f"byte index episodes rows {n} != expected {expected}")
if self.index_dir is not None:
keyframes_path = self.index_dir / KEYFRAMES_NAME
if keyframes_path.exists():
kf_tbl = pq.read_table(keyframes_path, memory_map=mmap)
self._keyframes_rows = len(kf_tbl)
else:
self._keyframes_rows = 0
else:
self._keyframes_rows = 0
self.resident_bytes = int(
self._global_episode_id.nbytes
+ self._file_id.nbytes
+ self._mdat_offset.nbytes
+ self._mdat_length.nbytes
+ self.file_size.nbytes
)
@classmethod
def from_metadata_root(cls, meta_root: Path, *, video_keys: list[str], num_episodes: int) -> EpisodeByteIndex:
return cls(meta_root / BYTE_INDEX_DIR, video_keys=video_keys, num_episodes=num_episodes)
@classmethod
def from_memory_build(
cls,
meta,
data_root: str,
*,
workers: int = 8,
max_episodes: int | None = None,
include_frame_mappings_cache: bool = True,
) -> EpisodeByteIndex:
"""Build a complete byte index in RAM (no parquet write, no dataset push)."""
from .byte_index_builder import build_byte_index_in_memory
return build_byte_index_in_memory(
meta,
data_root,
workers=workers,
max_episodes=max_episodes,
include_frame_mappings_cache=include_frame_mappings_cache,
)
def lookup(self, episode_index: int, camera_key: str) -> EpisodeSliceLookup:
cam_idx = self._cam_to_idx[camera_key]
gid = episode_index * self.num_cameras + cam_idx
row = int(gid)
if row < 0 or row >= len(self._global_episode_id):
raise IndexError(f"episode_index={episode_index} camera={camera_key} out of range")
file_id = int(self._file_id[row])
return EpisodeSliceLookup(
global_episode_id=gid,
file_id=file_id,
mdat_offset=int(self._mdat_offset[row]),
mdat_length=int(self._mdat_length[row]),
frame_count=int(self._frame_count[row]),
first_pts=float(self._first_pts[row]),
last_pts=float(self._last_pts[row]),
avg_fps=float(self.file_avg_fps[file_id]),
)
def file_lookup(self, file_id: int) -> FileLookup:
return FileLookup(
file_id=file_id,
file_path=self.file_path[file_id],
file_size=int(self.file_size[file_id]),
moov_offset=int(self.moov_offset[file_id]),
moov_length=int(self.moov_length[file_id]),
header_length=int(self.header_length[file_id]),
faststart=bool(self.faststart[file_id]),
avg_fps=float(self.file_avg_fps[file_id]),
codec=self.codec[file_id],
)
def header_byte_range(self, file_id: int) -> tuple[int, int]:
length = int(self.header_length[file_id])
return 0, max(0, length - 1)
def custom_frame_mappings(self, episode_index: int, camera_key: str) -> bytes | None:
cam_idx = self._cam_to_idx[camera_key]
gid = episode_index * self.num_cameras + cam_idx
cached = self._frame_mappings_by_gid.get(gid)
if cached is not None:
return cached
if self._mp4_by_rel is None:
return None
lookup = self.lookup(episode_index, camera_key)
rel = self.file_path[lookup.file_id]
mp4_index = self._mp4_by_rel.get(rel)
if mp4_index is None:
return None
payload = episode_custom_frame_mappings_json(mp4_index, lookup.first_pts, lookup.last_pts)
self._frame_mappings_by_gid[gid] = payload
return payload
def stats_dict(self) -> dict[str, float | int]:
return {
"load_time_s": self.load_time_s,
"build_time_s": self.build_time_s,
"resident_bytes": self.resident_bytes,
"frame_mappings_cached": len(self._frame_mappings_by_gid),
"mp4_indices_cached": len(self._mp4_by_rel or {}),
"num_files": len(self.file_path),
"num_episode_rows": len(self._global_episode_id),
}
+281
View File
@@ -0,0 +1,281 @@
"""Build mmap-able byte-index sidecars for LeRobot streaming video fetch."""
from __future__ import annotations
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import fsspec
import pyarrow as pa
import pyarrow.parquet as pq
from .mp4_episode_slice import (
HEADER_PROBE_BYTES,
MAX_HEADER_PROBE_BYTES,
average_fps_from_index,
episode_keyframes,
parse_mp4_file_layout,
parse_mp4_index,
)
logger = logging.getLogger(__name__)
BYTE_INDEX_DIR = "meta/byte_index"
FILES_NAME = "files.parquet"
EPISODES_NAME = "episodes.parquet"
KEYFRAMES_NAME = "keyframes.parquet"
@dataclass
class IndexedFile:
file_id: int
file_path: str
file_size: int
moov_offset: int
moov_length: int
header_length: int
faststart: bool
avg_fps: float
codec: str
def fetch_header_bytes(path: str, file_size: int) -> bytes:
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
probe = HEADER_PROBE_BYTES
while True:
with fs.open(path, "rb", block_size=max(probe, 2**20), cache_type="none") as f:
header = f.read(min(probe, file_size))
try:
parse_mp4_file_layout(header, file_size)
return header
except ValueError as exc:
if probe >= min(MAX_HEADER_PROBE_BYTES, file_size) or "mdat box not found" not in str(exc):
raise
probe = min(probe * 2, MAX_HEADER_PROBE_BYTES, file_size)
def index_video_file(path: str, *, rel_path: str | None = None) -> tuple[IndexedFile, Any]:
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
file_size = fs.info(path)["size"]
header = fetch_header_bytes(path, file_size)
layout = parse_mp4_file_layout(header, file_size)
if not layout.faststart:
logger.warning("non-faststart MP4 (moov after mdat): %s", path)
mp4_index = parse_mp4_index(header, file_size)
indexed = IndexedFile(
file_id=-1,
file_path=rel_path or path,
file_size=file_size,
moov_offset=layout.moov_offset,
moov_length=layout.moov_length,
header_length=layout.header_end,
faststart=layout.faststart,
avg_fps=average_fps_from_index(mp4_index),
codec=layout.codec,
)
return indexed, mp4_index
def build_byte_index_tables(
meta,
data_root: str,
*,
file_paths: list[str] | None = None,
include_keyframes: bool = True,
workers: int = 8,
existing_files: dict[str, int] | None = None,
max_episodes: int | None = None,
return_mp4_indices: bool = False,
complete_files_table: bool = False,
) -> tuple[pa.Table, pa.Table, pa.Table | None] | tuple[pa.Table, pa.Table, pa.Table | None, dict[str, Any]]:
"""Build files/episodes/(optional keyframes) Arrow tables."""
video_keys = list(meta.video_keys)
n_cams = len(video_keys)
cam_to_idx = {cam: i for i, cam in enumerate(video_keys)}
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
rel_paths: set[str] = set()
for ep_idx in range(num_episodes):
for cam in video_keys:
rel_paths.add(str(meta.get_video_file_path(ep_idx, cam)))
path_by_rel = {rel: f"{data_root.rstrip('/')}/{rel}" for rel in sorted(rel_paths)}
if file_paths is None:
file_paths = list(path_by_rel.values())
rel_by_path = {path_by_rel[rel]: rel for rel in path_by_rel}
existing_files = existing_files or {}
file_meta_by_rel: dict[str, dict[str, Any]] = {}
mp4_by_rel: dict[str, Any] = {}
next_file_id = max(existing_files.values(), default=-1) + 1
to_index = [rel for rel in sorted(rel_paths) if rel not in existing_files]
if to_index:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel for rel in to_index
}
for fut in as_completed(futures):
rel = futures[fut]
indexed, mp4_index = fut.result()
indexed.file_id = next_file_id
mp4_by_rel[rel] = mp4_index
file_meta_by_rel[rel] = {
"file_id": indexed.file_id,
"file_path": rel,
"file_size": indexed.file_size,
"moov_offset": indexed.moov_offset,
"moov_length": indexed.moov_length,
"header_length": indexed.header_length,
"faststart": indexed.faststart,
"avg_fps": indexed.avg_fps,
"codec": indexed.codec,
}
existing_files[rel] = indexed.file_id
next_file_id += 1
missing_rels = {
str(meta.get_video_file_path(ep, cam))
for ep in range(num_episodes)
for cam in video_keys
} - set(mp4_by_rel.keys())
if missing_rels:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel
for rel in missing_rels
if rel not in mp4_by_rel
}
for fut in as_completed(futures):
rel = futures[fut]
_, mp4_index = fut.result()
mp4_by_rel[rel] = mp4_index
episode_rows: list[dict[str, Any]] = []
keyframe_rows: list[dict[str, Any]] = []
for ep_idx in range(num_episodes):
for cam in video_keys:
rel = str(meta.get_video_file_path(ep_idx, cam))
path = f"{data_root.rstrip('/')}/{rel}"
if rel not in existing_files:
raise KeyError(f"file not indexed: {rel}")
mp4_index = mp4_by_rel[rel]
ep = meta.episodes[ep_idx]
from_ts = float(ep[f"videos/{cam}/from_timestamp"])
to_ts = float(ep[f"videos/{cam}/to_timestamp"])
span = mp4_index.episode_byte_span(from_ts, to_ts)
global_episode_id = ep_idx * n_cams + cam_to_idx[cam]
mdat_length = span.slice_hi - span.slice_lo + 1
episode_rows.append(
{
"global_episode_id": global_episode_id,
"episode_index": ep_idx,
"camera_key": cam,
"camera_index": cam_to_idx[cam],
"file_id": existing_files[rel],
"mdat_offset": span.slice_lo,
"mdat_length": mdat_length,
"frame_count": max(1, round((to_ts - from_ts) * meta.fps)),
"first_pts": from_ts,
"last_pts": to_ts,
}
)
if include_keyframes:
timescale = mp4_index.timescale
for pts_s, byte_off in episode_keyframes(mp4_index, from_ts, to_ts):
keyframe_rows.append(
{
"global_episode_id": global_episode_id,
"pts": int(round(pts_s * timescale)),
"byte_offset": byte_off,
}
)
referenced_rels = {
str(meta.get_video_file_path(ep, cam)) for ep in range(num_episodes) for cam in video_keys
}
if complete_files_table:
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(referenced_rels)])
elif to_index:
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(to_index)])
else:
files_table = None
episodes_table = pa.Table.from_pylist(episode_rows)
keyframes_table = pa.Table.from_pylist(keyframe_rows) if include_keyframes and keyframe_rows else None
if return_mp4_indices:
return files_table, episodes_table, keyframes_table, mp4_by_rel
return files_table, episodes_table, keyframes_table
def build_byte_index_in_memory(
meta,
data_root: str,
*,
workers: int = 8,
max_episodes: int | None = None,
include_frame_mappings_cache: bool = False,
):
"""Build a complete byte index resident in RAM (no parquet write, no dataset push)."""
from .byte_index import EpisodeByteIndex
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
files_tbl, episodes_tbl, _, mp4_by_rel = build_byte_index_tables(
meta,
data_root,
include_keyframes=False,
workers=workers,
max_episodes=max_episodes,
return_mp4_indices=True,
complete_files_table=True,
)
index = EpisodeByteIndex(
None,
video_keys=list(meta.video_keys),
num_episodes=num_episodes,
files_table=files_tbl,
episodes_table=episodes_tbl,
mp4_by_rel=mp4_by_rel,
)
if include_frame_mappings_cache:
for ep_idx in range(num_episodes):
for cam in meta.video_keys:
index.custom_frame_mappings(ep_idx, cam)
return index
def write_byte_index(
output_dir: Path,
files_table: pa.Table | None,
episodes_table: pa.Table,
keyframes_table: pa.Table | None = None,
*,
merge_existing: bool = True,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
files_path = output_dir / FILES_NAME
episodes_path = output_dir / EPISODES_NAME
keyframes_path = output_dir / KEYFRAMES_NAME
if merge_existing and files_path.exists() and files_table is not None:
prev = pq.read_table(files_path)
files_table = pa.concat_tables([prev, files_table])
if files_table is not None:
pq.write_table(files_table, files_path)
pq.write_table(episodes_table, episodes_path)
if keyframes_table is not None:
if merge_existing and keyframes_path.exists():
keyframes_table = pa.concat_tables([pq.read_table(keyframes_path), keyframes_table])
pq.write_table(keyframes_table, keyframes_path)
def load_existing_file_ids(index_dir: Path) -> dict[str, int]:
files_path = index_dir / FILES_NAME
if not files_path.exists():
return {}
table = pq.read_table(files_path, columns=["file_id", "file_path"])
return {row["file_path"]: int(row["file_id"]) for row in table.to_pylist()}
+11 -2
View File
@@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
ep_dataset = embed_images(ep_dataset)
table = ep_dataset.with_format("arrow")[:]
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
writer.write_table(table)
# Emit several row groups with a page index instead of one giant row group. A single row group forces
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
# groups bounds per-shard memory while keeping groups large enough to scan
# efficiently; the page index lets readers skip to the pages they need.
target_row_group_bytes = 32 * 1024 * 1024
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
)
writer.write_table(table, row_group_size=row_group_size)
writer.close()
+263
View File
@@ -0,0 +1,263 @@
"""Node-local LRU byte cache using precomputed byte-index manifest sidecars."""
from __future__ import annotations
import logging
import threading
import time
from collections import OrderedDict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any
import fsspec
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
from .mp4_episode_slice import SparseMp4Reader
from .torchcodec_utils import open_video_decoder
logger = logging.getLogger(__name__)
@dataclass
class CacheStats:
hits: int = 0
misses: int = 0
bytes_fetched: int = 0
full_file_fallbacks: int = 0
prefetch_submitted: int = 0
prefetch_waits: int = 0
mdat_slices: int = 0
prefix_fetches: int = 0
fetch_to_buffer_s: float = 0.0
buffer_to_decoder_s: float = 0.0
buffer_hit_decoder_s: float = 0.0
decode_frame_s: float = 0.0
decode_frames: int = 0
def merge(self, other: CacheStats) -> None:
for name in self.__dataclass_fields__:
setattr(self, name, getattr(self, name) + getattr(other, name))
def stats_dict(self) -> dict[str, int | float]:
avg_miss = self.bytes_fetched / max(1, self.misses)
return {
"byte_cache_hits": self.hits,
"byte_cache_misses": self.misses,
"byte_cache_bytes_fetched": self.bytes_fetched,
"byte_cache_bytes_per_miss": avg_miss,
"byte_cache_full_file_fallbacks": self.full_file_fallbacks,
"byte_cache_prefetch_submitted": self.prefetch_submitted,
"byte_cache_prefetch_waits": self.prefetch_waits,
"byte_cache_mdat_slices": self.mdat_slices,
"byte_cache_prefix_fetches": self.prefix_fetches,
"fetch_to_buffer_ms_per_miss": 1000 * self.fetch_to_buffer_s / max(1, self.misses),
"buffer_to_decoder_ms_per_miss": 1000 * self.buffer_to_decoder_s / max(1, self.misses),
"buffer_hit_decoder_ms_per_hit": 1000 * self.buffer_hit_decoder_s / max(1, self.hits),
"decode_ms_per_frame": 1000 * self.decode_frame_s / max(1, self.decode_frames),
}
@dataclass
class _EpisodeEntry:
decoders: dict[str, Any] = field(default_factory=dict)
ready: threading.Event = field(default_factory=threading.Event)
error: Exception | None = None
class RangeFetcher:
"""Sequential byte-range GETs via fsspec."""
def __init__(self, path: str):
self.path = path
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
def fetch(self, lo: int, hi: int) -> bytes:
if hi < lo:
return b""
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
f.seek(lo)
return f.read(hi - lo + 1)
class EpisodeByteCache:
"""Manifest-driven episode MP4 fetch + in-memory sparse decode."""
MAX_BYTES_PER_MISS = 25 * 1024 * 1024
def __init__(
self,
byte_index: EpisodeByteIndex,
max_bytes: int,
*,
data_root: str,
max_prefetch_workers: int = 4,
):
if max_bytes <= 0:
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
self.byte_index = byte_index
self.max_bytes = max_bytes
self.data_root = data_root.rstrip("/")
self._bytes_used = 0
self._lock = threading.Lock()
self._cache: OrderedDict[tuple[Any, ...], tuple[Any, int]] = OrderedDict()
self._header_cache: dict[int, bytes] = {}
self._fetcher_cache: dict[int, RangeFetcher] = {}
self._episodes: dict[int, _EpisodeEntry] = {}
self._stats = CacheStats()
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
self._futures: dict[int, Future] = {}
@property
def stats(self) -> CacheStats:
with self._lock:
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
def submit_prefetch(self, ep_idx: int) -> None:
with self._lock:
if ep_idx in self._episodes or ep_idx in self._futures:
return
self._stats.prefetch_submitted += 1
fut = self._executor.submit(self._prefetch_episode, ep_idx)
self._futures[ep_idx] = fut
def ensure_ready(self, ep_idx: int) -> None:
with self._lock:
fut = self._futures.pop(ep_idx, None)
if fut is not None:
with self._lock:
self._stats.prefetch_waits += 1
fut.result()
entry = self._episodes.get(ep_idx)
if entry is None:
raise KeyError(f"episode {ep_idx} not prefetched")
if entry.error is not None:
raise entry.error
entry.ready.wait()
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
entry = self._episodes[ep_idx]
if entry.error is not None:
raise entry.error
entry.ready.wait()
return entry.decoders[video_key]
def close(self) -> None:
self._executor.shutdown(wait=False, cancel_futures=True)
def _prefetch_episode(self, ep_idx: int) -> None:
entry = _EpisodeEntry()
self._episodes[ep_idx] = entry
try:
for cam in self.byte_index.video_keys:
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
except Exception as exc:
entry.error = exc
finally:
entry.ready.set()
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
key = (ep_idx, cam)
with self._lock:
cached = self._cache.get(key)
if cached is not None:
self._cache.move_to_end(key)
self._stats.hits += 1
payload, _ = cached
t0 = time.perf_counter()
dec = self._decoder_from_payload(payload, ep_idx, cam)
with self._lock:
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
return dec
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
with self._lock:
self._stats.misses += 1
if payload_bytes > self.MAX_BYTES_PER_MISS:
logger.warning(
"byte cache miss fetched %.1f MB (>25 MB) for ep=%s cam=%s",
payload_bytes / 1e6,
ep_idx,
cam,
)
self._evict_until(payload_bytes)
self._cache[key] = (payload, payload_bytes)
self._bytes_used += payload_bytes
return dec
def _fetch_manifest_slice(self, ep_idx: int, cam: str) -> tuple[SparseMp4Reader, int, Any]:
lookup = self.byte_index.lookup(ep_idx, cam)
file_info = self.byte_index.file_lookup(lookup.file_id)
fetcher = self._get_fetcher(lookup.file_id, file_info.file_path)
t_fetch = time.perf_counter()
header = self._get_header_bytes(lookup.file_id, fetcher, file_info.header_length)
lo = lookup.mdat_offset
hi = lo + lookup.mdat_length - 1
mdat = fetcher.fetch(lo, hi)
fetch_s = time.perf_counter() - t_fetch
nbytes = len(header) + len(mdat)
with self._lock:
self._stats.bytes_fetched += nbytes
self._stats.mdat_slices += 1
self._stats.fetch_to_buffer_s += fetch_s
def lazy_fetch(pos: int, end: int) -> bytes:
data = fetcher.fetch(pos, end)
with self._lock:
self._stats.bytes_fetched += len(data)
return data
reader = SparseMp4Reader(
file_size=file_info.file_size,
header=header,
mdat_lo=lo,
mdat_bytes=mdat,
lazy_fetch=lazy_fetch,
)
t_init = time.perf_counter()
dec = self._decoder_from_payload(reader, ep_idx, cam)
self._validate_decoder(dec, lookup)
init_s = time.perf_counter() - t_init
with self._lock:
self._stats.buffer_to_decoder_s += init_s
self._rewind_payload(reader)
return reader, nbytes, dec
def _get_fetcher(self, file_id: int, rel_path: str) -> RangeFetcher:
if file_id not in self._fetcher_cache:
path = rel_path if rel_path.startswith("hf://") else f"{self.data_root}/{rel_path}"
self._fetcher_cache[file_id] = RangeFetcher(path)
return self._fetcher_cache[file_id]
def _get_header_bytes(self, file_id: int, fetcher: RangeFetcher, header_length: int) -> bytes:
if file_id in self._header_cache:
return self._header_cache[file_id]
hi = max(0, header_length - 1)
header = fetcher.fetch(0, hi)
with self._lock:
self._header_cache[file_id] = header
self._stats.bytes_fetched += len(header)
return header
def _decoder_from_payload(
self, payload: SparseMp4Reader, ep_idx: int, cam: str
) -> Any:
payload.seek(0)
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
return open_video_decoder(payload, frame_mappings=mappings)
def _validate_decoder(self, dec: Any, lookup: EpisodeSliceLookup) -> None:
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
duration = max(0.01, end - begin)
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
dec.get_frames_played_at([ts]).data
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
payload.seek(0)
def _evict_until(self, need: int) -> None:
while self._bytes_used + need > self.max_bytes and self._cache:
_, (_, size) = self._cache.popitem(last=False)
self._bytes_used -= size
+1 -1
View File
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
tolerance_s=cfg.tolerance_s,
return_uint8=True,
)
+555
View File
@@ -0,0 +1,555 @@
"""MP4 moov parsing and tight per-episode mdat byte-range fetching.
LeRobot v3 concatenates episodes into shared MP4 files (faststart: moov at head).
For streaming we fetch only the file header plus the episode's contiguous mdat span
instead of the ``0..episode_end`` prefix.
"""
from __future__ import annotations
import io
import struct
import threading
from dataclasses import dataclass, field
from typing import Callable
KEYFRAME_PAD_S = 0.1
HEADER_PROBE_BYTES = 4 * 1024 * 1024
MAX_HEADER_PROBE_BYTES = 16 * 1024 * 1024
@dataclass
class Mp4FileLayout:
file_size: int
moov_offset: int
moov_length: int
header_end: int
mdat_offset: int
mdat_size: int
faststart: bool
codec: str
def parse_mp4_file_layout(header_bytes: bytes, file_size: int) -> Mp4FileLayout:
"""Return top-level MP4 layout (moov/mdat positions, faststart flag)."""
boxes = list(_iter_boxes(header_bytes))
moov_offset = mdat_offset = -1
moov_length = mdat_size = 0
for off, size, typ, _ in boxes:
if typ == b"moov" and moov_offset < 0:
moov_offset, moov_length = off, size
if typ == b"mdat" and mdat_offset < 0:
mdat_offset, mdat_size = off, size
if moov_offset < 0:
raise ValueError("moov box not found in header probe")
if mdat_offset < 0:
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
faststart = moov_offset < mdat_offset
header_end = mdat_offset
codec = _parse_video_codec(header_bytes)
return Mp4FileLayout(
file_size=file_size,
moov_offset=moov_offset,
moov_length=moov_length,
header_end=header_end,
mdat_offset=mdat_offset,
mdat_size=mdat_size,
faststart=faststart,
codec=codec,
)
def _parse_video_codec(header_bytes: bytes) -> str:
moov = _find_box_payload(header_bytes, b"moov")
if moov is None:
return "unknown"
trak = _find_video_trak(moov)
if trak is None:
return "unknown"
stsd = _find_box_payload(_find_box_payload(trak, b"stbl") or b"", b"stsd")
if stsd is None or len(stsd) < 12:
return "unknown"
# stsd: version(1)+flags(3)+entry_count(4)+entry_size(4)+codec(4)
if len(stsd) >= 12:
return stsd[8:12].decode("latin1", errors="replace").strip("\x00")
return "unknown"
def average_fps_from_index(index: Mp4VideoIndex) -> float:
index.ensure_tables()
if index.num_samples < 2:
return 30.0
duration = index.sample_pts(index.num_samples - 1)
if duration <= 0:
return 30.0
return index.num_samples / duration
def episode_custom_frame_mappings_json(
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
) -> bytes:
"""Build TorchCodec ``custom_frame_mappings`` JSON for one episode span."""
import json
index.ensure_tables()
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
hi_idx = min(hi_idx, index.num_samples - 1)
lo_idx = _keyframe_back(index.sync_samples, lo_idx)
sync = set(index.sync_samples)
timescale = index.timescale
# stts deltas for duration per sample (expand stts entries to per-sample delta)
sample_deltas: list[int] = []
for count, delta in index.stts:
sample_deltas.extend([delta] * count)
while len(sample_deltas) < index.num_samples:
sample_deltas.append(sample_deltas[-1] if sample_deltas else timescale // 30)
frames = []
for idx in range(lo_idx, hi_idx + 1):
frames.append(
{
"pts": int(round(index._pts[idx] * timescale)),
"duration": int(sample_deltas[idx]),
"key_frame": int((idx + 1) in sync) if sync else int(idx == lo_idx),
}
)
return json.dumps({"frames": frames}).encode()
def episode_keyframes(
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
) -> list[tuple[float, int]]:
"""Return (pts_seconds, byte_offset) for sync samples in the episode span."""
index.ensure_tables()
span = index.episode_byte_span(from_ts, to_ts, keyframe_pad_s)
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
if not index.sync_samples:
return [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
out: list[tuple[float, int]] = []
for sync_one_based in index.sync_samples:
idx = sync_one_based - 1
if lo_idx <= idx <= hi_idx:
out.append((index.sample_pts(idx), index.sample_offset(idx)))
return out or [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
@dataclass
class EpisodeByteSpan:
"""Absolute file byte ranges to fetch for one episode."""
file_size: int
header_end: int
slice_lo: int
slice_hi: int
@property
def header_bytes(self) -> tuple[int, int]:
return 0, self.header_end - 1
@property
def mdat_bytes(self) -> tuple[int, int]:
return self.slice_lo, self.slice_hi
@property
def total_fetch_bytes(self) -> int:
header = self.header_end
mdat = self.slice_hi - self.slice_lo + 1
return header + mdat
@dataclass
class Mp4VideoIndex:
file_size: int
header_end: int
mdat_offset: int
mdat_size: int
timescale: int
stts: list[tuple[int, int]]
stsz: list[int]
stsc: list[tuple[int, int, int]]
stco: list[int]
sync_samples: list[int]
_pts: list[float] = field(default_factory=list, repr=False)
_offsets: list[int] = field(default_factory=list, repr=False)
def ensure_tables(self) -> None:
if self._pts:
return
self._pts = _pts_from_stts(self.stts, self.timescale)
self._offsets = _sample_byte_offsets(self.stsc, self.stco, self.stsz)
@property
def num_samples(self) -> int:
return len(self.stsz)
def sample_pts(self, index: int) -> float:
self.ensure_tables()
return self._pts[index]
def sample_offset(self, index: int) -> int:
self.ensure_tables()
index = max(0, min(index, len(self._offsets) - 1))
return self._offsets[index]
def sample_end(self, index: int) -> int:
return self.sample_offset(index) + self.stsz[index]
def episode_byte_span(self, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S) -> EpisodeByteSpan:
self.ensure_tables()
n = self.num_samples
if n == 0:
raise ValueError("MP4 has no video samples")
pad = max(keyframe_pad_s, 0.05 * max(0.01, to_ts - from_ts))
lo_ts = max(0.0, from_ts - pad)
hi_ts = to_ts + pad
lo_idx = _first_sample_at_or_after(self._pts, lo_ts)
hi_idx = _last_sample_at_or_before(self._pts, hi_ts)
hi_idx = min(hi_idx, n - 1)
lo_idx = min(lo_idx, n - 1)
lo_idx = _keyframe_back(self.sync_samples, lo_idx)
slice_lo = self.sample_offset(lo_idx)
slice_hi = self.sample_end(min(hi_idx, len(self._offsets) - 1))
return EpisodeByteSpan(
file_size=self.file_size,
header_end=self.header_end,
slice_lo=slice_lo,
slice_hi=min(slice_hi, self.file_size - 1),
)
class SparseMp4Reader(io.BufferedIOBase):
"""Range-backed MP4 reader: header + one mdat span at absolute offsets."""
def __init__(
self,
file_size: int,
header: bytes,
mdat_lo: int,
mdat_bytes: bytes,
lazy_fetch: Callable[[int, int], bytes] | None = None,
):
self._size = file_size
self._header = header
self._mdat_lo = mdat_lo
self._mdat_hi = mdat_lo + len(mdat_bytes)
self._mdat = mdat_bytes
self._lazy_fetch = lazy_fetch
self._pos = 0
self._lock = threading.Lock()
def readable(self) -> bool:
return True
def seekable(self) -> bool:
return True
def tell(self) -> int:
return self._pos
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
self._pos = offset
elif whence == io.SEEK_CUR:
self._pos += offset
elif whence == io.SEEK_END:
self._pos = self._size + offset
else:
raise ValueError(f"invalid whence: {whence}")
self._pos = max(0, min(self._pos, self._size))
return self._pos
def read(self, size: int = -1) -> bytes:
if size < 0:
size = self._size - self._pos
if size <= 0:
return b""
out = bytearray()
remaining = size
pos = self._pos
while remaining > 0 and pos < self._size:
chunk = self._read_at(pos, remaining)
if not chunk:
break
out.extend(chunk)
pos += len(chunk)
remaining -= len(chunk)
self._pos = pos
return bytes(out)
def _read_at(self, pos: int, n: int) -> bytes:
header_len = len(self._header)
if pos < header_len:
end = min(pos + n, header_len)
return self._header[pos:end]
if self._mdat_lo <= pos < self._mdat_hi:
end = min(pos + n, self._mdat_hi)
off = pos - self._mdat_lo
return self._mdat[off : off + (end - pos)]
if self._lazy_fetch is not None:
with self._lock:
end = min(pos + n, self._size)
return self._lazy_fetch(pos, end - 1)
return b"\x00" * min(n, self._size - pos)
def parse_mp4_index(header_bytes: bytes, file_size: int) -> Mp4VideoIndex:
"""Parse moov sample tables from the file header (faststart layout)."""
layout = parse_mp4_file_layout(header_bytes, file_size)
mdat_offset, mdat_size = layout.mdat_offset, layout.mdat_size
moov = _find_box_payload(header_bytes, b"moov")
if moov is None:
raise ValueError("moov box not found in MP4 header probe")
trak = _find_video_trak(moov)
if trak is None:
raise ValueError("video trak not found in moov")
mdhd = _find_box_payload(trak, b"mdhd")
if mdhd is None:
raise ValueError("mdhd not found")
timescale = _parse_mdhd_timescale(mdhd)
stbl = _find_box_payload(trak, b"stbl")
if stbl is None:
raise ValueError("stbl not found")
stts = _parse_stts(_find_box_payload(stbl, b"stts"))
stsz = _parse_stsz(_find_box_payload(stbl, b"stsz"))
stsc = _parse_stsc(_find_box_payload(stbl, b"stsc"))
stco_payload = _find_box_payload(stbl, b"stco")
co64_payload = _find_box_payload(stbl, b"co64")
if stco_payload is not None:
stco = _parse_stco(stco_payload)
elif co64_payload is not None:
stco = _parse_co64(co64_payload)
else:
raise ValueError("stco/co64 not found")
stss_payload = _find_box_payload(stbl, b"stss")
sync_samples = _parse_stss(stss_payload) if stss_payload else []
return Mp4VideoIndex(
file_size=file_size,
header_end=layout.header_end,
mdat_offset=mdat_offset,
mdat_size=mdat_size,
timescale=timescale,
stts=stts,
stsz=stsz,
stsc=stsc,
stco=stco,
sync_samples=sync_samples,
)
def _box_header(data: bytes, offset: int) -> tuple[int, bytes, int] | None:
if offset + 8 > len(data):
return None
size, typ = struct.unpack_from(">I4s", data, offset)
header = 8
if size == 1:
if offset + 16 > len(data):
return None
size = struct.unpack_from(">Q", data, offset + 8)[0]
header = 16
elif size == 0:
size = len(data) - offset
return size, typ, header
def _iter_boxes(data: bytes, start: int = 0, end: int | None = None):
end = end if end is not None else len(data)
off = start
while off + 8 <= end:
hdr = _box_header(data, off)
if hdr is None or hdr[0] < hdr[2]:
break
size, typ, header = hdr
yield off, size, typ, data[off + header : off + size]
off += size
def _find_box_payload(data: bytes, target: bytes) -> bytes | None:
for _, _, typ, payload in _iter_boxes(data):
if typ == target:
return payload
if typ in (b"moov", b"trak", b"mdia", b"minf", b"stbl"):
found = _find_box_payload(payload, target)
if found is not None:
return found
return None
def _find_video_trak(moov: bytes) -> bytes | None:
for _, _, typ, payload in _iter_boxes(moov):
if typ != b"trak":
continue
hdlr = _find_box_payload(payload, b"hdlr")
if hdlr is not None and len(hdlr) >= 12 and hdlr[8:12] == b"vide":
return payload
return None
def _find_mdat(header_bytes: bytes, file_size: int) -> tuple[int, int]:
for off, size, typ, _ in _iter_boxes(header_bytes):
if typ == b"mdat":
return off, size
# mdat may start beyond probe; scan from file_size hint unavailable — require probe hit
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
def _parse_mdhd_timescale(mdhd: bytes) -> int:
version = mdhd[0]
if version == 0:
return struct.unpack_from(">I", mdhd, 12)[0]
return struct.unpack_from(">I", mdhd, 20)[0]
def _parse_stts(stts: bytes | None) -> list[tuple[int, int]]:
if stts is None:
raise ValueError("stts missing")
count = struct.unpack_from(">I", stts, 4)[0]
out = []
off = 8
for _ in range(count):
sample_count, delta = struct.unpack_from(">II", stts, off)
out.append((sample_count, delta))
off += 8
return out
def _parse_stsz(stsz: bytes | None) -> list[int]:
if stsz is None:
raise ValueError("stsz missing")
sample_size, sample_count = struct.unpack_from(">II", stsz, 4)
if sample_size != 0:
return [sample_size] * sample_count
off = 12
return list(struct.unpack_from(f">{sample_count}I", stsz, off))
def _parse_stsc(stsc: bytes | None) -> list[tuple[int, int, int]]:
if stsc is None:
raise ValueError("stsc missing")
count = struct.unpack_from(">I", stsc, 4)[0]
out = []
off = 8
for _ in range(count):
first_chunk, samples_per_chunk, sample_desc = struct.unpack_from(">III", stsc, off)
out.append((first_chunk, samples_per_chunk, sample_desc))
off += 12
return out
def _parse_stco(stco: bytes) -> list[int]:
count = struct.unpack_from(">I", stco, 4)[0]
return list(struct.unpack_from(f">{count}I", stco, 8))
def _parse_co64(co64: bytes) -> list[int]:
count = struct.unpack_from(">I", co64, 4)[0]
return [struct.unpack_from(">Q", co64, 8 + i * 8)[0] for i in range(count)]
def _parse_stss(stss: bytes) -> list[int]:
count = struct.unpack_from(">I", stss, 4)[0]
return list(struct.unpack_from(f">{count}I", stss, 8))
def _pts_from_stts(stts: list[tuple[int, int]], timescale: int) -> list[float]:
pts: list[float] = []
t = 0
for count, delta in stts:
for _ in range(count):
pts.append(t / timescale)
t += delta
return pts
def _sample_byte_offsets(
stsc: list[tuple[int, int, int]], stco: list[int], stsz: list[int]
) -> list[int]:
if not stsc:
stsc = [(1, len(stsz), 1)]
offsets: list[int] = []
chunk_idx = 0
sample_idx = 0
sc_idx = 0
num_chunks = len(stco)
while chunk_idx < num_chunks and sample_idx < len(stsz):
first_chunk, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
if sc_idx + 1 < len(stsc):
next_first = stsc[sc_idx + 1][0]
chunks_in_entry = next_first - first_chunk
else:
chunks_in_entry = num_chunks - chunk_idx
for _ in range(chunks_in_entry):
if chunk_idx >= num_chunks:
break
offset = stco[chunk_idx]
_, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
for _ in range(samples_per_chunk):
if sample_idx >= len(stsz):
break
offsets.append(offset)
offset += stsz[sample_idx]
sample_idx += 1
chunk_idx += 1
sc_idx += 1
if len(offsets) < len(stsz):
# Pad with last known offset progression for malformed stsc edge cases.
last = offsets[-1] if offsets else 0
while len(offsets) < len(stsz):
idx = len(offsets)
offsets.append(last)
last += stsz[idx]
return offsets
def _first_sample_at_or_after(pts: list[float], ts: float) -> int:
lo, hi = 0, len(pts)
while lo < hi:
mid = (lo + hi) // 2
if pts[mid] < ts:
lo = mid + 1
else:
hi = mid
return min(lo, len(pts) - 1)
def _last_sample_at_or_before(pts: list[float], ts: float) -> int:
lo, hi = 0, len(pts)
while lo < hi:
mid = (lo + hi) // 2
if pts[mid] <= ts:
lo = mid + 1
else:
hi = mid
return max(0, lo - 1)
def _keyframe_back(sync_samples: list[int], sample_idx: int) -> int:
if not sync_samples:
return max(0, sample_idx - 2)
# stss stores 1-based sample numbers
one_based = sample_idx + 1
prev = [s for s in sync_samples if s <= one_based]
if prev:
return prev[-1] - 1
return 0
File diff suppressed because it is too large Load Diff
+49
View File
@@ -0,0 +1,49 @@
"""TorchCodec helpers for sparse MP4 IO with optional custom frame mappings."""
from __future__ import annotations
import json
from typing import Any
import torch
from torchcodec import FrameBatch, _core as core
from torchcodec.decoders._video_decoder import _get_and_validate_stream_metadata
def frame_mappings_tensors(payload: bytes) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
data = json.loads(payload)
frames = data["frames"]
pts = torch.tensor([int(f["pts"]) for f in frames], dtype=torch.int64)
key = torch.tensor([bool(f["key_frame"]) for f in frames], dtype=torch.bool)
dur = torch.tensor([int(f["duration"]) for f in frames], dtype=torch.int64)
return pts, key, dur
class VideoDecoderLike:
"""Minimal VideoDecoder surface used by episode byte cache."""
def __init__(self, decoder: torch.Tensor, *, stream_index: int | None = None):
self._decoder = decoder
(
self.metadata,
self.stream_index,
self._begin_stream_seconds,
self._end_stream_seconds,
self._num_frames,
) = _get_and_validate_stream_metadata(decoder=decoder, stream_index=stream_index)
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
return FrameBatch(*core.get_frames_by_pts(self._decoder, timestamps=seconds))
def open_video_decoder(source: Any, *, frame_mappings: bytes | None = None) -> VideoDecoderLike:
"""Open a decoder on sparse or full MP4 IO, skipping metadata scan when mappings exist."""
if frame_mappings is None:
decoder = core.create_from_file_like(source, "approximate")
core.add_video_stream(decoder)
return VideoDecoderLike(decoder)
mappings = frame_mappings_tensors(frame_mappings)
decoder = core.create_from_file_like(source, "custom_frame_mappings")
core.add_video_stream(decoder, custom_frame_mappings=mappings)
return VideoDecoderLike(decoder)
+10 -3
View File
@@ -273,7 +273,11 @@ class VideoDecoderCache:
self._cache.move_to_end(video_path)
return entry[0]
file_handle = fsspec.open(video_path).__enter__()
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
# mostly-sequential reads torchcodec issues.
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
except Exception:
@@ -322,6 +326,7 @@ def decode_video_frames_torchcodec(
log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None,
return_uint8: bool = False,
episode_decoder: Any | None = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
@@ -343,8 +348,10 @@ def decode_video_frames_torchcodec(
if decoder_cache is None:
decoder_cache = _default_decoder_cache
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
if episode_decoder is not None:
decoder = episode_decoder
else:
decoder = decoder_cache.get_decoder(str(video_path))
loaded_ts = []
loaded_frames = []
+14 -4
View File
@@ -387,7 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
if hasattr(active_cfg, "drop_n_last_frames") and not cfg.dataset.streaming:
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
@@ -426,9 +426,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
if cfg.dataset.streaming:
# The streaming IterableDataset is already rank-disjoint via split_dataset_by_node, so we must
# NOT hand the dataloader to accelerate: its IterableDatasetShard would keep only every
# world_size-th batch of each rank's already-disjoint stream (silently training on 1/N of the
# data while decoding all of it). Batches are moved to the device manually in the loop below.
policy, optimizer, lr_scheduler = accelerator.prepare(policy, optimizer, lr_scheduler)
else:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
@@ -468,6 +475,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
if cfg.dataset.streaming:
# The streaming dataloader is not accelerate-prepared (see above), so move to device here.
batch = {k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) for k, v in batch.items()}
for cam_key in dataset.meta.camera_keys:
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
+150
View File
@@ -0,0 +1,150 @@
"""Acceptance tests for manifest byte-index sidecars.
Run on a compute node (not login-node):
srun --partition=hopper-dev --nodes=1 --ntasks=1 --cpus-per-task=8 --mem=32G --time=00:30:00 \\
bash -lc 'cd /admin/home/pepijn/lerobot && conda run --no-capture-output -n lerobot \\
env -u HF_HUB_ENABLE_HF_TRANSFER python -m pytest tests/datasets/test_byte_index.py -m integration -v'
"""
from __future__ import annotations
import json
import socket
import pytest
pytest.importorskip("torchcodec")
REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
REV = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
MAX_EPISODES = 64
COMPUTE_NODE = pytest.mark.skipif(
"login" in socket.gethostname(),
reason="run on compute node via srun (see module docstring), not login-node",
)
@pytest.fixture(scope="module")
def byte_index_dir(tmp_path_factory):
from lerobot.datasets.byte_index_builder import build_byte_index_tables, write_byte_index
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
out = tmp_path_factory.mktemp("byte_index")
meta = LeRobotDatasetMetadata(REPO, revision=REV)
files, episodes, _ = build_byte_index_tables(
meta, BUCKET, workers=4, max_episodes=MAX_EPISODES, include_keyframes=False
)
write_byte_index(out, files, episodes, None, merge_existing=False)
return out, meta
@pytest.mark.integration
@COMPUTE_NODE
def test_index_load_fast_and_small(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
out, meta = byte_index_dir
index = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
assert index.load_time_s < 1.0
assert index.resident_bytes < 1_000_000_000
@pytest.mark.integration
@COMPUTE_NODE
def test_tight_fetch_under_25mb(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
stats = cache.stats.stats_dict()
assert stats["byte_cache_bytes_per_miss"] < 25 * 1024 * 1024
@pytest.mark.integration
@COMPUTE_NODE
def test_in_memory_build_matches_parquet(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
out, meta = byte_index_dir
disk = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
mem = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
for cam in meta.video_keys:
a = disk.lookup(ep, cam)
b = mem.lookup(ep, cam)
assert a.mdat_offset == b.mdat_offset
assert a.mdat_length == b.mdat_length
assert abs(a.first_pts - b.first_pts) < 1e-6
@pytest.mark.integration
@COMPUTE_NODE
def test_custom_frame_mappings_available(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cam = meta.video_keys[0]
ep = MAX_EPISODES // 2
payload = index.custom_frame_mappings(ep, cam)
assert payload is not None
data = json.loads(payload)
assert len(data["frames"]) > 10
assert any(f["key_frame"] for f in data["frames"])
assert all("pts" in f and "duration" in f for f in data["frames"])
@pytest.mark.integration
@COMPUTE_NODE
def test_metadata_skip_decoder_init(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=8_000_000_000, data_root=BUCKET)
cam = meta.video_keys[0]
ep = 0
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
dec = cache.get_decoder(ep, cam)
assert dec.metadata.num_frames is not None
assert dec.metadata.num_frames > 0
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
ts = begin + 0.5 * (end - begin)
frame = dec.get_frames_played_at([ts]).data
assert frame.ndim == 4
@pytest.mark.integration
@COMPUTE_NODE
def test_sparse_decode_produces_frames(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
cam = meta.video_keys[0]
ep = 0
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
dec = cache.get_decoder(ep, cam)
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
ts = begin + 0.5 * (end - begin)
frame = dec.get_frames_played_at([ts]).data
assert frame.ndim == 4
assert frame.numel() > 0
assert float(frame.float().std()) > 1.0
+30 -95
View File
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
@@ -25,52 +24,6 @@ from lerobot.utils.constants import ACTION
from tests.fixtures.constants import DUMMY_REPO_ID
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
rng = np.random.default_rng(streaming_ds.seed)
buffer_size = streaming_ds.buffer_size
num_shards = streaming_ds.num_shards
shards_indices = []
for shard_idx in range(num_shards):
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
shard_indices = [item["index"] for item in shard]
shards_indices.append(shard_indices)
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
frames_buffer = []
expected_indices = []
while shard_iterators: # While there are still available shards
available_shard_keys = list(shard_iterators.keys())
if not available_shard_keys:
break
# Call _infinite_generator_over_elements with current available shards (key difference!)
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
try:
frame_index = next(shard_iterators[shard_key])
if len(frames_buffer) == buffer_size:
i = next(buffer_indices_generator)
expected_indices.append(frames_buffer[i])
frames_buffer[i] = frame_index
else:
frames_buffer.append(frame_index)
except StopIteration:
del shard_iterators[shard_key] # Remove exhausted shard
rng.shuffle(frames_buffer)
expected_indices.extend(frames_buffer)
return expected_indices
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
"""Test if are correctly accessed"""
ds_num_frames = 400
@@ -120,10 +73,9 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
[False, True],
)
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
ds_num_frames = 400
ds_num_episodes = 10
buffer_size = 100
seed = 42
n_epochs = 3
@@ -138,25 +90,17 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
)
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
expected_indices = get_frames_expected_order(streaming_ds)
for _ in range(n_epochs):
streaming_indices = [frame["index"] for frame in streaming_ds]
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
)
if shuffle:
assert not frames_match
else:
assert frames_match
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
for epoch_indices in epochs:
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
if shuffle:
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
else:
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
@pytest.mark.parametrize(
@@ -164,15 +108,11 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
[False, True],
)
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
ds_num_frames = 100
ds_num_episodes = 10
buffer_size = 10
seed = 42
n_epochs = 3
data_file_size_mb = 0.001
chunks_size = 1
local_path = tmp_path / "test"
@@ -187,31 +127,21 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
chunks_size=chunks_size,
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
buffer_size=buffer_size,
seed=seed,
shuffle=shuffle,
max_num_shards=4,
)
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
for _ in range(n_epochs):
streaming_indices = [
frame["index"] for frame in streaming_ds
] # NOTE: this is the same as first_epoch_indices
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
def make_ds():
return StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
episode_pool_size=3,
seed=seed,
shuffle=shuffle,
max_num_shards=4,
)
if shuffle:
assert not frames_match
else:
assert frames_match
first = [int(frame["index"]) for frame in make_ds()]
again = [int(frame["index"]) for frame in make_ds()]
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
assert first == again, "same seed must reproduce the same order"
@pytest.mark.parametrize(
@@ -288,6 +218,11 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
check = torch.allclose(left, right) and left.shape == right.shape
else:
# Scalar numerics: streaming yields python floats/ints where map-style yields
# 0-dim tensors (long-standing accepted difference). Compare by value.
check = float(left) == float(right)
key_checks.append((key, check))
assert all(t[1] for t in key_checks), (
@@ -0,0 +1,100 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
"""
import json
import shutil
import subprocess
import sys
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
from tests.fixtures.constants import DUMMY_REPO_ID
WORKER = """
import json, sys
from accelerate import PartialState
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
state = PartialState()
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
)
indices = [int(frame["index"]) for frame in ds]
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
json.dump(payload, f)
"""
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
total_frames = 160
repo_id = f"{DUMMY_REPO_ID}-acc"
root = tmp_path / "ds"
lerobot_dataset_factory(
root=root,
repo_id=repo_id,
total_episodes=8,
total_frames=total_frames,
use_videos=False,
data_files_size_in_mb=0.001,
chunks_size=1,
)
worker = tmp_path / "worker.py"
worker.write_text(WORKER)
out_dir = tmp_path / "out"
out_dir.mkdir()
cmd = [
"accelerate",
"launch",
"--num_processes=2",
"--num_machines=1",
"--mixed_precision=no",
"--dynamo_backend=no",
"--cpu",
str(worker),
str(root),
repo_id,
str(out_dir),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
assert result.returncode == 0, (
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
)
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
rank_sets = [set(p["indices"]) for p in payloads]
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))
+430
View File
@@ -0,0 +1,430 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
prefetching, deterministic fast-forward resume, and schema parity."""
import pytest
import torch
from torch.utils.data import DataLoader
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
from tests.fixtures.constants import DUMMY_REPO_ID
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
factory(
root=root,
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
use_videos=use_videos,
data_files_size_in_mb=0.001,
chunks_size=1,
**kw,
)
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
return [int(frame["index"]) for frame in ds]
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
monkeypatch.delenv("RANK", raising=False)
monkeypatch.delenv("WORLD_SIZE", raising=False)
# No accelerate state, no env -> single process.
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
monkeypatch.setenv("RANK", "3")
monkeypatch.setenv("WORLD_SIZE", "4")
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
repo_id = f"{DUMMY_REPO_ID}-ranks"
total_frames, total_episodes = 200, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
world_size = 2
per_rank = []
for rank in range(world_size):
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
episode_pool_size=8,
max_num_shards=8,
rank=rank,
world_size=world_size,
)
per_rank.append(set(_stream_indices(ds)))
assert per_rank[0].isdisjoint(per_rank[1]), (
"ranks streamed overlapping frames (duplicate data across GPUs)"
)
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
repo_id = f"{DUMMY_REPO_ID}-workers"
total_frames, total_episodes = 120, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
)
loader = DataLoader(ds, batch_size=None, num_workers=2)
indices = [int(batch["index"]) for batch in loader]
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
"""
repo_id = f"{DUMMY_REPO_ID}-sarm"
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
episode_frames = 300
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
)
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
delta_timestamps = {ACTION: [0.0, horizon_s]}
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
episode_pool_size=1,
max_num_shards=1,
delta_timestamps=delta_timestamps,
)
horizon_frames = int(round(horizon_s * ds.fps))
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
checked = 0
for frame in ds:
idx = int(frame["index"])
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
if idx + horizon_frames < episode_frames:
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
)
checked += 1
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
repo_id = f"{DUMMY_REPO_ID}-seeds"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
def order(seed):
return _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=seed,
episode_pool_size=4,
max_num_shards=2,
)
)
assert order(0) == order(0), "same seed must reproduce the same order"
assert order(0) != order(1), "different seeds should give different orders"
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
repo_id = f"{DUMMY_REPO_ID}-epochs"
total_frames = 120
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
)
epoch_0 = _stream_indices(ds)
epoch_1 = _stream_indices(ds)
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
assert epoch_0 != epoch_1, "epoch did not reshuffle"
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
repo_id = f"{DUMMY_REPO_ID}-mix"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
)
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
repo_id = f"{DUMMY_REPO_ID}-parity"
map_ds = lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
)
stream_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
)
map_frame = map_ds[0]
stream_frame = next(iter(stream_ds))
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
for key, value in stream_frame.items():
ref = map_frame[key]
if isinstance(value, torch.Tensor):
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
)
elif isinstance(value, str):
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
else:
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
# (a long-standing, accepted difference). Compare by value rather than exact type.
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
import lerobot.datasets.streaming_dataset as sd
repo_id = f"{DUMMY_REPO_ID}-vpath"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
)
seen_paths = []
def fake_decode(video_path, query_ts, *args, **kwargs):
seen_paths.append(str(video_path))
return torch.zeros(len(query_ts), 3, 64, 96)
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
next(iter(ds))
assert seen_paths, "no video decode was issued"
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
repo_id = f"{DUMMY_REPO_ID}-shuf"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ordered = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
)
)
shuffled = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
)
)
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
assert shuffled != ordered, "shuffle did not decorrelate output order"
def test_native_resume_never_repeats_and_loss_is_bounded(tmp_path, lerobot_dataset_factory):
"""Native state_dict resume: no sample is re-yielded; loss is bounded by the shuffle buffers."""
repo_id = f"{DUMMY_REPO_ID}-native-resume"
total_frames = 100
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
)
def fresh_ds():
return StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=7,
episode_pool_size=2,
frame_shuffle_buffer_size=8,
)
ds = fresh_ds()
it = iter(ds)
consumed = [int(next(it)["index"]) for _ in range(30)]
state = ds.state_dict()
resumed_ds = fresh_ds()
resumed_ds.load_state_dict(state)
rest = [int(frame["index"]) for frame in resumed_ds]
assert not set(consumed) & set(rest), "resume re-yielded already-seen frames"
# in-flight buffer contents are skipped on resume (documented datasets behavior):
# bounded by the episode pool (2 episodes of <= ~30 frames here) + frame buffer (8)
covered = len(set(consumed) | set(rest))
max_in_flight = 2 * 30 + 8
assert covered >= total_frames - max_in_flight
assert covered + len(consumed) >= total_frames - max_in_flight
def test_pipeline_uses_native_primitives(tmp_path, lerobot_dataset_factory):
"""The tabular pipeline is pure datasets: batch(by_column) + shuffle + map + shuffle."""
repo_id = f"{DUMMY_REPO_ID}-native-pipe"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=4, total_frames=80)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=2)
import datasets as hf_datasets
assert isinstance(ds._pipeline, hf_datasets.IterableDataset)
state = ds._pipeline.state_dict() # the native resume protocol is available end-to-end
assert state is not None
# --- Plan B: random-episode admission via reshard() + multi-input-shard shuffle ---
def test_reshard_makes_one_shard_per_episode(tmp_path, lerobot_dataset_factory):
"""With one row group per episode (the writer's invariant), reshard() turns each episode into its
own shard, so num_shards == total_episodes even when many episodes share a single data file."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-reshard"
total_episodes = 3
# Default (large) data-file size packs all (unequal-length) episodes into one file, so the only way
# num_shards can reach total_episodes is per-row-group resharding.
lerobot_dataset_factory(
root=tmp_path / "ds",
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=90,
use_videos=False,
)
ds = StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
file_to_eps = ds._episode_files()
assert len(file_to_eps) == 1, "test expects all episodes packed into a single data file"
for (chunk_idx, file_idx), eps in file_to_eps.items():
rel = ds.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
assert pq.ParquetFile(str(ds.root / rel)).num_row_groups == len(eps)
assert ds.num_shards == total_episodes
def test_max_buffer_input_shards_admits_random_episodes(tmp_path, lerobot_dataset_factory):
"""max_buffer_input_shards (== concurrently-live random episodes) drives the per-batch episode mix:
a single batch should already span most of the live episodes."""
repo_id = f"{DUMMY_REPO_ID}-frac"
total_episodes = 8
lerobot_dataset_factory(
root=tmp_path / "ds",
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=240,
use_videos=False,
)
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=True,
seed=0,
episode_pool_size=total_episodes,
max_buffer_input_shards=total_episodes,
)
assert ds.max_buffer_input_shards == total_episodes
batch = 32
head = {int(frame["episode_index"]) for _, frame in zip(range(batch), ds, strict=False)}
assert len(head) >= min(total_episodes, batch) - 2, f"batch did not mix random episodes: {head}"
def test_collapsed_row_groups_raise(tmp_path, lerobot_dataset_factory):
"""A data file that collapses several episodes into a single row group (bulk df.to_parquet /
push_to_hub) must be rejected with an actionable error: reshard() cannot address its episodes."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-collapsed"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
# Rewrite every data file as a single row group (simulating the aggregate/push_to_hub collapse).
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
pq.write_table(pq.read_table(parquet_path), parquet_path)
with pytest.raises(ValueError, match="ONE ROW GROUP PER EPISODE"):
StreamingLeRobotDataset(repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3)
def test_collapsed_row_groups_can_be_bypassed(tmp_path, lerobot_dataset_factory):
"""validate_row_groups=False skips the row-group check (collapsed datasets still load, degraded)."""
import pyarrow.parquet as pq
repo_id = f"{DUMMY_REPO_ID}-collapsed-bypass"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
for parquet_path in (tmp_path / "ds" / "data").rglob("*.parquet"):
pq.write_table(pq.read_table(parquet_path), parquet_path)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, validate_row_groups=False
)
assert sorted(int(frame["index"]) for frame in ds) == list(range(90))
def test_distributed_divisibility_guard_raises(tmp_path, lerobot_dataset_factory):
"""When num_shards (== episodes after reshard) is not divisible by world_size, every rank would
stream the whole dataset; the guard must raise instead of silently degrading."""
repo_id = f"{DUMMY_REPO_ID}-divis"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=3, total_frames=90, use_videos=False
)
with pytest.raises(ValueError, match="not divisible by world_size"):
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=3, rank=0, world_size=2
)
# Bypassing the guard downgrades it to a warning (no raise).
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
episode_pool_size=3,
rank=0,
world_size=2,
validate_row_groups=False,
)
assert ds.num_shards == 3
+22 -3
View File
@@ -17,6 +17,7 @@ from pathlib import Path
import datasets
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import pytest
from datasets import Dataset
@@ -35,6 +36,24 @@ from lerobot.datasets.utils import (
)
def _to_parquet_one_row_group_per_episode(hf_dataset: Dataset, path: Path) -> None:
"""Write ``hf_dataset`` to ``path`` with one Parquet row group per episode.
Mirrors the LeRobot recording writer (one ``write_table`` per episode) so each episode stays an
independently addressable shard after ``datasets.IterableDataset.reshard()``, which
``StreamingLeRobotDataset`` relies on. ``Dataset.to_parquet`` would collapse the file into a
single row group instead.
"""
table = hf_dataset.with_format("arrow")[:]
episode_index = np.asarray(hf_dataset["episode_index"])
boundaries = np.where(np.diff(episode_index) != 0)[0] + 1
starts = [0, *boundaries.tolist()]
ends = [*boundaries.tolist(), len(episode_index)]
with pq.ParquetWriter(str(path), table.schema) as writer:
for start, end in zip(starts, ends, strict=True):
writer.write_table(table.slice(start, end - start))
def write_hf_dataset(
hf_dataset: Dataset,
local_dir: Path,
@@ -67,7 +86,7 @@ def write_hf_dataset(
# If the dataset is small enough, write it to a single file.
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
path.parent.mkdir(parents=True, exist_ok=True)
hf_dataset.to_parquet(path)
_to_parquet_one_row_group_per_episode(hf_dataset, path)
return
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
@@ -114,8 +133,8 @@ def write_hf_dataset(
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
# Write the shard to a Parquet file.
dataset_shard.to_parquet(path)
# Write the shard to a Parquet file (one row group per episode).
_to_parquet_one_row_group_per_episode(dataset_shard, path)
# Update chunk and file indices for the next iteration.
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
Generated
+3 -7
View File
@@ -1084,8 +1084,8 @@ wheels = [
[[package]]
name = "datasets"
version = "4.8.5"
source = { registry = "https://pypi.org/simple" }
version = "5.0.1.dev0"
source = { git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3#2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
dependencies = [
{ name = "dill" },
{ name = "filelock" },
@@ -1102,10 +1102,6 @@ dependencies = [
{ name = "tqdm" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
]
[[package]]
name = "debugpy"
@@ -3078,7 +3074,7 @@ requires-dist = [
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?rev=2c45eab1bb975ac3d846f2aa6217b82adec8eba3" },
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },