Add GOP window range benchmark

This commit is contained in:
Pepijn
2026-06-22 15:10:21 +02:00
parent 9201be92cb
commit 6d6c82eb8c
+321 -3
View File
@@ -18,7 +18,8 @@ import tempfile
import threading
import time
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
import fsspec
@@ -33,6 +34,7 @@ from lerobot.datasets.episode_video_streaming import (
EpisodeVideoManifest,
NativeHTTPRangeFetcher,
assert_hf_hub_range_cache_branch,
make_range_fetcher,
)
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
@@ -50,7 +52,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
parser.add_argument(
"--strategy",
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
choices=("both", "full", "indexed", "remote-decoder", "native-http", "gop-window"),
default="both",
help=argparse.SUPPRESS,
)
@@ -103,6 +105,23 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
)
parser.add_argument(
"--include-gop-window",
action="store_true",
help="Also benchmark random frame GOP/window byte-range fetches from the MP4 sidecar.",
)
parser.add_argument(
"--gop-window-post-frames",
type=int,
default=0,
help="Extra compressed samples after each target frame to include in GOP/window ranges.",
)
parser.add_argument(
"--gop-window-merge-gap-kb",
type=int,
default=0,
help="Merge GOP/window ranges from the same MP4 when the byte gap is at most this many KiB.",
)
parser.add_argument("--decode-workers", type=int, default=1)
parser.add_argument("--prefetch-ahead", type=int, default=8)
parser.add_argument("--frames-per-episode", type=int, default=16)
@@ -158,6 +177,120 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
return total
@dataclass(frozen=True)
class GopWindowRange:
file_path: str
offset: int
length: int
target_frames: int
covered_samples: int
def _sample_bounds_for_episode(manifest: EpisodeVideoManifest, episode_index: int, camera_key: str):
span = manifest.lookup(episode_index, camera_key)
mp4 = manifest.file_lookup(span.file_id).mp4
sample_count = len(mp4.sample_pts)
if sample_count == 0:
raise ValueError(f"{mp4.file_path} contains no indexed samples")
lo = int(np.searchsorted(mp4.sample_pts, span.first_pts, side="left"))
hi = int(np.searchsorted(mp4.sample_pts, span.last_pts, side="right")) - 1
lo = min(max(lo, 0), sample_count - 1)
hi = min(max(hi, lo), sample_count - 1)
return span, mp4, lo, hi
def _byte_range_for_samples(mp4, sample_lo: int, sample_hi: int, *, file_size: int) -> tuple[int, int]:
offsets = mp4.sample_offsets[sample_lo : sample_hi + 1]
sizes = mp4.sample_sizes[sample_lo : sample_hi + 1]
byte_lo = int(offsets.min())
byte_hi = int((offsets + sizes).max())
byte_hi = min(byte_hi, file_size)
return byte_lo, byte_hi - byte_lo
def _gop_window_for_target_sample(
manifest: EpisodeVideoManifest,
episode_index: int,
camera_key: str,
target_sample: int,
*,
post_frames: int,
) -> GopWindowRange:
span = manifest.lookup(episode_index, camera_key)
file_record = manifest.file_lookup(span.file_id)
mp4 = file_record.mp4
sync = mp4.sync_samples[mp4.sync_samples <= target_sample]
sample_lo = int(sync[-1]) if len(sync) else 0
sample_hi = min(max(target_sample + post_frames, sample_lo), span.sample_hi, len(mp4.sample_pts) - 1)
offset, length = _byte_range_for_samples(mp4, sample_lo, sample_hi, file_size=file_record.file_size)
return GopWindowRange(
file_path=file_record.file_path,
offset=offset,
length=length,
target_frames=1,
covered_samples=sample_hi - sample_lo + 1,
)
def _gop_window_ranges(
manifest: EpisodeVideoManifest,
episodes: Sequence[int],
*,
frames_per_episode: int,
seed: int,
post_frames: int,
merge_gap_bytes: int,
) -> tuple[list[GopWindowRange], int, int, int]:
rng = random.Random(seed)
raw: list[GopWindowRange] = []
compressed_target_bytes = 0
covered_samples = 0
for ep in episodes:
for camera_key in manifest.video_keys:
span, mp4, target_lo, target_hi = _sample_bounds_for_episode(manifest, ep, camera_key)
for _ in range(frames_per_episode):
ts = rng.uniform(span.first_pts, max(span.last_pts, span.first_pts))
target = int(np.searchsorted(mp4.sample_pts, ts, side="left"))
target = min(max(target, target_lo), target_hi)
compressed_target_bytes += int(mp4.sample_sizes[target])
window = _gop_window_for_target_sample(
manifest,
ep,
camera_key,
target,
post_frames=post_frames,
)
covered_samples += window.covered_samples
raw.append(window)
merged = _merge_gop_window_ranges(raw, merge_gap_bytes)
return merged, len(raw), compressed_target_bytes, covered_samples
def _merge_gop_window_ranges(ranges: Sequence[GopWindowRange], merge_gap_bytes: int) -> list[GopWindowRange]:
if not ranges:
return []
ordered = sorted(ranges, key=lambda item: (item.file_path, item.offset, item.length))
merged: list[GopWindowRange] = []
current = ordered[0]
for item in ordered[1:]:
current_end = current.offset + current.length
if item.file_path == current.file_path and item.offset <= current_end + merge_gap_bytes:
new_end = max(current_end, item.offset + item.length)
current = GopWindowRange(
file_path=current.file_path,
offset=current.offset,
length=new_end - current.offset,
target_frames=current.target_frames + item.target_frames,
covered_samples=current.covered_samples + item.covered_samples,
)
else:
merged.append(current)
current = item
merged.append(current)
return merged
def _decode_all(
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
) -> float:
@@ -426,6 +559,87 @@ def run_fetch_pool(
return result
def run_gop_window_fetch(
manifest: EpisodeVideoManifest,
data_root: str,
episodes: Sequence[int],
workers: int,
range_backend: str,
args: argparse.Namespace,
) -> dict[str, float]:
merge_gap_bytes = int(args.gop_window_merge_gap_kb * 1024)
windows, raw_windows, compressed_target_bytes, covered_samples = _gop_window_ranges(
manifest,
episodes,
frames_per_episode=args.frames_per_episode,
seed=args.seed + 2,
post_frames=args.gop_window_post_frames,
merge_gap_bytes=merge_gap_bytes,
)
if not windows:
raise ValueError("No GOP/window ranges were planned")
fetcher = make_range_fetcher(
data_root,
range_backend=range_backend,
workers=workers,
native_http_connections=args.native_http_connections,
native_http_timeout=args.native_http_timeout,
native_http_retries=args.native_http_retries,
)
def fetch_window(window: GopWindowRange) -> int:
payload = fetcher.read_range(window.file_path, window.offset, window.length)
if len(payload) != window.length:
raise OSError(f"Short read for {window.file_path}: expected {window.length}, got {len(payload)}")
return len(payload)
byte_count = sum(window.length for window in windows)
start = time.perf_counter()
done = 0
done_ranges = 0
last_progress = start
try:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = [pool.submit(fetch_window, window) for window in windows]
for future in as_completed(futures):
done += future.result()
done_ranges += 1
now = time.perf_counter()
if args.progress_interval > 0 and now - last_progress >= args.progress_interval:
elapsed = max(now - start, 1e-9)
_log(
"gop_window_progress: "
f"ranges_done={done_ranges}/{len(windows)} "
f"fetched={done / 1024**3:.2f} GiB "
f"fetch={done / elapsed / 1024**2:.1f} MiB/s "
f"elapsed={_format_duration(elapsed)}"
)
last_progress = now
finally:
timings = fetcher.timing_summary() if hasattr(fetcher, "timing_summary") else {}
fetcher.close()
elapsed = time.perf_counter() - start
result = {
"fetch_s": elapsed,
"fetch_mbps": byte_count / elapsed / 1024**2,
"frame_windows_s": raw_windows / elapsed,
"ranges_s": len(windows) / elapsed,
"bytes": float(byte_count),
"raw_windows": float(raw_windows),
"merged_windows": float(len(windows)),
"compressed_target_bytes": float(compressed_target_bytes),
"covered_samples": float(covered_samples),
"avg_mb_range": byte_count / len(windows) / 1024**2,
"avg_kib_frame_window": byte_count / raw_windows / 1024,
"avg_compressed_kib_target": compressed_target_bytes / raw_windows / 1024,
"avg_covered_samples": covered_samples / raw_windows,
}
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
return result
def run_parallel(
manifest: EpisodeVideoManifest,
data_root: str,
@@ -856,6 +1070,80 @@ def run_indexed_strategy(
)
def run_gop_window_strategy(
meta: LeRobotDatasetMetadata,
data_root: str,
args: argparse.Namespace,
*,
range_backend: str = "fsspec",
sidecar_path: str | None = None,
) -> None:
_log("starting_strategy: gop-window")
memory_start = _memory_snapshot()
manifest_start = time.perf_counter()
dataset_episode_count = int(meta.total_episodes)
manifest_episode_count = args.manifest_episodes or dataset_episode_count
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
manifest = EpisodeVideoManifest.build(
meta,
data_root,
episode_indices=range(manifest_episode_count),
range_backend=range_backend,
workers=args.workers,
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
sidecar_path=sidecar_path,
)
manifest_s = time.perf_counter() - manifest_start
_log(f"gop-window: manifest_build_s={manifest_s:.2f}")
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
full_episode_bytes = _bytes_for(manifest, episodes)
result = run_gop_window_fetch(manifest, data_root, episodes, args.workers, range_backend, args)
estimated_benchmark_s = benchmark_episode_count * args.frames_per_episode / result["frame_windows_s"]
estimated_dataset_s = dataset_episode_count * args.frames_per_episode / result["frame_windows_s"]
print(f"manifest_build_s: {manifest_s:.2f}")
print("strategy: gop-window")
print(f"range_backend: {range_backend}")
print(f"mp4_sidecar: {sidecar_path or 'none'}")
print(f"data_root: {data_root}")
print(f"dataset_episodes: {dataset_episode_count}")
print(f"benchmark_episodes: {benchmark_episode_count}")
print(f"pool_episodes: {len(episodes)}")
print(f"frames_per_episode: {args.frames_per_episode}")
print(f"gop_window_post_frames: {args.gop_window_post_frames}")
print(f"gop_window_merge_gap_kb: {args.gop_window_merge_gap_kb}")
print(f"sampled_episodes: {episodes}")
print(f"cameras: {manifest.video_keys}")
print()
print(
"| Track | fetch MB/s | frame windows/s | ranges/s | wall s | "
"est benchmark | est full dataset | notes |"
)
print("|---|---:|---:|---:|---:|---:|---:|---|")
print(
f"| GOP/WINDOW FETCH | {result['fetch_mbps']:.1f} | {result['frame_windows_s']:.1f} | "
f"{result['ranges_s']:.1f} | {result['fetch_s']:.2f} | "
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
f"{args.workers} workers, fetch-and-drop, no decoder open/frame decode |"
)
print()
print("| GOP Window Shape | value |")
print("|---|---:|")
print(f"| target frame windows | {result['raw_windows']:.0f} |")
print(f"| fetched byte ranges | {result['merged_windows']:.0f} |")
print(f"| fetched GiB | {result['bytes'] / 1024**3:.2f} |")
print(f"| full episode-pool GiB | {full_episode_bytes / 1024**3:.2f} |")
print(f"| fetched/full episode bytes | {result['bytes'] / full_episode_bytes:.3f} |")
print(f"| avg MiB/range | {result['avg_mb_range']:.3f} |")
print(f"| avg KiB/frame window | {result['avg_kib_frame_window']:.1f} |")
print(f"| avg compressed KiB/target frame | {result['avg_compressed_kib_target']:.1f} |")
print(f"| avg compressed samples/window | {result['avg_covered_samples']:.1f} |")
_print_range_timing_summary(result)
_print_memory_summary(memory_start, _memory_snapshot())
def run_remote_strategy(
meta: LeRobotDatasetMetadata,
data_root: str,
@@ -925,6 +1213,15 @@ def main() -> None:
label=f"indexed-sidecar-{args.range_backend}",
sidecar_path=str(sidecar_path),
)
if args.include_gop_window:
print()
run_gop_window_strategy(
meta,
data_root,
args,
range_backend=args.range_backend,
sidecar_path=str(sidecar_path),
)
return
if sidecar_path is not None and args.strategy == "indexed":
run_indexed_strategy(
@@ -936,6 +1233,15 @@ def main() -> None:
label=f"indexed-sidecar-{args.range_backend}",
sidecar_path=str(sidecar_path),
)
if args.include_gop_window:
print()
run_gop_window_strategy(
meta,
data_root,
args,
range_backend=args.range_backend,
sidecar_path=str(sidecar_path),
)
return
if sidecar_path is not None and args.strategy == "native-http":
run_indexed_strategy(
@@ -948,7 +1254,16 @@ def main() -> None:
sidecar_path=str(sidecar_path),
)
return
if args.strategy == "both":
if sidecar_path is not None and args.strategy == "gop-window":
run_gop_window_strategy(
meta,
data_root,
args,
range_backend=args.range_backend,
sidecar_path=str(sidecar_path),
)
return
if args.strategy in ("both", "gop-window"):
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
@@ -958,6 +1273,9 @@ def main() -> None:
"uv run --no-sync python scripts/build_mp4_sidecar.py "
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
)
if args.strategy == "gop-window":
print("gop_window_requires_mp4_sidecar: existing per-sample MP4 index sidecar is required")
return
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
print()