Compare commits

...

13 Commits

Author SHA1 Message Date
Pepijn 42d4788e4a fix(streaming): drop undeclared parquet columns that break batch collation
The data_files_root/bucket path reads an unversioned source (e.g. `main`), which can
carry extra annotation columns not in the dataset's feature contract — notably
`language_events`, a variable-length list (length 0..N per frame). Passed through to the
sample, these break default DataLoader collation ("each element in list of batch should
be of equal size"), which is why bucket jobs failed while the hub path (pinned to the
clean v3.0 revision) succeeded.

Drop any hf_dataset column not in meta.features after load. No-op on a clean revision;
removes language_events/language_persistent on main. Verified by reproducing the bucket
code path locally via --data_files_root hf://datasets/<repo> (parquet builder + main
columns): now decodes and collates instead of raising.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 17:24:30 +02:00
Pepijn 2d1c17d971 docs(streaming): note AV1 is LeRobot's default codec (vcodec=libsvtav1)
So the A100/H100 no-AV1-NVDEC limitation applies to most LeRobot v3 datasets, not just
RoboCasa — GPU decode needs an Ada GPU, an hevc/h264-encoded dataset, or a re-encode.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 17:10:18 +02:00
Pepijn 7241f029c6 docs(streaming): A100/H100 NVDEC cannot decode AV1 — correct guidance
NVIDIA's decode support matrix: the compute GPUs A100 (GA100) and H100 (GH100) have no
AV1 NVDEC decoder; only Ada (L4/L40/RTX40) and some Ampere (A10/A40/A16) do. So on
A100/H100 nodes, AV1 datasets must be decoded on CPU or re-encoded to H.265/H.264 — no
torchcodec build enables cuda AV1 decode there. Also distinguish that error from
"Unsupported device: cuda (variant: ffmpeg)", which is a torchcodec-built-without-CUDA
issue. Update diagnose_decode.py message + benchmark README accordingly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 17:08:54 +02:00
Pepijn 06ddc59913 feat(streaming): CONDA_ENV knob for the matrix submitter
Add CONDA_ENV=<name> to run each matrix job via `conda run --no-capture-output -n
<name>` — works inside the dash `sbatch --wrap` without sourcing conda.sh / activating,
and streams logs live. Point it at a conda env with a modern torchcodec (>=0.11) +
datasets (>=4.7); the default cluster `base` env is often too old to decode AV1.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 16:55:42 +02:00
Pepijn 23c58f5f9e feat(streaming): decode diagnostic + fail benchmark on 0 frames
- benchmark: raise SystemExit if 0 frames were measured, so a run that produces no
  batches (swallowed decode error, all batches dropped) fails loudly instead of being
  reported green with NaN/zero numbers (the misleading "COMPLETED" CUDA jobs).
- add benchmarks/streaming/diagnose_decode.py: isolates the streaming decode path
  (resolve path -> fsspec.open -> torchcodec VideoDecoder -> get one frame) and prints
  package versions + the first bytes of the handle. Pinpoints decode failures: bad/
  placeholder bytes vs ffmpeg/torchcodec build issue. RoboCasa videos are AV1; the
  failure message calls out AV1 decoder + NVDEC-on-Ada requirements explicitly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 16:40:24 +02:00
Pepijn b0ab57cedc fix(streaming): make matrix sbatch --wrap body POSIX-sh safe
`sbatch --wrap` runs the wrapped body under /bin/sh (dash), which has no
`set -o pipefail`, so every matrix job died on line 1 ("Illegal option -o pipefail")
before reaching the benchmark. The command has no pipes, so drop the bashism and chain
with `&&` (cd-guards the run) — fully POSIX-sh compatible. Runtime env expansion
(${HF_HOME:-$SCRATCH/hf_home}) is preserved.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 16:16:54 +02:00
Pepijn afdc084677 feat(streaming): serial-by-default matrix submitter (afterany dependency chain)
For a bandwidth-sensitive benchmark, concurrent jobs would share the network to the
Hub/bucket and corrupt throughput numbers. Chain the matrix jobs with
--dependency=afterany (captured via `sbatch --parsable`) so SLURM runs exactly one at a
time while keeping each config an isolated job (own log + per-job OOM reporting).
afterany keeps the chain going if one job fails/OOMs. SERIAL=0 restores parallel
submission for OOM-isolation-only testing.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:55:58 +02:00
Pepijn a32a2c647b feat(streaming): full-matrix SLURM submitter + results summarizer
slurm/run_streaming_matrix.sh fans the benchmark matrix (sources {hub,bucket,
warmed_bucket} x modes {single,sarm} x decode {cpu,cuda}) out as isolated single-GPU
SLURM jobs, so an OOM in one config is contained and reported per-job by SLURM. Worker
count and shuffle buffer are bounded (lower for cuda, which holds a CUDA context + NVDEC
session per worker) to avoid host/VRAM OOM. Source/mode/decode/workers/buffer/account/
partition are env-overridable; SOURCES/MODES/DECODES select subsets.

benchmarks/streaming/summarize_results.py collapses the per-run JSONs into one comparison
table + summary.csv (frames/s/node, first-batch + p50/p95/p99 latency, cache hit-rate).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:51:36 +02:00
Pepijn 343ecd7980 feat(streaming): optional GPU (NVDEC) video decode device
Add `video_decode_device` to StreamingLeRobotDataset and a `device` arg to
VideoDecoderCache, passed to torchcodec's VideoDecoder. "cuda" offloads H.264/H.265
decode to the GPU's dedicated NVDEC engine (independent of the training SMs); requires
a CUDA-enabled torchcodec build.

benchmark: `--video_decode_device` flag. With cuda + num_workers>0 it forces the
`spawn` start method (CUDA cannot init in forked workers) and disables CPU pin_memory
(frames are already on-GPU). Decode device is recorded in results and the output
filename. README documents the NVDEC option and its concurrency/IPC caveats.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:47:11 +02:00
Pepijn f7c8a526e8 feat(streaming): wallclock benchmark throughput, cross-worker cache stats, bucket source
- benchmark: frames_per_s_node now measures sustained wall-clock throughput over the
  post-warmup window. The previous metric summed inter-batch gaps, which collapse to ~0
  under async prefetch (consumer drains a pre-filled queue) and overstated throughput ~100x.
- VideoDecoderCache gains an optional shared [hits, misses, evictions] counter tensor;
  StreamingLeRobotDataset.video_decoder_cache_stats() aggregates it across DataLoader
  workers (lock-free, approximate; hit_rate preserved). Fixes empty cache stats with workers.
