#!/usr/bin/env python # Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import argparse import random import tempfile import threading import time from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from pathlib import Path import fsspec import numpy as np import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.episode_video_streaming import ( EpisodeByteCache, EpisodeVideoManifest, NativeHTTPRangeFetcher, assert_hf_hub_range_cache_branch, ) from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset" DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c" DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket" SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.") parser.add_argument("--repo-id", default=DEFAULT_REPO) parser.add_argument("--revision", default=DEFAULT_REVISION) parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT) parser.add_argument( "--strategy", choices=("both", "indexed", "remote-decoder", "native-http"), default="both", help=argparse.SUPPRESS, ) parser.add_argument("--num-episodes", type=int, default=512) parser.add_argument( "--manifest-episodes", type=int, default=None, help="Limit manifest construction to the first N episodes for local smoke tests.", ) parser.add_argument("--pool-size", type=int, default=16) parser.add_argument("--workers", type=int, default=8) parser.add_argument( "--include-decode", action="store_true", help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.", ) parser.add_argument("--decode-workers", type=int, default=1) parser.add_argument("--prefetch-ahead", type=int, default=8) parser.add_argument("--frames-per-episode", type=int, default=16) parser.add_argument("--max-probe-mb", type=int, default=64) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--byte-budget-gb", type=float, default=80) parser.add_argument( "--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory." ) parser.add_argument("--no-hub-branch-assert", action="store_true") return parser.parse_args() def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]: rng = random.Random(seed) upper = min(total, requested) if pool_size > upper: raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}") return rng.sample(range(upper), pool_size) def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int): rng = random.Random(seed) out: dict[tuple[int, str], list[float]] = {} for ep in episodes: for camera_key in manifest.video_keys: span = manifest.lookup(ep, camera_key) lo = span.first_pts hi = max(span.last_pts, lo) out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode)) return out def _timestamps_from_meta( meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int ) -> dict[tuple[int, str], list[float]]: rng = random.Random(seed) out: dict[tuple[int, str], list[float]] = {} for ep in episodes: row = meta.episodes[ep] for camera_key in meta.video_keys: lo = float(row[f"videos/{camera_key}/from_timestamp"]) hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo) out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode)) return out def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int: total = 0 for ep in episodes: for camera_key in manifest.video_keys: total += manifest.lookup(ep, camera_key).mdat_length return total def _decode_all( cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int ) -> float: start = time.perf_counter() items = list(timestamps.items()) if decode_workers <= 1: for (ep, camera_key), ts in items: cache.get_frames(ep, camera_key, ts) else: with ThreadPoolExecutor(max_workers=decode_workers) as pool: futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items] for future in futures: future.result() return time.perf_counter() - start def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float: start = time.perf_counter() for ep in episodes: cache.submit_prefetch(ep) for ep in episodes: cache.ensure_ready(ep) return time.perf_counter() - start def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float: if elapsed_s <= 0: return float("inf") return len(episodes) * frames_per_episode / elapsed_s def _log(message: str) -> None: print(message, flush=True) def _root_join(data_root: str, relative_path: str) -> str: if data_root.startswith("hf://"): return f"{data_root.rstrip('/')}/{relative_path}" return str(Path(data_root) / relative_path) def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None: local = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz" if _valid_sidecar(local): return local if local.exists(): print(f"mp4_sidecar_invalid_local: {local}") local.unlink() full_local = SIDECAR_CACHE_DIR / "molmoact2-full.npz" if _valid_sidecar(full_local): return full_local remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz") protocol = "hf" if data_root.startswith("hf://") else "file" fs = fsspec.filesystem(protocol) if not fs.exists(remote): return None local.parent.mkdir(parents=True, exist_ok=True) print(f"downloading_mp4_sidecar: {remote} -> {local}") if data_root.startswith("hf://"): _download_sidecar_native_http( data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz", local ) else: fs.get(remote, str(local)) return local def _valid_sidecar(path: Path) -> bool: if not path.exists(): return False try: with np.load(path, allow_pickle=False) as data: return "manifest_json" in data except Exception: return False def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None: fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16) tmp = local.with_suffix(local.suffix + ".tmp") try: size = fetcher.info_size(relative_path) chunk_size = 16 * 1024 * 1024 ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)] with tmp.open("wb") as out_file: out_file.truncate(size) def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]: offset, length = offset_length return offset, fetcher.read_range(relative_path, offset, length) start = time.perf_counter() done = 0 with ThreadPoolExecutor(max_workers=8) as pool: futures = [pool.submit(read_chunk, item) for item in ranges] with tmp.open("r+b") as rw_file: for future in futures: offset, data = future.result() rw_file.seek(offset) rw_file.write(data) done += len(data) elapsed = max(time.perf_counter() - start, 1e-9) print( f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB " f"({done / elapsed / 1024**2:.1f} MiB/s)", flush=True, ) tmp.replace(local) finally: fetcher.close() class EpisodeParquetReader: def __init__(self, meta: LeRobotDatasetMetadata, data_root: str): self.meta = meta self.data_root = data_root protocol = "hf" if data_root.startswith("hf://") else "file" self.fs = fsspec.filesystem(protocol) self._episode_row_groups = self._build_episode_row_groups() self._table_cache: dict[str, pa.Table] = {} self._cache_lock = threading.Lock() def read_episode(self, episode_index: int) -> None: relative_path = str(self.meta.get_data_file_path(episode_index)) table = self._read_table(relative_path) table.filter(pc.equal(table["episode_index"], episode_index)) def _read_table(self, relative_path: str) -> pa.Table: with self._cache_lock: table = self._table_cache.get(relative_path) if table is not None: return table with self.fs.open( _root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none" ) as f: table = pq.ParquetFile(f).read() with self._cache_lock: return self._table_cache.setdefault(relative_path, table) def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int): return pool.submit(self.read_episode, episode_index) def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float: start = time.perf_counter() if workers <= 1: for ep in episodes: self.read_episode(ep) else: with ThreadPoolExecutor(max_workers=workers) as pool: futures = [pool.submit(self.read_episode, ep) for ep in episodes] for future in futures: future.result() return time.perf_counter() - start def _build_episode_row_groups(self) -> dict[int, int]: counts: dict[tuple[int, int], int] = {} row_groups = {} for ep_idx in range(int(self.meta.total_episodes)): ep = self.meta.episodes[ep_idx] key = (int(ep["data/chunk_index"]), int(ep["data/file_index"])) row_groups[ep_idx] = counts.get(key, 0) counts[key] = row_groups[ep_idx] + 1 return row_groups def run_fetch_pool( manifest: EpisodeVideoManifest, data_root: str, episodes: Sequence[int], byte_budget: int, workers: int, range_backend: str, ) -> dict[str, float]: with EpisodeByteCache( manifest, data_root, byte_budget=byte_budget, workers=workers, range_backend=range_backend, open_decoders=False, ) as cache: elapsed = _fill_cache(cache, episodes) byte_count = _bytes_for(manifest, episodes) episode_mb = byte_count / len(episodes) / 1024**2 return { "fetch_s": elapsed, "fetch_mbps": byte_count / elapsed / 1024**2, "fetch_episodes_s": len(episodes) / elapsed, "episode_mb": episode_mb, "avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2, } def run_parallel( manifest: EpisodeVideoManifest, data_root: str, episodes: Sequence[int], timestamps: dict[tuple[int, str], list[float]], byte_budget: int, workers: int, decode_workers: int, frames_per_episode: int, parquet_reader: EpisodeParquetReader, range_backend: str, ) -> dict[str, float]: with EpisodeByteCache( manifest, data_root, byte_budget=byte_budget, workers=workers, range_backend=range_backend, open_decoders=False, ) as cache: parquet_s = parquet_reader.read_episodes(episodes, workers=workers) fetch_s = _fill_cache(cache, episodes) decoder_start = time.perf_counter() for ep in episodes: for camera_key in manifest.video_keys: cache.get_decoder(ep, camera_key) decoder_s = time.perf_counter() - decoder_start decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers) byte_count = _bytes_for(manifest, episodes) return { "fetch_s": fetch_s, "fetch_mbps": byte_count / fetch_s / 1024**2, "fetch_episodes_s": len(episodes) / fetch_s, "parquet_s": parquet_s, "decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)), "decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode), } def run_overlapped( manifest: EpisodeVideoManifest, data_root: str, episodes: Sequence[int], timestamps: dict[tuple[int, str], list[float]], byte_budget: int, workers: int, decode_workers: int, frames_per_episode: int, prefetch_ahead: int, parquet_reader: EpisodeParquetReader, range_backend: str, ) -> dict[str, float]: with EpisodeByteCache( manifest, data_root, byte_budget=byte_budget, workers=workers, range_backend=range_backend, open_decoders=True, ) as cache: start = time.perf_counter() video_wait_decode_s = 0.0 parquet_wait_s = 0.0 parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes)))) parquet_futures = { ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead] } for ep in episodes[:prefetch_ahead]: cache.submit_prefetch(ep) try: for idx, ep in enumerate(episodes): next_idx = idx + prefetch_ahead if next_idx < len(episodes): next_ep = episodes[next_idx] cache.submit_prefetch(next_ep) parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep) parquet_start = time.perf_counter() parquet_futures.pop(ep).result() parquet_wait_s += time.perf_counter() - parquet_start video_start = time.perf_counter() cache.ensure_ready(ep) if decode_workers <= 1: for camera_key in manifest.video_keys: cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)]) else: with ThreadPoolExecutor(max_workers=decode_workers) as pool: futures = [ pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)]) for camera_key in manifest.video_keys ] for future in futures: future.result() video_wait_decode_s += time.perf_counter() - video_start finally: parquet_pool.shutdown(wait=True) elapsed = time.perf_counter() - start return { "samples_s": _samples_per_s(elapsed, episodes, frames_per_episode), "video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode), "parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode), "wall_s": elapsed, "video_wait_decode_s": video_wait_decode_s, "parquet_wait_s": parquet_wait_s, } _remote_decoder_local = threading.local() def _remote_decoder_cache() -> VideoDecoderCache: cache = getattr(_remote_decoder_local, "cache", None) if cache is None: cache = VideoDecoderCache(max_size=None) _remote_decoder_local.cache = cache return cache def _decode_remote_source( meta: LeRobotDatasetMetadata, data_root: str, episode_index: int, camera_key: str, timestamps: list[float], ): video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key))) return decode_video_frames_torchcodec( video_path, timestamps, tolerance_s=1.0 / float(meta.fps), decoder_cache=_remote_decoder_cache(), return_uint8=True, ) def run_remote_decoder( meta: LeRobotDatasetMetadata, data_root: str, episodes: Sequence[int], timestamps: dict[tuple[int, str], list[float]], *, frames_per_episode: int, decode_workers: int, parquet_reader: EpisodeParquetReader, ) -> dict[str, float]: items = [ (ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys ] start = time.perf_counter() for ep, camera_key, ts in items: if camera_key == meta.video_keys[0]: parquet_reader.read_episode(ep) _decode_remote_source(meta, data_root, ep, camera_key, ts) sequential_s = time.perf_counter() - start start = time.perf_counter() if decode_workers <= 1: for ep, camera_key, ts in items: if camera_key == meta.video_keys[0]: parquet_reader.read_episode(ep) _decode_remote_source(meta, data_root, ep, camera_key, ts) else: with ThreadPoolExecutor(max_workers=decode_workers) as pool: parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes] futures = [ pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts) for ep, camera_key, ts in items ] for future in parquet_futures: future.result() for future in futures: future.result() parallel_s = time.perf_counter() - start return { "sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode), "parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode), } def run_indexed_strategy( meta: LeRobotDatasetMetadata, data_root: str, args: argparse.Namespace, parquet_reader: EpisodeParquetReader, *, range_backend: str = "fsspec", label: str = "indexed", sidecar_path: str | None = None, ) -> None: _log(f"starting_strategy: {label}") manifest_start = time.perf_counter() manifest_episode_count = args.manifest_episodes or int(meta.total_episodes) manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes) manifest = EpisodeVideoManifest.build( meta, data_root, episode_indices=range(manifest_episode_count), range_backend=range_backend, workers=args.workers, max_probe_bytes=args.max_probe_mb * 1024 * 1024, sidecar_path=sidecar_path, ) manifest_s = time.perf_counter() - manifest_start _log(f"{label}: manifest_build_s={manifest_s:.2f}") episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed) byte_budget = int(args.byte_budget_gb * 1024**3) byte_count = _bytes_for(manifest, episodes) _log( f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track " f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)" ) _log(f"{label}: filling episode byte cache with {args.workers} workers") fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend) print(f"manifest_build_s: {manifest_s:.2f}") print(f"strategy: {label}") print(f"range_backend: {range_backend}") print(f"mp4_sidecar: {sidecar_path or 'none'}") print(f"data_root: {data_root}") print(f"episodes: {episodes}") print(f"cameras: {manifest.video_keys}") print() print("| Track | fetch MB/s | fetch eps/s | wall s | avg MB/camera | notes |") print("|---|---:|---:|---:|---:|---|") print( f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | " f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | " f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |" ) if args.include_decode: timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1) _log(f"{label}: running parallel video fetch + decode-only") parallel = run_parallel( manifest, data_root, episodes, timestamps, byte_budget, args.workers, args.decode_workers, args.frames_per_episode, parquet_reader, range_backend, ) _log(f"{label}: running overlapped end-to-end") overlapped = run_overlapped( manifest, data_root, episodes, timestamps, byte_budget, args.workers, args.decode_workers, args.frames_per_episode, args.prefetch_ahead, parquet_reader, range_backend, ) print( f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | " f"{parallel['fetch_s']:.2f} | {fetch_pool['avg_mb_miss']:.1f} | " f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, " f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |" ) print( f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | {fetch_pool['avg_mb_miss']:.1f} | " f"{overlapped['samples_s']:.1f} samples/s; video+decode " f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |" ) def run_remote_strategy( meta: LeRobotDatasetMetadata, data_root: str, args: argparse.Namespace, parquet_reader: EpisodeParquetReader, ) -> None: _log("starting_strategy: remote-decoder") episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed) timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1) _log("remote-decoder: running direct source MP4 decoder") result = run_remote_decoder( meta, data_root, episodes, timestamps, frames_per_episode=args.frames_per_episode, decode_workers=args.decode_workers, parquet_reader=parquet_reader, ) print("strategy: remote-decoder") print(f"data_root: {data_root}") print(f"episodes: {episodes}") print(f"cameras: {list(meta.video_keys)}") print() print("| Track | samples/s | notes |") print("|---|---:|---|") print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |") print( f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | " f"direct source MP4 decoder, {args.decode_workers} workers |" ) def main() -> None: args = parse_args() data_root = args.data_root if data_root.startswith("hf://") and not args.no_hub_branch_assert: assert_hf_hub_range_cache_branch() meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision) meta.ensure_readable() parquet_reader = EpisodeParquetReader(meta, data_root) manifest_episode_count = args.manifest_episodes or int(meta.total_episodes) manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes) sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count) if sidecar_path is not None: print(f"using_mp4_sidecar: {sidecar_path}") if sidecar_path is not None and args.strategy == "both": if args.include_decode: run_remote_strategy(meta, data_root, args, parquet_reader) print() run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="native-http", label="indexed-native-http-sidecar", sidecar_path=str(sidecar_path), ) print() run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="fsspec", label="indexed-sidecar", sidecar_path=str(sidecar_path), ) return if sidecar_path is not None and args.strategy == "indexed": run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="fsspec", label="indexed-sidecar", sidecar_path=str(sidecar_path), ) return if sidecar_path is not None and args.strategy == "native-http": run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="native-http", label="indexed-native-http-sidecar", sidecar_path=str(sidecar_path), ) return if args.strategy == "both": expected_sidecar = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz" expected_remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz") print(f"mp4_sidecar_missing_local: {expected_sidecar}") print(f"mp4_sidecar_missing_remote: {expected_remote}") print( "build_mp4_sidecar: " f"uv run --no-sync python scripts/build_mp4_sidecar.py --episodes {manifest_episode_count} " f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}" ) print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online") print() if args.strategy in ("both", "indexed"): run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="fsspec", label="indexed", sidecar_path=None, ) if args.strategy == "both": print() if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode): run_remote_strategy(meta, data_root, args, parquet_reader) if args.strategy == "both" and args.include_decode: print() if args.strategy in ("both", "native-http"): run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="native-http", label="indexed-native-http", sidecar_path=None, ) if __name__ == "__main__": main()