#!/usr/bin/env python # Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import argparse import random import resource import tempfile import threading import time from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from pathlib import Path import fsspec import numpy as np import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.episode_video_streaming import ( EpisodeByteCache, EpisodeVideoManifest, NativeHTTPRangeFetcher, assert_hf_hub_range_cache_branch, ) from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset" DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c" DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket" SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars" FULL_SIDECAR_NAME = "molmoact2-full.npz" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.") parser.add_argument("--repo-id", default=DEFAULT_REPO) parser.add_argument("--revision", default=DEFAULT_REVISION) parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT) parser.add_argument( "--strategy", choices=("both", "full", "indexed", "remote-decoder", "native-http"), default="both", help=argparse.SUPPRESS, ) parser.add_argument("--num-episodes", type=int, default=512) parser.add_argument( "--manifest-episodes", type=int, default=None, help="Limit manifest construction to the first N episodes for local smoke tests.", ) parser.add_argument("--pool-size", type=int, default=16) parser.add_argument("--workers", type=int, default=8) parser.add_argument( "--include-decode", action="store_true", help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.", ) parser.add_argument("--decode-workers", type=int, default=1) parser.add_argument("--prefetch-ahead", type=int, default=8) parser.add_argument("--frames-per-episode", type=int, default=16) parser.add_argument("--max-probe-mb", type=int, default=64) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--byte-budget-gb", type=float, default=80) parser.add_argument( "--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory." ) parser.add_argument("--no-hub-branch-assert", action="store_true") return parser.parse_args() def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]: rng = random.Random(seed) upper = min(total, requested) if pool_size > upper: raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}") return rng.sample(range(upper), pool_size) def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int): rng = random.Random(seed) out: dict[tuple[int, str], list[float]] = {} for ep in episodes: for camera_key in manifest.video_keys: span = manifest.lookup(ep, camera_key) lo = span.first_pts hi = max(span.last_pts, lo) out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode)) return out def _timestamps_from_meta( meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int ) -> dict[tuple[int, str], list[float]]: rng = random.Random(seed) out: dict[tuple[int, str], list[float]] = {} for ep in episodes: row = meta.episodes[ep] for camera_key in meta.video_keys: lo = float(row[f"videos/{camera_key}/from_timestamp"]) hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo) out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode)) return out def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int: total = 0 for ep in episodes: for camera_key in manifest.video_keys: total += manifest.lookup(ep, camera_key).mdat_length return total def _decode_all( cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int ) -> float: start = time.perf_counter() items = list(timestamps.items()) if decode_workers <= 1: for (ep, camera_key), ts in items: cache.get_frames(ep, camera_key, ts) else: with ThreadPoolExecutor(max_workers=decode_workers) as pool: futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items] for future in futures: future.result() return time.perf_counter() - start def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float: start = time.perf_counter() for ep in episodes: cache.submit_prefetch(ep) for ep in episodes: cache.ensure_ready(ep) return time.perf_counter() - start def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float: if elapsed_s <= 0: return float("inf") return len(episodes) * frames_per_episode / elapsed_s def _log(message: str) -> None: print(message, flush=True) def _format_duration(seconds: float) -> str: if seconds < 60: return f"{seconds:.1f}s" if seconds < 3600: return f"{seconds / 60:.1f}m" return f"{seconds / 3600:.1f}h" def _current_rss_mib() -> float | None: status_path = Path("/proc/self/status") if not status_path.exists(): return None for line in status_path.read_text().splitlines(): if line.startswith("VmRSS:"): return float(line.split()[1]) / 1024 return None def _peak_rss_mib() -> float: rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss # Linux reports KiB; macOS reports bytes. if rss > 10**8: return rss / 1024**2 return rss / 1024 def _memory_snapshot() -> dict[str, float | None]: return {"rss_mib": _current_rss_mib(), "peak_rss_mib": _peak_rss_mib()} def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | None]) -> None: start_rss = start["rss_mib"] end_rss = end["rss_mib"] delta = None if start_rss is None or end_rss is None else end_rss - start_rss print() print("| Memory | MiB |") print("|---|---:|") if start_rss is not None: print(f"| rss start | {start_rss:.1f} |") if end_rss is not None: print(f"| rss end | {end_rss:.1f} |") if delta is not None: print(f"| rss delta | {delta:.1f} |") print(f"| peak rss | {end['peak_rss_mib']:.1f} |") def _root_join(data_root: str, relative_path: str) -> str: if data_root.startswith("hf://"): return f"{data_root.rstrip('/')}/{relative_path}" return str(Path(data_root) / relative_path) def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None: _ = manifest_episode_count local = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME if _valid_sidecar(local): return local if local.exists(): print(f"mp4_sidecar_invalid_local: {local}") local.unlink() remote_relative = f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}" remote = _root_join(data_root, remote_relative) protocol = "hf" if data_root.startswith("hf://") else "file" fs = fsspec.filesystem(protocol) if not fs.exists(remote): return None local.parent.mkdir(parents=True, exist_ok=True) print(f"downloading_mp4_sidecar: {remote} -> {local}") if data_root.startswith("hf://"): _download_sidecar_native_http(data_root, remote_relative, local) else: fs.get(remote, str(local)) return local def _valid_sidecar(path: Path) -> bool: if not path.exists(): return False try: with np.load(path, allow_pickle=False) as data: return "manifest_json" in data except Exception: return False def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None: fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16) tmp = local.with_suffix(local.suffix + ".tmp") try: size = fetcher.info_size(relative_path) chunk_size = 16 * 1024 * 1024 ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)] with tmp.open("wb") as out_file: out_file.truncate(size) def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]: offset, length = offset_length return offset, fetcher.read_range(relative_path, offset, length) start = time.perf_counter() done = 0 with ThreadPoolExecutor(max_workers=8) as pool: futures = [pool.submit(read_chunk, item) for item in ranges] with tmp.open("r+b") as rw_file: for future in futures: offset, data = future.result() rw_file.seek(offset) rw_file.write(data) done += len(data) elapsed = max(time.perf_counter() - start, 1e-9) print( f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB " f"({done / elapsed / 1024**2:.1f} MiB/s)", flush=True, ) tmp.replace(local) finally: fetcher.close() class EpisodeParquetReader: def __init__(self, meta: LeRobotDatasetMetadata, data_root: str): self.meta = meta self.data_root = data_root protocol = "hf" if data_root.startswith("hf://") else "file" self.fs = fsspec.filesystem(protocol) self._episode_row_groups = self._build_episode_row_groups() self._table_cache: dict[str, pa.Table] = {} self._cache_lock = threading.Lock() def read_episode(self, episode_index: int) -> None: relative_path = str(self.meta.get_data_file_path(episode_index)) table = self._read_table(relative_path) table.filter(pc.equal(table["episode_index"], episode_index)) def _read_table(self, relative_path: str) -> pa.Table: with self._cache_lock: table = self._table_cache.get(relative_path) if table is not None: return table with self.fs.open( _root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none" ) as f: table = pq.ParquetFile(f).read() with self._cache_lock: return self._table_cache.setdefault(relative_path, table) def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int): return pool.submit(self.read_episode, episode_index) def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float: start = time.perf_counter() if workers <= 1: for ep in episodes: self.read_episode(ep) else: with ThreadPoolExecutor(max_workers=workers) as pool: futures = [pool.submit(self.read_episode, ep) for ep in episodes] for future in futures: future.result() return time.perf_counter() - start def _build_episode_row_groups(self) -> dict[int, int]: counts: dict[tuple[int, int], int] = {} row_groups = {} for ep_idx in range(int(self.meta.total_episodes)): ep = self.meta.episodes[ep_idx] key = (int(ep["data/chunk_index"]), int(ep["data/file_index"])) row_groups[ep_idx] = counts.get(key, 0) counts[key] = row_groups[ep_idx] + 1 return row_groups def run_fetch_pool( manifest: EpisodeVideoManifest, data_root: str, episodes: Sequence[int], byte_budget: int, workers: int, range_backend: str, ) -> dict[str, float]: with EpisodeByteCache( manifest, data_root, byte_budget=byte_budget, workers=workers, range_backend=range_backend, open_decoders=False, ) as cache: elapsed = _fill_cache(cache, episodes) timings = cache.timing_summary() byte_count = _bytes_for(manifest, episodes) episode_mb = byte_count / len(episodes) / 1024**2 job_count = max(timings["jobs"], 1.0) result = { "fetch_s": elapsed, "fetch_mbps": byte_count / elapsed / 1024**2, "fetch_episodes_s": len(episodes) / elapsed, "episode_mb": episode_mb, "avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2, "jobs": timings["jobs"], "lookup_ms": timings["lookup_s"] * 1000 / job_count, "range_fetch_ms": timings["fetch_s"] * 1000 / job_count, "synthesize_ms": timings["synthesize_s"] * 1000 / job_count, "store_ms": timings["store_s"] * 1000 / job_count, } result.update({key: value for key, value in timings.items() if key.startswith("range_")}) return result def run_parallel( manifest: EpisodeVideoManifest, data_root: str, episodes: Sequence[int], timestamps: dict[tuple[int, str], list[float]], byte_budget: int, workers: int, decode_workers: int, frames_per_episode: int, parquet_reader: EpisodeParquetReader, range_backend: str, ) -> dict[str, float]: with EpisodeByteCache( manifest, data_root, byte_budget=byte_budget, workers=workers, range_backend=range_backend, open_decoders=False, ) as cache: parquet_s = parquet_reader.read_episodes(episodes, workers=workers) fetch_s = _fill_cache(cache, episodes) decoder_start = time.perf_counter() for ep in episodes: for camera_key in manifest.video_keys: cache.get_decoder(ep, camera_key) decoder_s = time.perf_counter() - decoder_start decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers) byte_count = _bytes_for(manifest, episodes) return { "fetch_s": fetch_s, "fetch_mbps": byte_count / fetch_s / 1024**2, "fetch_episodes_s": len(episodes) / fetch_s, "parquet_s": parquet_s, "decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)), "decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode), } def run_overlapped( manifest: EpisodeVideoManifest, data_root: str, episodes: Sequence[int], timestamps: dict[tuple[int, str], list[float]], byte_budget: int, workers: int, decode_workers: int, frames_per_episode: int, prefetch_ahead: int, parquet_reader: EpisodeParquetReader, range_backend: str, ) -> dict[str, float]: with EpisodeByteCache( manifest, data_root, byte_budget=byte_budget, workers=workers, range_backend=range_backend, open_decoders=True, ) as cache: start = time.perf_counter() video_wait_decode_s = 0.0 parquet_wait_s = 0.0 parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes)))) parquet_futures = { ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead] } for ep in episodes[:prefetch_ahead]: cache.submit_prefetch(ep) try: for idx, ep in enumerate(episodes): next_idx = idx + prefetch_ahead if next_idx < len(episodes): next_ep = episodes[next_idx] cache.submit_prefetch(next_ep) parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep) parquet_start = time.perf_counter() parquet_futures.pop(ep).result() parquet_wait_s += time.perf_counter() - parquet_start video_start = time.perf_counter() cache.ensure_ready(ep) if decode_workers <= 1: for camera_key in manifest.video_keys: cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)]) else: with ThreadPoolExecutor(max_workers=decode_workers) as pool: futures = [ pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)]) for camera_key in manifest.video_keys ] for future in futures: future.result() video_wait_decode_s += time.perf_counter() - video_start finally: parquet_pool.shutdown(wait=True) elapsed = time.perf_counter() - start return { "samples_s": _samples_per_s(elapsed, episodes, frames_per_episode), "video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode), "parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode), "wall_s": elapsed, "video_wait_decode_s": video_wait_decode_s, "parquet_wait_s": parquet_wait_s, } _remote_decoder_local = threading.local() def _remote_decoder_cache() -> VideoDecoderCache: cache = getattr(_remote_decoder_local, "cache", None) if cache is None: cache = VideoDecoderCache(max_size=None) _remote_decoder_local.cache = cache return cache def _decode_remote_source( meta: LeRobotDatasetMetadata, data_root: str, episode_index: int, camera_key: str, timestamps: list[float], ): video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key))) return decode_video_frames_torchcodec( video_path, timestamps, tolerance_s=1.0 / float(meta.fps), decoder_cache=_remote_decoder_cache(), return_uint8=True, ) def run_remote_decoder( meta: LeRobotDatasetMetadata, data_root: str, episodes: Sequence[int], timestamps: dict[tuple[int, str], list[float]], *, frames_per_episode: int, decode_workers: int, parquet_reader: EpisodeParquetReader, ) -> dict[str, float]: items = [ (ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys ] start = time.perf_counter() for ep, camera_key, ts in items: if camera_key == meta.video_keys[0]: parquet_reader.read_episode(ep) _decode_remote_source(meta, data_root, ep, camera_key, ts) sequential_s = time.perf_counter() - start start = time.perf_counter() if decode_workers <= 1: for ep, camera_key, ts in items: if camera_key == meta.video_keys[0]: parquet_reader.read_episode(ep) _decode_remote_source(meta, data_root, ep, camera_key, ts) else: with ThreadPoolExecutor(max_workers=decode_workers) as pool: parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes] futures = [ pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts) for ep, camera_key, ts in items ] for future in parquet_futures: future.result() for future in futures: future.result() parallel_s = time.perf_counter() - start return { "sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode), "parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode), } def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None: range_jobs = fetch_pool.get("range_jobs", 0.0) if range_jobs <= 0: return print() print("| Range Read Stage | avg ms/range |") print("|---|---:|") for key, label in ( ("range_open_s", "fsspec handle open/lookup"), ("range_seek_s", "fsspec seek"), ("range_read_s", "fsspec read"), ("range_resolve_s", "http URL resolve"), ("range_header_s", "http response headers"), ("range_first_byte_s", "http first body byte"), ("range_body_s", "http body drain"), ): value = fetch_pool.get(key) if value is not None: print(f"| {label} | {value * 1000 / range_jobs:.3f} |") print(f"| range reads | {range_jobs:.0f} |") print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |") def run_indexed_strategy( meta: LeRobotDatasetMetadata, data_root: str, args: argparse.Namespace, parquet_reader: EpisodeParquetReader, *, range_backend: str = "fsspec", label: str = "indexed", sidecar_path: str | None = None, ) -> None: _log(f"starting_strategy: {label}") memory_start = _memory_snapshot() manifest_start = time.perf_counter() dataset_episode_count = int(meta.total_episodes) manifest_episode_count = args.manifest_episodes or dataset_episode_count manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes) manifest = EpisodeVideoManifest.build( meta, data_root, episode_indices=range(manifest_episode_count), range_backend=range_backend, workers=args.workers, max_probe_bytes=args.max_probe_mb * 1024 * 1024, sidecar_path=sidecar_path, ) manifest_s = time.perf_counter() - manifest_start _log(f"{label}: manifest_build_s={manifest_s:.2f}") benchmark_episode_count = min(dataset_episode_count, args.num_episodes) episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed) byte_budget = int(args.byte_budget_gb * 1024**3) byte_count = _bytes_for(manifest, episodes) _log( f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track " f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)" ) _log(f"{label}: filling episode byte cache with {args.workers} workers") fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend) estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"] estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"] print(f"manifest_build_s: {manifest_s:.2f}") print(f"strategy: {label}") print(f"range_backend: {range_backend}") print(f"mp4_sidecar: {sidecar_path or 'none'}") print(f"data_root: {data_root}") print(f"dataset_episodes: {dataset_episode_count}") print(f"benchmark_episodes: {benchmark_episode_count}") print(f"pool_episodes: {len(episodes)}") print(f"sampled_episodes: {episodes}") print(f"cameras: {manifest.video_keys}") print() print( "| Track | fetch MB/s | fetch eps/s | wall s | est benchmark | est full dataset | avg MB/camera | notes |" ) print("|---|---:|---:|---:|---:|---:|---:|---|") print( f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | " f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | " f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | " f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |" ) print() print("| Camera Job Stage | avg ms/job |") print("|---|---:|") print(f"| manifest lookup | {fetch_pool['lookup_ms']:.3f} |") print(f"| remote byte-range fetch | {fetch_pool['range_fetch_ms']:.3f} |") print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |") print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |") print(f"| camera jobs | {fetch_pool['jobs']:.0f} |") _print_range_timing_summary(fetch_pool) _print_memory_summary(memory_start, _memory_snapshot()) if args.include_decode: timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1) _log(f"{label}: running parallel video fetch + decode-only") parallel = run_parallel( manifest, data_root, episodes, timestamps, byte_budget, args.workers, args.decode_workers, args.frames_per_episode, parquet_reader, range_backend, ) _log(f"{label}: running overlapped end-to-end") overlapped = run_overlapped( manifest, data_root, episodes, timestamps, byte_budget, args.workers, args.decode_workers, args.frames_per_episode, args.prefetch_ahead, parquet_reader, range_backend, ) print( f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | " f"{parallel['fetch_s']:.2f} | " f"{_format_duration(benchmark_episode_count / parallel['fetch_episodes_s'])} | " f"{_format_duration(dataset_episode_count / parallel['fetch_episodes_s'])} | " f"{fetch_pool['avg_mb_miss']:.1f} | " f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, " f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |" ) print( f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | - | - | " f"{fetch_pool['avg_mb_miss']:.1f} | " f"{overlapped['samples_s']:.1f} samples/s; video+decode " f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |" ) def run_remote_strategy( meta: LeRobotDatasetMetadata, data_root: str, args: argparse.Namespace, parquet_reader: EpisodeParquetReader, ) -> None: _log("starting_strategy: remote-decoder") episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed) timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1) _log("remote-decoder: running direct source MP4 decoder") result = run_remote_decoder( meta, data_root, episodes, timestamps, frames_per_episode=args.frames_per_episode, decode_workers=args.decode_workers, parquet_reader=parquet_reader, ) print("strategy: remote-decoder") print(f"data_root: {data_root}") print(f"episodes: {episodes}") print(f"cameras: {list(meta.video_keys)}") print() print("| Track | samples/s | notes |") print("|---|---:|---|") print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |") print( f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | " f"direct source MP4 decoder, {args.decode_workers} workers |" ) def main() -> None: args = parse_args() if args.strategy == "full": args.strategy = "both" data_root = args.data_root if data_root.startswith("hf://") and not args.no_hub_branch_assert: assert_hf_hub_range_cache_branch() meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision) meta.ensure_readable() parquet_reader = EpisodeParquetReader(meta, data_root) manifest_episode_count = args.manifest_episodes or int(meta.total_episodes) manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes) sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count) if sidecar_path is not None: print(f"using_mp4_sidecar: {sidecar_path}") if sidecar_path is not None and args.strategy == "both": if args.include_decode: run_remote_strategy(meta, data_root, args, parquet_reader) print() run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="fsspec", label="indexed-sidecar", sidecar_path=str(sidecar_path), ) return if sidecar_path is not None and args.strategy == "indexed": run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="fsspec", label="indexed-sidecar", sidecar_path=str(sidecar_path), ) return if sidecar_path is not None and args.strategy == "native-http": print("using_indexed_sidecar_for_native_http: sidecar mode uses HfFileSystem range reads") run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="fsspec", label="indexed-sidecar", sidecar_path=str(sidecar_path), ) return if args.strategy == "both": expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}") print(f"mp4_sidecar_missing_local: {expected_sidecar}") print(f"mp4_sidecar_missing_remote: {expected_remote}") print( "build_mp4_sidecar: " "uv run --no-sync python scripts/build_mp4_sidecar.py " f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}" ) print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online") print() if args.strategy in ("both", "indexed"): run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="fsspec", label="indexed", sidecar_path=None, ) if args.strategy == "both": print() if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode): run_remote_strategy(meta, data_root, args, parquet_reader) if args.strategy == "both" and args.include_decode: print() if args.strategy in ("both", "native-http"): run_indexed_strategy( meta, data_root, args, parquet_reader, range_backend="native-http", label="indexed-native-http", sidecar_path=None, ) if __name__ == "__main__": main()