mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-11 13:49:43 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 79b547de32 | |||
| a7b7f4964e | |||
| 1050c2fb6c | |||
| 66ac901632 | |||
| ce326207e6 | |||
| 2ab71231cd | |||
| 42d4788e4a | |||
| 2d1c17d971 | |||
| 7241f029c6 | |||
| 06ddc59913 | |||
| 23c58f5f9e | |||
| b0ab57cedc | |||
| afdc084677 | |||
| a32a2c647b | |||
| 343ecd7980 | |||
| f7c8a526e8 | |||
| 77af66a29c | |||
| 68fa5d80b0 | |||
| d1fc8e298c |
@@ -0,0 +1,91 @@
|
||||
# Streaming dataloading benchmark
|
||||
|
||||
Measures **dataloading only** (no model) for `StreamingLeRobotDataset`: parquet read + video decode +
|
||||
delta windowing + shuffle. A dummy consumer pulls batches and moves them to the device, so the numbers
|
||||
isolate the data pipeline. Use it to compare sources (Hub vs. storage bucket vs. prewarmed bucket),
|
||||
frame modes, and node counts, and to catch p95/p99 video-decode regressions.
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
|
||||
--mode sarm --batch_size 64 --num_workers 12 --num_batches 200 \
|
||||
--source hub --out_dir benchmarks/streaming/results
|
||||
```
|
||||
|
||||
Multinode (per-node throughput) goes through Accelerate under SLURM:
|
||||
|
||||
```bash
|
||||
sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
```
|
||||
|
||||
## Matrix
|
||||
|
||||
| Axis | Values |
|
||||
| ---------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| Source | `hub` (verify now), `bucket`, `warmed_bucket` (bucket + prewarming; with user's help later) |
|
||||
| Baseline | current `main` `StreamingLeRobotDataset` on Hub streaming |
|
||||
| Nodes | 1 and 2 (per-node throughput should be independent) |
|
||||
| Frame mode | `single` (1 frame, all cameras; target ≥ 120 frames/s/node) · `sarm` (8 steps spaced 1s; target ≥ 320 frames/s/node) |
|
||||
|
||||
`--source` is a label only; the actual source is whatever `--repo_id` / `--root` / `--data_files_root`
|
||||
point at.
|
||||
|
||||
### GPU (NVDEC) decoding
|
||||
|
||||
By default video is decoded on the **CPU** in each DataLoader worker, so throughput is CPU-decode-bound and
|
||||
scales with `--num_workers` (capped by the dataset's `num_shards`). Pass `--video_decode_device cuda` to
|
||||
offload H.264/H.265 decode to the GPU's dedicated **NVDEC** engine, which runs independently of the SMs used
|
||||
for training (see <https://developer.nvidia.com/video-codec-sdk>). This requires a CUDA-enabled torchcodec
|
||||
build, and because CUDA cannot initialize in forked workers the benchmark switches to the `spawn` start
|
||||
method automatically when `--num_workers > 0`.
|
||||
|
||||
```bash
|
||||
# GPU/NVDEC decode, 6 workers, bucket source
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
|
||||
--data_files_root hf://buckets/pepijn223/robocasa-stream \
|
||||
--mode sarm --batch_size 64 --num_workers 6 --num_batches 200 \
|
||||
--video_decode_device cuda --source bucket
|
||||
```
|
||||
|
||||
Caveats with `cuda` + many workers: each worker creates its own CUDA context (VRAM overhead) and NVDEC has a
|
||||
limited number of concurrent decode sessions per GPU; if you hit session/IPC limits, reduce `--num_workers`
|
||||
or compare against `--num_workers 0` (single-process NVDEC, which often saturates the decode engine on its
|
||||
own). Result files include the decode device in their name (`..._w6_cuda.json`).
|
||||
|
||||
> **Codec ⇄ NVDEC compatibility (important).** NVDEC can only decode codecs its hardware supports. LeRobot's
|
||||
> **default video codec is AV1** (`VideoEncoderConfig.vcodec = "libsvtav1"`), so most v3 datasets are
|
||||
> AV1-encoded — and the **A100 and H100 compute GPUs have no AV1 NVDEC decoder**
|
||||
> (per NVIDIA's [decode support matrix](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new));
|
||||
> only Ada (L4/L40/RTX40) and a few Ampere cards (A10/A40/A16) do. On A100/H100, AV1 must be decoded on
|
||||
> **CPU**, or the dataset re-encoded to H.265/H.264 (which those GPUs' NVDEC do support). Run
|
||||
> `diagnose_decode.py --video_decode_device cuda` to check your exact node before relying on `cuda` decode.
|
||||
> A `cuda` torchcodec build also needs an FFmpeg with NVDEC; see
|
||||
> <https://github.com/meta-pytorch/torchcodec#installing-cuda-enabled-torchcodec>.
|
||||
|
||||
Reference data root: bucket sources resolve through `--data_files_root hf://buckets/<owner>/<name>` (metadata
|
||||
still loads from `--repo_id`). The local `single`/`sarm` CPU baselines on this dataset were ~176 / ~212
|
||||
frames/s/node at `--num_workers 3` (3 cameras, fps 20).
|
||||
|
||||
## Metrics emitted (JSON + CSV)
|
||||
|
||||
`frames_per_s_node`, `samples_per_s`, `first_batch_latency_s`, `p50/p95/p99_sample_latency_ms`,
|
||||
`wallclock_s`, and `video_decoder_cache` (`hits`, `misses`, `evictions`, `hit_rate`, `size`). A low
|
||||
cache `hit_rate` with high `p99` is the decoder-thrash signature — raise `--video_decoder_cache_size`
|
||||
or `--episode_pool_size`, or reduce `num_workers`.
|
||||
|
||||
## Bucket sources & prewarming (manual)
|
||||
|
||||
Prewarming is a **server-side** Hugging Face storage-bucket feature — there is no client script. To
|
||||
benchmark the `warmed_bucket` source:
|
||||
|
||||
1. Attach a storage bucket to the dataset and enable it (see
|
||||
<https://huggingface.co/docs/hub/storage-buckets>). Buckets resolve through `fsspec`, the same as
|
||||
`hf://`, so no code change is needed — point `--repo_id`/`--revision` (or `--root`) at the bucket.
|
||||
2. Enable **prewarming** in the bucket settings and wait for warm-up to complete.
|
||||
3. Run the benchmark with `--source warmed_bucket`. Compare against the cold `--source bucket` and the
|
||||
`--source hub` baseline.
|
||||
|
||||
Manual only — not run in CI.
|
||||
@@ -0,0 +1,322 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataloading-only benchmark for StreamingLeRobotDataset.
|
||||
|
||||
A dummy consumer pulls batches and moves them to the device; no model runs, so the numbers isolate the
|
||||
data pipeline (parquet read + video decode + delta windowing + shuffle). Reports per-node throughput and
|
||||
sample-latency percentiles, plus video-decoder-cache reuse stats, and emits JSON + CSV.
|
||||
|
||||
Frame modes (matching the streaming design targets):
|
||||
- ``single``: one frame, all cameras (target >= 120 frames/s/node).
|
||||
- ``sarm``: an 8-step window spaced 1s (delta over 8s) (target >= 320 frames/s/node).
|
||||
|
||||
Example (stream from the Hub, single node):
|
||||
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 --mode sarm \
|
||||
--batch_size 64 --num_workers 12 --num_batches 200 --out_dir benchmarks/streaming/results
|
||||
|
||||
Distributed / multinode runs go through Accelerate; see ``slurm/benchmark_streaming_robocasa.sh``. Set
|
||||
``--source`` purely for labeling the output (``hub`` / ``bucket`` / ``warmed_bucket``); the actual source
|
||||
is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarming.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def _tree_rss_bytes() -> int:
|
||||
"""Sum RSS of this process and all its descendants via /proc (Linux only; 0 elsewhere).
|
||||
|
||||
DataLoader workers are separate processes, so the parent's own RSS misses most of the pipeline's
|
||||
memory. Walking the process tree captures the real footprint (parquet buffers + decoders + shuffle).
|
||||
"""
|
||||
try:
|
||||
children: dict[int, list[int]] = {}
|
||||
for entry in os.listdir("/proc"):
|
||||
if not entry.isdigit():
|
||||
continue
|
||||
try:
|
||||
with open(f"/proc/{entry}/stat") as f:
|
||||
ppid = int(f.read().split(") ", 1)[1].split()[1])
|
||||
children.setdefault(ppid, []).append(int(entry))
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
total, stack = 0, [os.getpid()]
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
try:
|
||||
with open(f"/proc/{cur}/statm") as f:
|
||||
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
stack.extend(children.get(cur, []))
|
||||
return total
|
||||
except OSError:
|
||||
return 0
|
||||
|
||||
|
||||
class PeakRSSSampler:
|
||||
"""Background thread tracking peak process-tree RSS for the duration of the `with` block."""
|
||||
|
||||
def __init__(self, interval_s: float = 0.5):
|
||||
self.interval_s = interval_s
|
||||
self.peak_bytes = 0
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
|
||||
self._stop.wait(self.interval_s)
|
||||
|
||||
def __enter__(self) -> "PeakRSSSampler":
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--repo_id", type=str, required=True)
|
||||
parser.add_argument("--root", type=str, default=None, help="Local/prewarmed root (else stream from Hub).")
|
||||
parser.add_argument(
|
||||
"--data_files_root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="fsspec root for bulk data/videos, e.g. hf://buckets/<owner>/<name>. Metadata still loads "
|
||||
"from --repo_id on the Hub. Use for bucket / warmed_bucket sources.",
|
||||
)
|
||||
parser.add_argument("--mode", choices=["single", "sarm"], default="single")
|
||||
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--prefetch_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help="DataLoader batches prefetched per worker. Higher hides IO/decode latency but raises RAM "
|
||||
"(prefetch_factor x num_workers x batch_size decoded frames held in flight). Ignored if num_workers=0.",
|
||||
)
|
||||
parser.add_argument("--buffer_size", type=int, default=None, help="Deprecated; ignored.")
|
||||
parser.add_argument(
|
||||
"--max_num_shards",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Cap on concurrently-open stream shards. Each open shard holds ~one parquet row group in "
|
||||
"RAM; reading from an hf:// bucket buffers ~5x more per shard than hf:// datasets, so lower this "
|
||||
"(e.g. to num_workers) for bucket sources to avoid OOM. All data is still covered via re-sharding.",
|
||||
)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--episode_pool_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Whole episodes each consumer keeps open to shuffle across (the randomness knob).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_decode_device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Decode device passed to torchcodec. 'cuda' offloads decode to the GPU's NVDEC engine "
|
||||
"(needs a CUDA-enabled torchcodec build). With num_workers>0 this forces the 'spawn' start method.",
|
||||
)
|
||||
parser.add_argument("--num_batches", type=int, default=200)
|
||||
parser.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state stats.")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--out_dir", type=str, default="benchmarks/streaming/results")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> StreamingLeRobotDataset:
|
||||
# sarm: an 8-step window spaced 1s => an 8s delta window (the SARM stress case).
|
||||
delta_timestamps = {ACTION: [float(t) for t in range(8)]} if args.mode == "sarm" else None
|
||||
return StreamingLeRobotDataset(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
data_files_root=args.data_files_root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
max_num_shards=args.max_num_shards,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
video_decode_device=args.video_decode_device,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
device = torch.device(args.device)
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
||||
dataset = build_dataset(args, meta)
|
||||
|
||||
gpu_decode = args.video_decode_device.startswith("cuda")
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
# GPU-decoded frames are already on the GPU, so CPU pinning is irrelevant (and pinning CUDA
|
||||
# tensors errors). Pin only when decode is on CPU and we copy to a CUDA device.
|
||||
pin_memory=device.type == "cuda" and not gpu_decode,
|
||||
drop_last=True,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
# CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method.
|
||||
multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
episodes_per_batch: list[int] = [] # shuffle-randomness proxy: distinct episodes within a batch
|
||||
frames = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None # wall-clock start of the post-warmup measurement window
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
with PeakRSSSampler() as rss:
|
||||
for i, batch in enumerate(loader):
|
||||
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
|
||||
if i == args.warmup_batches:
|
||||
# Start the steady window here; the slow first batch and the prefetch queue it filled are
|
||||
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
frames += args.batch_size
|
||||
ep = batch.get("episode_index")
|
||||
if torch.is_tensor(ep):
|
||||
episodes_per_batch.append(int(torch.unique(ep).numel()))
|
||||
t_prev = now
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
# Wall-clock throughput over the steady window. NOT sum(inter-batch gaps): under async prefetch those
|
||||
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
cache_stats = dataset.video_decoder_cache_stats()
|
||||
timing = dataset.timing_stats() # cumulative decode/fetch seconds summed across workers
|
||||
# Image (camera frame) resolution as decoded, e.g. [C, H, W]. Read from the dataset feature contract.
|
||||
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
|
||||
# Decode/fetch overlap in wall-clock (workers run in parallel), so normalize against the total worker
|
||||
# budget (num_workers x wallclock) to express each stage as a fraction of available worker time.
|
||||
worker_budget_s = max(args.num_workers, 1) * elapsed
|
||||
decode_pct = round(100 * timing["decode_s_total"] / worker_budget_s, 1) if worker_budget_s else None
|
||||
fetch_pct = round(100 * timing["fetch_s_total"] / worker_budget_s, 1) if worker_budget_s else None
|
||||
|
||||
# A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode
|
||||
# error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is
|
||||
# never reported green with NaN/zero numbers.
|
||||
if frames == 0:
|
||||
raise SystemExit(
|
||||
f"FAILED: measured 0 frames over {args.num_batches} requested batches "
|
||||
f"(cache misses={cache_stats.get('misses', 0)}, hits={cache_stats.get('hits', 0)}). "
|
||||
"The data pipeline yielded no usable batches — inspect worker logs for decode errors. "
|
||||
"Try --num_workers 0 to surface the underlying exception directly."
|
||||
)
|
||||
|
||||
results = {
|
||||
"repo_id": args.repo_id,
|
||||
"source": args.source,
|
||||
"mode": args.mode,
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"prefetch_factor": args.prefetch_factor if args.num_workers > 0 else None,
|
||||
"buffer_size": args.buffer_size,
|
||||
"episode_pool_size": args.episode_pool_size,
|
||||
"episodes_per_batch_mean": round(statistics.mean(episodes_per_batch), 1)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
# Fraction of a batch that is distinct episodes; ~1.0 ≈ map-style uniform, low ≈ correlated.
|
||||
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
"num_cameras": len(meta.video_keys),
|
||||
"image_shape": image_shape,
|
||||
"fps": meta.fps,
|
||||
"device": str(device),
|
||||
"video_decode_device": args.video_decode_device,
|
||||
"peak_rss_gb": peak_rss_gb,
|
||||
"frames_measured": frames,
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4),
|
||||
"frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"samples_per_s": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
|
||||
if sample_latencies_ms
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"total_time_s": round(elapsed, 2),
|
||||
"steady_time_s": round(steady_elapsed_s, 2),
|
||||
"wallclock_s": round(elapsed, 2),
|
||||
"decode_s_total": timing["decode_s_total"],
|
||||
"fetch_s_total": timing["fetch_s_total"],
|
||||
"decode_pct_worker_time": decode_pct,
|
||||
"fetch_pct_worker_time": fetch_pct,
|
||||
"video_decoder_cache": cache_stats,
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
pool_tag = f"_ep{args.episode_pool_size}" if args.episode_pool_size else ""
|
||||
tag = (
|
||||
f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}"
|
||||
f"_pf{args.prefetch_factor}{pool_tag}_{args.video_decode_device}"
|
||||
)
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=list(flat))
|
||||
writer.writeheader()
|
||||
writer.writerow(flat)
|
||||
|
||||
print("Command config:", vars(args))
|
||||
print(json.dumps(results, indent=2))
|
||||
print(f"Wrote {out_dir / tag}.json and .csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Isolate the streaming video-decode path — no SLURM, no DataLoader, no benchmark loop.
|
||||
|
||||
Reproduces exactly what StreamingLeRobotDataset does for one video (resolve path -> fsspec.open ->
|
||||
torchcodec VideoDecoder -> get one frame) and prints the environment + the first bytes of the handle, so
|
||||
a decode failure ("No valid stream found in input file") can be pinpointed: bad/placeholder bytes vs a
|
||||
torchcodec/ffmpeg build issue vs a device issue.
|
||||
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id pepijn223/robocasa_pretrain_human300_v4
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id … --data_files_root hf://buckets/<o>/<n>
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id … --video_decode_device cuda
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib.metadata as im
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata
|
||||
|
||||
|
||||
def _version(pkg: str) -> str:
|
||||
try:
|
||||
return im.version(pkg)
|
||||
except Exception:
|
||||
return "MISSING"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--repo_id", required=True)
|
||||
p.add_argument("--data_files_root", default=None, help="e.g. hf://buckets/<owner>/<name>")
|
||||
p.add_argument("--revision", default=None)
|
||||
p.add_argument("--video_decode_device", default="cpu")
|
||||
p.add_argument("--episode", type=int, default=0)
|
||||
args = p.parse_args()
|
||||
|
||||
print("== environment ==")
|
||||
for pkg in ("torchcodec", "av", "huggingface_hub", "hf_xet", "datasets", "fsspec"):
|
||||
print(f" {pkg}: {_version(pkg)}")
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
video_key = meta.video_keys[0]
|
||||
rel_path = meta.get_video_file_path(args.episode, video_key)
|
||||
root = args.data_files_root.rstrip("/") if args.data_files_root else meta.url_root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
print("\n== target ==")
|
||||
print(f" video_key: {video_key}")
|
||||
print(f" video_path: {video_path}")
|
||||
|
||||
print("\n== fsspec handle ==")
|
||||
try:
|
||||
fh = fsspec.open(video_path).__enter__()
|
||||
head = fh.read(32)
|
||||
print(f" first 32 bytes (hex): {head.hex()}")
|
||||
# A valid MP4/MOV has an 'ftyp' box near the start; anything else (HTML/JSON/empty) means the
|
||||
# handle resolved to a placeholder or error page, not the video bytes.
|
||||
looks_mp4 = b"ftyp" in head
|
||||
print(f" looks like MP4 (contains 'ftyp'): {looks_mp4}")
|
||||
if not looks_mp4:
|
||||
print(f" !! first bytes as text: {head[:32]!r}")
|
||||
fh.seek(0)
|
||||
except Exception as e:
|
||||
print(f" !! fsspec.open/read FAILED: {type(e).__name__}: {e}")
|
||||
return
|
||||
|
||||
print("\n== torchcodec VideoDecoder ==")
|
||||
try:
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
decoder = VideoDecoder(fh, seek_mode="approximate", device=args.video_decode_device)
|
||||
md = decoder.metadata
|
||||
print(f" OK: {md.num_frames} frames, {md.average_fps} fps, codec={getattr(md, 'codec', '?')}")
|
||||
frame = decoder.get_frames_at(indices=[0])
|
||||
print(f" decoded frame 0: shape={tuple(frame.data.shape)}, device={frame.data.device}")
|
||||
print("\nDECODE OK — the streaming pipeline can read this video on this machine.")
|
||||
except Exception as e:
|
||||
print(f" !! VideoDecoder FAILED: {type(e).__name__}: {e}")
|
||||
print(
|
||||
"\nDECODE FAILED. If the bytes above look like MP4 (ftyp=True), this is a torchcodec/ffmpeg "
|
||||
"build issue, NOT bad bytes. Common cause for LeRobot v3 datasets: the videos are AV1-encoded "
|
||||
"(see the 'codec' line on a working machine). Then:\n"
|
||||
" - CPU decode needs an ffmpeg built with an AV1 decoder (libdav1d/libaom); a build without it "
|
||||
"reports 'No valid stream found'.\n"
|
||||
" - GPU/NVDEC decode of AV1 is only on AV1-capable NVDEC GPUs: Ada (L4/L40/RTX40) and some "
|
||||
"Ampere (A10/A40/A16). The COMPUTE GPUs A100 and H100 have NO AV1 NVDEC decoder (per NVIDIA's "
|
||||
"support matrix), so no torchcodec build enables cuda decode of AV1 on them.\n"
|
||||
" - 'Unsupported device: cuda (variant: ffmpeg)' instead means torchcodec was built without "
|
||||
"the CUDA backend; install a CUDA-enabled wheel (see README) — but on A100/H100 that still "
|
||||
"won't decode AV1.\n"
|
||||
"Fix: decode on CPU, run NVDEC on an Ada GPU, or re-encode the dataset to H.265/H.264 (which "
|
||||
"A100/H100 NVDEC do support).\n"
|
||||
"If ftyp=False instead, the handle resolved to a placeholder/error page (auth, revision, or Xet "
|
||||
"resolution) rather than the video bytes."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+79
@@ -0,0 +1,79 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Collapse a directory of benchmark JSON results into one comparison table (and a combined CSV).
|
||||
|
||||
python benchmarks/streaming/summarize_results.py benchmarks/streaming/results
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
COLUMNS = [
|
||||
("source", "source"),
|
||||
("mode", "mode"),
|
||||
("video_decode_device", "decode"),
|
||||
("num_workers", "workers"),
|
||||
("batch_size", "bs"),
|
||||
("frames_per_s_node", "frames/s/node"),
|
||||
("first_batch_latency_s", "first_batch_s"),
|
||||
("p50_sample_latency_ms", "p50_ms"),
|
||||
("p95_sample_latency_ms", "p95_ms"),
|
||||
("p99_sample_latency_ms", "p99_ms"),
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results_dir = Path(sys.argv[1] if len(sys.argv) > 1 else "benchmarks/streaming/results")
|
||||
files = sorted(results_dir.rglob("*.json"))
|
||||
if not files:
|
||||
print(f"No JSON results under {results_dir}")
|
||||
return
|
||||
|
||||
rows = []
|
||||
for f in files:
|
||||
d = json.loads(f.read_text())
|
||||
d["hit_rate"] = d.get("video_decoder_cache", {}).get("hit_rate")
|
||||
rows.append(d)
|
||||
|
||||
rows.sort(key=lambda r: (r.get("source", ""), r.get("mode", ""), r.get("video_decode_device", "")))
|
||||
|
||||
headers = [label for _, label in COLUMNS] + ["cache_hit_rate"]
|
||||
widths = {h: len(h) for h in headers}
|
||||
table = []
|
||||
for r in rows:
|
||||
row = {label: r.get(key, "") for key, label in COLUMNS}
|
||||
row["cache_hit_rate"] = r.get("hit_rate", "")
|
||||
table.append(row)
|
||||
for h in headers:
|
||||
widths[h] = max(widths[h], len(str(row[h])))
|
||||
|
||||
line = " ".join(h.ljust(widths[h]) for h in headers)
|
||||
print(line)
|
||||
print(" ".join("-" * widths[h] for h in headers))
|
||||
for row in table:
|
||||
print(" ".join(str(row[h]).ljust(widths[h]) for h in headers))
|
||||
|
||||
combined = results_dir / "summary.csv"
|
||||
with open(combined, "w", newline="") as fh:
|
||||
writer = csv.DictWriter(fh, fieldnames=headers)
|
||||
writer.writeheader()
|
||||
writer.writerows(table)
|
||||
print(f"\nWrote {combined}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Distributed, resumable streaming training on a large HF-hosted dataset.
|
||||
|
||||
This example shows how to train (or just stress the data pipeline) over a multi-TB dataset that never
|
||||
touches local disk, scaling across GPUs and nodes with Accelerate. It demonstrates the large-scale
|
||||
streaming features of :class:`StreamingLeRobotDataset`:
|
||||
|
||||
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
||||
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
||||
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
||||
- deterministic fast-forward resume via ``dataset.load_state_dict()`` (trainer-side counters only);
|
||||
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
||||
|
||||
Launch with Accelerate (single node, N GPUs):
|
||||
|
||||
accelerate launch --num_processes=8 examples/scaling/train_streaming_multinode.py \
|
||||
--repo_id=pepijn223/robocasa_pretrain_human300_v4 --batch_size=64
|
||||
|
||||
Multinode runs use the same script under SLURM; see ``slurm/train_streaming_robocasa.sh``.
|
||||
|
||||
Pass ``--dummy`` to skip the model entirely and measure pure dataloading throughput.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--repo_id", type=str, default="lerobot/droid_1.0.1")
|
||||
parser.add_argument(
|
||||
"--root", type=str, default=None, help="Local/prewarmed dataset root (else stream from Hub)."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default="outputs/train/streaming_multinode")
|
||||
parser.add_argument("--steps", type=int, default=1000)
|
||||
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--episode_pool_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Whole episodes open per consumer (randomness knob).",
|
||||
)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
|
||||
parser.add_argument("--save_freq", type=int, default=200)
|
||||
parser.add_argument("--log_freq", type=int, default=20)
|
||||
parser.add_argument("--resume_from", type=str, default=None, help="Checkpoint dir to resume from.")
|
||||
parser.add_argument("--dummy", action="store_true", help="Skip the model; measure dataloading only.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_dataloader(
|
||||
args: argparse.Namespace, meta: LeRobotDatasetMetadata
|
||||
) -> tuple[DataLoader, StreamingLeRobotDataset]:
|
||||
# Supervise an action chunk; delta_timestamps drive the SARM-style temporal window.
|
||||
delta_timestamps = {ACTION: [t / meta.fps for t in range(args.n_action_steps)]}
|
||||
# rank / world_size are resolved automatically from the Accelerate state inside the dataset.
|
||||
dataset = StreamingLeRobotDataset(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
prefetch_factor=2 if args.num_workers > 0 else None,
|
||||
)
|
||||
return loader, dataset
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
accelerator = Accelerator()
|
||||
output_dir = Path(args.output_dir)
|
||||
if accelerator.is_main_process:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
||||
loader, dataset = make_dataloader(args, meta)
|
||||
|
||||
if args.dummy:
|
||||
model = optimizer = None
|
||||
else:
|
||||
from lerobot.policies.act import ACTConfig, ACTPolicy
|
||||
from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
features = dataset_to_policy_features(meta.features)
|
||||
output_features = {k: ft for k, ft in features.items() if k == ACTION}
|
||||
input_features = {k: ft for k, ft in features.items() if k not in output_features}
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
model = ACTPolicy(cfg)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
# Do NOT prepare the dataloader: the dataset is already rank-disjoint via
|
||||
# split_dataset_by_node, and accelerate's IterableDatasetShard would keep only every
|
||||
# world_size-th batch of it (silently training on 1/N of the data while decoding all
|
||||
# of it). Batches are moved to the device manually in the loop.
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
# Resume: deterministic fast-forward. Every consumer's order is a pure function of
|
||||
# (seed, epoch, rank, worker), so resuming only needs the trainer-side counters; each rank and
|
||||
# worker re-derives its own skip. Same file works for every rank.
|
||||
if args.resume_from is not None:
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=True)
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resuming dataset stream: {state['batches_consumed']} batches consumed")
|
||||
|
||||
step = 0
|
||||
frames_seen = 0
|
||||
window_start = time.perf_counter()
|
||||
done = False
|
||||
while not done:
|
||||
for batch in loader:
|
||||
if model is not None:
|
||||
batch = {k: (v.to(accelerator.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
||||
loss, _ = model.forward(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
step += 1
|
||||
frames_seen += args.batch_size
|
||||
if step % args.log_freq == 0:
|
||||
elapsed = time.perf_counter() - window_start
|
||||
fps_per_proc = (args.log_freq * args.batch_size) / max(elapsed, 1e-9)
|
||||
total_fps = fps_per_proc * accelerator.num_processes
|
||||
accelerator.print(
|
||||
f"step {step} | {fps_per_proc:.1f} frames/s/proc | {total_fps:.1f} frames/s total"
|
||||
+ ("" if model is None else f" | loss {loss.item():.3f}")
|
||||
)
|
||||
window_start = time.perf_counter()
|
||||
|
||||
if step % args.save_freq == 0 and accelerator.is_main_process:
|
||||
ckpt = output_dir / f"checkpoint-{step}"
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
# Save the consumed-batch counters so a restart fast-forwards to this position.
|
||||
torch.save(
|
||||
{"batches_consumed": step, "batch_size": args.batch_size},
|
||||
ckpt / "dataset_state.pt",
|
||||
)
|
||||
if model is not None:
|
||||
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
||||
|
||||
if step >= args.steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
accelerator.print(f"End of training: {step} steps, ~{frames_seen} frames/proc")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench_stream
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --time=02:00:00
|
||||
#SBATCH --output=logs/%x-%j.out
|
||||
|
||||
# Per-node dataloading benchmark for StreamingLeRobotDataset across 1-2 nodes. Each node runs an
|
||||
# independent dummy-consumer benchmark; per-node throughput should be independent (separate network).
|
||||
# Results are written per (node, source, mode) under --out_dir.
|
||||
#
|
||||
# Submit with: sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
# Override the source label for cold/warm bucket runs: SOURCE=warmed_bucket sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
SOURCE=${SOURCE:-hub}
|
||||
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
|
||||
|
||||
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# One benchmark process per node (each saturates the node's DataLoader workers + network independently).
|
||||
srun --kill-on-bad-exit=1 bash -c '
|
||||
for MODE in single sarm; do
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id '"$REPO_ID"' \
|
||||
--source '"$SOURCE"' \
|
||||
--mode $MODE \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--episode_pool_size 64 \
|
||||
--num_batches 300 \
|
||||
--out_dir '"$OUT_DIR"'/node${SLURM_NODEID}
|
||||
done
|
||||
'
|
||||
Executable
+107
@@ -0,0 +1,107 @@
|
||||
#!/bin/bash
|
||||
# Submit the FULL streaming dataloading-benchmark matrix as isolated single-GPU SLURM jobs.
|
||||
#
|
||||
# sources : hub (Hub streaming) | bucket (cold HF bucket) | warmed_bucket (prewarmed HF bucket)
|
||||
# modes : single (1 frame, all cameras) | sarm (8-step / 8s delta window)
|
||||
# decode : cpu (torchcodec on CPU, scales with workers) | cuda (NVDEC, offloads decode to the GPU)
|
||||
#
|
||||
# => 3 x 2 x 2 = 12 jobs. Each runs in its OWN job (1 node, 1 GPU) so an OOM is isolated and reported
|
||||
# per-job by SLURM (check `sacct -j <id> --format=JobID,State,MaxRSS,ReqMem`). Submit from a login node
|
||||
# inside the repo: bash slurm/run_streaming_matrix.sh
|
||||
#
|
||||
# SERIAL (default 1): chain the jobs with --dependency=afterany so SLURM runs exactly ONE at a time. This
|
||||
# is important for a bandwidth benchmark — concurrent jobs would share the network to the Hub/bucket and
|
||||
# corrupt every throughput number. `afterany` means a failed/OOM'd job does not stall the chain. Set
|
||||
# SERIAL=0 to let the scheduler run them in parallel (only for OOM-isolation testing, not for throughput).
|
||||
#
|
||||
# Knobs (env overrides):
|
||||
# REPO_ID, BUCKET, WARM_BUCKET, OUT_DIR, NUM_BATCHES, TIME, MEM, GPUS, SERIAL
|
||||
# CPU_WORKERS / CPU_BUFFER (cpu-decode jobs) GPU_WORKERS / GPU_BUFFER (cuda-decode jobs, kept low to
|
||||
# bound VRAM + NVDEC sessions). RUN ("python" by default; set RUN="uv run python" if using uv).
|
||||
# SOURCES / MODES / DECODES to run a subset (e.g. SOURCES="hub bucket" DECODES="cpu").
|
||||
# ACCOUNT / PARTITION / QOS passed through to sbatch if set.
|
||||
set -euo pipefail
|
||||
|
||||
REPO_DIR=$(git rev-parse --show-toplevel)
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
BUCKET=${BUCKET:-hf://buckets/pepijn223/robocasa-stream}
|
||||
WARM_BUCKET=${WARM_BUCKET:-hf://buckets/pepijn223/robocasa-stream-warm}
|
||||
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
|
||||
NUM_BATCHES=${NUM_BATCHES:-200}
|
||||
TIME=${TIME:-01:00:00}
|
||||
MEM=${MEM:-64G}
|
||||
GPUS=${GPUS:-1}
|
||||
SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement)
|
||||
CPU_WORKERS=${CPU_WORKERS:-8}
|
||||
GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session
|
||||
CPU_BUFFER=${CPU_BUFFER:-64} # episode pool size (whole episodes per consumer; tabular-only RAM)
|
||||
GPU_BUFFER=${GPU_BUFFER:-32} # smaller episode pool bounds in-flight decoded frames
|
||||
# Cap concurrently-open stream shards. Each open shard holds ~one parquet row group in RAM, and reading
|
||||
# from an hf:// bucket buffers ~5x more per shard than hf:// datasets (~1.2GB vs ~0.26GB). So for bucket
|
||||
# sources default to num_workers (1 shard/worker); hub keeps 16. Override globally with MAX_SHARDS.
|
||||
MAX_SHARDS=${MAX_SHARDS:-}
|
||||
BATCH_SIZE=${BATCH_SIZE:-64}
|
||||
PREFETCH=${PREFETCH:-2} # DataLoader batches prefetched per worker (higher = more throughput + RAM)
|
||||
RUN=${RUN:-python}
|
||||
# CONDA_ENV=<name> runs each job via `conda run -n <name>` (no activation needed inside the dash --wrap;
|
||||
# --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11)
|
||||
# + datasets (>=4.7) — the default `base` env on many clusters is too old to decode AV1 / lacks CUDA.
|
||||
CONDA_ENV=${CONDA_ENV:-}
|
||||
if [ -n "$CONDA_ENV" ] && [ "$RUN" = "python" ]; then
|
||||
RUN="conda run --no-capture-output -n $CONDA_ENV python"
|
||||
fi
|
||||
|
||||
SOURCES=${SOURCES:-"hub bucket warmed_bucket"}
|
||||
MODES=${MODES:-"single sarm"}
|
||||
DECODES=${DECODES:-"cpu cuda"}
|
||||
|
||||
mkdir -p "$REPO_DIR/logs" "$REPO_DIR/$OUT_DIR"
|
||||
|
||||
data_root_for () {
|
||||
case "$1" in
|
||||
hub) echo "" ;;
|
||||
bucket) echo "$BUCKET" ;;
|
||||
warmed_bucket) echo "$WARM_BUCKET" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
n=0
|
||||
prev_jid=""
|
||||
for SOURCE in $SOURCES; do
|
||||
DATA_ROOT=$(data_root_for "$SOURCE")
|
||||
ROOTFLAG=""
|
||||
[ -n "$DATA_ROOT" ] && ROOTFLAG="--data_files_root $DATA_ROOT"
|
||||
for MODE in $MODES; do
|
||||
for DECODE in $DECODES; do
|
||||
if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi
|
||||
if [ -n "$MAX_SHARDS" ]; then S=$MAX_SHARDS; elif [ "$SOURCE" = hub ]; then S=16; else S=$W; fi
|
||||
# Run strictly after the previous job so only one job touches the network at a time.
|
||||
DEPFLAG=""
|
||||
if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi
|
||||
jid=$(sbatch --parsable \
|
||||
--job-name="bench_${SOURCE}_${MODE}_${DECODE}" \
|
||||
--nodes=1 --ntasks=1 --gpus="$GPUS" --cpus-per-task=$((W + 4)) \
|
||||
--mem="$MEM" --time="$TIME" --output="$REPO_DIR/logs/%x-%j.out" \
|
||||
$DEPFLAG \
|
||||
${ACCOUNT:+--account=$ACCOUNT} ${PARTITION:+--partition=$PARTITION} ${QOS:+--qos=$QOS} \
|
||||
--wrap "cd '$REPO_DIR' && \
|
||||
export TOKENIZERS_PARALLELISM=false && export HF_HOME=\${HF_HOME:-\$SCRATCH/hf_home} && \
|
||||
$RUN benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id $REPO_ID $ROOTFLAG \
|
||||
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
|
||||
--batch_size $BATCH_SIZE --num_workers $W --prefetch_factor $PREFETCH \
|
||||
--episode_pool_size $B --max_num_shards $S \
|
||||
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
|
||||
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
|
||||
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
|
||||
prev_jid=$jid
|
||||
n=$((n + 1))
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))."
|
||||
echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)"
|
||||
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_pf<prefetch>_<decode>.{json,csv}"
|
||||
echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR"
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=stream_robocasa
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --output=logs/%x-%j.out
|
||||
|
||||
# Multinode streaming training over a large HF-hosted RoboCasa dataset (never touches local disk).
|
||||
# Launches examples/scaling/train_streaming_multinode.py with Accelerate. Each rank streams a disjoint
|
||||
# set of shards via split_dataset_by_node (auto-resolved from the Accelerate state), so per-node
|
||||
# throughput scales independently. For an even split, ensure n_shards % (nodes * gpus_per_node) == 0.
|
||||
#
|
||||
# Submit with: sbatch slurm/train_streaming_robocasa.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
GPUS_PER_NODE=8
|
||||
NUM_PROCESSES=$((SLURM_NNODES * GPUS_PER_NODE))
|
||||
|
||||
# Rendezvous: use the first node in the allocation as the main process.
|
||||
MAIN_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
|
||||
MAIN_PORT=${MAIN_PORT:-29500}
|
||||
|
||||
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
|
||||
# Avoid each rank fighting over the tokenizers' internal thread pool.
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
srun --kill-on-bad-exit=1 bash -c '
|
||||
accelerate launch \
|
||||
--num_machines '"$SLURM_NNODES"' \
|
||||
--num_processes '"$NUM_PROCESSES"' \
|
||||
--machine_rank $SLURM_NODEID \
|
||||
--main_process_ip '"$MAIN_ADDR"' \
|
||||
--main_process_port '"$MAIN_PORT"' \
|
||||
--mixed_precision bf16 \
|
||||
--dynamo_backend no \
|
||||
examples/scaling/train_streaming_multinode.py \
|
||||
--repo_id '"$REPO_ID"' \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--episode_pool_size 64 \
|
||||
--steps 200000 \
|
||||
--save_freq 2000 \
|
||||
--log_freq 50
|
||||
'
|
||||
@@ -39,6 +39,10 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Whole episodes each streaming consumer keeps open to shuffle across (the randomness knob).
|
||||
# Larger mixes more episodes per batch at the cost of cold-start latency; RAM stays small because
|
||||
# the pool holds tabular rows only. Ignored when streaming is False.
|
||||
streaming_episode_pool_size: int = 64
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
writer.write_table(table)
|
||||
# Emit several row groups with a page index instead of one giant row group. A single row group forces
|
||||
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
|
||||
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
|
||||
# groups bounds per-shard memory while keeping groups large enough to scan
|
||||
# efficiently; the page index lets readers skip to the pages they need.
|
||||
target_row_group_bytes = 32 * 1024 * 1024
|
||||
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
|
||||
writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
|
||||
)
|
||||
writer.write_table(table, row_group_size=row_group_size)
|
||||
writer.close()
|
||||
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
episode_pool_size=cfg.dataset.streaming_episode_pool_size,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,6 +22,7 @@ import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
@@ -47,6 +48,92 @@ from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_REMOTE_IO_MAX_RETRIES = 5
|
||||
"""Retry budget for transient hf:// / fsspec / httpx transport errors during streaming video decode.
|
||||
|
||||
Streaming a dataset from an HF bucket/CDN issues many small range requests and occasionally hits a
|
||||
transient transport failure (timeout, dropped connection, 408/5xx). The right response is to rebuild
|
||||
the connection and retry rather than crash the DataLoader worker. Override via
|
||||
``LEROBOT_REMOTE_IO_MAX_RETRIES``; set to ``0`` to disable retries (fail fast).
|
||||
"""
|
||||
|
||||
# Transient transport failures from the hf:// -> fsspec -> httpx stack. We match on text because the
|
||||
# concrete exception types live in optional deps (httpx, huggingface_hub) and vary across versions.
|
||||
# "client has been closed" is the important one: once a shared httpx client is closed by a single
|
||||
# failed read, every subsequent read in that worker fails until the fsspec instance cache is cleared.
|
||||
_RETRYABLE_TRANSPORT_FRAGMENTS = (
|
||||
"client has been closed",
|
||||
"server disconnected",
|
||||
"remoteprotocolerror",
|
||||
"unexpected_eof",
|
||||
"eof occurred in violation of protocol",
|
||||
"connection reset",
|
||||
"connection aborted",
|
||||
"connection broken",
|
||||
"incompleteread",
|
||||
"read operation timed out",
|
||||
"timed out",
|
||||
"request time-out",
|
||||
"408",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
)
|
||||
|
||||
|
||||
def _remote_io_max_retries() -> int:
|
||||
raw = os.environ.get("LEROBOT_REMOTE_IO_MAX_RETRIES")
|
||||
if raw is None:
|
||||
return DEFAULT_REMOTE_IO_MAX_RETRIES
|
||||
try:
|
||||
return max(0, int(raw))
|
||||
except ValueError as e:
|
||||
raise ValueError(f"LEROBOT_REMOTE_IO_MAX_RETRIES must be an integer; got {raw!r}") from e
|
||||
|
||||
|
||||
def _is_retryable_transport_error(exc: BaseException) -> bool:
|
||||
"""True if ``exc`` looks like a transient remote-IO failure worth retrying (vs a real bug)."""
|
||||
text = f"{type(exc).__name__}: {exc}".lower()
|
||||
return any(fragment in text for fragment in _RETRYABLE_TRANSPORT_FRAGMENTS)
|
||||
|
||||
|
||||
def _recover_remote_io(decoder_cache: "VideoDecoderCache", video_path: str) -> None:
|
||||
"""Drop the dead decoder for ``video_path`` and force a fresh fsspec client before a retry.
|
||||
|
||||
fsspec caches one filesystem instance per (protocol, args), and that instance owns the httpx
|
||||
client a failed read may have closed. Clearing the instance cache makes the next ``fsspec.open``
|
||||
build a new client, which is what breaks the "client has been closed" cascade.
|
||||
"""
|
||||
decoder_cache.invalidate(video_path)
|
||||
with contextlib.suppress(Exception):
|
||||
fsspec.AbstractFileSystem.clear_instance_cache()
|
||||
|
||||
|
||||
def _retry_remote_io(operation, on_retry, max_retries: int, base_delay: float = 0.5, max_delay: float = 10.0):
|
||||
"""Run ``operation()``, retrying transient transport errors after ``on_retry()`` + capped backoff.
|
||||
|
||||
Non-transport errors (decode / index / timestamp issues) propagate immediately so real bugs are
|
||||
never masked by retries.
|
||||
"""
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return operation()
|
||||
except Exception as e:
|
||||
if attempt >= max_retries or not _is_retryable_transport_error(e):
|
||||
raise
|
||||
attempt += 1
|
||||
logger.warning(
|
||||
"Transient remote-IO error (%s: %s); rebuilding connection and retrying (%d/%d).",
|
||||
type(e).__name__,
|
||||
e,
|
||||
attempt,
|
||||
max_retries,
|
||||
)
|
||||
on_retry()
|
||||
time.sleep(min(base_delay * 2 ** (attempt - 1), max_delay))
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
@@ -242,7 +329,12 @@ class VideoDecoderCache:
|
||||
|
||||
_SENTINEL: ClassVar[object] = object()
|
||||
|
||||
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int | None | object = _SENTINEL,
|
||||
counters: "torch.Tensor | None" = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
if max_size is VideoDecoderCache._SENTINEL:
|
||||
max_size = _default_max_cache_size()
|
||||
if max_size is not None and max_size <= 0:
|
||||
@@ -250,6 +342,18 @@ class VideoDecoderCache:
|
||||
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
# Decode device for the underlying torchcodec VideoDecoder. "cuda" offloads H.264/H.265 decode to
|
||||
# the GPU's dedicated NVDEC engine (independent of the SMs used for training); requires a
|
||||
# CUDA-enabled torchcodec/FFmpeg build. See https://developer.nvidia.com/video-codec-sdk.
|
||||
self.device = device
|
||||
# Observability counters (cheap, updated under the lock) for benchmarking decoder reuse.
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.evictions = 0
|
||||
# Optional shared [hits, misses, evictions] tensor so DataLoader workers aggregate into one place
|
||||
# (the per-worker `self.*` ints are invisible to the main process). Lock-free across processes, so
|
||||
# treat the aggregate as approximate; the hit-rate ratio is preserved.
|
||||
self._counters = counters
|
||||
|
||||
def __contains__(self, video_path: object) -> bool:
|
||||
with self._lock:
|
||||
@@ -271,11 +375,21 @@ class VideoDecoderCache:
|
||||
entry = self._cache.get(video_path)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(video_path)
|
||||
self.hits += 1
|
||||
if self._counters is not None:
|
||||
self._counters[0] += 1
|
||||
return entry[0]
|
||||
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
self.misses += 1
|
||||
if self._counters is not None:
|
||||
self._counters[1] += 1
|
||||
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
|
||||
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
|
||||
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
|
||||
# mostly-sequential reads torchcodec issues.
|
||||
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device)
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
@@ -287,6 +401,9 @@ class VideoDecoderCache:
|
||||
if self.max_size is not None:
|
||||
while len(self._cache) > self.max_size:
|
||||
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||
self.evictions += 1
|
||||
if self._counters is not None:
|
||||
self._counters[2] += 1
|
||||
with contextlib.suppress(Exception):
|
||||
evicted_handle.close()
|
||||
|
||||
@@ -300,11 +417,35 @@ class VideoDecoderCache:
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def invalidate(self, video_path: str) -> None:
|
||||
"""Drop and close the cached decoder for a path whose connection went bad.
|
||||
|
||||
After a transport error the cached ``fsspec`` handle (and the httpx client behind it) is dead;
|
||||
removing the entry forces the next :meth:`get_decoder` to re-open a fresh handle.
|
||||
"""
|
||||
with self._lock:
|
||||
entry = self._cache.pop(str(video_path), None)
|
||||
if entry is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
entry[1].close()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the number of cached decoders."""
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
def stats(self) -> dict[str, int | float]:
|
||||
"""Return reuse counters (hits/misses/evictions, hit rate, current size) for benchmarking."""
|
||||
with self._lock:
|
||||
total = self.hits + self.misses
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"hit_rate": self.hits / total if total else 0.0,
|
||||
"size": len(self._cache),
|
||||
}
|
||||
|
||||
|
||||
class FrameTimestampError(ValueError):
|
||||
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
||||
@@ -343,20 +484,24 @@ def decode_video_frames_torchcodec(
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
def _decode_frames():
|
||||
# Both opening the decoder and reading frames go over the network for hf:// paths, so wrap the
|
||||
# whole unit: a transient transport error retries by dropping the dead handle and rebuilding
|
||||
# the connection (see _retry_remote_io / _recover_remote_io) instead of killing the worker.
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
average_fps = decoder.metadata.average_fps
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
return decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
frames_batch = _retry_remote_io(
|
||||
_decode_frames,
|
||||
on_retry=lambda: _recover_remote_io(decoder_cache, str(video_path)),
|
||||
max_retries=_remote_io_max_retries(),
|
||||
)
|
||||
|
||||
loaded_ts = []
|
||||
loaded_frames = []
|
||||
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -25,52 +24,6 @@ from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
@@ -120,10 +73,9 @@ def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
"""Each epoch covers every frame exactly once; shuffle reshuffles across epochs."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
@@ -138,25 +90,17 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
repo_id=repo_id, root=local_path, episode_pool_size=4, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
epochs = [[int(frame["index"]) for frame in streaming_ds] for _ in range(n_epochs)]
|
||||
for epoch_indices in epochs:
|
||||
assert sorted(epoch_indices) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
if shuffle:
|
||||
assert epochs[0] != epochs[1], "shuffle did not reshuffle across epochs"
|
||||
assert epochs[0] != list(range(ds_num_frames)), "shuffle left the stream in sequential order"
|
||||
else:
|
||||
assert epochs[0] == epochs[1] == epochs[2], "unshuffled epochs must repeat the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -164,15 +108,11 @@ def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
"""Multi-shard streams keep exactly-once coverage and deterministic per-seed order."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
@@ -187,31 +127,21 @@ def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
def make_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
episode_pool_size=3,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
first = [int(frame["index"]) for frame in make_ds()]
|
||||
again = [int(frame["index"]) for frame in make_ds()]
|
||||
|
||||
assert sorted(first) == list(range(ds_num_frames)), "epoch did not cover every frame once"
|
||||
assert first == again, "same seed must reproduce the same order"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
|
||||
|
||||
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
|
||||
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
|
||||
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
|
||||
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
WORKER = """
|
||||
import json, sys
|
||||
from accelerate import PartialState
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, episode_pool_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
|
||||
json.dump(payload, f)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
|
||||
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
|
||||
total_frames = 160
|
||||
repo_id = f"{DUMMY_REPO_ID}-acc"
|
||||
root = tmp_path / "ds"
|
||||
lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=8,
|
||||
total_frames=total_frames,
|
||||
use_videos=False,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
)
|
||||
|
||||
worker = tmp_path / "worker.py"
|
||||
worker.write_text(WORKER)
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num_processes=2",
|
||||
"--num_machines=1",
|
||||
"--mixed_precision=no",
|
||||
"--dynamo_backend=no",
|
||||
"--cpu",
|
||||
str(worker),
|
||||
str(root),
|
||||
repo_id,
|
||||
str(out_dir),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
assert result.returncode == 0, (
|
||||
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
|
||||
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
|
||||
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
|
||||
|
||||
rank_sets = [set(p["indices"]) for p in payloads]
|
||||
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
|
||||
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
@@ -0,0 +1,355 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, the episode pool (randomness, coverage, exact deltas), video
|
||||
prefetching, deterministic fast-forward resume, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
|
||||
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
|
||||
episode_frames = 300
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
|
||||
)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
episode_pool_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
|
||||
if idx + horizon_frames < episode_frames:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_fast_forward_resume_is_sample_exact(tmp_path, lerobot_dataset_factory):
|
||||
"""Resume replays the deterministic stream and continues at the exact sample."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=7,
|
||||
episode_pool_size=3,
|
||||
max_num_shards=1,
|
||||
)
|
||||
|
||||
full_epoch = _stream_indices(fresh_ds())
|
||||
assert sorted(full_epoch) == list(range(total_frames))
|
||||
|
||||
batches_consumed, batch_size = 5, 4 # 20 samples in
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict({"batches_consumed": batches_consumed, "batch_size": batch_size})
|
||||
resumed = _stream_indices(resumed_ds)
|
||||
|
||||
assert resumed == full_epoch[batches_consumed * batch_size :], (
|
||||
"fast-forward resume did not continue at the exact sample"
|
||||
)
|
||||
|
||||
|
||||
def test_pool_order_is_deterministic_per_seed(tmp_path, lerobot_dataset_factory):
|
||||
repo_id = f"{DUMMY_REPO_ID}-seeds"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=120)
|
||||
|
||||
def order(seed):
|
||||
return _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
episode_pool_size=4,
|
||||
max_num_shards=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert order(0) == order(0), "same seed must reproduce the same order"
|
||||
assert order(0) != order(1), "different seeds should give different orders"
|
||||
|
||||
|
||||
def test_pool_epochs_reshuffle_and_cover(tmp_path, lerobot_dataset_factory):
|
||||
"""Consecutive passes over the same dataset object reshuffle (epoch advances) but keep coverage."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-epochs"
|
||||
total_frames = 120
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=6, total_frames=total_frames
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=3, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
epoch_0 = _stream_indices(ds)
|
||||
epoch_1 = _stream_indices(ds)
|
||||
assert sorted(epoch_0) == sorted(epoch_1) == list(range(total_frames))
|
||||
assert epoch_0 != epoch_1, "epoch did not reshuffle"
|
||||
|
||||
|
||||
def test_pool_mixes_episodes(tmp_path, lerobot_dataset_factory):
|
||||
"""Early samples should already come from several distinct episodes (the pool's purpose)."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-mix"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, seed=0, episode_pool_size=8, max_num_shards=4
|
||||
)
|
||||
episodes_in_head = {int(frame["episode_index"]) for _, frame in zip(range(20), ds, strict=False)}
|
||||
assert len(episodes_in_head) >= 3, f"pool did not mix episodes: {episodes_in_head}"
|
||||
|
||||
|
||||
def test_video_prefetcher_refcounted_lifecycle(tmp_path):
|
||||
from lerobot.datasets.streaming_dataset import _VideoPrefetcher
|
||||
|
||||
remote = tmp_path / "remote"
|
||||
(remote / "videos").mkdir(parents=True)
|
||||
payload = b"x" * 1024
|
||||
(remote / "videos" / "a.mp4").write_bytes(payload)
|
||||
|
||||
prefetcher = _VideoPrefetcher(str(remote), cache_dir=tmp_path / "cache", max_workers=1)
|
||||
prefetcher.acquire("videos/a.mp4")
|
||||
prefetcher.acquire("videos/a.mp4") # second pooled episode sharing the file
|
||||
local = prefetcher.wait_local("videos/a.mp4")
|
||||
assert local is not None and local.read_bytes() == payload
|
||||
|
||||
prefetcher.release("videos/a.mp4")
|
||||
assert local.exists(), "file deleted while still referenced"
|
||||
prefetcher.release("videos/a.mp4")
|
||||
assert not local.exists(), "file not deleted at refcount zero"
|
||||
prefetcher.shutdown()
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, episode_pool_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, episode_pool_size=8, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
|
||||
|
||||
def test_fast_forward_resume_with_dataloader_workers(tmp_path, lerobot_dataset_factory):
|
||||
"""Resume must be exact under num_workers > 0: each worker re-derives its own skip."""
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume-workers"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=120)
|
||||
|
||||
num_workers = 2
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=True,
|
||||
seed=11,
|
||||
episode_pool_size=3,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
def epoch_samples(ds):
|
||||
# batch_size=None yields raw samples; the DataLoader round-robins them across workers,
|
||||
# which is batch_size=1 in the resume arithmetic.
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=num_workers)
|
||||
return [int(sample["index"]) for sample in loader]
|
||||
|
||||
full = epoch_samples(fresh_ds())
|
||||
|
||||
samples_consumed = 17
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict({"batches_consumed": samples_consumed, "batch_size": 1})
|
||||
resumed = epoch_samples(resumed_ds)
|
||||
|
||||
assert resumed == full[samples_consumed:], (
|
||||
"fast-forward resume with DataLoader workers did not continue at the exact sample"
|
||||
)
|
||||
Reference in New Issue
Block a user