mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-10 05:09:48 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 42d4788e4a | |||
| 2d1c17d971 | |||
| 7241f029c6 | |||
| 06ddc59913 | |||
| 23c58f5f9e | |||
| b0ab57cedc | |||
| afdc084677 | |||
| a32a2c647b | |||
| 343ecd7980 | |||
| f7c8a526e8 | |||
| 77af66a29c | |||
| 68fa5d80b0 | |||
| d1fc8e298c | |||
| 49755a3d9e |
@@ -0,0 +1,91 @@
|
||||
# Streaming dataloading benchmark
|
||||
|
||||
Measures **dataloading only** (no model) for `StreamingLeRobotDataset`: parquet read + video decode +
|
||||
delta windowing + shuffle. A dummy consumer pulls batches and moves them to the device, so the numbers
|
||||
isolate the data pipeline. Use it to compare sources (Hub vs. storage bucket vs. prewarmed bucket),
|
||||
frame modes, and node counts, and to catch p95/p99 video-decode regressions.
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
|
||||
--mode sarm --batch_size 64 --num_workers 12 --num_batches 200 \
|
||||
--source hub --out_dir benchmarks/streaming/results
|
||||
```
|
||||
|
||||
Multinode (per-node throughput) goes through Accelerate under SLURM:
|
||||
|
||||
```bash
|
||||
sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
```
|
||||
|
||||
## Matrix
|
||||
|
||||
| Axis | Values |
|
||||
| ---------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| Source | `hub` (verify now), `bucket`, `warmed_bucket` (bucket + prewarming; with user's help later) |
|
||||
| Baseline | current `main` `StreamingLeRobotDataset` on Hub streaming |
|
||||
| Nodes | 1 and 2 (per-node throughput should be independent) |
|
||||
| Frame mode | `single` (1 frame, all cameras; target ≥ 120 frames/s/node) · `sarm` (8 steps spaced 1s; target ≥ 320 frames/s/node) |
|
||||
|
||||
`--source` is a label only; the actual source is whatever `--repo_id` / `--root` / `--data_files_root`
|
||||
point at.
|
||||
|
||||
### GPU (NVDEC) decoding
|
||||
|
||||
By default video is decoded on the **CPU** in each DataLoader worker, so throughput is CPU-decode-bound and
|
||||
scales with `--num_workers` (capped by the dataset's `num_shards`). Pass `--video_decode_device cuda` to
|
||||
offload H.264/H.265 decode to the GPU's dedicated **NVDEC** engine, which runs independently of the SMs used
|
||||
for training (see <https://developer.nvidia.com/video-codec-sdk>). This requires a CUDA-enabled torchcodec
|
||||
build, and because CUDA cannot initialize in forked workers the benchmark switches to the `spawn` start
|
||||
method automatically when `--num_workers > 0`.
|
||||
|
||||
```bash
|
||||
# GPU/NVDEC decode, 6 workers, bucket source
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 \
|
||||
--data_files_root hf://buckets/pepijn223/robocasa-stream \
|
||||
--mode sarm --batch_size 64 --num_workers 6 --num_batches 200 \
|
||||
--video_decode_device cuda --source bucket
|
||||
```
|
||||
|
||||
Caveats with `cuda` + many workers: each worker creates its own CUDA context (VRAM overhead) and NVDEC has a
|
||||
limited number of concurrent decode sessions per GPU; if you hit session/IPC limits, reduce `--num_workers`
|
||||
or compare against `--num_workers 0` (single-process NVDEC, which often saturates the decode engine on its
|
||||
own). Result files include the decode device in their name (`..._w6_cuda.json`).
|
||||
|
||||
> **Codec ⇄ NVDEC compatibility (important).** NVDEC can only decode codecs its hardware supports. LeRobot's
|
||||
> **default video codec is AV1** (`VideoEncoderConfig.vcodec = "libsvtav1"`), so most v3 datasets are
|
||||
> AV1-encoded — and the **A100 and H100 compute GPUs have no AV1 NVDEC decoder**
|
||||
> (per NVIDIA's [decode support matrix](https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new));
|
||||
> only Ada (L4/L40/RTX40) and a few Ampere cards (A10/A40/A16) do. On A100/H100, AV1 must be decoded on
|
||||
> **CPU**, or the dataset re-encoded to H.265/H.264 (which those GPUs' NVDEC do support). Run
|
||||
> `diagnose_decode.py --video_decode_device cuda` to check your exact node before relying on `cuda` decode.
|
||||
> A `cuda` torchcodec build also needs an FFmpeg with NVDEC; see
|
||||
> <https://github.com/meta-pytorch/torchcodec#installing-cuda-enabled-torchcodec>.
|
||||
|
||||
Reference data root: bucket sources resolve through `--data_files_root hf://buckets/<owner>/<name>` (metadata
|
||||
still loads from `--repo_id`). The local `single`/`sarm` CPU baselines on this dataset were ~176 / ~212
|
||||
frames/s/node at `--num_workers 3` (3 cameras, fps 20).
|
||||
|
||||
## Metrics emitted (JSON + CSV)
|
||||
|
||||
`frames_per_s_node`, `samples_per_s`, `first_batch_latency_s`, `p50/p95/p99_sample_latency_ms`,
|
||||
`wallclock_s`, and `video_decoder_cache` (`hits`, `misses`, `evictions`, `hit_rate`, `size`). A low
|
||||
cache `hit_rate` with high `p99` is the decoder-thrash signature — raise `--video_decoder_cache_size`
|
||||
or `--buffer_size`, or reduce `num_workers`.
|
||||
|
||||
## Bucket sources & prewarming (manual)
|
||||
|
||||
Prewarming is a **server-side** Hugging Face storage-bucket feature — there is no client script. To
|
||||
benchmark the `warmed_bucket` source:
|
||||
|
||||
1. Attach a storage bucket to the dataset and enable it (see
|
||||
<https://huggingface.co/docs/hub/storage-buckets>). Buckets resolve through `fsspec`, the same as
|
||||
`hf://`, so no code change is needed — point `--repo_id`/`--revision` (or `--root`) at the bucket.
|
||||
2. Enable **prewarming** in the bucket settings and wait for warm-up to complete.
|
||||
3. Run the benchmark with `--source warmed_bucket`. Compare against the cold `--source bucket` and the
|
||||
`--source hub` baseline.
|
||||
|
||||
Manual only — not run in CI.
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataloading-only benchmark for StreamingLeRobotDataset.
|
||||
|
||||
A dummy consumer pulls batches and moves them to the device; no model runs, so the numbers isolate the
|
||||
data pipeline (parquet read + video decode + delta windowing + shuffle). Reports per-node throughput and
|
||||
sample-latency percentiles, plus video-decoder-cache reuse stats, and emits JSON + CSV.
|
||||
|
||||
Frame modes (matching the streaming design targets):
|
||||
- ``single``: one frame, all cameras (target >= 120 frames/s/node).
|
||||
- ``sarm``: an 8-step window spaced 1s (delta over 8s) (target >= 320 frames/s/node).
|
||||
|
||||
Example (stream from the Hub, single node):
|
||||
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id pepijn223/robocasa_pretrain_human300_v4 --mode sarm \
|
||||
--batch_size 64 --num_workers 12 --num_batches 200 --out_dir benchmarks/streaming/results
|
||||
|
||||
Distributed / multinode runs go through Accelerate; see ``slurm/benchmark_streaming_robocasa.sh``. Set
|
||||
``--source`` purely for labeling the output (``hub`` / ``bucket`` / ``warmed_bucket``); the actual source
|
||||
is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarming.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import statistics
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--repo_id", type=str, required=True)
|
||||
parser.add_argument("--root", type=str, default=None, help="Local/prewarmed root (else stream from Hub).")
|
||||
parser.add_argument(
|
||||
"--data_files_root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="fsspec root for bulk data/videos, e.g. hf://buckets/<owner>/<name>. Metadata still loads "
|
||||
"from --repo_id on the Hub. Use for bucket / warmed_bucket sources.",
|
||||
)
|
||||
parser.add_argument("--mode", choices=["single", "sarm"], default="single")
|
||||
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--buffer_size", type=int, default=2000)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--video_decode_device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Decode device passed to torchcodec. 'cuda' offloads decode to the GPU's NVDEC engine "
|
||||
"(needs a CUDA-enabled torchcodec build). With num_workers>0 this forces the 'spawn' start method.",
|
||||
)
|
||||
parser.add_argument("--num_batches", type=int, default=200)
|
||||
parser.add_argument("--warmup_batches", type=int, default=5, help="Excluded from steady-state stats.")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--out_dir", type=str, default="benchmarks/streaming/results")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> StreamingLeRobotDataset:
|
||||
# sarm: an 8-step window spaced 1s => an 8s delta window (the SARM stress case).
|
||||
delta_timestamps = {ACTION: [float(t) for t in range(8)]} if args.mode == "sarm" else None
|
||||
return StreamingLeRobotDataset(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
data_files_root=args.data_files_root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
video_decode_device=args.video_decode_device,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
|
||||
|
||||
def percentile(values: list[float], pct: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
ordered = sorted(values)
|
||||
k = max(0, min(len(ordered) - 1, int(round((pct / 100.0) * (len(ordered) - 1)))))
|
||||
return ordered[k]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
device = torch.device(args.device)
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
||||
dataset = build_dataset(args, meta)
|
||||
|
||||
gpu_decode = args.video_decode_device.startswith("cuda")
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
# GPU-decoded frames are already on the GPU, so CPU pinning is irrelevant (and pinning CUDA
|
||||
# tensors errors). Pin only when decode is on CPU and we copy to a CUDA device.
|
||||
pin_memory=device.type == "cuda" and not gpu_decode,
|
||||
drop_last=True,
|
||||
prefetch_factor=2 if args.num_workers > 0 else None,
|
||||
# CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method.
|
||||
multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
frames = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None # wall-clock start of the post-warmup measurement window
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
for i, batch in enumerate(loader):
|
||||
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
|
||||
if i == args.warmup_batches:
|
||||
# Start the steady window here; the slow first batch and the prefetch queue it filled are
|
||||
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
frames += args.batch_size
|
||||
t_prev = now
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
# Wall-clock throughput over the steady window. NOT sum(inter-batch gaps): under async prefetch those
|
||||
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
cache_stats = dataset.video_decoder_cache_stats()
|
||||
|
||||
# A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode
|
||||
# error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is
|
||||
# never reported green with NaN/zero numbers.
|
||||
if frames == 0:
|
||||
raise SystemExit(
|
||||
f"FAILED: measured 0 frames over {args.num_batches} requested batches "
|
||||
f"(cache misses={cache_stats.get('misses', 0)}, hits={cache_stats.get('hits', 0)}). "
|
||||
"The data pipeline yielded no usable batches — inspect worker logs for decode errors. "
|
||||
"Try --num_workers 0 to surface the underlying exception directly."
|
||||
)
|
||||
|
||||
results = {
|
||||
"repo_id": args.repo_id,
|
||||
"source": args.source,
|
||||
"mode": args.mode,
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"buffer_size": args.buffer_size,
|
||||
"num_cameras": len(meta.video_keys),
|
||||
"fps": meta.fps,
|
||||
"device": str(device),
|
||||
"video_decode_device": args.video_decode_device,
|
||||
"frames_measured": frames,
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4),
|
||||
"frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"samples_per_s": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
"p50_sample_latency_ms": round(statistics.median(sample_latencies_ms), 3)
|
||||
if sample_latencies_ms
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"wallclock_s": round(elapsed, 2),
|
||||
"video_decoder_cache": cache_stats,
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}_{args.video_decode_device}"
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=list(flat))
|
||||
writer.writeheader()
|
||||
writer.writerow(flat)
|
||||
|
||||
print("Command config:", vars(args))
|
||||
print(json.dumps(results, indent=2))
|
||||
print(f"Wrote {out_dir / tag}.json and .csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Isolate the streaming video-decode path — no SLURM, no DataLoader, no benchmark loop.
|
||||
|
||||
Reproduces exactly what StreamingLeRobotDataset does for one video (resolve path -> fsspec.open ->
|
||||
torchcodec VideoDecoder -> get one frame) and prints the environment + the first bytes of the handle, so
|
||||
a decode failure ("No valid stream found in input file") can be pinpointed: bad/placeholder bytes vs a
|
||||
torchcodec/ffmpeg build issue vs a device issue.
|
||||
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id pepijn223/robocasa_pretrain_human300_v4
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id … --data_files_root hf://buckets/<o>/<n>
|
||||
python benchmarks/streaming/diagnose_decode.py --repo_id … --video_decode_device cuda
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib.metadata as im
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata
|
||||
|
||||
|
||||
def _version(pkg: str) -> str:
|
||||
try:
|
||||
return im.version(pkg)
|
||||
except Exception:
|
||||
return "MISSING"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--repo_id", required=True)
|
||||
p.add_argument("--data_files_root", default=None, help="e.g. hf://buckets/<owner>/<name>")
|
||||
p.add_argument("--revision", default=None)
|
||||
p.add_argument("--video_decode_device", default="cpu")
|
||||
p.add_argument("--episode", type=int, default=0)
|
||||
args = p.parse_args()
|
||||
|
||||
print("== environment ==")
|
||||
for pkg in ("torchcodec", "av", "huggingface_hub", "hf_xet", "datasets", "fsspec"):
|
||||
print(f" {pkg}: {_version(pkg)}")
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
video_key = meta.video_keys[0]
|
||||
rel_path = meta.get_video_file_path(args.episode, video_key)
|
||||
root = args.data_files_root.rstrip("/") if args.data_files_root else meta.url_root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
print("\n== target ==")
|
||||
print(f" video_key: {video_key}")
|
||||
print(f" video_path: {video_path}")
|
||||
|
||||
print("\n== fsspec handle ==")
|
||||
try:
|
||||
fh = fsspec.open(video_path).__enter__()
|
||||
head = fh.read(32)
|
||||
print(f" first 32 bytes (hex): {head.hex()}")
|
||||
# A valid MP4/MOV has an 'ftyp' box near the start; anything else (HTML/JSON/empty) means the
|
||||
# handle resolved to a placeholder or error page, not the video bytes.
|
||||
looks_mp4 = b"ftyp" in head
|
||||
print(f" looks like MP4 (contains 'ftyp'): {looks_mp4}")
|
||||
if not looks_mp4:
|
||||
print(f" !! first bytes as text: {head[:32]!r}")
|
||||
fh.seek(0)
|
||||
except Exception as e:
|
||||
print(f" !! fsspec.open/read FAILED: {type(e).__name__}: {e}")
|
||||
return
|
||||
|
||||
print("\n== torchcodec VideoDecoder ==")
|
||||
try:
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
decoder = VideoDecoder(fh, seek_mode="approximate", device=args.video_decode_device)
|
||||
md = decoder.metadata
|
||||
print(f" OK: {md.num_frames} frames, {md.average_fps} fps, codec={getattr(md, 'codec', '?')}")
|
||||
frame = decoder.get_frames_at(indices=[0])
|
||||
print(f" decoded frame 0: shape={tuple(frame.data.shape)}, device={frame.data.device}")
|
||||
print("\nDECODE OK — the streaming pipeline can read this video on this machine.")
|
||||
except Exception as e:
|
||||
print(f" !! VideoDecoder FAILED: {type(e).__name__}: {e}")
|
||||
print(
|
||||
"\nDECODE FAILED. If the bytes above look like MP4 (ftyp=True), this is a torchcodec/ffmpeg "
|
||||
"build issue, NOT bad bytes. Common cause for LeRobot v3 datasets: the videos are AV1-encoded "
|
||||
"(see the 'codec' line on a working machine). Then:\n"
|
||||
" - CPU decode needs an ffmpeg built with an AV1 decoder (libdav1d/libaom); a build without it "
|
||||
"reports 'No valid stream found'.\n"
|
||||
" - GPU/NVDEC decode of AV1 is only on AV1-capable NVDEC GPUs: Ada (L4/L40/RTX40) and some "
|
||||
"Ampere (A10/A40/A16). The COMPUTE GPUs A100 and H100 have NO AV1 NVDEC decoder (per NVIDIA's "
|
||||
"support matrix), so no torchcodec build enables cuda decode of AV1 on them.\n"
|
||||
" - 'Unsupported device: cuda (variant: ffmpeg)' instead means torchcodec was built without "
|
||||
"the CUDA backend; install a CUDA-enabled wheel (see README) — but on A100/H100 that still "
|
||||
"won't decode AV1.\n"
|
||||
"Fix: decode on CPU, run NVDEC on an Ada GPU, or re-encode the dataset to H.265/H.264 (which "
|
||||
"A100/H100 NVDEC do support).\n"
|
||||
"If ftyp=False instead, the handle resolved to a placeholder/error page (auth, revision, or Xet "
|
||||
"resolution) rather than the video bytes."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+79
@@ -0,0 +1,79 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Collapse a directory of benchmark JSON results into one comparison table (and a combined CSV).
|
||||
|
||||
python benchmarks/streaming/summarize_results.py benchmarks/streaming/results
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
COLUMNS = [
|
||||
("source", "source"),
|
||||
("mode", "mode"),
|
||||
("video_decode_device", "decode"),
|
||||
("num_workers", "workers"),
|
||||
("batch_size", "bs"),
|
||||
("frames_per_s_node", "frames/s/node"),
|
||||
("first_batch_latency_s", "first_batch_s"),
|
||||
("p50_sample_latency_ms", "p50_ms"),
|
||||
("p95_sample_latency_ms", "p95_ms"),
|
||||
("p99_sample_latency_ms", "p99_ms"),
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results_dir = Path(sys.argv[1] if len(sys.argv) > 1 else "benchmarks/streaming/results")
|
||||
files = sorted(results_dir.rglob("*.json"))
|
||||
if not files:
|
||||
print(f"No JSON results under {results_dir}")
|
||||
return
|
||||
|
||||
rows = []
|
||||
for f in files:
|
||||
d = json.loads(f.read_text())
|
||||
d["hit_rate"] = d.get("video_decoder_cache", {}).get("hit_rate")
|
||||
rows.append(d)
|
||||
|
||||
rows.sort(key=lambda r: (r.get("source", ""), r.get("mode", ""), r.get("video_decode_device", "")))
|
||||
|
||||
headers = [label for _, label in COLUMNS] + ["cache_hit_rate"]
|
||||
widths = {h: len(h) for h in headers}
|
||||
table = []
|
||||
for r in rows:
|
||||
row = {label: r.get(key, "") for key, label in COLUMNS}
|
||||
row["cache_hit_rate"] = r.get("hit_rate", "")
|
||||
table.append(row)
|
||||
for h in headers:
|
||||
widths[h] = max(widths[h], len(str(row[h])))
|
||||
|
||||
line = " ".join(h.ljust(widths[h]) for h in headers)
|
||||
print(line)
|
||||
print(" ".join("-" * widths[h] for h in headers))
|
||||
for row in table:
|
||||
print(" ".join(str(row[h]).ljust(widths[h]) for h in headers))
|
||||
|
||||
combined = results_dir / "summary.csv"
|
||||
with open(combined, "w", newline="") as fh:
|
||||
writer = csv.DictWriter(fh, fieldnames=headers)
|
||||
writer.writeheader()
|
||||
writer.writerows(table)
|
||||
print(f"\nWrote {combined}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,169 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Distributed, resumable streaming training on a large HF-hosted dataset.
|
||||
|
||||
This example shows how to train (or just stress the data pipeline) over a multi-TB dataset that never
|
||||
touches local disk, scaling across GPUs and nodes with Accelerate. It demonstrates the large-scale
|
||||
streaming features of :class:`StreamingLeRobotDataset`:
|
||||
|
||||
- per-rank sharding via ``split_dataset_by_node`` (each GPU streams disjoint data; ``rank``/``world_size``
|
||||
are auto-resolved from the Accelerate state, so nothing needs to be passed explicitly);
|
||||
- DataLoader-worker shard splitting (no duplicate frames within a rank);
|
||||
- resumable streaming via ``dataset.state_dict()`` / ``load_state_dict()`` saved into the checkpoint;
|
||||
- an explicit video-decoder cache size so the working set of open decoders does not thrash.
|
||||
|
||||
Launch with Accelerate (single node, N GPUs):
|
||||
|
||||
accelerate launch --num_processes=8 examples/scaling/train_streaming_multinode.py \
|
||||
--repo_id=pepijn223/robocasa_pretrain_human300_v4 --batch_size=64
|
||||
|
||||
Multinode runs use the same script under SLURM; see ``slurm/train_streaming_robocasa.sh``.
|
||||
|
||||
Pass ``--dummy`` to skip the model entirely and measure pure dataloading throughput.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--repo_id", type=str, default="lerobot/droid_1.0.1")
|
||||
parser.add_argument(
|
||||
"--root", type=str, default=None, help="Local/prewarmed dataset root (else stream from Hub)."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default="outputs/train/streaming_multinode")
|
||||
parser.add_argument("--steps", type=int, default=1000)
|
||||
parser.add_argument("--batch_size", type=int, default=64, help="Per-process batch size.")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--buffer_size", type=int, default=2000, help="Output shuffle-buffer size, in frames."
|
||||
)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument("--n_action_steps", type=int, default=16, help="Action-chunk length (delta horizon).")
|
||||
parser.add_argument("--save_freq", type=int, default=200)
|
||||
parser.add_argument("--log_freq", type=int, default=20)
|
||||
parser.add_argument("--resume_from", type=str, default=None, help="Checkpoint dir to resume from.")
|
||||
parser.add_argument("--dummy", action="store_true", help="Skip the model; measure dataloading only.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_dataloader(
|
||||
args: argparse.Namespace, meta: LeRobotDatasetMetadata
|
||||
) -> tuple[DataLoader, StreamingLeRobotDataset]:
|
||||
# Supervise an action chunk; delta_timestamps drive the SARM-style temporal window.
|
||||
delta_timestamps = {ACTION: [t / meta.fps for t in range(args.n_action_steps)]}
|
||||
# rank / world_size are resolved automatically from the Accelerate state inside the dataset.
|
||||
dataset = StreamingLeRobotDataset(
|
||||
args.repo_id,
|
||||
root=args.root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
prefetch_factor=2 if args.num_workers > 0 else None,
|
||||
)
|
||||
return loader, dataset
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
accelerator = Accelerator()
|
||||
output_dir = Path(args.output_dir)
|
||||
if accelerator.is_main_process:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, root=args.root)
|
||||
loader, dataset = make_dataloader(args, meta)
|
||||
|
||||
if args.dummy:
|
||||
model = optimizer = None
|
||||
else:
|
||||
from lerobot.policies.act import ACTConfig, ACTPolicy
|
||||
from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
features = dataset_to_policy_features(meta.features)
|
||||
output_features = {k: ft for k, ft in features.items() if k == ACTION}
|
||||
input_features = {k: ft for k, ft in features.items() if k not in output_features}
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
model = ACTPolicy(cfg)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
|
||||
|
||||
# Resume: restore the dataset's stream position so we don't replay already-seen data. The state holds
|
||||
# plain HF stream dicts + RNG state (not tensors), so weights_only=False is required; the file is a
|
||||
# checkpoint this script wrote itself.
|
||||
if args.resume_from is not None:
|
||||
state = torch.load(Path(args.resume_from) / "dataset_state.pt", weights_only=False) # nosec B614
|
||||
dataset.load_state_dict(state)
|
||||
accelerator.print(f"Resumed dataset stream from {args.resume_from}")
|
||||
|
||||
step = 0
|
||||
frames_seen = 0
|
||||
window_start = time.perf_counter()
|
||||
done = False
|
||||
while not done:
|
||||
for batch in loader:
|
||||
if model is not None:
|
||||
batch = {k: (v.to(accelerator.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
||||
loss, _ = model.forward(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
step += 1
|
||||
frames_seen += args.batch_size
|
||||
if step % args.log_freq == 0:
|
||||
elapsed = time.perf_counter() - window_start
|
||||
fps_per_proc = (args.log_freq * args.batch_size) / max(elapsed, 1e-9)
|
||||
total_fps = fps_per_proc * accelerator.num_processes
|
||||
accelerator.print(
|
||||
f"step {step} | {fps_per_proc:.1f} frames/s/proc | {total_fps:.1f} frames/s total"
|
||||
+ ("" if model is None else f" | loss {loss.item():.3f}")
|
||||
)
|
||||
window_start = time.perf_counter()
|
||||
|
||||
if step % args.save_freq == 0 and accelerator.is_main_process:
|
||||
ckpt = output_dir / f"checkpoint-{step}"
|
||||
ckpt.mkdir(parents=True, exist_ok=True)
|
||||
# Save the dataset stream position alongside the model so a restart resumes mid-stream.
|
||||
torch.save(dataset.state_dict(), ckpt / "dataset_state.pt")
|
||||
if model is not None:
|
||||
accelerator.unwrap_model(model).save_pretrained(ckpt)
|
||||
|
||||
if step >= args.steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
accelerator.print(f"End of training: {step} steps, ~{frames_seen} frames/proc")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench_stream
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --time=02:00:00
|
||||
#SBATCH --output=logs/%x-%j.out
|
||||
|
||||
# Per-node dataloading benchmark for StreamingLeRobotDataset across 1-2 nodes. Each node runs an
|
||||
# independent dummy-consumer benchmark; per-node throughput should be independent (separate network).
|
||||
# Results are written per (node, source, mode) under --out_dir.
|
||||
#
|
||||
# Submit with: sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
# Override the source label for cold/warm bucket runs: SOURCE=warmed_bucket sbatch slurm/benchmark_streaming_robocasa.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
SOURCE=${SOURCE:-hub}
|
||||
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
|
||||
|
||||
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# One benchmark process per node (each saturates the node's DataLoader workers + network independently).
|
||||
srun --kill-on-bad-exit=1 bash -c '
|
||||
for MODE in single sarm; do
|
||||
python benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id '"$REPO_ID"' \
|
||||
--source '"$SOURCE"' \
|
||||
--mode $MODE \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--buffer_size 4000 \
|
||||
--num_batches 300 \
|
||||
--out_dir '"$OUT_DIR"'/node${SLURM_NODEID}
|
||||
done
|
||||
'
|
||||
Executable
+100
@@ -0,0 +1,100 @@
|
||||
#!/bin/bash
|
||||
# Submit the FULL streaming dataloading-benchmark matrix as isolated single-GPU SLURM jobs.
|
||||
#
|
||||
# sources : hub (Hub streaming) | bucket (cold HF bucket) | warmed_bucket (prewarmed HF bucket)
|
||||
# modes : single (1 frame, all cameras) | sarm (8-step / 8s delta window)
|
||||
# decode : cpu (torchcodec on CPU, scales with workers) | cuda (NVDEC, offloads decode to the GPU)
|
||||
#
|
||||
# => 3 x 2 x 2 = 12 jobs. Each runs in its OWN job (1 node, 1 GPU) so an OOM is isolated and reported
|
||||
# per-job by SLURM (check `sacct -j <id> --format=JobID,State,MaxRSS,ReqMem`). Submit from a login node
|
||||
# inside the repo: bash slurm/run_streaming_matrix.sh
|
||||
#
|
||||
# SERIAL (default 1): chain the jobs with --dependency=afterany so SLURM runs exactly ONE at a time. This
|
||||
# is important for a bandwidth benchmark — concurrent jobs would share the network to the Hub/bucket and
|
||||
# corrupt every throughput number. `afterany` means a failed/OOM'd job does not stall the chain. Set
|
||||
# SERIAL=0 to let the scheduler run them in parallel (only for OOM-isolation testing, not for throughput).
|
||||
#
|
||||
# Knobs (env overrides):
|
||||
# REPO_ID, BUCKET, WARM_BUCKET, OUT_DIR, NUM_BATCHES, TIME, MEM, GPUS, SERIAL
|
||||
# CPU_WORKERS / CPU_BUFFER (cpu-decode jobs) GPU_WORKERS / GPU_BUFFER (cuda-decode jobs, kept low to
|
||||
# bound VRAM + NVDEC sessions). RUN ("python" by default; set RUN="uv run python" if using uv).
|
||||
# SOURCES / MODES / DECODES to run a subset (e.g. SOURCES="hub bucket" DECODES="cpu").
|
||||
# ACCOUNT / PARTITION / QOS passed through to sbatch if set.
|
||||
set -euo pipefail
|
||||
|
||||
REPO_DIR=$(git rev-parse --show-toplevel)
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
BUCKET=${BUCKET:-hf://buckets/pepijn223/robocasa-stream}
|
||||
WARM_BUCKET=${WARM_BUCKET:-hf://buckets/pepijn223/robocasa-stream-warm}
|
||||
OUT_DIR=${OUT_DIR:-benchmarks/streaming/results}
|
||||
NUM_BATCHES=${NUM_BATCHES:-200}
|
||||
TIME=${TIME:-01:00:00}
|
||||
MEM=${MEM:-64G}
|
||||
GPUS=${GPUS:-1}
|
||||
SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement)
|
||||
CPU_WORKERS=${CPU_WORKERS:-8}
|
||||
GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session
|
||||
CPU_BUFFER=${CPU_BUFFER:-4000}
|
||||
GPU_BUFFER=${GPU_BUFFER:-1000} # smaller buffer bounds on-GPU frame memory
|
||||
BATCH_SIZE=${BATCH_SIZE:-64}
|
||||
RUN=${RUN:-python}
|
||||
# CONDA_ENV=<name> runs each job via `conda run -n <name>` (no activation needed inside the dash --wrap;
|
||||
# --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11)
|
||||
# + datasets (>=4.7) — the default `base` env on many clusters is too old to decode AV1 / lacks CUDA.
|
||||
CONDA_ENV=${CONDA_ENV:-}
|
||||
if [ -n "$CONDA_ENV" ] && [ "$RUN" = "python" ]; then
|
||||
RUN="conda run --no-capture-output -n $CONDA_ENV python"
|
||||
fi
|
||||
|
||||
SOURCES=${SOURCES:-"hub bucket warmed_bucket"}
|
||||
MODES=${MODES:-"single sarm"}
|
||||
DECODES=${DECODES:-"cpu cuda"}
|
||||
|
||||
mkdir -p "$REPO_DIR/logs" "$REPO_DIR/$OUT_DIR"
|
||||
|
||||
data_root_for () {
|
||||
case "$1" in
|
||||
hub) echo "" ;;
|
||||
bucket) echo "$BUCKET" ;;
|
||||
warmed_bucket) echo "$WARM_BUCKET" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
n=0
|
||||
prev_jid=""
|
||||
for SOURCE in $SOURCES; do
|
||||
DATA_ROOT=$(data_root_for "$SOURCE")
|
||||
ROOTFLAG=""
|
||||
[ -n "$DATA_ROOT" ] && ROOTFLAG="--data_files_root $DATA_ROOT"
|
||||
for MODE in $MODES; do
|
||||
for DECODE in $DECODES; do
|
||||
if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi
|
||||
# Run strictly after the previous job so only one job touches the network at a time.
|
||||
DEPFLAG=""
|
||||
if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi
|
||||
jid=$(sbatch --parsable \
|
||||
--job-name="bench_${SOURCE}_${MODE}_${DECODE}" \
|
||||
--nodes=1 --ntasks=1 --gpus="$GPUS" --cpus-per-task=$((W + 4)) \
|
||||
--mem="$MEM" --time="$TIME" --output="$REPO_DIR/logs/%x-%j.out" \
|
||||
$DEPFLAG \
|
||||
${ACCOUNT:+--account=$ACCOUNT} ${PARTITION:+--partition=$PARTITION} ${QOS:+--qos=$QOS} \
|
||||
--wrap "cd '$REPO_DIR' && \
|
||||
export TOKENIZERS_PARALLELISM=false && export HF_HOME=\${HF_HOME:-\$SCRATCH/hf_home} && \
|
||||
$RUN benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id $REPO_ID $ROOTFLAG \
|
||||
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
|
||||
--batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \
|
||||
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
|
||||
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
|
||||
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
|
||||
prev_jid=$jid
|
||||
n=$((n + 1))
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))."
|
||||
echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)"
|
||||
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_<decode>.{json,csv}"
|
||||
echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR"
|
||||
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=stream_robocasa
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --output=logs/%x-%j.out
|
||||
|
||||
# Multinode streaming training over a large HF-hosted RoboCasa dataset (never touches local disk).
|
||||
# Launches examples/scaling/train_streaming_multinode.py with Accelerate. Each rank streams a disjoint
|
||||
# set of shards via split_dataset_by_node (auto-resolved from the Accelerate state), so per-node
|
||||
# throughput scales independently. For an even split, ensure n_shards % (nodes * gpus_per_node) == 0.
|
||||
#
|
||||
# Submit with: sbatch slurm/train_streaming_robocasa.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ID=${REPO_ID:-pepijn223/robocasa_pretrain_human300_v4}
|
||||
GPUS_PER_NODE=8
|
||||
NUM_PROCESSES=$((SLURM_NNODES * GPUS_PER_NODE))
|
||||
|
||||
# Rendezvous: use the first node in the allocation as the main process.
|
||||
MAIN_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
|
||||
MAIN_PORT=${MAIN_PORT:-29500}
|
||||
|
||||
export HF_HOME=${HF_HOME:-$SCRATCH/hf_home}
|
||||
# Avoid each rank fighting over the tokenizers' internal thread pool.
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
srun --kill-on-bad-exit=1 bash -c '
|
||||
accelerate launch \
|
||||
--num_machines '"$SLURM_NNODES"' \
|
||||
--num_processes '"$NUM_PROCESSES"' \
|
||||
--machine_rank $SLURM_NODEID \
|
||||
--main_process_ip '"$MAIN_ADDR"' \
|
||||
--main_process_port '"$MAIN_PORT"' \
|
||||
--mixed_precision bf16 \
|
||||
--dynamo_backend no \
|
||||
examples/scaling/train_streaming_multinode.py \
|
||||
--repo_id '"$REPO_ID"' \
|
||||
--batch_size 64 \
|
||||
--num_workers 12 \
|
||||
--buffer_size 4000 \
|
||||
--steps 200000 \
|
||||
--save_freq 2000 \
|
||||
--log_freq 50
|
||||
'
|
||||
@@ -39,6 +39,9 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Output shuffle-buffer size (in frames) when streaming. Larger decorrelates samples better at the cost
|
||||
# of host RAM. Ignored when streaming is False.
|
||||
streaming_buffer_size: int = 1000
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
|
||||
@@ -106,7 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
buffer_size=cfg.dataset.streaming_buffer_size,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from pathlib import Path
|
||||
@@ -21,6 +24,7 @@ import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
|
||||
@@ -38,6 +42,8 @@ from .video_utils import (
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
@@ -252,6 +258,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
return_uint8: bool = False,
|
||||
rank: int | None = None,
|
||||
world_size: int | None = None,
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
video_decode_device: str = "cpu",
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -272,6 +283,25 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
rank (int | None, optional): This process' rank for distributed (multi-GPU/multi-node) training.
|
||||
Each rank streams a disjoint set of shards via ``split_dataset_by_node``. When omitted, it is
|
||||
resolved from Accelerate (``process_index``) or the ``RANK`` env var, defaulting to 0.
|
||||
world_size (int | None, optional): Total number of distributed processes. When omitted, resolved
|
||||
from Accelerate (``num_processes``) or the ``WORLD_SIZE`` env var, defaulting to 1 (no sharding).
|
||||
For an even per-rank split, ``num_shards % world_size == 0`` should hold.
|
||||
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
|
||||
When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working
|
||||
set of live decoders never thrashes. See :class:`VideoDecoderCache`.
|
||||
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
|
||||
trees (e.g. an HF storage bucket ``hf://buckets/<owner>/<name>``). When set, parquet and
|
||||
video frames are read from there while metadata still loads from ``repo_id`` on the Hub.
|
||||
Resolves through fsspec exactly like ``hf://``; use it to benchmark bucket / prewarmed-bucket
|
||||
sources without copying the (small) metadata.
|
||||
video_decode_device (str, optional): Device for video decoding, passed to the torchcodec
|
||||
``VideoDecoder``. Defaults to ``"cpu"``. Set to ``"cuda"`` to offload H.264/H.265 decode to
|
||||
the GPU's dedicated NVDEC engine (independent of the training SMs), which requires a
|
||||
CUDA-enabled torchcodec build. Note: ``"cuda"`` decode inside ``DataLoader`` workers needs
|
||||
the ``spawn`` start method (CUDA cannot init in forked workers).
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -289,10 +319,21 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
self.max_num_shards = max_num_shards
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
|
||||
self.video_decode_device = video_decode_device
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
# Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into
|
||||
# one place the main process can read after iteration (see video_decoder_cache_stats()).
|
||||
self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_()
|
||||
# Resume state captured by load_state_dict() and consumed at the next __iter__.
|
||||
self._resume_state: dict | None = None
|
||||
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -314,13 +355,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
if self.data_files_root is not None:
|
||||
# Bulk data lives in an fsspec root (e.g. an HF storage bucket); metadata stays on the Hub.
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
"parquet",
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files=f"{self.data_files_root}/data/*/*.parquet",
|
||||
)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Drop any parquet columns not declared in the dataset's feature contract. Some revisions / sources
|
||||
# (e.g. an unversioned bucket holding `main`) carry extra, possibly variable-length annotation
|
||||
# columns such as `language_events`; left in, they leak into the sample and break default DataLoader
|
||||
# collation across frames of differing length. On a clean revision this is a no-op.
|
||||
known_columns = set(self.meta.features)
|
||||
extra_columns = [c for c in (self.hf_dataset.column_names or []) if c not in known_columns]
|
||||
if extra_columns:
|
||||
self.hf_dataset = self.hf_dataset.remove_columns(extra_columns)
|
||||
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
|
||||
@@ -348,22 +407,99 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_distributed(rank: int | None, world_size: int | None) -> tuple[int, int]:
|
||||
"""Resolve (rank, world_size) for distributed streaming.
|
||||
|
||||
Explicit arguments win. Otherwise prefer an already-initialized Accelerate state, then the
|
||||
``RANK``/``WORLD_SIZE`` env vars set by launchers, and finally fall back to single-process (0, 1).
|
||||
"""
|
||||
if rank is not None and world_size is not None:
|
||||
return rank, world_size
|
||||
|
||||
try:
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if PartialState._shared_state: # only read it if already initialized; never initialize here
|
||||
state = PartialState()
|
||||
return state.process_index, state.num_processes
|
||||
except Exception:
|
||||
logger.debug("Could not resolve distributed state from Accelerate; using env/defaults.")
|
||||
|
||||
env_rank = os.environ.get("RANK")
|
||||
env_world = os.environ.get("WORLD_SIZE")
|
||||
if env_rank is not None and env_world is not None:
|
||||
return int(env_rank), int(env_world)
|
||||
|
||||
return 0, 1
|
||||
|
||||
def _make_video_decoder_cache(self, num_active_shards: int) -> VideoDecoderCache:
|
||||
"""Size the decoder cache to the working set of live shards so it does not thrash.
|
||||
|
||||
Each shard mid-episode keeps one open decoder per camera; with several shards iterated
|
||||
concurrently the working set is ``num_active_shards × num_cameras``. We add one shard worth of
|
||||
margin so the round-robin never evicts a still-live decoder.
|
||||
"""
|
||||
if self.video_decoder_cache_size is not None:
|
||||
return VideoDecoderCache(
|
||||
max_size=self.video_decoder_cache_size,
|
||||
counters=self._cache_counters,
|
||||
device=self.video_decode_device,
|
||||
)
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if num_cameras == 0:
|
||||
return VideoDecoderCache(counters=self._cache_counters, device=self.video_decode_device)
|
||||
return VideoDecoderCache(
|
||||
max_size=(num_active_shards + 1) * num_cameras,
|
||||
counters=self._cache_counters,
|
||||
device=self.video_decode_device,
|
||||
)
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
# Distributed correctness: each rank streams a disjoint set of shards (order preserved).
|
||||
ds = self.hf_dataset
|
||||
if self.world_size > 1:
|
||||
ds = split_dataset_by_node(ds, rank=self.rank, world_size=self.world_size)
|
||||
|
||||
num_shards = min(ds.num_shards, self.max_num_shards)
|
||||
shard_indices = list(range(num_shards))
|
||||
|
||||
# DataLoader workers within this rank further split the shards so they don't yield duplicates.
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
shard_indices = shard_indices[worker_info.id :: worker_info.num_workers]
|
||||
|
||||
self.video_decoder_cache = self._make_video_decoder_cache(len(shard_indices))
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
# Best-effort resume: restore RNG + exhausted shards and rewind each shard's HF stream. The
|
||||
# shuffle buffer is re-warmed rather than restored, so resumption is not bit-exact (acceptable
|
||||
# for pretraining); the underlying stream may also skip the few frames Backtrackable read ahead.
|
||||
resume = self._resume_state
|
||||
self._resume_state = None
|
||||
self._exhausted: set[int] = set(resume["exhausted"]) if resume is not None else set()
|
||||
if resume is not None:
|
||||
rng.bit_generator.state = resume["rng"]
|
||||
|
||||
self._shards: dict[int, datasets.IterableDataset] = {}
|
||||
for idx in shard_indices:
|
||||
shard = safe_shard(ds, idx, num_shards)
|
||||
if resume is not None and str(idx) in resume["shards"]:
|
||||
shard.load_state_dict(resume["shards"][str(idx)])
|
||||
self._shards[idx] = shard
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
idx: self._make_backtrackable_dataset(shard)
|
||||
for idx, shard in self._shards.items()
|
||||
if idx not in self._exhausted
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
@@ -389,11 +525,47 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
self._exhausted.add(shard_key)
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Capture resume state: per-shard HF stream position, exhausted shards, and RNG state.
|
||||
|
||||
Must be called after iteration has started (so the shard streams exist). Restore the returned
|
||||
dict with :meth:`load_state_dict` before re-iterating. The shuffle buffer is not captured, so
|
||||
resumption is not bit-exact — see :meth:`__iter__`.
|
||||
"""
|
||||
if not hasattr(self, "_shards"):
|
||||
raise RuntimeError("state_dict() requires the dataset to have been iterated at least once.")
|
||||
return {
|
||||
"shards": {str(idx): shard.state_dict() for idx, shard in self._shards.items()},
|
||||
"exhausted": sorted(self._exhausted),
|
||||
"rng": self.rng.bit_generator.state,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
"""Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``."""
|
||||
self._resume_state = state_dict
|
||||
|
||||
def video_decoder_cache_stats(self) -> dict[str, int | float]:
|
||||
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
|
||||
|
||||
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
|
||||
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
|
||||
approximate; the ``hit_rate`` ratio is preserved.
|
||||
"""
|
||||
hits, misses, evictions = (int(x) for x in self._cache_counters.tolist())
|
||||
total = hits + misses
|
||||
return {
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"evictions": evictions,
|
||||
"hit_rate": round(hits / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
@@ -405,19 +577,23 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
else:
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
# Dynamically size the windows to exactly cover the requested delta_timestamps (in frames).
|
||||
# This removes the fixed LOOKAHEAD_BACKTRACKTABLE ceiling, which would raise LookAheadError for
|
||||
# long horizons (e.g. a SARM window of 8 steps spaced 1s = ~160 frames @ fps20).
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
lookback = math.floor(min(all_timestamps) * self.fps)
|
||||
lookahead = math.ceil(max(all_timestamps) * self.fps)
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
lookback = 0 if lookback >= 0 else -lookback
|
||||
|
||||
return lookback, lookahead
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps, dynamic_bounds=True)
|
||||
# Backtrackable.peek_back(n) needs `history >= n + 1`, so reach a frame `lookback` steps back requires
|
||||
# history = lookback + 1. history must be >= 1 and lookahead > 0, so clamp both to at least 1.
|
||||
return Backtrackable(dataset, history=max(1, lookback + 1), lookahead=max(1, lookahead))
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
@@ -473,13 +649,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
# `timestamp` is episode-local (restarts at 0 each episode). The absolute in-file timestamp is
|
||||
# `from_timestamp + timestamp`, applied per camera at decode time (see `_query_videos`), mirroring
|
||||
# the map-style reader. Using `index / fps` here is a dataset-global value that only matches the
|
||||
# file timeline when the whole dataset is a single video (e.g. small test fixtures), and otherwise
|
||||
# decodes out-of-range frames on multi-file v3 datasets.
|
||||
current_ts = float(item["timestamp"])
|
||||
|
||||
# Per-camera episode-local bounds [0, duration]. Query timestamps are clamped into this range so
|
||||
# out-of-episode deltas pad rather than decode against a neighbouring episode in the same file.
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
0.0,
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"]
|
||||
- self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
@@ -552,11 +735,19 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
item = {}
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
|
||||
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
if self.data_files_root is not None:
|
||||
root = self.data_files_root
|
||||
elif self.streaming and not self.streaming_from_local:
|
||||
root = self.meta.url_root
|
||||
else:
|
||||
root = self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
query_ts,
|
||||
shifted_query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
|
||||
@@ -242,7 +242,12 @@ class VideoDecoderCache:
|
||||
|
||||
_SENTINEL: ClassVar[object] = object()
|
||||
|
||||
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int | None | object = _SENTINEL,
|
||||
counters: "torch.Tensor | None" = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
if max_size is VideoDecoderCache._SENTINEL:
|
||||
max_size = _default_max_cache_size()
|
||||
if max_size is not None and max_size <= 0:
|
||||
@@ -250,6 +255,18 @@ class VideoDecoderCache:
|
||||
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
# Decode device for the underlying torchcodec VideoDecoder. "cuda" offloads H.264/H.265 decode to
|
||||
# the GPU's dedicated NVDEC engine (independent of the SMs used for training); requires a
|
||||
# CUDA-enabled torchcodec/FFmpeg build. See https://developer.nvidia.com/video-codec-sdk.
|
||||
self.device = device
|
||||
# Observability counters (cheap, updated under the lock) for benchmarking decoder reuse.
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.evictions = 0
|
||||
# Optional shared [hits, misses, evictions] tensor so DataLoader workers aggregate into one place
|
||||
# (the per-worker `self.*` ints are invisible to the main process). Lock-free across processes, so
|
||||
# treat the aggregate as approximate; the hit-rate ratio is preserved.
|
||||
self._counters = counters
|
||||
|
||||
def __contains__(self, video_path: object) -> bool:
|
||||
with self._lock:
|
||||
@@ -271,11 +288,17 @@ class VideoDecoderCache:
|
||||
entry = self._cache.get(video_path)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(video_path)
|
||||
self.hits += 1
|
||||
if self._counters is not None:
|
||||
self._counters[0] += 1
|
||||
return entry[0]
|
||||
|
||||
self.misses += 1
|
||||
if self._counters is not None:
|
||||
self._counters[1] += 1
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device)
|
||||
except Exception:
|
||||
file_handle.close()
|
||||
raise
|
||||
@@ -287,6 +310,9 @@ class VideoDecoderCache:
|
||||
if self.max_size is not None:
|
||||
while len(self._cache) > self.max_size:
|
||||
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||
self.evictions += 1
|
||||
if self._counters is not None:
|
||||
self._counters[2] += 1
|
||||
with contextlib.suppress(Exception):
|
||||
evicted_handle.close()
|
||||
|
||||
@@ -305,6 +331,18 @@ class VideoDecoderCache:
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
def stats(self) -> dict[str, int | float]:
|
||||
"""Return reuse counters (hits/misses/evictions, hit rate, current size) for benchmarking."""
|
||||
with self._lock:
|
||||
total = self.hits + self.misses
|
||||
return {
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"evictions": self.evictions,
|
||||
"hit_rate": self.hits / total if total else 0.0,
|
||||
"size": len(self._cache),
|
||||
}
|
||||
|
||||
|
||||
class FrameTimestampError(ValueError):
|
||||
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
||||
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end distributed streaming smoke test under a real `accelerate launch`.
|
||||
|
||||
Mirrors tests/training/test_multi_gpu.py but runs on CPU and only checks the dataloading contract: with
|
||||
two processes, `split_dataset_by_node` (auto-resolved from the Accelerate state) must give each rank a
|
||||
disjoint set of frames that together cover the dataset. Skips if the environment can't actually spawn
|
||||
>= 2 processes (e.g. local macOS multi-CPU), so it never silently passes as a single process.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("accelerate", reason="accelerate is required (install lerobot[training])")
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
WORKER = """
|
||||
import json, sys
|
||||
from accelerate import PartialState
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
root, repo_id, out_dir = sys.argv[1], sys.argv[2], sys.argv[3]
|
||||
state = PartialState()
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=root, shuffle=False, buffer_size=8, max_num_shards=8
|
||||
)
|
||||
indices = [int(frame["index"]) for frame in ds]
|
||||
payload = {"rank": state.process_index, "world": state.num_processes, "indices": indices}
|
||||
with open(f"{out_dir}/rank_{state.process_index}.json", "w") as f:
|
||||
json.dump(payload, f)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(shutil.which("accelerate") is None, reason="accelerate CLI not available")
|
||||
def test_accelerate_launch_ranks_are_disjoint(tmp_path, lerobot_dataset_factory):
|
||||
total_frames = 160
|
||||
repo_id = f"{DUMMY_REPO_ID}-acc"
|
||||
root = tmp_path / "ds"
|
||||
lerobot_dataset_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=8,
|
||||
total_frames=total_frames,
|
||||
use_videos=False,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
)
|
||||
|
||||
worker = tmp_path / "worker.py"
|
||||
worker.write_text(WORKER)
|
||||
out_dir = tmp_path / "out"
|
||||
out_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num_processes=2",
|
||||
"--num_machines=1",
|
||||
"--mixed_precision=no",
|
||||
"--dynamo_backend=no",
|
||||
"--cpu",
|
||||
str(worker),
|
||||
str(root),
|
||||
repo_id,
|
||||
str(out_dir),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
|
||||
assert result.returncode == 0, (
|
||||
f"accelerate launch failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
payloads = [json.loads(p.read_text()) for p in sorted(out_dir.glob("rank_*.json"))]
|
||||
if len(payloads) < 2 or any(p["world"] < 2 for p in payloads):
|
||||
pytest.skip("environment did not spawn >= 2 distributed processes (e.g. local macOS multi-CPU)")
|
||||
|
||||
rank_sets = [set(p["indices"]) for p in payloads]
|
||||
assert rank_sets[0].isdisjoint(rank_sets[1]), "ranks streamed overlapping frames under accelerate launch"
|
||||
assert set().union(*rank_sets) == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main([__file__, "-v"]))
|
||||
@@ -0,0 +1,251 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the HF-native large-scale streaming additions: distributed (per-rank) sharding,
|
||||
DataLoader worker splitting, SARM-sized delta windows, resumability, and schema parity."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def _make_local_dataset(factory, root, repo_id, *, total_episodes, total_frames, use_videos=False, **kw):
|
||||
factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=0.001,
|
||||
chunks_size=1,
|
||||
**kw,
|
||||
)
|
||||
|
||||
|
||||
def _stream_indices(ds: StreamingLeRobotDataset) -> list[int]:
|
||||
return [int(frame["index"]) for frame in ds]
|
||||
|
||||
|
||||
def test_resolve_distributed_prefers_explicit_then_env(monkeypatch):
|
||||
assert StreamingLeRobotDataset._resolve_distributed(2, 8) == (2, 8)
|
||||
|
||||
monkeypatch.delenv("RANK", raising=False)
|
||||
monkeypatch.delenv("WORLD_SIZE", raising=False)
|
||||
# No accelerate state, no env -> single process.
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (0, 1)
|
||||
|
||||
monkeypatch.setenv("RANK", "3")
|
||||
monkeypatch.setenv("WORLD_SIZE", "4")
|
||||
assert StreamingLeRobotDataset._resolve_distributed(None, None) == (3, 4)
|
||||
|
||||
|
||||
def test_split_by_node_disjoint_across_ranks(tmp_path, lerobot_dataset_factory):
|
||||
"""Each rank must stream a disjoint set of frames, and the ranks together must cover every frame."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-ranks"
|
||||
total_frames, total_episodes = 200, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
world_size = 2
|
||||
per_rank = []
|
||||
for rank in range(world_size):
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=8,
|
||||
max_num_shards=8,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
per_rank.append(set(_stream_indices(ds)))
|
||||
|
||||
assert per_rank[0].isdisjoint(per_rank[1]), (
|
||||
"ranks streamed overlapping frames (duplicate data across GPUs)"
|
||||
)
|
||||
assert per_rank[0] | per_rank[1] == set(range(total_frames)), "ranks did not jointly cover all frames"
|
||||
|
||||
|
||||
def test_dataloader_workers_no_duplicates_within_rank(tmp_path, lerobot_dataset_factory):
|
||||
"""DataLoader workers within a rank must split shards so no frame is yielded twice."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-workers"
|
||||
total_frames, total_episodes = 120, 8
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory,
|
||||
tmp_path / "ds",
|
||||
repo_id,
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
)
|
||||
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=4
|
||||
)
|
||||
loader = DataLoader(ds, batch_size=None, num_workers=2)
|
||||
indices = [int(batch["index"]) for batch in loader]
|
||||
|
||||
assert len(indices) == len(set(indices)), "DataLoader workers yielded duplicate frames within a rank"
|
||||
|
||||
|
||||
def test_sarm_window_covers_long_horizon_without_padding(tmp_path, lerobot_dataset_factory):
|
||||
"""A delta window longer than the old 100-frame ceiling must fetch real frames, not pad them.
|
||||
|
||||
SARM uses a window of 8 steps spaced 1s (~160 frames @ fps20). Here fps=30, so +5s = 150 frames > 100.
|
||||
"""
|
||||
repo_id = f"{DUMMY_REPO_ID}-sarm"
|
||||
# A single long episode so a +150-frame lookahead is unambiguously inside the episode (the fixture
|
||||
# gives episodes variable lengths, so multi-episode boundaries can't be assumed).
|
||||
episode_frames = 300
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=1, total_frames=episode_frames
|
||||
)
|
||||
|
||||
horizon_s = 5.0 # 150 frames @ fps30, well beyond LOOKAHEAD_BACKTRACKTABLE=100
|
||||
delta_timestamps = {ACTION: [0.0, horizon_s]}
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=tmp_path / "ds",
|
||||
shuffle=False,
|
||||
buffer_size=1,
|
||||
max_num_shards=1,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
horizon_frames = int(round(horizon_s * ds.fps))
|
||||
assert horizon_frames > 100, "test must exceed the old LOOKAHEAD_BACKTRACKTABLE ceiling"
|
||||
checked = 0
|
||||
for frame in ds:
|
||||
idx = int(frame["index"])
|
||||
# The +horizon target is inside the single episode -> it must be a real frame, not padding.
|
||||
if idx + horizon_frames < episode_frames:
|
||||
assert not bool(frame[f"{ACTION}_is_pad"][-1]), (
|
||||
f"frame {idx}: +{horizon_frames} target was padded; long delta window did not reach it"
|
||||
)
|
||||
checked += 1
|
||||
assert checked > 0, "test did not exercise any in-episode long-horizon frame"
|
||||
|
||||
|
||||
def test_state_dict_resume_continues_without_restart(tmp_path, lerobot_dataset_factory):
|
||||
"""state_dict()/load_state_dict() must resume the stream near where it stopped, not from the start."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-resume"
|
||||
total_frames = 100
|
||||
_make_local_dataset(
|
||||
lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=5, total_frames=total_frames
|
||||
)
|
||||
|
||||
def fresh_ds():
|
||||
return StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
ds = fresh_ds()
|
||||
it = iter(ds)
|
||||
stop_after = 40
|
||||
seen_before = [int(next(it)["index"]) for _ in range(stop_after)]
|
||||
state = ds.state_dict()
|
||||
assert set(state) == {"shards", "exhausted", "rng"}
|
||||
|
||||
resumed_ds = fresh_ds()
|
||||
resumed_ds.load_state_dict(state)
|
||||
resumed = _stream_indices(resumed_ds)
|
||||
|
||||
# Resume continues rather than replaying: the full first pass is not re-yielded.
|
||||
assert len(resumed) < total_frames
|
||||
overlap = set(seen_before) & set(resumed)
|
||||
assert len(overlap) <= 2, f"resume re-yielded already-seen frames: {sorted(overlap)}"
|
||||
# Together the two passes cover essentially the whole dataset (a few frames may be dropped by the
|
||||
# ahead-read at the resume boundary -- documented non-bit-exact behaviour).
|
||||
assert len(set(seen_before) | set(resumed)) >= total_frames - 2
|
||||
|
||||
|
||||
def test_schema_parity_with_map_style(tmp_path, lerobot_dataset_factory):
|
||||
"""Streamed samples must have the same keys / shapes / dtypes as map-style LeRobotDataset."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-parity"
|
||||
map_ds = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=4, total_frames=80, use_videos=True
|
||||
)
|
||||
stream_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=4, max_num_shards=2
|
||||
)
|
||||
|
||||
map_frame = map_ds[0]
|
||||
stream_frame = next(iter(stream_ds))
|
||||
|
||||
assert set(stream_frame) == set(map_frame), set(stream_frame) ^ set(map_frame)
|
||||
for key, value in stream_frame.items():
|
||||
ref = map_frame[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert isinstance(ref, torch.Tensor) and value.shape == ref.shape and value.dtype == ref.dtype, (
|
||||
f"{key}: stream {tuple(value.shape)}/{value.dtype} vs map {tuple(ref.shape)}/{ref.dtype}"
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
assert isinstance(ref, str), f"{key}: {type(value)} vs {type(ref)}"
|
||||
else:
|
||||
# Scalar numerics: streaming yields python floats where map-style yields 0-dim tensors
|
||||
# (a long-standing, accepted difference). Compare by value rather than exact type.
|
||||
assert float(value) == float(ref), f"{key}: {value} vs {ref}"
|
||||
|
||||
|
||||
def test_video_path_resolution_local(tmp_path, lerobot_dataset_factory, monkeypatch):
|
||||
"""For a local (prewarmed) root, video decode must be issued against the local path, not hf://."""
|
||||
import lerobot.datasets.streaming_dataset as sd
|
||||
|
||||
repo_id = f"{DUMMY_REPO_ID}-vpath"
|
||||
lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", repo_id=repo_id, total_episodes=2, total_frames=40, use_videos=True
|
||||
)
|
||||
ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
|
||||
seen_paths = []
|
||||
|
||||
def fake_decode(video_path, query_ts, *args, **kwargs):
|
||||
seen_paths.append(str(video_path))
|
||||
return torch.zeros(len(query_ts), 3, 64, 96)
|
||||
|
||||
monkeypatch.setattr(sd, "decode_video_frames_torchcodec", fake_decode)
|
||||
next(iter(ds))
|
||||
|
||||
assert seen_paths, "no video decode was issued"
|
||||
assert all(str(ds.root) in p and not p.startswith("hf://") for p in seen_paths), seen_paths
|
||||
|
||||
|
||||
def test_shuffle_decorrelates_output_order(tmp_path, lerobot_dataset_factory):
|
||||
"""With shuffle on, streamed frame order must differ from the underlying sequential order."""
|
||||
repo_id = f"{DUMMY_REPO_ID}-shuf"
|
||||
_make_local_dataset(lerobot_dataset_factory, tmp_path / "ds", repo_id, total_episodes=8, total_frames=200)
|
||||
ordered = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=False, buffer_size=1, max_num_shards=1
|
||||
)
|
||||
)
|
||||
shuffled = _stream_indices(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=tmp_path / "ds", shuffle=True, buffer_size=64, max_num_shards=4, seed=0
|
||||
)
|
||||
)
|
||||
assert sorted(shuffled) == sorted(ordered), "shuffling changed the set of frames"
|
||||
assert shuffled != ordered, "shuffle did not decorrelate output order"
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user