- StreamingLeRobotDataset.data_files_root: read bulk data/ + videos/ from an fsspec root
  (e.g. hf://buckets/<owner>/<name>) while metadata still loads from repo_id. Enables
  bucket / prewarmed-bucket benchmark sources without copying metadata. Exposed as
  benchmark --data_files_root.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 15:25:44 +02:00
Pepijn 77af66a29c fix(streaming): decode video at episode-local timestamp + from_timestamp offset
make_frame used `item["index"] / fps` (a dataset-global value) as the in-file
video timestamp. That only matches the file timeline when the whole dataset is a
single video (as in the test fixtures); on multi-file v3 datasets it decodes
out-of-range frames and crashes (e.g. RoboCasa: "Invalid frame index=23314614 ...
must be less than 41021").

Mirror the map-style reader: use the episode-local `timestamp` column as the base,
clamp delta query timestamps to per-camera episode-local bounds [0, duration], and
shift by the episode's `from_timestamp` per camera at decode time. For single-file
datasets `from_timestamp + timestamp == index / fps`, so existing parity tests are
unaffected; multi-file streaming is now correct.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 14:54:10 +02:00
Pepijn 68fa5d80b0 feat(streaming): multinode example, dataloading benchmark, distributed smoke test
- examples/scaling/train_streaming_multinode.py: Accelerate-based distributed/
  resumable streaming training (no DistributedSampler; rank/world_size auto-resolved),
  checkpoints the dataset stream state, and supports a --dummy pure-dataloading path
  with throughput logging. SLURM launcher in slurm/train_streaming_robocasa.sh.
- benchmarks/streaming/benchmark_streaming.py: dummy-consumer dataloading benchmark
  (single / sarm frame modes) emitting frames/s/node, p50/p95/p99 sample latency,
  first-batch latency, and VideoDecoderCache reuse stats as JSON + CSV. SLURM launcher
  + README documenting the source/node/mode matrix and manual bucket prewarming.
- VideoDecoderCache: add hit/miss/eviction counters and a stats() method so the
  benchmark can surface decoder thrash (no new cache, no eviction-policy change).
- tests/datasets/test_streaming_distributed.py: accelerate-launch smoke test asserting
  per-rank disjointness; skips (does not false-pass) when <2 processes spawn.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 13:48:23 +02:00
Pepijn d1fc8e298c feat(streaming): distributed + resumable HF-native StreamingLeRobotDataset
Add the large-scale streaming pieces that were missing from the frame-streaming
internals, keeping the existing Backtrackable + output-reservoir frame-shuffle:

- split_dataset_by_node(rank, world_size) before the per-shard loop so each rank
  streams a disjoint set of shards (fixes duplicate data across GPUs). rank and
  world_size auto-resolve from Accelerate state / RANK,WORLD_SIZE env / (0, 1).
- get_worker_info() shard splitting so DataLoader workers within a rank don't
  yield duplicate frames.
- Dynamic Backtrackable window (dynamic_bounds=True) sized to the requested
  delta_timestamps, removing the fixed 100-frame ceiling so long horizons (e.g. a
  SARM window ~160 frames) reach real frames instead of silently padding. Fix the
  peek_back off-by-one: history = lookback + 1.
- video_decoder_cache_size knob; default (active_shards + 1) x num_cameras so the
  live decoder working set does not thrash the VideoDecoderCache LRU.
- state_dict()/load_state_dict() for resume (per-shard HF stream state + exhausted
  set + RNG). Reservoir is re-warmed, so resumption is not bit-exact (documented).
- factory.py wires buffer_size from a new DatasetConfig.streaming_buffer_size field
  instead of repurposing max_num_shards as the worker count.

Tests: tests/datasets/test_streaming_native.py covers distributed disjointness,
worker de-duplication, the SARM-length window, resume, schema parity vs map-style,
local video path resolution, and shuffle decorrelation. 21 passed (13 existing + 8).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 13:37:30 +02:00
14 changed files with 1458 additions and 26 deletions
+91
View File
@@ -0,0 +1,91 @@
# Streaming dataloading benchmark
Measures **dataloading only** (no model) for `StreamingLeRobotDataset`: parquet read + video decode +
delta windowing + shuffle. A dummy consumer pulls batches and moves them to the device, so the numbers
isolate the data pipeline. Use it to compare sources (Hub vs. storage bucket vs. prewarmed bucket),
frame modes, and node counts, and to catch p95/p99 video-decode regressions.
## Run
```bash
python benchmarks/streaming/benchmark_streaming.py \
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
--mode sarm --batch_size 64 --num_workers 12 --num_batches 200 \
--source hub --out_dir benchmarks/streaming/results
```
Multinode (per-node throughput) goes through Accelerate under SLURM:
```bash
sbatch slurm/benchmark_streaming_robocasa.sh
```
## Matrix
| Axis | Values |
| ---------- | -------------------------------------------------------------------------------------------------------------------- |
| Source | `hub` (verify now), `bucket`, `warmed_bucket` (bucket + prewarming; with user's help later) |
| Baseline | current `main` `StreamingLeRobotDataset` on Hub streaming |
| Nodes | 1 and 2 (per-node throughput should be independent) |
| Frame mode | `single` (1 frame, all cameras; target ≥ 120 frames/s/node) · `sarm` (8 steps spaced 1s; target ≥ 320 frames/s/node) |
`--source` is a label only; the actual source is whatever `--repo_id` / `--root` / `--data_files_root`
point at.
### GPU (NVDEC) decoding
By default video is decoded on the **CPU** in each DataLoader worker, so throughput is CPU-decode-bound and
scales with `--num_workers` (capped by the dataset's `num_shards`). Pass `--video_decode_device cuda` to
offload H.264/H.265 decode to the GPU's dedicated **NVDEC** engine, which runs independently of the SMs used
for training (see <https://developer.nvidia.com/video-codec-sdk>). This requires a CUDA-enabled torchcodec
build, and because CUDA cannot initialize in forked workers the benchmark switches to the `spawn` start
method automatically when `--num_workers > 0`.
```bash
# GPU/NVDEC decode, 6 workers, bucket source
python benchmarks/streaming/benchmark_streaming.py \
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
--data_files_root hf://buckets/pepijn223/robocasa-stream \
--mode sarm --batch_size 64 --num_workers 6 --num_batches 200 \
--video_decode_device cuda --source bucket
```
Caveats with `cuda` + many workers: each worker creates its own CUDA context (VRAM overhead) and NVDEC has a
limited number of concurrent decode sessions per GPU; if you hit session/IPC limits, reduce `--num_workers`
or compare against `--num_workers 0` (single-process NVDEC, which often saturates the decode engine on its
own). Result files include the decode device in their name (`..._w6_cuda.json`).
> **Codec ⇄ NVDEC compatibility (important).** NVDEC can only decode codecs its hardware supports. LeRobot's
> **default video codec is AV1** (`VideoEncoderConfig.vcodec = "libsvtav1"`), so most v3 datasets are
> AV1-encoded — and the **A100 and H100 compute GPUs have no AV1 NVDEC decoder**
> (per NVIDIA's [decode support matrix](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new));
> only Ada (L4/L40/RTX40) and a few Ampere cards (A10/A40/A16) do. On A100/H100, AV1 must be decoded on
> **CPU**, or the dataset re-encoded to H.265/H.264 (which those GPUs' NVDEC do support). Run
> `diagnose_decode.py --video_decode_device cuda` to check your exact node before relying on `cuda` decode.
> A `cuda` torchcodec build also needs an FFmpeg with NVDEC; see
> <https://github.com/meta-pytorch/torchcodec#installing-cuda-enabled-torchcodec>.
Reference data root: bucket sources resolve through `--data_files_root hf://buckets/<owner>/<name>` (metadata
still loads from `--repo_id`). The local `single`/`sarm` CPU baselines on this dataset were ~176 / ~212
frames/s/node at `--num_workers 3` (3 cameras, fps 20).
## Metrics emitted (JSON + CSV)
`frames_per_s_node`, `samples_per_s`, `first_batch_latency_s`, `p50/p95/p99_sample_latency_ms`,
`wallclock_s`, and `video_decoder_cache` (`hits`, `misses`, `evictions`, `hit_rate`, `size`). A low
cache `hit_rate` with high `p99` is the decoder-thrash signature — raise `--video_decoder_cache_size`
or `--buffer_size`, or reduce `num_workers`.
## Bucket sources & prewarming (manual)
Prewarming is a **server-side** Hugging Face storage-bucket feature — there is no client script. To
benchmark the `warmed_bucket` source:
1. Attach a storage bucket to the dataset and enable it (see
<https://huggingface.co/docs/hub/storage-buckets>). Buckets resolve through `fsspec`, the same as
`hf://`, so no code change is needed — point `--repo_id`/`--revision` (or `--root`) at the bucket.
2. Enable **prewarming** in the bucket settings and wait for warm-up to complete.
3. Run the benchmark with `--source warmed_bucket`. Compare against the cold `--source bucket` and the
`--source hub` baseline.
Manual only — not run in CI.
+209
View File
@@ -0,0 +1,209 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataloading-only benchmark for StreamingLeRobotDataset.
A dummy consumer pulls batches and moves them to the device; no model runs, so the numbers isolate the
data pipeline (parquet read + video decode + delta windowing + shuffle). Reports per-node throughput and
sample-latency percentiles, plus video-decoder-cache reuse stats, and emits JSON + CSV.
Frame modes (matching the streaming design targets):
- ``single``: one frame, all cameras (target >= 120 frames/s/node).
- ``sarm``: an 8-step window spaced 1s (delta over 8s) (target >= 320 frames/s/node).
Example (stream from the Hub, single node):
python benchmarks/streaming/benchmark_streaming.py \
--repo_id pepijn223/robocasa_pretrain_human300_v4 --mode sarm \
--batch_size 64 --num_workers 12 --num_batches 200 --out_dir benchmarks/streaming/results
Distributed / multinode runs go through Accelerate; see ``slurm/benchmark_streaming_robocasa.sh``. Set
``--source`` purely for labeling the output (``hub`` / ``bucket`` / ``warmed_bucket``); the actual source
is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarming.
"""
import argparse
import csv
import json
import statistics
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--repo_id", type=str, required=True)
parser.add_argument("--root", type=str, default=None, help="Local/prewarmed root (else stream from Hub).")
parser.add_argument(
"--data_files_root",
type=str,
default=None,
help="fsspec root for bulk data/videos, e.g. hf://buckets/<owner>/<name>. Metadata still loads "
"from --repo_id on the Hub. Use for bucket / warmed_bucket sources.",
)
parser.add_argument("--mode", choices=["single", "sarm"], default="single")
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--buffer_size", type=int, default=2000)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument(
"--video_decode_device",
type=str,
default="cpu",
help="Decode device passed to torchcodec. 'cuda' offloads decode to the GPU's NVDEC engine "
"(needs a CUDA-enabled torchcodec build). With num_workers>0 this forces the 'spawn' start method.",
)
parser.add_argument("--num_batches", type=int, default=200)
parser.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state stats.")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--out_dir", type=str, default="benchmarks/streaming/results")
return parser.parse_args()
def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> StreamingLeRobotDataset:
# sarm: an 8-step window spaced 1s => an 8s delta window (the SARM stress case).
delta_timestamps = {ACTION: [float(t) for t in range(8)]} if args.mode == "sarm" else None
return StreamingLeRobotDataset(
args.repo_id,
root=args.root,
data_files_root=args.data_files_root,
delta_timestamps=delta_timestamps,
buffer_size=args.buffer_size,
video_decoder_cache_size=args.video_decoder_cache_size,
video_decode_device=args.video_decode_device,
tolerance_s=1e-3,
)
def percentile(values: list[float], pct: float) -> float:
if not values:
return float("nan")
ordered = sorted(values)
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
return ordered[k]
def main() -> None:
args = parse_args()
device = torch.device(args.device)
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
dataset = build_dataset(args, meta)
gpu_decode = args.video_decode_device.startswith("cuda")
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
# GPU-decoded frames are already on the GPU, so CPU pinning is irrelevant (and pinning CUDA
# tensors errors). Pin only when decode is on CPU and we copy to a CUDA device.
pin_memory=device.type == "cuda" and not gpu_decode,
drop_last=True,
prefetch_factor=2 if args.num_workers > 0 else None,
# CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method.
multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None,
)
sample_latencies_ms: list[float] = []
frames = 0
first_batch_latency_s = None
steady_start = None # wall-clock start of the post-warmup measurement window
t_start = time.perf_counter()
t_prev = t_start
for i, batch in enumerate(loader):
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
for value in batch.values():
if torch.is_tensor(value):
value.to(device, non_blocking=device.type == "cuda")
now = time.perf_counter()
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
if i == args.warmup_batches:
# Start the steady window here; the slow first batch and the prefetch queue it filled are
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
frames += args.batch_size
t_prev = now
if i + 1 >= args.num_batches:
break
now = time.perf_counter()
elapsed = now - t_start
# Wall-clock throughput over the steady window. NOT sum(inter-batch gaps): under async prefetch those
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
cache_stats = dataset.video_decoder_cache_stats()
# A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode
# error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is
# never reported green with NaN/zero numbers.
if frames == 0:
raise SystemExit(
f"FAILED: measured 0 frames over {args.num_batches} requested batches "
f"(cache misses={cache_stats.get('misses', 0)}, hits={cache_stats.get('hits', 0)}). "
"The data pipeline yielded no usable batches — inspect worker logs for decode errors. "
"Try --num_workers 0 to surface the underlying exception directly."
)
results = {
"repo_id": args.repo_id,
"source": args.source,
"mode": args.mode,
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"buffer_size": args.buffer_size,
"num_cameras": len(meta.video_keys),
"fps": meta.fps,
"device": str(device),
"video_decode_device": args.video_decode_device,
"frames_measured": frames,
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4),
"frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
"samples_per_s": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
if sample_latencies_ms
else None,
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
"wallclock_s": round(elapsed, 2),
"video_decoder_cache": cache_stats,
}
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
tag = f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}_{args.video_decode_device}"
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()}
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(flat))
writer.writeheader()
writer.writerow(flat)
print("Command config:", vars(args))
print(json.dumps(results, indent=2))
print(f"Wrote {out_dir / tag}.json and .csv")
if __name__ == "__main__":
main()
+112
View File
@@ -0,0 +1,112 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Isolate the streaming video-decode path — no SLURM, no DataLoader, no benchmark loop.
Reproduces exactly what StreamingLeRobotDataset does for one video (resolve path -> fsspec.open ->
torchcodec VideoDecoder -> get one frame) and prints the environment + the first bytes of the handle, so
a decode failure ("No valid stream found in input file") can be pinpointed: bad/placeholder bytes vs a
torchcodec/ffmpeg build issue vs a device issue.
python benchmarks/streaming/diagnose_decode.py --repo_id pepijn223/robocasa_pretrain_human300_v4
python benchmarks/streaming/diagnose_decode.py --repo_id … --data_files_root hf://buckets/<o>/<n>
python benchmarks/streaming/diagnose_decode.py --repo_id … --video_decode_device cuda
"""
import argparse
import importlib.metadata as im
import fsspec
from lerobot.datasets import LeRobotDatasetMetadata
def _version(pkg: str) -> str:
try:
return im.version(pkg)
except Exception:
return "MISSING"
def main() -> None:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--repo_id", required=True)
p.add_argument("--data_files_root", default=None, help="e.g. hf://buckets/<owner>/<name>")
p.add_argument("--revision", default=None)
p.add_argument("--video_decode_device", default="cpu")
p.add_argument("--episode", type=int, default=0)
args = p.parse_args()
print("== environment ==")
for pkg in ("torchcodec", "av", "huggingface_hub", "hf_xet", "datasets", "fsspec"):
print(f" {pkg}: {_version(pkg)}")
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
video_key = meta.video_keys[0]
rel_path = meta.get_video_file_path(args.episode, video_key)
root = args.data_files_root.rstrip("/") if args.data_files_root else meta.url_root
video_path = f"{root}/{rel_path}"
print("\n== target ==")
print(f" video_key: {video_key}")
print(f" video_path: {video_path}")
print("\n== fsspec handle ==")
try:
fh = fsspec.open(video_path).__enter__()
head = fh.read(32)
print(f" first 32 bytes (hex): {head.hex()}")
# A valid MP4/MOV has an 'ftyp' box near the start; anything else (HTML/JSON/empty) means the
# handle resolved to a placeholder or error page, not the video bytes.
looks_mp4 = b"ftyp" in head
print(f" looks like MP4 (contains 'ftyp'): {looks_mp4}")
if not looks_mp4:
print(f" !! first bytes as text: {head[:32]!r}")
fh.seek(0)
except Exception as e:
print(f" !! fsspec.open/read FAILED: {type(e).__name__}: {e}")
return
print("\n== torchcodec VideoDecoder ==")
try:
from torchcodec.decoders import VideoDecoder
decoder = VideoDecoder(fh, seek_mode="approximate", device=args.video_decode_device)
md = decoder.metadata
print(f" OK: {md.num_frames} frames, {md.average_fps} fps, codec={getattr(md, 'codec', '?')}")
frame = decoder.get_frames_at(indices=[0])
print(f" decoded frame 0: shape={tuple(frame.data.shape)}, device={frame.data.device}")
print("\nDECODE OK — the streaming pipeline can read this video on this machine.")
except Exception as e:
print(f" !! VideoDecoder FAILED: {type(e).__name__}: {e}")
print(
"\nDECODE FAILED. If the bytes above look like MP4 (ftyp=True), this is a torchcodec/ffmpeg "
"build issue, NOT bad bytes. Common cause for LeRobot v3 datasets: the videos are AV1-encoded "
"(see the 'codec' line on a working machine). Then:\n"
" - CPU decode needs an ffmpeg built with an AV1 decoder (libdav1d/libaom); a build without it "
"reports 'No valid stream found'.\n"
" - GPU/NVDEC decode of AV1 is only on AV1-capable NVDEC GPUs: Ada (L4/L40/RTX40) and some "
"Ampere (A10/A40/A16). The COMPUTE GPUs A100 and H100 have NO AV1 NVDEC decoder (per NVIDIA's "
"support matrix), so no torchcodec build enables cuda decode of AV1 on them.\n"
" - 'Unsupported device: cuda (variant: ffmpeg)' instead means torchcodec was built without "
"the CUDA backend; install a CUDA-enabled wheel (see README) — but on A100/H100 that still "
"won't decode AV1.\n"
"Fix: decode on CPU, run NVDEC on an Ada GPU, or re-encode the dataset to H.265/H.264 (which "
"A100/H100 NVDEC do support).\n"
"If ftyp=False instead, the handle resolved to a placeholder/error page (auth, revision, or Xet "
"resolution) rather than the video bytes."
)
if __name__ == "__main__":
main()
+79
View File
@@ -0,0 +1,79 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Collapse a directory of benchmark JSON results into one comparison table (and a combined CSV).
python benchmarks/streaming/summarize_results.py benchmarks/streaming/results
"""
import csv
import json
import sys
from pathlib import Path
COLUMNS = [
("source", "source"),
("mode", "mode"),
("video_decode_device", "decode"),
("num_workers", "workers"),
("batch_size", "bs"),
("frames_per_s_node", "frames/s/node"),
("first_batch_latency_s", "first_batch_s"),
("p50_sample_latency_ms", "p50_ms"),
("p95_sample_latency_ms", "p95_ms"),
("p99_sample_latency_ms", "p99_ms"),
]
def main() -> None:
results_dir = Path(sys.argv[1] if len(sys.argv) > 1 else "benchmarks/streaming/results")
files = sorted(results_dir.rglob("*.json"))
if not files:
print(f"No JSON results under {results_dir}")
return
rows = []
for f in files:
d = json.loads(f.read_text())
d["hit_rate"] = d.get("video_decoder_cache", {}).get("hit_rate")
rows.append(d)
rows.sort(key=lambda r: (r.get("source", ""), r.get("mode", ""), r.get("video_decode_device", "")))
headers = [label for _, label in COLUMNS] + ["cache_hit_rate"]
widths = {h: len(h) for h in headers}
table = []
for r in rows:
row = {label: r.get(key, "") for key, label in COLUMNS}
row["cache_hit_rate"] = r.get("hit_rate", "")
table.append(row)
for h in headers:
widths[h] = max(widths[h], len(str(row[h])))
line = " ".join(h.ljust(widths[h]) for h in headers)
print(line)
print(" ".join("-" * widths[h] for h in headers))
for row in table:
print(" ".join(str(row[h]).ljust(widths[h]) for h in headers))
combined = results_dir / "summary.csv"
with open(combined, "w", newline="") as fh:
writer = csv.DictWriter(fh, fieldnames=headers)
writer.writeheader()
writer.writerows(table)
print(f"\nWrote {combined}")
if __name__ == "__main__":
main()
@@ -0,0 +1,169 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Distributed, resumable streaming training on a large HF-hosted dataset.
This example shows how to train (or just stress the data pipeline) over a multi-TB dataset that never
touches local disk, scaling across GPUs and nodes with Accelerate. It demonstrates the large-scale
streaming features of :class:`StreamingLeRobotDataset`:
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
- DataLoader-worker shard splitting (no duplicate frames within a rank);
- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint;
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
Launch with Accelerate (single node, N GPUs):
accelerate launch --num_processes=8 examples/scaling/train_streaming_multinode.py \
--repo_id=pepijn223/robocasa_pretrain_human300_v4 --batch_size=64
Multinode runs use the same script under SLURM; see ``slurm/train_streaming_robocasa.sh``.
Pass ``--dummy`` to skip the model entirely and measure pure dataloading throughput.
"""
import argparse
import time
from pathlib import Path
import torch
from accelerate import Accelerator
from torch.utils.data import DataLoader
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--repo_id", type=str, default="lerobot/droid_1.0.1")
parser.add_argument(
"--root", type=str, default=None, help="Local/prewarmed dataset root (else stream from Hub)."
)
parser.add_argument("--output_dir", type=str, default="outputs/train/streaming_multinode")
parser.add_argument("--steps", type=int, default=1000)
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument(
"--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames."
)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
parser.add_argument("--save_freq", type=int, default=200)
parser.add_argument("--log_freq", type=int, default=20)
parser.add_argument("--resume_from", type=str, default=None, help="Checkpoint dir to resume from.")
parser.add_argument("--dummy", action="store_true", help="Skip the model; measure dataloading only.")
return parser.parse_args()
def make_dataloader(
args: argparse.Namespace, meta: LeRobotDatasetMetadata
) -> tuple[DataLoader, StreamingLeRobotDataset]:
# Supervise an action chunk; delta_timestamps drive the SARM-style temporal window.
delta_timestamps = {ACTION: [t / meta.fps for t in range(args.n_action_steps)]}
# rank / world_size are resolved automatically from the Accelerate state inside the dataset.
dataset = StreamingLeRobotDataset(
args.repo_id,
root=args.root,
delta_timestamps=delta_timestamps,
buffer_size=args.buffer_size,
video_decoder_cache_size=args.video_decoder_cache_size,
tolerance_s=1e-3,
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
prefetch_factor=2 if args.num_workers > 0 else None,
)
return loader, dataset
def main() -> None:
args = parse_args()
accelerator = Accelerator()
output_dir = Path(args.output_dir)
if accelerator.is_main_process:
output_dir.mkdir(parents=True, exist_ok=True)
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
loader, dataset = make_dataloader(args, meta)
if args.dummy:
model = optimizer = None
else:
from lerobot.policies.act import ACTConfig, ACTPolicy
from lerobot.utils.feature_utils import dataset_to_policy_features
features = dataset_to_policy_features(meta.features)
output_features = {k: ft for k, ft in features.items() if k == ACTION}
input_features = {k: ft for k, ft in features.items() if k not in output_features}
cfg = ACTConfig(input_features=input_features, output_features=output_features)
model = ACTPolicy(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
# Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds
# plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a
# checkpoint this script wrote itself.
if args.resume_from is not None:
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614
dataset.load_state_dict(state)
accelerator.print(f"Resumed dataset stream from {args.resume_from}")
step = 0
frames_seen = 0
window_start = time.perf_counter()
done = False
while not done:
for batch in loader:
if model is not None:
batch = {k: (v.to(accelerator.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
loss, _ = model.forward(batch)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
step += 1
frames_seen += args.batch_size
if step % args.log_freq == 0:
elapsed = time.perf_counter() - window_start
fps_per_proc = (args.log_freq * args.batch_size) / max(elapsed, 1e-9)
total_fps = fps_per_proc * accelerator.num_processes
accelerator.print(
f"step {step} | {fps_per_proc:.1f} frames/s/proc | {total_fps:.1f} frames/s total"
+ ("" if model is None else f" | loss {loss.item():.3f}")
)
window_start = time.perf_counter()
if step % args.save_freq == 0 and accelerator.is_main_process:
ckpt = output_dir / f"checkpoint-{step}"
ckpt.mkdir(parents=True, exist_ok=True)
# Save the dataset stream position alongside the model so a restart resumes mid-stream.
torch.save(dataset.state_dict(), ckpt / "dataset_state.pt")
if model is not None:
accelerator.unwrap_model(model).save_pretrained(ckpt)
if step >= args.steps:
done = True
break
accelerator.print(f"End of training: {step} steps, ~{frames_seen} frames/proc")
if __name__ == "__main__":
main()
+40
View File
@@ -0,0 +1,40 @@
#!/bin/bash
#SBATCH --job-name=bench_stream
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8
#SBATCH --cpus-per-task=96
#SBATCH --exclusive
#SBATCH --time=02:00:00
#SBATCH --output=logs/%x-%j.out
# Per-node dataloading benchmark for StreamingLeRobotDataset across 1-2 nodes. Each node runs an
# independent dummy-consumer benchmark; per-node throughput should be independent (separate network).
# Results are written per (node, source, mode) under --out_dir.
#
# Submit with: sbatch slurm/benchmark_streaming_robocasa.sh
# Override the source label for cold/warm bucket runs: SOURCE=warmed_bucket sbatch slurm/benchmark_streaming_robocasa.sh
set -euo pipefail
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
SOURCE=${SOURCE:-hub}
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
export TOKENIZERS_PARALLELISM=false
# One benchmark process per node (each saturates the node's DataLoader workers + network independently).
srun --kill-on-bad-exit=1 bash -c '
for MODE in single sarm; do
python benchmarks/streaming/benchmark_streaming.py \
--repo_id '"$REPO_ID"' \
--source '"$SOURCE"' \
--mode $MODE \
--batch_size 64 \
--num_workers 12 \
--buffer_size 4000 \
--num_batches 300 \
--out_dir '"$OUT_DIR"'/node${SLURM_NODEID}
done
'
+100
View File
@@ -0,0 +1,100 @@
#!/bin/bash
# Submit the FULL streaming dataloading-benchmark matrix as isolated single-GPU SLURM jobs.
#
# sources : hub (Hub streaming) | bucket (cold HF bucket) | warmed_bucket (prewarmed HF bucket)
# modes : single (1 frame, all cameras) | sarm (8-step / 8s delta window)
# decode : cpu (torchcodec on CPU, scales with workers) | cuda (NVDEC, offloads decode to the GPU)
#
# => 3 x 2 x 2 = 12 jobs. Each runs in its OWN job (1 node, 1 GPU) so an OOM is isolated and reported
# per-job by SLURM (check `sacct -j <id> --format=JobID,State,MaxRSS,ReqMem`). Submit from a login node
# inside the repo: bash slurm/run_streaming_matrix.sh
#
# SERIAL (default 1): chain the jobs with --dependency=afterany so SLURM runs exactly ONE at a time. This
# is important for a bandwidth benchmark — concurrent jobs would share the network to the Hub/bucket and
# corrupt every throughput number. `afterany` means a failed/OOM'd job does not stall the chain. Set
# SERIAL=0 to let the scheduler run them in parallel (only for OOM-isolation testing, not for throughput).
#
# Knobs (env overrides):
# REPO_ID, BUCKET, WARM_BUCKET, OUT_DIR, NUM_BATCHES, TIME, MEM, GPUS, SERIAL
# CPU_WORKERS / CPU_BUFFER (cpu-decode jobs) GPU_WORKERS / GPU_BUFFER (cuda-decode jobs, kept low to
# bound VRAM + NVDEC sessions). RUN ("python" by default; set RUN="uv run python" if using uv).
# SOURCES / MODES / DECODES to run a subset (e.g. SOURCES="hub bucket" DECODES="cpu").
# ACCOUNT / PARTITION / QOS passed through to sbatch if set.
set -euo pipefail
REPO_DIR=$(git rev-parse --show-toplevel)
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
BUCKET=${BUCKET:-hf://buckets/pepijn223/robocasa-stream}
WARM_BUCKET=${WARM_BUCKET:-hf://buckets/pepijn223/robocasa-stream-warm}
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
NUM_BATCHES=${NUM_BATCHES:-200}
TIME=${TIME:-01:00:00}
MEM=${MEM:-64G}
GPUS=${GPUS:-1}
SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement)
CPU_WORKERS=${CPU_WORKERS:-8}
GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session
CPU_BUFFER=${CPU_BUFFER:-4000}
GPU_BUFFER=${GPU_BUFFER:-1000} # smaller buffer bounds on-GPU frame memory
BATCH_SIZE=${BATCH_SIZE:-64}
RUN=${RUN:-python}
# CONDA_ENV=<name> runs each job via `conda run -n <name>` (no activation needed inside the dash --wrap;
# --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11)
# + datasets (>=4.7) — the default `base` env on many clusters is too old to decode AV1 / lacks CUDA.
CONDA_ENV=${CONDA_ENV:-}
if [ -n "$CONDA_ENV" ] && [ "$RUN" = "python" ]; then
RUN="conda run --no-capture-output -n $CONDA_ENV python"
fi
SOURCES=${SOURCES:-"hub bucket warmed_bucket"}
MODES=${MODES:-"single sarm"}
DECODES=${DECODES:-"cpu cuda"}
mkdir -p "$REPO_DIR/logs" "$REPO_DIR/$OUT_DIR"
data_root_for () {
case "$1" in
hub) echo "" ;;
bucket) echo "$BUCKET" ;;
warmed_bucket) echo "$WARM_BUCKET" ;;
esac
}
n=0
prev_jid=""
for SOURCE in $SOURCES; do
DATA_ROOT=$(data_root_for "$SOURCE")
ROOTFLAG=""
[ -n "$DATA_ROOT" ] && ROOTFLAG="--data_files_root $DATA_ROOT"
for MODE in $MODES; do
for DECODE in $DECODES; do
if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi
# Run strictly after the previous job so only one job touches the network at a time.
DEPFLAG=""
if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi
jid=$(sbatch --parsable \
--job-name="bench_${SOURCE}_${MODE}_${DECODE}" \
--nodes=1 --ntasks=1 --gpus="$GPUS" --cpus-per-task=$((W + 4)) \
--mem="$MEM" --time="$TIME" --output="$REPO_DIR/logs/%x-%j.out" \
$DEPFLAG \
${ACCOUNT:+--account=$ACCOUNT} ${PARTITION:+--partition=$PARTITION} ${QOS:+--qos=$QOS} \
--wrap "cd '$REPO_DIR' && \
export TOKENIZERS_PARALLELISM=false && export HF_HOME=\${HF_HOME:-\$SCRATCH/hf_home} && \
$RUN benchmarks/streaming/benchmark_streaming.py \
--repo_id $REPO_ID $ROOTFLAG \
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
--batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
prev_jid=$jid
n=$((n + 1))
done
done
done
echo
echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))."
echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)"
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_<decode>.{json,csv}"
echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR"
+49
View File
@@ -0,0 +1,49 @@
#!/bin/bash
#SBATCH --job-name=stream_robocasa
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8
#SBATCH --cpus-per-task=96
#SBATCH --exclusive
#SBATCH --time=24:00:00
#SBATCH --output=logs/%x-%j.out
# Multinode streaming training over a large HF-hosted RoboCasa dataset (never touches local disk).
# Launches examples/scaling/train_streaming_multinode.py with Accelerate. Each rank streams a disjoint
# set of shards via split_dataset_by_node (auto-resolved from the Accelerate state), so per-node
# throughput scales independently. For an even split, ensure n_shards % (nodes * gpus_per_node) == 0.
#
# Submit with: sbatch slurm/train_streaming_robocasa.sh
set -euo pipefail
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
GPUS_PER_NODE=8
NUM_PROCESSES=$((SLURM_NNODES * GPUS_PER_NODE))
# Rendezvous: use the first node in the allocation as the main process.
MAIN_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
MAIN_PORT=${MAIN_PORT:-29500}
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
# Avoid each rank fighting over the tokenizers' internal thread pool.
export TOKENIZERS_PARALLELISM=false
srun --kill-on-bad-exit=1 bash -c '
accelerate launch \
--num_machines '"$SLURM_NNODES"' \
--num_processes '"$NUM_PROCESSES"' \
--machine_rank $SLURM_NODEID \
--main_process_ip '"$MAIN_ADDR"' \
--main_process_port '"$MAIN_PORT"' \
--mixed_precision bf16 \
--dynamo_backend no \
examples/scaling/train_streaming_multinode.py \
--repo_id '"$REPO_ID"' \
--batch_size 64 \
--num_workers 12 \
--buffer_size 4000 \
--steps 200000 \
--save_freq 2000 \
--log_freq 50
'
+3
View File
@@ -39,6 +39,9 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost
# of host RAM. Ignored when streaming is False.
streaming_buffer_size: int = 1000
def __post_init__(self) -> None:
if self.episodes is not None:
+1 -1
View File
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
buffer_size=cfg.dataset.streaming_buffer_size,
tolerance_s=cfg.tolerance_s,
return_uint8=True,
)
+214 -23
View File
@@ -13,6 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from collections import deque
from collections.abc import Callable, Generator, Iterable, Iterator
from pathlib import Path
@@ -21,6 +24,7 @@ import datasets
import numpy as np
import torch
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
@@ -38,6 +42,8 @@ from .video_utils import (
decode_video_frames_torchcodec,
)
logger = logging.getLogger(__name__)
class LookBackError(Exception):
"""
@@ -252,6 +258,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
rng: np.random.Generator | None = None,
shuffle: bool = True,
return_uint8: bool = False,
rank: int | None = None,
world_size: int | None = None,
video_decoder_cache_size: int | None = None,
data_files_root: str | None = None,
video_decode_device: str = "cpu",
):
"""Initialize a StreamingLeRobotDataset.
@@ -272,6 +283,25 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
seed (int, optional): Reproducibility random seed.
rng (np.random.Generator | None, optional): Random number generator.
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
rank (int | None, optional): This process' rank for distributed (multi-GPU/multi-node) training.
Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, it is
resolved from Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0.
world_size (int | None, optional): Total number of distributed processes. When omitted, resolved
from Accelerate (``num_processes``) or the ``WORLD_SIZE`` env var, defaulting to 1 (no sharding).
For an even per-rank split, ``num_shards % world_size == 0`` should hold.
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working
set of live decoders never thrashes. See :class:`VideoDecoderCache`.
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
trees (e.g. an HF storage bucket ``hf://buckets/<owner>/<name>``). When set, parquet and
video frames are read from there while metadata still loads from ``repo_id`` on the Hub.
Resolves through fsspec exactly like ``hf://``; use it to benchmark bucket / prewarmed-bucket
sources without copying the (small) metadata.
video_decode_device (str, optional): Device for video decoding, passed to the torchcodec
``VideoDecoder``. Defaults to ``"cpu"``. Set to ``"cuda"`` to offload H.264/H.265 decode to
the GPU's dedicated NVDEC engine (independent of the training SMs), which requires a
CUDA-enabled torchcodec build. Note: ``"cuda"`` decode inside ``DataLoader`` workers needs
the ``spawn`` start method (CUDA cannot init in forked workers).
"""
super().__init__()
self.repo_id = repo_id
@@ -289,10 +319,21 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.streaming = streaming
self.buffer_size = buffer_size
self.max_num_shards = max_num_shards
self._return_uint8 = return_uint8
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
self.video_decoder_cache_size = video_decoder_cache_size
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
self.video_decode_device = video_decode_device
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
# Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into
# one place the main process can read after iteration (see video_decoder_cache_stats()).
self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_()
# Resume state captured by load_state_dict() and consumed at the next __iter__.
self._resume_state: dict | None = None
if self._requested_root is not None:
self.root.mkdir(exist_ok=True, parents=True)
@@ -314,13 +355,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.delta_timestamps = delta_timestamps
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
self.hf_dataset: datasets.IterableDataset = load_dataset(
self.repo_id if not self.streaming_from_local else str(self.root),
split="train",
streaming=self.streaming,
data_files="data/*/*.parquet",
revision=self.revision,
)
if self.data_files_root is not None:
# Bulk data lives in an fsspec root (e.g. an HF storage bucket); metadata stays on the Hub.
self.hf_dataset: datasets.IterableDataset = load_dataset(
"parquet",
split="train",
streaming=self.streaming,
data_files=f"{self.data_files_root}/data/*/*.parquet",
)
else:
self.hf_dataset = load_dataset(
self.repo_id if not self.streaming_from_local else str(self.root),
split="train",
streaming=self.streaming,
data_files="data/*/*.parquet",
revision=self.revision,
)
# Drop any parquet columns not declared in the dataset's feature contract. Some revisions / sources
# (e.g. an unversioned bucket holding `main`) carry extra, possibly variable-length annotation
# columns such as `language_events`; left in, they leak into the sample and break default DataLoader
# collation across frames of differing length. On a clean revision this is a no-op.
known_columns = set(self.meta.features)
extra_columns = [c for c in (self.hf_dataset.column_names or []) if c not in known_columns]
if extra_columns:
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
@@ -348,22 +407,99 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
while True:
yield rng.choice(elements)
@staticmethod
def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]:
"""Resolve (rank, world_size) for distributed streaming.
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
"""
if rank is not None and world_size is not None:
return rank, world_size
try:
from accelerate.state import PartialState
if PartialState._shared_state: # only read it if already initialized; never initialize here
state = PartialState()
return state.process_index, state.num_processes
except Exception:
logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.")
env_rank = os.environ.get("RANK")
env_world = os.environ.get("WORLD_SIZE")
if env_rank is not None and env_world is not None:
return int(env_rank), int(env_world)
return 0, 1
def _make_video_decoder_cache(self, num_active_shards: int) -> VideoDecoderCache:
"""Size the decoder cache to the working set of live shards so it does not thrash.
Each shard mid-episode keeps one open decoder per camera; with several shards iterated
concurrently the working set is ``num_active_shards × num_cameras``. We add one shard worth of
margin so the round-robin never evicts a still-live decoder.
"""
if self.video_decoder_cache_size is not None:
return VideoDecoderCache(
max_size=self.video_decoder_cache_size,
counters=self._cache_counters,
device=self.video_decode_device,
)
num_cameras = len(self.meta.video_keys)
if num_cameras == 0:
return VideoDecoderCache(counters=self._cache_counters, device=self.video_decode_device)
return VideoDecoderCache(
max_size=(num_active_shards + 1) * num_cameras,
counters=self._cache_counters,
device=self.video_decode_device,
)
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
# The current sequential iteration is a bottleneck. A producer-consumer pattern
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
# in parallel, feeding a queue from which this iterator will yield processed items.
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
if self.video_decoder_cache is None:
self.video_decoder_cache = VideoDecoderCache()
# Distributed correctness: each rank streams a disjoint set of shards (order preserved).
ds = self.hf_dataset
if self.world_size > 1:
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
num_shards = min(ds.num_shards, self.max_num_shards)
shard_indices = list(range(num_shards))
# DataLoader workers within this rank further split the shards so they don't yield duplicates.
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
shard_indices = shard_indices[worker_info.id :: worker_info.num_workers]
self.video_decoder_cache = self._make_video_decoder_cache(len(shard_indices))
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
# Best-effort resume: restore RNG + exhausted shards and rewind each shard's HF stream. The
# shuffle buffer is re-warmed rather than restored, so resumption is not bit-exact (acceptable
# for pretraining); the underlying stream may also skip the few frames Backtrackable read ahead.
resume = self._resume_state
self._resume_state = None
self._exhausted: set[int] = set(resume["exhausted"]) if resume is not None else set()
if resume is not None:
rng.bit_generator.state = resume["rng"]
self._shards: dict[int, datasets.IterableDataset] = {}
for idx in shard_indices:
shard = safe_shard(ds, idx, num_shards)
if resume is not None and str(idx) in resume["shards"]:
shard.load_state_dict(resume["shards"][str(idx)])
self._shards[idx] = shard
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
idx_to_backtrack_dataset = {
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
for idx in range(self.num_shards)
idx: self._make_backtrackable_dataset(shard)
for idx, shard in self._shards.items()
if idx not in self._exhausted
}
# This buffer is populated while iterating on the dataset's shards
@@ -389,11 +525,47 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
StopIteration,
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
self._exhausted.add(shard_key)
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
rng.shuffle(frames_buffer)
yield from frames_buffer
def state_dict(self) -> dict:
"""Capture resume state: per-shard HF stream position, exhausted shards, and RNG state.
Must be called after iteration has started (so the shard streams exist). Restore the returned
dict with :meth:`load_state_dict` before re-iterating. The shuffle buffer is not captured, so
resumption is not bit-exact — see :meth:`__iter__`.
"""
if not hasattr(self, "_shards"):
raise RuntimeError("state_dict() requires the dataset to have been iterated at least once.")
return {
"shards": {str(idx): shard.state_dict() for idx, shard in self._shards.items()},
"exhausted": sorted(self._exhausted),
"rng": self.rng.bit_generator.state,
}
def load_state_dict(self, state_dict: dict) -> None:
"""Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``."""
self._resume_state = state_dict
def video_decoder_cache_stats(self) -> dict[str, int | float]:
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
approximate; the ``hit_rate`` ratio is preserved.
"""
hits, misses, evictions = (int(x) for x in self._cache_counters.tolist())
total = hits + misses
return {
"hits": hits,
"misses": misses,
"evictions": evictions,
"hit_rate": round(hits / total, 4) if total else 0.0,
}
def _get_window_steps(
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
) -> tuple[int, int]:
@@ -405,19 +577,23 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
lookback = LOOKBACK_BACKTRACKTABLE
lookahead = LOOKAHEAD_BACKTRACKTABLE
else:
# Dynamically adjust the windows based on the given delta_timesteps
# Dynamically size the windows to exactly cover the requested delta_timestamps (in frames).
# This removes the fixed LOOKAHEAD_BACKTRACKTABLE ceiling, which would raise LookAheadError for
# long horizons (e.g. a SARM window of 8 steps spaced 1s = ~160 frames @ fps20).
all_timestamps = sum(delta_timestamps.values(), [])
lookback = min(all_timestamps) * self.fps
lookahead = max(all_timestamps) * self.fps
lookback = math.floor(min(all_timestamps) * self.fps)
lookahead = math.ceil(max(all_timestamps) * self.fps)
# When lookback is >=0 it means no negative timesteps have been provided
lookback = 0 if lookback >= 0 else (lookback * -1)
lookback = 0 if lookback >= 0 else -lookback
return lookback, lookahead
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
lookback, lookahead = self._get_window_steps(self.delta_timestamps, dynamic_bounds=True)
# Backtrackable.peek_back(n) needs `history >= n + 1`, so reach a frame `lookback` steps back requires
# history = lookback + 1. history must be >= 1 and lookahead > 0, so clamp both to at least 1.
return Backtrackable(dataset, history=max(1, lookback + 1), lookahead=max(1, lookahead))
def _make_timestamps_from_indices(
self, start_ts: float, indices: dict[str, list[int]] | None = None
@@ -473,13 +649,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# Get episode index from the item
ep_idx = item["episode_index"]
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
current_ts = item["index"] / self.fps
# `timestamp` is episode-local (restarts at 0 each episode). The absolute in-file timestamp is
# `from_timestamp + timestamp`, applied per camera at decode time (see `_query_videos`), mirroring
# the map-style reader. Using `index / fps` here is a dataset-global value that only matches the
# file timeline when the whole dataset is a single video (e.g. small test fixtures), and otherwise
# decodes out-of-range frames on multi-file v3 datasets.
current_ts = float(item["timestamp"])
# Per-camera episode-local bounds [0, duration]. Query timestamps are clamped into this range so
# out-of-episode deltas pad rather than decode against a neighbouring episode in the same file.
episode_boundaries_ts = {
key: (
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
0.0,
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"]
- self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
)
for key in self.meta.video_keys
}
@@ -552,11 +735,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
item = {}
for video_key, query_ts in query_timestamps.items():
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
if self.data_files_root is not None:
root = self.data_files_root
elif self.streaming and not self.streaming_from_local:
root = self.meta.url_root
else:
root = self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
frames = decode_video_frames_torchcodec(
video_path,
query_ts,
shifted_query_ts,
self.tolerance_s,
decoder_cache=self.video_decoder_cache,
return_uint8=self._return_uint8,
+40 -2
View File
@@ -242,7 +242,12 @@ class VideoDecoderCache:
_SENTINEL: ClassVar[object] = object()
def __init__(self, max_size: int | None | object = _SENTINEL):
def __init__(
self,
max_size: int | None | object = _SENTINEL,
counters: "torch.Tensor | None" = None,
device: str = "cpu",
):
if max_size is VideoDecoderCache._SENTINEL:
max_size = _default_max_cache_size()
if max_size is not None and max_size <= 0:
@@ -250,6 +255,18 @@ class VideoDecoderCache:
self.max_size: int | None = max_size # type: ignore[assignment]
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
self._lock = Lock()
# Decode device for the underlying torchcodec VideoDecoder. "cuda" offloads H.264/H.265 decode to
# the GPU's dedicated NVDEC engine (independent of the SMs used for training); requires a
# CUDA-enabled torchcodec/FFmpeg build. See https://developer.nvidia.com/video-codec-sdk.
self.device = device
# Observability counters (cheap, updated under the lock) for benchmarking decoder reuse.
self.hits = 0
self.misses = 0
self.evictions = 0
# Optional shared [hits, misses, evictions] tensor so DataLoader workers aggregate into one place
# (the per-worker `self.*` ints are invisible to the main process). Lock-free across processes, so
# treat the aggregate as approximate; the hit-rate ratio is preserved.
self._counters = counters
def __contains__(self, video_path: object) -> bool:
with self._lock:
@@ -271,11 +288,17 @@ class VideoDecoderCache:
entry = self._cache.get(video_path)
if entry is not None:
self._cache.move_to_end(video_path)
self.hits += 1
if self._counters is not None:
self._counters[0] += 1
return entry[0]
self.misses += 1
if self._counters is not None:
self._counters[1] += 1
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device)
except Exception:
file_handle.close()
raise
@@ -287,6 +310,9 @@ class VideoDecoderCache:
if self.max_size is not None:
while len(self._cache) > self.max_size:
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
self.evictions += 1
if self._counters is not None:
self._counters[2] += 1
with contextlib.suppress(Exception):
evicted_handle.close()
@@ -305,6 +331,18 @@ class VideoDecoderCache:
with self._lock:
return len(self._cache)
def stats(self) -> dict[str, int | float]:
"""Return reuse counters (hits/misses/evictions, hit rate, current size) for benchmarking."""
with self._lock:
total = self.hits + self.misses
return {
"hits": self.hits,
"misses": self.misses,
"evictions": self.evictions,
"hit_rate": self.hits / total if total else 0.0,
"size": len(self._cache),
}
class FrameTimestampError(ValueError):
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
@@ -0,0 +1,100 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
"""
import json
import shutil
import subprocess
import sys
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
from tests.fixtures.constants import DUMMY_REPO_ID
WORKER = """
import json, sys
from accelerate import PartialState
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
state = PartialState()
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=root, shuffle=False, buffer_size=8, max_num_shards=8
)
indices = [int(frame["index"]) for frame in ds]
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
json.dump(payload, f)
"""
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
total_frames = 160
repo_id = f"{DUMMY_REPO_ID}-acc"
root = tmp_path / "ds"
lerobot_dataset_factory(
root=root,
repo_id=repo_id,
total_episodes=8,
total_frames=total_frames,
use_videos=False,
data_files_size_in_mb=0.001,
chunks_size=1,
)
worker = tmp_path / "worker.py"
worker.write_text(WORKER)
out_dir = tmp_path / "out"
out_dir.mkdir()
cmd = [
"accelerate",
"launch",
"--num_processes=2",
"--num_machines=1",
"--mixed_precision=no",
"--dynamo_backend=no",
"--cpu",
str(worker),
str(root),
repo_id,
str(out_dir),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
assert result.returncode == 0, (
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
)
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
rank_sets = [set(p["indices"]) for p in payloads]
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))
+251
View File
@@ -0,0 +1,251 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity."""
import pytest
import torch
from torch.utils.data import DataLoader
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
from tests.fixtures.constants import DUMMY_REPO_ID
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
factory(
root=root,
repo_id=repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
use_videos=use_videos,
data_files_size_in_mb=0.001,
chunks_size=1,
**kw,
)
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
return [int(frame["index"]) for frame in ds]
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
monkeypatch.delenv("RANK", raising=False)
monkeypatch.delenv("WORLD_SIZE", raising=False)
# No accelerate state, no env -> single process.
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
monkeypatch.setenv("RANK", "3")
monkeypatch.setenv("WORLD_SIZE", "4")
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
repo_id = f"{DUMMY_REPO_ID}-ranks"
total_frames, total_episodes = 200, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
world_size = 2
per_rank = []
for rank in range(world_size):
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
buffer_size=8,
max_num_shards=8,
rank=rank,
world_size=world_size,
)
per_rank.append(set(_stream_indices(ds)))
assert per_rank[0].isdisjoint(per_rank[1]), (
"ranks streamed overlapping frames (duplicate data across GPUs)"
)
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
repo_id = f"{DUMMY_REPO_ID}-workers"
total_frames, total_episodes = 120, 8
_make_local_dataset(
lerobot_dataset_factory,
tmp_path / "ds",
repo_id,
total_episodes=total_episodes,
total_frames=total_frames,
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4
)
loader = DataLoader(ds, batch_size=None, num_workers=2)
indices = [int(batch["index"]) for batch in loader]
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
"""
repo_id = f"{DUMMY_REPO_ID}-sarm"
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
episode_frames = 300
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
)
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
delta_timestamps = {ACTION: [0.0, horizon_s]}
ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=tmp_path / "ds",
shuffle=False,
buffer_size=1,
max_num_shards=1,
delta_timestamps=delta_timestamps,
)
horizon_frames = int(round(horizon_s * ds.fps))
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
checked = 0
for frame in ds:
idx = int(frame["index"])
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
if idx + horizon_frames < episode_frames:
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
)
checked += 1
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory):
"""state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start."""
repo_id = f"{DUMMY_REPO_ID}-resume"
total_frames = 100
_make_local_dataset(
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
)
def fresh_ds():
return StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
)
ds = fresh_ds()
it = iter(ds)
stop_after = 40
seen_before = [int(next(it)["index"]) for _ in range(stop_after)]
state = ds.state_dict()
assert set(state) == {"shards", "exhausted", "rng"}
resumed_ds = fresh_ds()
resumed_ds.load_state_dict(state)
resumed = _stream_indices(resumed_ds)
# Resume continues rather than replaying: the full first pass is not re-yielded.
assert len(resumed) < total_frames
overlap = set(seen_before) & set(resumed)
assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}"
# Together the two passes cover essentially the whole dataset (a few frames may be dropped by the
# ahead-read at the resume boundary -- documented non-bit-exact behaviour).
assert len(set(seen_before) | set(resumed)) >= total_frames - 2
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
repo_id = f"{DUMMY_REPO_ID}-parity"
map_ds = lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
)
stream_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2
)
map_frame = map_ds[0]
stream_frame = next(iter(stream_ds))
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
for key, value in stream_frame.items():
ref = map_frame[key]
if isinstance(value, torch.Tensor):
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
)
elif isinstance(value, str):
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
else:
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
# (a long-standing, accepted difference). Compare by value rather than exact type.
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
import lerobot.datasets.streaming_dataset as sd
repo_id = f"{DUMMY_REPO_ID}-vpath"
lerobot_dataset_factory(
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
)
ds = StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
)
seen_paths = []
def fake_decode(video_path, query_ts, *args, **kwargs):
seen_paths.append(str(video_path))
return torch.zeros(len(query_ts), 3, 64, 96)
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
next(iter(ds))
assert seen_paths, "no video decode was issued"
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
repo_id = f"{DUMMY_REPO_ID}-shuf"
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
ordered = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
)
)
shuffled = _stream_indices(
StreamingLeRobotDataset(
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0
)
)
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
assert shuffled != ordered, "shuffle did not decorrelate output order"