diff --git a/scripts/bench_episode_byte_cache.py b/scripts/bench_episode_byte_cache.py index a43d45bf5..7007c6e4d 100644 --- a/scripts/bench_episode_byte_cache.py +++ b/scripts/bench_episode_byte_cache.py @@ -18,7 +18,8 @@ import tempfile import threading import time from collections.abc import Sequence -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass from pathlib import Path import fsspec @@ -33,6 +34,7 @@ from lerobot.datasets.episode_video_streaming import ( EpisodeVideoManifest, NativeHTTPRangeFetcher, assert_hf_hub_range_cache_branch, + make_range_fetcher, ) from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec @@ -50,7 +52,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT) parser.add_argument( "--strategy", - choices=("both", "full", "indexed", "remote-decoder", "native-http"), + choices=("both", "full", "indexed", "remote-decoder", "native-http", "gop-window"), default="both", help=argparse.SUPPRESS, ) @@ -103,6 +105,23 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.", ) + parser.add_argument( + "--include-gop-window", + action="store_true", + help="Also benchmark random frame GOP/window byte-range fetches from the MP4 sidecar.", + ) + parser.add_argument( + "--gop-window-post-frames", + type=int, + default=0, + help="Extra compressed samples after each target frame to include in GOP/window ranges.", + ) + parser.add_argument( + "--gop-window-merge-gap-kb", + type=int, + default=0, + help="Merge GOP/window ranges from the same MP4 when the byte gap is at most this many KiB.", + ) parser.add_argument("--decode-workers", type=int, default=1) parser.add_argument("--prefetch-ahead", type=int, default=8) parser.add_argument("--frames-per-episode", type=int, default=16) @@ -158,6 +177,120 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int: return total +@dataclass(frozen=True) +class GopWindowRange: + file_path: str + offset: int + length: int + target_frames: int + covered_samples: int + + +def _sample_bounds_for_episode(manifest: EpisodeVideoManifest, episode_index: int, camera_key: str): + span = manifest.lookup(episode_index, camera_key) + mp4 = manifest.file_lookup(span.file_id).mp4 + sample_count = len(mp4.sample_pts) + if sample_count == 0: + raise ValueError(f"{mp4.file_path} contains no indexed samples") + lo = int(np.searchsorted(mp4.sample_pts, span.first_pts, side="left")) + hi = int(np.searchsorted(mp4.sample_pts, span.last_pts, side="right")) - 1 + lo = min(max(lo, 0), sample_count - 1) + hi = min(max(hi, lo), sample_count - 1) + return span, mp4, lo, hi + + +def _byte_range_for_samples(mp4, sample_lo: int, sample_hi: int, *, file_size: int) -> tuple[int, int]: + offsets = mp4.sample_offsets[sample_lo : sample_hi + 1] + sizes = mp4.sample_sizes[sample_lo : sample_hi + 1] + byte_lo = int(offsets.min()) + byte_hi = int((offsets + sizes).max()) + byte_hi = min(byte_hi, file_size) + return byte_lo, byte_hi - byte_lo + + +def _gop_window_for_target_sample( + manifest: EpisodeVideoManifest, + episode_index: int, + camera_key: str, + target_sample: int, + *, + post_frames: int, +) -> GopWindowRange: + span = manifest.lookup(episode_index, camera_key) + file_record = manifest.file_lookup(span.file_id) + mp4 = file_record.mp4 + sync = mp4.sync_samples[mp4.sync_samples <= target_sample] + sample_lo = int(sync[-1]) if len(sync) else 0 + sample_hi = min(max(target_sample + post_frames, sample_lo), span.sample_hi, len(mp4.sample_pts) - 1) + offset, length = _byte_range_for_samples(mp4, sample_lo, sample_hi, file_size=file_record.file_size) + return GopWindowRange( + file_path=file_record.file_path, + offset=offset, + length=length, + target_frames=1, + covered_samples=sample_hi - sample_lo + 1, + ) + + +def _gop_window_ranges( + manifest: EpisodeVideoManifest, + episodes: Sequence[int], + *, + frames_per_episode: int, + seed: int, + post_frames: int, + merge_gap_bytes: int, +) -> tuple[list[GopWindowRange], int, int, int]: + rng = random.Random(seed) + raw: list[GopWindowRange] = [] + compressed_target_bytes = 0 + covered_samples = 0 + for ep in episodes: + for camera_key in manifest.video_keys: + span, mp4, target_lo, target_hi = _sample_bounds_for_episode(manifest, ep, camera_key) + for _ in range(frames_per_episode): + ts = rng.uniform(span.first_pts, max(span.last_pts, span.first_pts)) + target = int(np.searchsorted(mp4.sample_pts, ts, side="left")) + target = min(max(target, target_lo), target_hi) + compressed_target_bytes += int(mp4.sample_sizes[target]) + window = _gop_window_for_target_sample( + manifest, + ep, + camera_key, + target, + post_frames=post_frames, + ) + covered_samples += window.covered_samples + raw.append(window) + + merged = _merge_gop_window_ranges(raw, merge_gap_bytes) + return merged, len(raw), compressed_target_bytes, covered_samples + + +def _merge_gop_window_ranges(ranges: Sequence[GopWindowRange], merge_gap_bytes: int) -> list[GopWindowRange]: + if not ranges: + return [] + ordered = sorted(ranges, key=lambda item: (item.file_path, item.offset, item.length)) + merged: list[GopWindowRange] = [] + current = ordered[0] + for item in ordered[1:]: + current_end = current.offset + current.length + if item.file_path == current.file_path and item.offset <= current_end + merge_gap_bytes: + new_end = max(current_end, item.offset + item.length) + current = GopWindowRange( + file_path=current.file_path, + offset=current.offset, + length=new_end - current.offset, + target_frames=current.target_frames + item.target_frames, + covered_samples=current.covered_samples + item.covered_samples, + ) + else: + merged.append(current) + current = item + merged.append(current) + return merged + + def _decode_all( cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int ) -> float: @@ -426,6 +559,87 @@ def run_fetch_pool( return result +def run_gop_window_fetch( + manifest: EpisodeVideoManifest, + data_root: str, + episodes: Sequence[int], + workers: int, + range_backend: str, + args: argparse.Namespace, +) -> dict[str, float]: + merge_gap_bytes = int(args.gop_window_merge_gap_kb * 1024) + windows, raw_windows, compressed_target_bytes, covered_samples = _gop_window_ranges( + manifest, + episodes, + frames_per_episode=args.frames_per_episode, + seed=args.seed + 2, + post_frames=args.gop_window_post_frames, + merge_gap_bytes=merge_gap_bytes, + ) + if not windows: + raise ValueError("No GOP/window ranges were planned") + + fetcher = make_range_fetcher( + data_root, + range_backend=range_backend, + workers=workers, + native_http_connections=args.native_http_connections, + native_http_timeout=args.native_http_timeout, + native_http_retries=args.native_http_retries, + ) + + def fetch_window(window: GopWindowRange) -> int: + payload = fetcher.read_range(window.file_path, window.offset, window.length) + if len(payload) != window.length: + raise OSError(f"Short read for {window.file_path}: expected {window.length}, got {len(payload)}") + return len(payload) + + byte_count = sum(window.length for window in windows) + start = time.perf_counter() + done = 0 + done_ranges = 0 + last_progress = start + try: + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = [pool.submit(fetch_window, window) for window in windows] + for future in as_completed(futures): + done += future.result() + done_ranges += 1 + now = time.perf_counter() + if args.progress_interval > 0 and now - last_progress >= args.progress_interval: + elapsed = max(now - start, 1e-9) + _log( + "gop_window_progress: " + f"ranges_done={done_ranges}/{len(windows)} " + f"fetched={done / 1024**3:.2f} GiB " + f"fetch={done / elapsed / 1024**2:.1f} MiB/s " + f"elapsed={_format_duration(elapsed)}" + ) + last_progress = now + finally: + timings = fetcher.timing_summary() if hasattr(fetcher, "timing_summary") else {} + fetcher.close() + + elapsed = time.perf_counter() - start + result = { + "fetch_s": elapsed, + "fetch_mbps": byte_count / elapsed / 1024**2, + "frame_windows_s": raw_windows / elapsed, + "ranges_s": len(windows) / elapsed, + "bytes": float(byte_count), + "raw_windows": float(raw_windows), + "merged_windows": float(len(windows)), + "compressed_target_bytes": float(compressed_target_bytes), + "covered_samples": float(covered_samples), + "avg_mb_range": byte_count / len(windows) / 1024**2, + "avg_kib_frame_window": byte_count / raw_windows / 1024, + "avg_compressed_kib_target": compressed_target_bytes / raw_windows / 1024, + "avg_covered_samples": covered_samples / raw_windows, + } + result.update({key: value for key, value in timings.items() if key.startswith("range_")}) + return result + + def run_parallel( manifest: EpisodeVideoManifest, data_root: str, @@ -856,6 +1070,80 @@ def run_indexed_strategy( ) +def run_gop_window_strategy( + meta: LeRobotDatasetMetadata, + data_root: str, + args: argparse.Namespace, + *, + range_backend: str = "fsspec", + sidecar_path: str | None = None, +) -> None: + _log("starting_strategy: gop-window") + memory_start = _memory_snapshot() + manifest_start = time.perf_counter() + dataset_episode_count = int(meta.total_episodes) + manifest_episode_count = args.manifest_episodes or dataset_episode_count + manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes) + manifest = EpisodeVideoManifest.build( + meta, + data_root, + episode_indices=range(manifest_episode_count), + range_backend=range_backend, + workers=args.workers, + max_probe_bytes=args.max_probe_mb * 1024 * 1024, + sidecar_path=sidecar_path, + ) + manifest_s = time.perf_counter() - manifest_start + _log(f"gop-window: manifest_build_s={manifest_s:.2f}") + + benchmark_episode_count = min(dataset_episode_count, args.num_episodes) + episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed) + full_episode_bytes = _bytes_for(manifest, episodes) + result = run_gop_window_fetch(manifest, data_root, episodes, args.workers, range_backend, args) + estimated_benchmark_s = benchmark_episode_count * args.frames_per_episode / result["frame_windows_s"] + estimated_dataset_s = dataset_episode_count * args.frames_per_episode / result["frame_windows_s"] + + print(f"manifest_build_s: {manifest_s:.2f}") + print("strategy: gop-window") + print(f"range_backend: {range_backend}") + print(f"mp4_sidecar: {sidecar_path or 'none'}") + print(f"data_root: {data_root}") + print(f"dataset_episodes: {dataset_episode_count}") + print(f"benchmark_episodes: {benchmark_episode_count}") + print(f"pool_episodes: {len(episodes)}") + print(f"frames_per_episode: {args.frames_per_episode}") + print(f"gop_window_post_frames: {args.gop_window_post_frames}") + print(f"gop_window_merge_gap_kb: {args.gop_window_merge_gap_kb}") + print(f"sampled_episodes: {episodes}") + print(f"cameras: {manifest.video_keys}") + print() + print( + "| Track | fetch MB/s | frame windows/s | ranges/s | wall s | " + "est benchmark | est full dataset | notes |" + ) + print("|---|---:|---:|---:|---:|---:|---:|---|") + print( + f"| GOP/WINDOW FETCH | {result['fetch_mbps']:.1f} | {result['frame_windows_s']:.1f} | " + f"{result['ranges_s']:.1f} | {result['fetch_s']:.2f} | " + f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | " + f"{args.workers} workers, fetch-and-drop, no decoder open/frame decode |" + ) + print() + print("| GOP Window Shape | value |") + print("|---|---:|") + print(f"| target frame windows | {result['raw_windows']:.0f} |") + print(f"| fetched byte ranges | {result['merged_windows']:.0f} |") + print(f"| fetched GiB | {result['bytes'] / 1024**3:.2f} |") + print(f"| full episode-pool GiB | {full_episode_bytes / 1024**3:.2f} |") + print(f"| fetched/full episode bytes | {result['bytes'] / full_episode_bytes:.3f} |") + print(f"| avg MiB/range | {result['avg_mb_range']:.3f} |") + print(f"| avg KiB/frame window | {result['avg_kib_frame_window']:.1f} |") + print(f"| avg compressed KiB/target frame | {result['avg_compressed_kib_target']:.1f} |") + print(f"| avg compressed samples/window | {result['avg_covered_samples']:.1f} |") + _print_range_timing_summary(result) + _print_memory_summary(memory_start, _memory_snapshot()) + + def run_remote_strategy( meta: LeRobotDatasetMetadata, data_root: str, @@ -925,6 +1213,15 @@ def main() -> None: label=f"indexed-sidecar-{args.range_backend}", sidecar_path=str(sidecar_path), ) + if args.include_gop_window: + print() + run_gop_window_strategy( + meta, + data_root, + args, + range_backend=args.range_backend, + sidecar_path=str(sidecar_path), + ) return if sidecar_path is not None and args.strategy == "indexed": run_indexed_strategy( @@ -936,6 +1233,15 @@ def main() -> None: label=f"indexed-sidecar-{args.range_backend}", sidecar_path=str(sidecar_path), ) + if args.include_gop_window: + print() + run_gop_window_strategy( + meta, + data_root, + args, + range_backend=args.range_backend, + sidecar_path=str(sidecar_path), + ) return if sidecar_path is not None and args.strategy == "native-http": run_indexed_strategy( @@ -948,7 +1254,16 @@ def main() -> None: sidecar_path=str(sidecar_path), ) return - if args.strategy == "both": + if sidecar_path is not None and args.strategy == "gop-window": + run_gop_window_strategy( + meta, + data_root, + args, + range_backend=args.range_backend, + sidecar_path=str(sidecar_path), + ) + return + if args.strategy in ("both", "gop-window"): expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}") print(f"mp4_sidecar_missing_local: {expected_sidecar}") @@ -958,6 +1273,9 @@ def main() -> None: "uv run --no-sync python scripts/build_mp4_sidecar.py " f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}" ) + if args.strategy == "gop-window": + print("gop_window_requires_mp4_sidecar: existing per-sample MP4 index sidecar is required") + return print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online") print()