diff --git a/pyproject.toml b/pyproject.toml index 0dc86d7ff..91794bf76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -421,6 +421,7 @@ exclude_dirs = [ skips = ["B101", "B311", "B404", "B603", "B615"] [tool.typos] +default.extend-words = { trak = "trak" } default.extend-ignore-re = [ "(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line "(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker: diff --git a/scripts/bench_episode_byte_cache.py b/scripts/bench_episode_byte_cache.py new file mode 100644 index 000000000..34a97efda --- /dev/null +++ b/scripts/bench_episode_byte_cache.py @@ -0,0 +1,724 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import argparse +import random +import tempfile +import threading +import time +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import fsspec +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.parquet as pq + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.episode_video_streaming import ( + EpisodeByteCache, + EpisodeVideoManifest, + NativeHTTPRangeFetcher, + assert_hf_hub_range_cache_branch, +) +from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec + +DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset" +DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c" +DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket" +SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.") + parser.add_argument("--repo-id", default=DEFAULT_REPO) + parser.add_argument("--revision", default=DEFAULT_REVISION) + parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT) + parser.add_argument( + "--strategy", + choices=("both", "indexed", "remote-decoder", "native-http"), + default="both", + help=argparse.SUPPRESS, + ) + parser.add_argument("--num-episodes", type=int, default=512) + parser.add_argument( + "--manifest-episodes", + type=int, + default=None, + help="Limit manifest construction to the first N episodes for local smoke tests.", + ) + parser.add_argument("--pool-size", type=int, default=16) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--decode-workers", type=int, default=1) + parser.add_argument("--prefetch-ahead", type=int, default=8) + parser.add_argument("--frames-per-episode", type=int, default=16) + parser.add_argument("--max-probe-mb", type=int, default=64) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--byte-budget-gb", type=float, default=80) + parser.add_argument( + "--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory." + ) + parser.add_argument("--no-hub-branch-assert", action="store_true") + return parser.parse_args() + + +def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]: + rng = random.Random(seed) + upper = min(total, requested) + if pool_size > upper: + raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}") + return rng.sample(range(upper), pool_size) + + +def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int): + rng = random.Random(seed) + out: dict[tuple[int, str], list[float]] = {} + for ep in episodes: + for camera_key in manifest.video_keys: + span = manifest.lookup(ep, camera_key) + lo = span.first_pts + hi = max(span.last_pts, lo) + out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode)) + return out + + +def _timestamps_from_meta( + meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int +) -> dict[tuple[int, str], list[float]]: + rng = random.Random(seed) + out: dict[tuple[int, str], list[float]] = {} + for ep in episodes: + row = meta.episodes[ep] + for camera_key in meta.video_keys: + lo = float(row[f"videos/{camera_key}/from_timestamp"]) + hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo) + out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode)) + return out + + +def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int: + total = 0 + for ep in episodes: + for camera_key in manifest.video_keys: + total += manifest.lookup(ep, camera_key).mdat_length + return total + + +def _decode_all( + cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int +) -> float: + start = time.perf_counter() + items = list(timestamps.items()) + if decode_workers <= 1: + for (ep, camera_key), ts in items: + cache.get_frames(ep, camera_key, ts) + else: + with ThreadPoolExecutor(max_workers=decode_workers) as pool: + futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items] + for future in futures: + future.result() + return time.perf_counter() - start + + +def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float: + start = time.perf_counter() + for ep in episodes: + cache.submit_prefetch(ep) + for ep in episodes: + cache.ensure_ready(ep) + return time.perf_counter() - start + + +def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float: + if elapsed_s <= 0: + return float("inf") + return len(episodes) * frames_per_episode / elapsed_s + + +def _log(message: str) -> None: + print(message, flush=True) + + +def _root_join(data_root: str, relative_path: str) -> str: + if data_root.startswith("hf://"): + return f"{data_root.rstrip('/')}/{relative_path}" + return str(Path(data_root) / relative_path) + + +def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None: + local = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz" + if _valid_sidecar(local): + return local + if local.exists(): + print(f"mp4_sidecar_invalid_local: {local}") + local.unlink() + full_local = SIDECAR_CACHE_DIR / "molmoact2-full.npz" + if _valid_sidecar(full_local): + return full_local + remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz") + protocol = "hf" if data_root.startswith("hf://") else "file" + fs = fsspec.filesystem(protocol) + if not fs.exists(remote): + return None + local.parent.mkdir(parents=True, exist_ok=True) + print(f"downloading_mp4_sidecar: {remote} -> {local}") + if data_root.startswith("hf://"): + _download_sidecar_native_http( + data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz", local + ) + else: + fs.get(remote, str(local)) + return local + + +def _valid_sidecar(path: Path) -> bool: + if not path.exists(): + return False + try: + with np.load(path, allow_pickle=False) as data: + return "manifest_json" in data + except Exception: + return False + + +def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None: + fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16) + tmp = local.with_suffix(local.suffix + ".tmp") + try: + size = fetcher.info_size(relative_path) + chunk_size = 16 * 1024 * 1024 + ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)] + with tmp.open("wb") as out_file: + out_file.truncate(size) + + def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]: + offset, length = offset_length + return offset, fetcher.read_range(relative_path, offset, length) + + start = time.perf_counter() + done = 0 + with ThreadPoolExecutor(max_workers=8) as pool: + futures = [pool.submit(read_chunk, item) for item in ranges] + with tmp.open("r+b") as rw_file: + for future in futures: + offset, data = future.result() + rw_file.seek(offset) + rw_file.write(data) + done += len(data) + elapsed = max(time.perf_counter() - start, 1e-9) + print( + f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB " + f"({done / elapsed / 1024**2:.1f} MiB/s)", + flush=True, + ) + tmp.replace(local) + finally: + fetcher.close() + + +class EpisodeParquetReader: + def __init__(self, meta: LeRobotDatasetMetadata, data_root: str): + self.meta = meta + self.data_root = data_root + protocol = "hf" if data_root.startswith("hf://") else "file" + self.fs = fsspec.filesystem(protocol) + self._episode_row_groups = self._build_episode_row_groups() + self._table_cache: dict[str, pa.Table] = {} + self._cache_lock = threading.Lock() + + def read_episode(self, episode_index: int) -> None: + relative_path = str(self.meta.get_data_file_path(episode_index)) + table = self._read_table(relative_path) + table.filter(pc.equal(table["episode_index"], episode_index)) + + def _read_table(self, relative_path: str) -> pa.Table: + with self._cache_lock: + table = self._table_cache.get(relative_path) + if table is not None: + return table + with self.fs.open( + _root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none" + ) as f: + table = pq.ParquetFile(f).read() + with self._cache_lock: + return self._table_cache.setdefault(relative_path, table) + + def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int): + return pool.submit(self.read_episode, episode_index) + + def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float: + start = time.perf_counter() + if workers <= 1: + for ep in episodes: + self.read_episode(ep) + else: + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = [pool.submit(self.read_episode, ep) for ep in episodes] + for future in futures: + future.result() + return time.perf_counter() - start + + def _build_episode_row_groups(self) -> dict[int, int]: + counts: dict[tuple[int, int], int] = {} + row_groups = {} + for ep_idx in range(int(self.meta.total_episodes)): + ep = self.meta.episodes[ep_idx] + key = (int(ep["data/chunk_index"]), int(ep["data/file_index"])) + row_groups[ep_idx] = counts.get(key, 0) + counts[key] = row_groups[ep_idx] + 1 + return row_groups + + +def run_sequential( + manifest: EpisodeVideoManifest, + data_root: str, + episodes: Sequence[int], + byte_budget: int, + parquet_reader: EpisodeParquetReader, + range_backend: str, +) -> dict[str, float]: + with EpisodeByteCache( + manifest, + data_root, + byte_budget=byte_budget, + workers=1, + range_backend=range_backend, + open_decoders=False, + ) as cache: + parquet_s = parquet_reader.read_episodes(episodes, workers=1) + elapsed = _fill_cache(cache, episodes) + byte_count = _bytes_for(manifest, episodes) + episode_mb = byte_count / len(episodes) / 1024**2 + return { + "fetch_s": elapsed, + "fetch_mbps": byte_count / elapsed / 1024**2, + "fetch_episodes_s": len(episodes) / elapsed, + "episode_mb": episode_mb, + "parquet_s": parquet_s, + "avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2, + } + + +def run_parallel( + manifest: EpisodeVideoManifest, + data_root: str, + episodes: Sequence[int], + timestamps: dict[tuple[int, str], list[float]], + byte_budget: int, + workers: int, + decode_workers: int, + frames_per_episode: int, + parquet_reader: EpisodeParquetReader, + range_backend: str, +) -> dict[str, float]: + with EpisodeByteCache( + manifest, + data_root, + byte_budget=byte_budget, + workers=workers, + range_backend=range_backend, + open_decoders=False, + ) as cache: + parquet_s = parquet_reader.read_episodes(episodes, workers=workers) + fetch_s = _fill_cache(cache, episodes) + decoder_start = time.perf_counter() + for ep in episodes: + for camera_key in manifest.video_keys: + cache.get_decoder(ep, camera_key) + decoder_s = time.perf_counter() - decoder_start + decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers) + byte_count = _bytes_for(manifest, episodes) + return { + "fetch_s": fetch_s, + "fetch_mbps": byte_count / fetch_s / 1024**2, + "fetch_episodes_s": len(episodes) / fetch_s, + "parquet_s": parquet_s, + "decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)), + "decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode), + } + + +def run_overlapped( + manifest: EpisodeVideoManifest, + data_root: str, + episodes: Sequence[int], + timestamps: dict[tuple[int, str], list[float]], + byte_budget: int, + workers: int, + decode_workers: int, + frames_per_episode: int, + prefetch_ahead: int, + parquet_reader: EpisodeParquetReader, + range_backend: str, +) -> dict[str, float]: + with EpisodeByteCache( + manifest, + data_root, + byte_budget=byte_budget, + workers=workers, + range_backend=range_backend, + open_decoders=True, + ) as cache: + start = time.perf_counter() + video_wait_decode_s = 0.0 + parquet_wait_s = 0.0 + parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes)))) + parquet_futures = { + ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead] + } + for ep in episodes[:prefetch_ahead]: + cache.submit_prefetch(ep) + try: + for idx, ep in enumerate(episodes): + next_idx = idx + prefetch_ahead + if next_idx < len(episodes): + next_ep = episodes[next_idx] + cache.submit_prefetch(next_ep) + parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep) + + parquet_start = time.perf_counter() + parquet_futures.pop(ep).result() + parquet_wait_s += time.perf_counter() - parquet_start + + video_start = time.perf_counter() + cache.ensure_ready(ep) + if decode_workers <= 1: + for camera_key in manifest.video_keys: + cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)]) + else: + with ThreadPoolExecutor(max_workers=decode_workers) as pool: + futures = [ + pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)]) + for camera_key in manifest.video_keys + ] + for future in futures: + future.result() + video_wait_decode_s += time.perf_counter() - video_start + finally: + parquet_pool.shutdown(wait=True) + elapsed = time.perf_counter() - start + return { + "samples_s": _samples_per_s(elapsed, episodes, frames_per_episode), + "video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode), + "parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode), + "wall_s": elapsed, + "video_wait_decode_s": video_wait_decode_s, + "parquet_wait_s": parquet_wait_s, + } + + +_remote_decoder_local = threading.local() + + +def _remote_decoder_cache() -> VideoDecoderCache: + cache = getattr(_remote_decoder_local, "cache", None) + if cache is None: + cache = VideoDecoderCache(max_size=None) + _remote_decoder_local.cache = cache + return cache + + +def _decode_remote_source( + meta: LeRobotDatasetMetadata, + data_root: str, + episode_index: int, + camera_key: str, + timestamps: list[float], +): + video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key))) + return decode_video_frames_torchcodec( + video_path, + timestamps, + tolerance_s=1.0 / float(meta.fps), + decoder_cache=_remote_decoder_cache(), + return_uint8=True, + ) + + +def run_remote_decoder( + meta: LeRobotDatasetMetadata, + data_root: str, + episodes: Sequence[int], + timestamps: dict[tuple[int, str], list[float]], + *, + frames_per_episode: int, + decode_workers: int, + parquet_reader: EpisodeParquetReader, +) -> dict[str, float]: + items = [ + (ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys + ] + + start = time.perf_counter() + for ep, camera_key, ts in items: + if camera_key == meta.video_keys[0]: + parquet_reader.read_episode(ep) + _decode_remote_source(meta, data_root, ep, camera_key, ts) + sequential_s = time.perf_counter() - start + + start = time.perf_counter() + if decode_workers <= 1: + for ep, camera_key, ts in items: + if camera_key == meta.video_keys[0]: + parquet_reader.read_episode(ep) + _decode_remote_source(meta, data_root, ep, camera_key, ts) + else: + with ThreadPoolExecutor(max_workers=decode_workers) as pool: + parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes] + futures = [ + pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts) + for ep, camera_key, ts in items + ] + for future in parquet_futures: + future.result() + for future in futures: + future.result() + parallel_s = time.perf_counter() - start + + return { + "sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode), + "parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode), + } + + +def run_indexed_strategy( + meta: LeRobotDatasetMetadata, + data_root: str, + args: argparse.Namespace, + parquet_reader: EpisodeParquetReader, + *, + range_backend: str = "fsspec", + label: str = "indexed", + sidecar_path: str | None = None, +) -> None: + _log(f"starting_strategy: {label}") + manifest_start = time.perf_counter() + manifest_episode_count = args.manifest_episodes or int(meta.total_episodes) + manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes) + manifest = EpisodeVideoManifest.build( + meta, + data_root, + episode_indices=range(manifest_episode_count), + range_backend=range_backend, + workers=args.workers, + max_probe_bytes=args.max_probe_mb * 1024 * 1024, + sidecar_path=sidecar_path, + ) + manifest_s = time.perf_counter() - manifest_start + _log(f"{label}: manifest_build_s={manifest_s:.2f}") + + episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed) + timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1) + byte_budget = int(args.byte_budget_gb * 1024**3) + byte_count = _bytes_for(manifest, episodes) + _log( + f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track " + f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)" + ) + + _log(f"{label}: running sequential video fetch") + sequential = run_sequential(manifest, data_root, episodes, byte_budget, parquet_reader, range_backend) + _log(f"{label}: running parallel video fetch + decode-only") + parallel = run_parallel( + manifest, + data_root, + episodes, + timestamps, + byte_budget, + args.workers, + args.decode_workers, + args.frames_per_episode, + parquet_reader, + range_backend, + ) + _log(f"{label}: running overlapped end-to-end") + overlapped = run_overlapped( + manifest, + data_root, + episodes, + timestamps, + byte_budget, + args.workers, + args.decode_workers, + args.frames_per_episode, + args.prefetch_ahead, + parquet_reader, + range_backend, + ) + + print(f"manifest_build_s: {manifest_s:.2f}") + print(f"strategy: {label}") + print(f"range_backend: {range_backend}") + print(f"mp4_sidecar: {sidecar_path or 'none'}") + print(f"data_root: {data_root}") + print(f"episodes: {episodes}") + print(f"cameras: {manifest.video_keys}") + print() + print("| Track | fetch MB/s | fetch eps/s | samples/s | avg MB/miss | notes |") + print("|---|---:|---:|---:|---:|---|") + print( + f"| SEQUENTIAL | {sequential['fetch_mbps']:.1f} | {sequential['fetch_episodes_s']:.2f} | - | " + f"{sequential['avg_mb_miss']:.1f} | 1 worker video fetch, parquet {sequential['parquet_s']:.2f}s |" + ) + print( + f"| PARALLEL | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | " + f"{parallel['decode_samples_s']:.1f} | " + f"{sequential['avg_mb_miss']:.1f} | decode-only, decoder open " + f"{parallel['decoder_ms_miss']:.1f} ms/miss, parquet {parallel['parquet_s']:.2f}s |" + ) + print( + f"| OVERLAPPED | - | - | {overlapped['samples_s']:.1f} | {sequential['avg_mb_miss']:.1f} | " + f"end-to-end; video {overlapped['video_samples_s']:.1f} samples/s " + f"({overlapped['video_wait_decode_s']:.2f}s), parquet {overlapped['parquet_samples_s']:.1f} " + f"samples/s ({overlapped['parquet_wait_s']:.2f}s) |" + ) + + +def run_remote_strategy( + meta: LeRobotDatasetMetadata, + data_root: str, + args: argparse.Namespace, + parquet_reader: EpisodeParquetReader, +) -> None: + _log("starting_strategy: remote-decoder") + episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed) + timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1) + _log("remote-decoder: running direct source MP4 decoder") + result = run_remote_decoder( + meta, + data_root, + episodes, + timestamps, + frames_per_episode=args.frames_per_episode, + decode_workers=args.decode_workers, + parquet_reader=parquet_reader, + ) + print("strategy: remote-decoder") + print(f"data_root: {data_root}") + print(f"episodes: {episodes}") + print(f"cameras: {list(meta.video_keys)}") + print() + print("| Track | samples/s | notes |") + print("|---|---:|---|") + print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |") + print( + f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | " + f"direct source MP4 decoder, {args.decode_workers} workers |" + ) + + +def main() -> None: + args = parse_args() + data_root = args.data_root + if data_root.startswith("hf://") and not args.no_hub_branch_assert: + assert_hf_hub_range_cache_branch() + + meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision) + meta.ensure_readable() + parquet_reader = EpisodeParquetReader(meta, data_root) + manifest_episode_count = args.manifest_episodes or int(meta.total_episodes) + manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes) + sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count) + + if sidecar_path is not None: + print(f"using_mp4_sidecar: {sidecar_path}") + + if sidecar_path is not None and args.strategy == "both": + run_remote_strategy(meta, data_root, args, parquet_reader) + print() + run_indexed_strategy( + meta, + data_root, + args, + parquet_reader, + range_backend="native-http", + label="indexed-native-http-sidecar", + sidecar_path=str(sidecar_path), + ) + print() + run_indexed_strategy( + meta, + data_root, + args, + parquet_reader, + range_backend="fsspec", + label="indexed-sidecar", + sidecar_path=str(sidecar_path), + ) + return + if sidecar_path is not None and args.strategy == "indexed": + run_indexed_strategy( + meta, + data_root, + args, + parquet_reader, + range_backend="fsspec", + label="indexed-sidecar", + sidecar_path=str(sidecar_path), + ) + return + if sidecar_path is not None and args.strategy == "native-http": + run_indexed_strategy( + meta, + data_root, + args, + parquet_reader, + range_backend="native-http", + label="indexed-native-http-sidecar", + sidecar_path=str(sidecar_path), + ) + return + if args.strategy == "both": + expected_sidecar = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz" + expected_remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz") + print(f"mp4_sidecar_missing_local: {expected_sidecar}") + print(f"mp4_sidecar_missing_remote: {expected_remote}") + print( + "build_mp4_sidecar: " + f"uv run --no-sync python scripts/build_mp4_sidecar.py --episodes {manifest_episode_count} " + f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}" + ) + print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online") + print() + + if args.strategy in ("both", "indexed"): + run_indexed_strategy( + meta, + data_root, + args, + parquet_reader, + range_backend="fsspec", + label="indexed", + sidecar_path=None, + ) + if args.strategy == "both": + print() + if args.strategy in ("both", "remote-decoder"): + run_remote_strategy(meta, data_root, args, parquet_reader) + if args.strategy == "both": + print() + if args.strategy in ("both", "native-http"): + run_indexed_strategy( + meta, + data_root, + args, + parquet_reader, + range_backend="native-http", + label="indexed-native-http", + sidecar_path=None, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/build_mp4_sidecar.py b/scripts/build_mp4_sidecar.py new file mode 100644 index 000000000..ef6d77ff0 --- /dev/null +++ b/scripts/build_mp4_sidecar.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import argparse +import time +from pathlib import Path + +import fsspec + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch + +DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset" +DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c" +DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.") + parser.add_argument("--repo-id", default=DEFAULT_REPO) + parser.add_argument("--revision", default=DEFAULT_REVISION) + parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT) + parser.add_argument("--output", required=True) + parser.add_argument("--episodes", type=int, default=None) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http") + parser.add_argument("--max-probe-mb", type=int, default=64) + parser.add_argument( + "--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars." + ) + parser.add_argument("--no-hub-branch-assert", action="store_true") + return parser.parse_args() + + +def push_sidecar(local_path: str, data_root: str, episode_count: int) -> list[str]: + if not data_root.startswith("hf://"): + return [] + + local = Path(local_path) + fs = fsspec.filesystem("hf") + remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars" + remote_paths = [f"{remote_dir}/{local.name}"] + alias = f"{remote_dir}/molmoact2-{episode_count}.npz" + if alias not in remote_paths: + remote_paths.append(alias) + + for remote in remote_paths: + fs.put(str(local), remote) + return remote_paths + + +def main() -> None: + args = parse_args() + if args.data_root.startswith("hf://") and not args.no_hub_branch_assert: + assert_hf_hub_range_cache_branch() + + meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision) + meta.ensure_readable() + total = ( + int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes)) + ) + rel_paths = sorted( + {str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys} + ) + + start = time.perf_counter() + EpisodeVideoManifest.write_file_sidecar( + args.output, + rel_paths, + args.data_root, + range_backend=args.range_backend, + workers=args.workers, + max_probe_bytes=args.max_probe_mb * 1024 * 1024, + ) + elapsed = time.perf_counter() - start + print(f"wrote {args.output}") + print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}") + if args.no_push: + print("push_skipped: --no-push") + else: + pushed = push_sidecar(args.output, args.data_root, total) + for remote in pushed: + print(f"pushed {remote}") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/datasets/episode_video_streaming.py b/src/lerobot/datasets/episode_video_streaming.py new file mode 100644 index 000000000..f40a62849 --- /dev/null +++ b/src/lerobot/datasets/episode_video_streaming.py @@ -0,0 +1,668 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import contextlib +import io +import json +import threading +import time +from collections import OrderedDict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from importlib import metadata +from pathlib import Path +from typing import Any +from urllib.parse import quote, urljoin, urlparse + +import fsspec +import httpx +import numpy as np +from huggingface_hub import HfApi, HfFileSystem, constants +from huggingface_hub.utils import hf_raise_for_status + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4 + + +@dataclass(frozen=True) +class EpisodeVideoSpan: + file_id: int + mdat_offset: int + mdat_length: int + first_pts: float + last_pts: float + frame_count: int + sample_lo: int + sample_hi: int + source_start_pts: float + + +@dataclass(frozen=True) +class VideoFileRecord: + file_path: str + file_size: int + mp4: Mp4Index + + +class ThreadLocalRangeFetcher: + """Range reader that gives each worker thread independent file handles.""" + + def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"): + self.data_root = str(data_root).rstrip("/") + protocol = "hf" if self.data_root.startswith("hf://") else "file" + self.fs = fsspec.filesystem(protocol) + self.block_size = block_size + self.cache_type = cache_type + self._local = threading.local() + + def _url(self, relative_path: str) -> str: + if self.data_root.startswith("hf://"): + return f"{self.data_root}/{relative_path}" + return str(Path(self.data_root) / relative_path) + + def _handle(self, relative_path: str): + handles = getattr(self._local, "handles", None) + if handles is None: + handles = {} + self._local.handles = handles + handle = handles.get(relative_path) + if handle is None or getattr(handle, "closed", False): + handle = self.fs.open( + self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type + ) + handles[relative_path] = handle + return handle + + def info_size(self, relative_path: str) -> int: + return int(self.fs.info(self._url(relative_path))["size"]) + + def read_range(self, relative_path: str, offset: int, length: int) -> bytes: + handle = self._handle(relative_path) + handle.seek(offset) + return handle.read(length) + + def close(self) -> None: + handles = getattr(self._local, "handles", None) + if handles is None: + return + for handle in handles.values(): + with contextlib.suppress(Exception): + handle.close() + handles.clear() + + +class NativeHTTPRangeFetcher: + """Direct pooled HTTP range reader for hf:// paths.""" + + _GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {} + _GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {} + _GLOBAL_SIZES: dict[tuple[str, str], int] = {} + _GLOBAL_LOCK = threading.Lock() + + def __init__(self, data_root: str | Path, *, max_connections: int = 32, timeout: float = 60.0): + self.data_root = str(data_root).rstrip("/") + if not self.data_root.startswith("hf://"): + raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots") + self.api = HfApi() + self.fs: HfFileSystem | None = None + self._bucket_id: str | None = None + self._bucket_prefix = "" + if self.data_root.startswith("hf://buckets/"): + bucket_root = self.data_root.removeprefix("hf://buckets/") + parts = bucket_root.split("/", 2) + if len(parts) < 2: + raise ValueError(f"Invalid bucket root: {self.data_root}") + self._bucket_id = f"{parts[0]}/{parts[1]}" + self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else "" + else: + self.fs = HfFileSystem() + self.client = httpx.Client( + timeout=timeout, + limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections), + follow_redirects=False, + ) + self._resolved_urls: dict[str, str] = {} + self._source_urls: dict[str, str] = {} + self._sizes: dict[str, int] = {} + self._lock = threading.Lock() + + def _cache_key(self, relative_path: str) -> tuple[str, str]: + return self.data_root, relative_path + + def _path(self, relative_path: str) -> str: + return f"{self.data_root}/{relative_path}" + + def _bucket_path(self, relative_path: str) -> str: + if self._bucket_prefix: + return f"{self._bucket_prefix}/{relative_path}" + return relative_path + + def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]: + headers = self.api._build_hf_headers() + if urlparse(request_url).netloc != urlparse(source_url).netloc: + headers.pop("authorization", None) + headers.pop("Authorization", None) + return headers + + def _source_url(self, relative_path: str) -> str: + with self._lock: + source = self._source_urls.get(relative_path) + if source is not None: + return source + key = self._cache_key(relative_path) + with self._GLOBAL_LOCK: + source = self._GLOBAL_SOURCE_URLS.get(key) + if source is None: + if self._bucket_id is not None: + source = ( + f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/" + f"{quote(self._bucket_path(relative_path))}" + ) + else: + if self.fs is None: + raise RuntimeError("HfFileSystem fallback was not initialized") + source = self.fs.url(self._path(relative_path)) + with self._GLOBAL_LOCK: + self._GLOBAL_SOURCE_URLS[key] = source + with self._lock: + self._source_urls[relative_path] = source + return source + + def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str: + with self._lock: + if not refresh and relative_path in self._resolved_urls: + return self._resolved_urls[relative_path] + key = self._cache_key(relative_path) + if not refresh: + with self._GLOBAL_LOCK: + resolved = self._GLOBAL_RESOLVED_URLS.get(key) + size = self._GLOBAL_SIZES.get(key) + if resolved is not None: + with self._lock: + self._resolved_urls[relative_path] = resolved + if size is not None: + self._sizes[relative_path] = size + return resolved + + source = self._source_url(relative_path) + response = self.client.head(source, headers=self.api._build_hf_headers(), follow_redirects=False) + try: + hf_raise_for_status(response) + location = response.headers.get("Location") + resolved = urljoin(source, location) if location else source + with self._lock: + self._resolved_urls[relative_path] = resolved + if "Content-Length" in response.headers: + self._sizes[relative_path] = int(response.headers["Content-Length"]) + with self._GLOBAL_LOCK: + self._GLOBAL_RESOLVED_URLS[key] = resolved + if "Content-Length" in response.headers: + self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"]) + return resolved + finally: + response.close() + + def info_size(self, relative_path: str) -> int: + with self._lock: + size = self._sizes.get(relative_path) + if size is not None: + return size + key = self._cache_key(relative_path) + with self._GLOBAL_LOCK: + size = self._GLOBAL_SIZES.get(key) + if size is not None: + with self._lock: + self._sizes[relative_path] = size + return size + + resolved = self._resolve_url(relative_path) + source = self._source_url(relative_path) + response = self.client.head( + resolved, headers=self._headers_for(resolved, source), follow_redirects=True + ) + try: + hf_raise_for_status(response) + size = int(response.headers["Content-Length"]) + with self._lock: + self._sizes[relative_path] = size + with self._GLOBAL_LOCK: + self._GLOBAL_SIZES[key] = size + return size + finally: + response.close() + + def read_range(self, relative_path: str, offset: int, length: int) -> bytes: + resolved = self._resolve_url(relative_path) + source = self._source_url(relative_path) + headers = self._headers_for(resolved, source) + headers["Range"] = f"bytes={offset}-{offset + length - 1}" + response = self.client.get(resolved, headers=headers) + if response.status_code == 403: + response.close() + resolved = self._resolve_url(relative_path, refresh=True) + headers = self._headers_for(resolved, source) + headers["Range"] = f"bytes={offset}-{offset + length - 1}" + response = self.client.get(resolved, headers=headers) + try: + hf_raise_for_status(response) + return response.content + finally: + response.close() + + def close(self) -> None: + self.client.close() + + +def make_range_fetcher(data_root: str | Path, *, range_backend: str, workers: int): + if range_backend == "fsspec": + return ThreadLocalRangeFetcher(data_root) + if range_backend == "native-http": + return NativeHTTPRangeFetcher(data_root, max_connections=max(8, workers * 4)) + raise ValueError(f"Unknown range backend: {range_backend}") + + +class EpisodeVideoManifest: + _FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {} + _FILE_SIDECAR_CACHE_LOCK = threading.Lock() + + def __init__( + self, + *, + video_keys: list[str], + files: list[VideoFileRecord], + spans: dict[str, np.ndarray], + ): + self.video_keys = list(video_keys) + self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)} + self.files = files + self.spans = spans + + @classmethod + def build( + cls, + meta: LeRobotDatasetMetadata, + data_root: str | Path, + *, + episode_indices: list[int] | range | None = None, + range_backend: str = "fsspec", + workers: int = 8, + header_probe_bytes: int = 4 * 1024 * 1024, + max_probe_bytes: int = 64 * 1024 * 1024, + keyframe_pad_s: float = 0.1, + keyframe_pad_fraction: float = 0.05, + sidecar_path: str | Path | None = None, + ) -> EpisodeVideoManifest: + meta.ensure_readable() + video_keys = list(meta.video_keys) + if episode_indices is None: + episode_indices = range(int(meta.total_episodes)) + rel_paths = sorted( + {str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys} + ) + path_to_id = {path: idx for idx, path in enumerate(rel_paths)} + if sidecar_path is None: + files = cls._build_file_records( + rel_paths, + data_root, + range_backend=range_backend, + workers=workers, + header_probe_bytes=header_probe_bytes, + max_probe_bytes=max_probe_bytes, + ) + else: + records = cls.load_file_sidecar(sidecar_path) + missing = [path for path in rel_paths if path not in records] + if missing: + raise ValueError( + f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}" + ) + files = [records[path] for path in rel_paths] + + total = int(meta.total_episodes) + num_cameras = len(video_keys) + spans: dict[str, np.ndarray] = { + "file_id": np.zeros((total, num_cameras), dtype=np.int32), + "mdat_offset": np.zeros((total, num_cameras), dtype=np.int64), + "mdat_length": np.zeros((total, num_cameras), dtype=np.int64), + "first_pts": np.zeros((total, num_cameras), dtype=np.float64), + "last_pts": np.zeros((total, num_cameras), dtype=np.float64), + "frame_count": np.zeros((total, num_cameras), dtype=np.int32), + "sample_lo": np.zeros((total, num_cameras), dtype=np.int32), + "sample_hi": np.zeros((total, num_cameras), dtype=np.int32), + "source_start_pts": np.zeros((total, num_cameras), dtype=np.float64), + } + + for ep_idx in episode_indices: + ep = meta.episodes[ep_idx] + for cam_idx, key in enumerate(video_keys): + rel_path = str(meta.get_video_file_path(ep_idx, key)) + file_id = path_to_id[rel_path] + mp4 = files[file_id].mp4 + from_ts = float(ep[f"videos/{key}/from_timestamp"]) + to_ts = float(ep[f"videos/{key}/to_timestamp"]) + sample_slice = mp4.sample_slice( + from_ts, + to_ts, + keyframe_pad_s=keyframe_pad_s, + keyframe_pad_fraction=keyframe_pad_fraction, + file_size=files[file_id].file_size, + ) + spans["file_id"][ep_idx, cam_idx] = file_id + spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset + spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length + spans["first_pts"][ep_idx, cam_idx] = from_ts + spans["last_pts"][ep_idx, cam_idx] = to_ts + spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1 + spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo + spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi + spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts + + return cls(video_keys=video_keys, files=files, spans=spans) + + @staticmethod + def _build_file_records( + rel_paths: list[str], + data_root: str | Path, + *, + range_backend: str, + workers: int, + header_probe_bytes: int, + max_probe_bytes: int, + ) -> list[VideoFileRecord]: + fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers) + + def build_file(path: str) -> VideoFileRecord: + file_size = fetcher.info_size(path) + mp4 = fetch_mp4_index( + path, + fetcher.read_range, + file_size=file_size, + header_probe_bytes=header_probe_bytes, + max_probe_bytes=max_probe_bytes, + ) + return VideoFileRecord(path, file_size, mp4) + + try: + with ThreadPoolExecutor(max_workers=workers) as pool: + return list(pool.map(build_file, rel_paths)) + finally: + fetcher.close() + + @classmethod + def write_file_sidecar( + cls, + sidecar_path: str | Path, + rel_paths: list[str], + data_root: str | Path, + *, + range_backend: str = "native-http", + workers: int = 8, + header_probe_bytes: int = 4 * 1024 * 1024, + max_probe_bytes: int = 64 * 1024 * 1024, + ) -> None: + records = cls._build_file_records( + sorted(set(rel_paths)), + data_root, + range_backend=range_backend, + workers=workers, + header_probe_bytes=header_probe_bytes, + max_probe_bytes=max_probe_bytes, + ) + cls.save_file_sidecar(sidecar_path, records) + + @staticmethod + def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None: + sidecar_path = Path(sidecar_path) + sidecar_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "version": 1, + "files": [ + {"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()} + for record in records + ], + } + arrays = {} + for file_idx, record in enumerate(records): + arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts + arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations + arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes + arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets + arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples + np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays) + + @staticmethod + def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]: + cache_key = str(Path(sidecar_path).expanduser()) + with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK: + cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key) + if cached is not None: + return cached + + with np.load(sidecar_path, allow_pickle=False) as data: + payload = json.loads(bytes(data["manifest_json"]).decode("utf-8")) + records = {} + for file_idx, item in enumerate(payload["files"]): + arrays = { + name: data[f"{file_idx}/{name}"] + for name in [ + "sample_pts", + "sample_durations", + "sample_sizes", + "sample_offsets", + "sync_samples", + ] + } + mp4 = Mp4Index.from_dict(item["mp4"], arrays) + records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4) + with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK: + EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records + return records + + def camera_id(self, camera_key: str) -> int: + return self._camera_to_id[camera_key] + + def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan: + cam = self.camera_id(camera_key) + return EpisodeVideoSpan( + file_id=int(self.spans["file_id"][episode_index, cam]), + mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]), + mdat_length=int(self.spans["mdat_length"][episode_index, cam]), + first_pts=float(self.spans["first_pts"][episode_index, cam]), + last_pts=float(self.spans["last_pts"][episode_index, cam]), + frame_count=int(self.spans["frame_count"][episode_index, cam]), + sample_lo=int(self.spans["sample_lo"][episode_index, cam]), + sample_hi=int(self.spans["sample_hi"][episode_index, cam]), + source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]), + ) + + def file_lookup(self, file_id: int) -> VideoFileRecord: + return self.files[file_id] + + def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index: + return self.files[self.lookup(episode_index, camera_key).file_id].mp4 + + def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice: + span = self.lookup(episode_index, camera_key) + return Mp4SampleSlice( + sample_lo=span.sample_lo, + sample_hi=span.sample_hi, + byte_offset=span.mdat_offset, + byte_length=span.mdat_length, + source_start_pts=span.source_start_pts, + ) + + +class EpisodeByteCache: + def __init__( + self, + manifest: EpisodeVideoManifest, + data_root: str | Path, + *, + byte_budget: int = 80 * 1024**3, + workers: int = 8, + range_backend: str = "fsspec", + open_decoders: bool = True, + ): + self.manifest = manifest + self.fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers) + self.byte_budget = byte_budget + self.open_decoders = open_decoders + self._pool = ThreadPoolExecutor(max_workers=workers) + self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict() + self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {} + self._bytes = 0 + self._lock = threading.Lock() + + def close(self) -> None: + self._pool.shutdown(wait=True) + with self._lock: + self._cache.clear() + self._futures.clear() + self._bytes = 0 + self.fetcher.close() + + def __enter__(self) -> EpisodeByteCache: + return self + + def __exit__(self, *_exc) -> None: + self.close() + + def submit_prefetch(self, episode_index: int) -> None: + for camera_key in self.manifest.video_keys: + self._submit(episode_index, camera_key) + + def ensure_ready(self, episode_index: int) -> None: + for camera_key in self.manifest.video_keys: + self.get_bytes(episode_index, camera_key) + + def get_bytes(self, episode_index: int, camera_key: str) -> bytes: + return self._get_entry(episode_index, camera_key)["bytes"] + + def get_decoder(self, episode_index: int, camera_key: str): + entry = self._get_entry(episode_index, camera_key) + decoder = entry.get("decoder") + if decoder is None: + decoder = open_video_decoder(io.BytesIO(entry["bytes"])) + entry["decoder"] = decoder + return decoder + + def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]): + span = self.manifest.lookup(episode_index, camera_key) + local_ts = [ts - span.source_start_pts for ts in timestamps] + decoder = self.get_decoder(episode_index, camera_key) + if hasattr(decoder, "get_frames_played_at"): + return decoder.get_frames_played_at(local_ts).data + metadata = decoder.metadata + fps = getattr(metadata, "average_fps", None) + if fps is None: + duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9) + fps = metadata.num_frames / duration + return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data + + def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]: + key = (episode_index, camera_key) + with self._lock: + if key in self._cache: + future: Future[dict[str, Any]] = Future() + future.set_result(self._cache[key]) + return future + future = self._futures.get(key) + if future is None: + future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key) + self._futures[key] = future + return future + + def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]: + key = (episode_index, camera_key) + with self._lock: + entry = self._cache.get(key) + if entry is not None: + self._cache.move_to_end(key) + return entry + future = self._submit(episode_index, camera_key) + entry = future.result() + with self._lock: + self._futures.pop(key, None) + existing = self._cache.get(key) + if existing is not None: + self._cache.move_to_end(key) + return existing + self._cache[key] = entry + self._bytes += len(entry["bytes"]) + self._evict_locked() + return entry + + def _evict_locked(self) -> None: + while self._bytes > self.byte_budget and self._cache: + _key, entry = self._cache.popitem(last=False) + self._bytes -= len(entry["bytes"]) + + def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]: + span = self.manifest.lookup(episode_index, camera_key) + file_record = self.manifest.file_lookup(span.file_id) + payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length) + if len(payload) != span.mdat_length: + raise OSError( + f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}" + ) + mp4_bytes = synthesize_mp4( + file_record.mp4, self.manifest.sample_slice(episode_index, camera_key), payload + ) + entry: dict[str, Any] = {"bytes": mp4_bytes, "decoder": None} + if self.open_decoders: + entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes)) + return entry + + +def open_video_decoder(file_like_or_bytesio, frame_mappings=None): + if frame_mappings is not None: + raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.") + from torchcodec.decoders import VideoDecoder + + return VideoDecoder(file_like_or_bytesio, seek_mode="approximate") + + +def assert_hf_hub_range_cache_branch() -> None: + """Fail unless huggingface_hub was installed from the required range-cache branch.""" + + try: + dist = metadata.distribution("huggingface_hub") + except metadata.PackageNotFoundError as exc: + raise AssertionError("huggingface_hub is not installed") from exc + + candidates = [] + direct_url = dist.read_text("direct_url.json") + if direct_url: + candidates.append(direct_url) + with contextlib.suppress(json.JSONDecodeError): + parsed = json.loads(direct_url) + candidates.append(str(parsed.get("url", ""))) + candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", ""))) + candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", ""))) + + text = "\n".join(candidates) + if "feat/hffs-cache-cdn-range-reads" not in text: + raise AssertionError( + "huggingface_hub must be installed from " + "git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads" + ) + + +@dataclass +class StageTimer: + fetch_ms: float = 0.0 + decode_ms: float = 0.0 + bytes_read: int = 0 + misses: int = 0 + + def record_fetch(self, start: float, byte_count: int) -> None: + self.fetch_ms += (time.perf_counter() - start) * 1000 + self.bytes_read += byte_count + self.misses += 1 diff --git a/src/lerobot/datasets/mp4.py b/src/lerobot/datasets/mp4.py new file mode 100644 index 000000000..5d16908d1 --- /dev/null +++ b/src/lerobot/datasets/mp4.py @@ -0,0 +1,666 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +from __future__ import annotations + +import struct +from collections.abc import Callable, Iterable +from dataclasses import dataclass + +import numpy as np + + +@dataclass(frozen=True) +class Box: + type: bytes + start: int + header_size: int + end: int + + @property + def payload_start(self) -> int: + return self.start + self.header_size + + @property + def size(self) -> int: + return self.end - self.start + + +@dataclass(frozen=True) +class Mp4SampleSlice: + sample_lo: int + sample_hi: int + byte_offset: int + byte_length: int + source_start_pts: float + + +@dataclass(frozen=True) +class Mp4Index: + file_path: str + file_size: int + ftyp: bytes + moov_offset: int + mdat_offset: int + mdat_payload_offset: int + mdat_payload_size: int + faststart: bool + codec: str + timescale: int + duration: int + track_id: int + width: int + height: int + stsd_body: bytes + sample_pts: np.ndarray + sample_durations: np.ndarray + sample_sizes: np.ndarray + sample_offsets: np.ndarray + sync_samples: np.ndarray + + def sample_slice( + self, + from_ts: float, + to_ts: float, + *, + keyframe_pad_s: float = 0.1, + keyframe_pad_fraction: float = 0.05, + file_size: int | None = None, + ) -> Mp4SampleSlice: + if to_ts < from_ts: + raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}") + if len(self.sample_pts) == 0: + raise ValueError(f"{self.file_path} contains no indexed samples") + + pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction) + lo_ts = max(0.0, from_ts - pad) + hi_ts = to_ts + pad + lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left")) + hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1 + lo = min(max(lo, 0), len(self.sample_pts) - 1) + hi = min(max(hi, lo), len(self.sample_pts) - 1) + + if len(self.sync_samples): + prev_sync = self.sync_samples[self.sync_samples <= lo] + if len(prev_sync): + lo = int(prev_sync[-1]) + else: + lo = int(self.sync_samples[0]) + if lo > hi: + hi = lo + + offsets = self.sample_offsets[lo : hi + 1] + sizes = self.sample_sizes[lo : hi + 1] + slice_lo = int(offsets.min()) + slice_hi = int((offsets + sizes).max()) + if file_size is not None: + slice_hi = min(slice_hi, int(file_size)) + return Mp4SampleSlice( + sample_lo=lo, + sample_hi=hi, + byte_offset=slice_lo, + byte_length=slice_hi - slice_lo, + source_start_pts=float(self.sample_pts[lo]), + ) + + def to_dict(self) -> dict: + return { + "file_path": self.file_path, + "file_size": self.file_size, + "ftyp": self.ftyp.hex(), + "moov_offset": self.moov_offset, + "mdat_offset": self.mdat_offset, + "mdat_payload_offset": self.mdat_payload_offset, + "mdat_payload_size": self.mdat_payload_size, + "faststart": self.faststart, + "codec": self.codec, + "timescale": self.timescale, + "duration": self.duration, + "track_id": self.track_id, + "width": self.width, + "height": self.height, + "stsd_body": self.stsd_body.hex(), + } + + @classmethod + def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index: + return cls( + file_path=data["file_path"], + file_size=int(data["file_size"]), + ftyp=bytes.fromhex(data["ftyp"]), + moov_offset=int(data["moov_offset"]), + mdat_offset=int(data["mdat_offset"]), + mdat_payload_offset=int(data["mdat_payload_offset"]), + mdat_payload_size=int(data["mdat_payload_size"]), + faststart=bool(data["faststart"]), + codec=data["codec"], + timescale=int(data["timescale"]), + duration=int(data["duration"]), + track_id=int(data["track_id"]), + width=int(data["width"]), + height=int(data["height"]), + stsd_body=bytes.fromhex(data["stsd_body"]), + sample_pts=arrays["sample_pts"], + sample_durations=arrays["sample_durations"], + sample_sizes=arrays["sample_sizes"], + sample_offsets=arrays["sample_offsets"], + sync_samples=arrays["sync_samples"], + ) + + +def fetch_mp4_index( + path: str, + read_range: Callable[[str, int, int], bytes], + *, + file_size: int, + header_probe_bytes: int = 4 * 1024 * 1024, + max_probe_bytes: int = 64 * 1024 * 1024, +) -> Mp4Index: + probe_size = min(header_probe_bytes, file_size) + while True: + data = read_range(path, 0, probe_size) + top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True)) + has_mdat = any(box.type == b"mdat" for box in top) + has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top) + if has_mdat and has_moov: + return parse_mp4_index(path, data, file_size=file_size) + if probe_size >= min(max_probe_bytes, file_size): + if has_mdat and not has_moov: + tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes) + if tail_index is not None: + return tail_index + missing = [] + if not has_mdat: + missing.append("mdat") + if not has_moov: + missing.append("moov") + raise ValueError( + f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}" + ) + probe_size = min(probe_size * 2, max_probe_bytes, file_size) + + +def _fetch_tail_moov_index( + path: str, + read_range: Callable[[str, int, int], bytes], + prefix: bytes, + top_boxes: list[Box], + file_size: int, + max_probe_bytes: int, +) -> Mp4Index | None: + mdat_box = _one(top_boxes, b"mdat") + if mdat_box is None or mdat_box.end >= file_size: + return None + tail_offset = mdat_box.end + tail_length = min(max_probe_bytes, file_size - tail_offset) + tail = read_range(path, tail_offset, tail_length) + tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True)) + moov_box = next( + (box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None + ) + if moov_box is None: + return None + ftyp_box = _one(top_boxes, b"ftyp", required=False) + ftyp = ( + prefix[ftyp_box.start : ftyp_box.end] + if ftyp_box is not None + else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41") + ) + moov_start = moov_box.payload_start - tail_offset + moov_end = moov_box.end - tail_offset + return _parse_mp4_index_from_layout( + path, + file_size=file_size, + ftyp=ftyp, + moov_offset=moov_box.start, + moov=tail[moov_start:moov_end], + mdat_box=mdat_box, + ) + + +def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index: + if file_size is None: + file_size = len(data) + top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True)) + ftyp_box = _one(top, b"ftyp", required=False) + moov_box = _one(top, b"moov") + mdat_box = _one(top, b"mdat") + if moov_box.end > len(data): + raise ValueError(f"{path}: moov box is truncated") + + moov = data[moov_box.payload_start : moov_box.end] + ftyp = ( + data[ftyp_box.start : ftyp_box.end] + if ftyp_box is not None + else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41") + ) + return _parse_mp4_index_from_layout( + path, + file_size=file_size, + ftyp=ftyp, + moov_offset=moov_box.start, + moov=moov, + mdat_box=mdat_box, + ) + + +def _parse_mp4_index_from_layout( + path: str, + *, + file_size: int, + ftyp: bytes, + moov_offset: int, + moov: bytes, + mdat_box: Box, +) -> Mp4Index: + mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"])) + trak_box, trak_payload = _find_video_trak(moov) + _ = trak_box + tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"])) + mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"])) + stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"]) + + stsd = _find_child(stbl, b"stsd") + stsd_body = stbl[stsd.payload_start : stsd.end] + codec = _parse_stsd_codec(stsd_body) + stts = _parse_stts(_payload(stbl, b"stts")) + sample_sizes = _parse_stsz(_payload(stbl, b"stsz")) + stsc = _parse_stsc(_payload(stbl, b"stsc")) + chunk_offsets = _parse_chunk_offsets(stbl) + sync_samples = _parse_stss(stbl, len(sample_sizes)) + + sample_durations = _expand_stts(stts, len(sample_sizes)) + sample_pts_units = np.empty(len(sample_durations), dtype=np.int64) + if len(sample_durations): + sample_pts_units[0] = 0 + if len(sample_durations) > 1: + sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64) + sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale) + sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes) + + return Mp4Index( + file_path=path, + file_size=file_size, + ftyp=ftyp, + moov_offset=moov_offset, + mdat_offset=mdat_box.start, + mdat_payload_offset=mdat_box.payload_start, + mdat_payload_size=mdat_box.end - mdat_box.payload_start + if mdat_box.end <= file_size + else file_size - mdat_box.payload_start, + faststart=moov_offset < mdat_box.start, + codec=codec, + timescale=mdhd_timescale, + duration=mdhd_duration or mvhd_duration, + track_id=tkhd["track_id"], + width=tkhd["width"], + height=tkhd["height"], + stsd_body=stsd_body, + sample_pts=sample_pts, + sample_durations=sample_durations, + sample_sizes=sample_sizes, + sample_offsets=sample_offsets, + sync_samples=sync_samples, + ) + + +def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes: + lo = sample_slice.sample_lo + hi = sample_slice.sample_hi + 1 + if lo < 0 or hi > len(index.sample_sizes) or lo >= hi: + raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}") + + offsets = index.sample_offsets[lo:hi] + sizes = index.sample_sizes[lo:hi] + rel_offsets = offsets - sample_slice.byte_offset + if int(rel_offsets.min()) != 0: + raise ValueError("Sample slice must start at the minimum referenced sample offset") + if int((rel_offsets + sizes).max()) > len(mdat_payload): + raise ValueError("Sample slice does not cover all referenced samples") + + durations = index.sample_durations[lo:hi] + sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1 + moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0) + header_size = len(index.ftyp) + len(moov) + moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8) + return index.ftyp + moov + _box(b"mdat", mdat_payload) + + +def iter_boxes( + data: bytes, + start: int, + end: int, + *, + absolute_base: int = 0, + allow_truncated: bool = False, +) -> Iterable[Box]: + pos = start + while pos + 8 <= end: + size = struct.unpack_from(">I", data, pos)[0] + typ = data[pos + 4 : pos + 8] + header_size = 8 + if size == 1: + if pos + 16 > end: + break + size = struct.unpack_from(">Q", data, pos + 8)[0] + header_size = 16 + elif size == 0: + size = end - pos + if size < header_size: + break + box_end = pos + size + if box_end > end and not allow_truncated: + break + yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end) + pos = box_end + + +def _find_video_trak(moov: bytes) -> tuple[Box, bytes]: + for trak in _children(moov, 0, len(moov)): + if trak.type != b"trak": + continue + payload = moov[trak.payload_start : trak.end] + hdlr = _find_descendant(payload, [b"mdia", b"hdlr"]) + if hdlr[8:12] == b"vide": + return trak, payload + raise ValueError("No video track found") + + +def _find_descendant(data: bytes, path: list[bytes]) -> bytes: + current = data + for typ in path: + box = _find_child(current, typ) + current = current[box.payload_start : box.end] + return current + + +def _find_child(data: bytes, typ: bytes) -> Box: + for box in _children(data, 0, len(data)): + if box.type == typ: + return box + raise ValueError(f"Missing MP4 box {typ.decode('latin1')}") + + +def _children(data: bytes, start: int, end: int) -> Iterable[Box]: + return iter_boxes(data, start, end, absolute_base=0) + + +def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None: + matches = [box for box in boxes if box.type == typ] + if not matches and required: + raise ValueError(f"Missing MP4 box {typ.decode('latin1')}") + return matches[0] if matches else None + + +def _payload(parent: bytes, typ: bytes) -> bytes: + box = _find_child(parent, typ) + return parent[box.payload_start : box.end] + + +def _parse_mvhd(payload: bytes) -> tuple[int, int]: + version = payload[0] + if version == 1: + return struct.unpack_from(">IQ", payload, 20) + return struct.unpack_from(">II", payload, 12) + + +def _parse_mdhd(payload: bytes) -> tuple[int, int]: + version = payload[0] + if version == 1: + return struct.unpack_from(">IQ", payload, 20) + return struct.unpack_from(">II", payload, 12) + + +def _parse_tkhd(payload: bytes) -> dict[str, int]: + version = payload[0] + if version == 1: + track_id = struct.unpack_from(">I", payload, 20)[0] + duration = struct.unpack_from(">Q", payload, 28)[0] + width, height = struct.unpack_from(">II", payload, 88) + else: + track_id = struct.unpack_from(">I", payload, 12)[0] + duration = struct.unpack_from(">I", payload, 20)[0] + width, height = struct.unpack_from(">II", payload, 76) + return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16} + + +def _parse_stsd_codec(stsd_body: bytes) -> str: + if len(stsd_body) < 16: + return "unknown" + return stsd_body[12:16].decode("latin1") + + +def _parse_stts(payload: bytes) -> list[tuple[int, int]]: + count = struct.unpack_from(">I", payload, 4)[0] + out = [] + offset = 8 + for _ in range(count): + out.append(struct.unpack_from(">II", payload, offset)) + offset += 8 + return out + + +def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray: + values = np.empty(sample_count, dtype=np.int64) + pos = 0 + for count, delta in entries: + values[pos : pos + count] = delta + pos += count + if pos != sample_count: + raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}") + return values + + +def _parse_stsz(payload: bytes) -> np.ndarray: + sample_size, sample_count = struct.unpack_from(">II", payload, 4) + if sample_size: + return np.full(sample_count, sample_size, dtype=np.int64) + offset = 12 + values = np.empty(sample_count, dtype=np.int64) + for idx in range(sample_count): + values[idx] = struct.unpack_from(">I", payload, offset)[0] + offset += 4 + return values + + +def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]: + count = struct.unpack_from(">I", payload, 4)[0] + out = [] + offset = 8 + for _ in range(count): + out.append(struct.unpack_from(">III", payload, offset)) + offset += 12 + return out + + +def _parse_chunk_offsets(stbl: bytes) -> np.ndarray: + with_stco = None + with_co64 = None + for box in _children(stbl, 0, len(stbl)): + if box.type == b"stco": + with_stco = stbl[box.payload_start : box.end] + elif box.type == b"co64": + with_co64 = stbl[box.payload_start : box.end] + if with_co64 is not None: + count = struct.unpack_from(">I", with_co64, 4)[0] + return np.array( + [struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64 + ) + if with_stco is None: + raise ValueError("Missing stco/co64 chunk offsets") + count = struct.unpack_from(">I", with_stco, 4)[0] + return np.array( + [struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64 + ) + + +def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray: + for box in _children(stbl, 0, len(stbl)): + if box.type == b"stss": + payload = stbl[box.payload_start : box.end] + count = struct.unpack_from(">I", payload, 4)[0] + return np.array( + [struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)], + dtype=np.int64, + ) + return np.arange(sample_count, dtype=np.int64) + + +def _sample_offsets( + stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray +) -> np.ndarray: + if not stsc: + raise ValueError("stsc is empty") + offsets = np.empty(len(sample_sizes), dtype=np.int64) + sample_idx = 0 + for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc): + next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1 + for chunk_number in range(first_chunk, next_first): + if chunk_number < 1 or chunk_number > len(chunk_offsets): + raise ValueError("stsc references a chunk outside stco/co64") + chunk_pos = int(chunk_offsets[chunk_number - 1]) + for _ in range(samples_per_chunk): + if sample_idx >= len(sample_sizes): + return offsets + offsets[sample_idx] = chunk_pos + chunk_pos += int(sample_sizes[sample_idx]) + sample_idx += 1 + if sample_idx != len(sample_sizes): + raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}") + return offsets + + +def _make_moov( + index: Mp4Index, + durations: np.ndarray, + sizes: np.ndarray, + rel_offsets: np.ndarray, + sync_samples: np.ndarray, + *, + mdat_data_offset: int, +) -> bytes: + duration = int(durations.sum()) + stco_values = [int(mdat_data_offset + value) for value in rel_offsets] + if any(value > 0xFFFFFFFF for value in stco_values): + offset_box = _co64(stco_values) + else: + offset_box = _stco(stco_values) + stbl = _box( + b"stbl", + _box(b"stsd", index.stsd_body) + + _stts(durations) + + _stsc_one_sample_per_chunk(len(sizes)) + + _stsz(sizes) + + offset_box + + (_stss(sync_samples) if len(sync_samples) else b""), + ) + minf = _box(b"minf", _vmhd() + _dinf() + stbl) + mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf) + trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia) + return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak) + + +def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes: + return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload) + + +def _box(typ: bytes, payload: bytes) -> bytes: + size = len(payload) + 8 + if size <= 0xFFFFFFFF: + return struct.pack(">I4s", size, typ) + payload + return struct.pack(">I4sQ", 1, typ, size + 8) + payload + + +def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes: + matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000) + payload = ( + struct.pack(">IIII", 0, 0, timescale, duration) + + struct.pack(">IHH", 0x00010000, 0x0100, 0) + + b"\0" * 8 + + matrix + + b"\0" * 24 + + struct.pack(">I", next_track_id) + ) + return _full_box(b"mvhd", 0, 0, payload) + + +def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes: + matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000) + payload = ( + struct.pack(">IIIII", 0, 0, track_id, 0, duration) + + b"\0" * 8 + + struct.pack(">hhhh", 0, 0, 0, 0) + + matrix + + struct.pack(">II", width << 16, height << 16) + ) + return _full_box(b"tkhd", 0, 7, payload) + + +def _mdhd(timescale: int, duration: int) -> bytes: + return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0") + + +def _hdlr() -> bytes: + return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0") + + +def _vmhd() -> bytes: + return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0)) + + +def _dinf() -> bytes: + url = _full_box(b"url ", 0, 1) + dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url) + return _box(b"dinf", dref) + + +def _stts(durations: np.ndarray) -> bytes: + runs = [] + for duration in durations.tolist(): + if runs and runs[-1][1] == int(duration): + runs[-1][0] += 1 + else: + runs.append([1, int(duration)]) + payload = struct.pack(">I", len(runs)) + b"".join( + struct.pack(">II", count, delta) for count, delta in runs + ) + return _full_box(b"stts", 0, 0, payload) + + +def _stsc_one_sample_per_chunk(sample_count: int) -> bytes: + return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1)) + + +def _stsz(sizes: np.ndarray) -> bytes: + return _full_box( + b"stsz", + 0, + 0, + struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()), + ) + + +def _stco(values: list[int]) -> bytes: + return _full_box( + b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values) + ) + + +def _co64(values: list[int]) -> bytes: + return _full_box( + b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values) + ) + + +def _stss(values: np.ndarray) -> bytes: + return _full_box( + b"stss", + 0, + 0, + struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()), + ) diff --git a/tests/datasets/test_episode_video_streaming.py b/tests/datasets/test_episode_video_streaming.py new file mode 100644 index 000000000..226b04c9b --- /dev/null +++ b/tests/datasets/test_episode_video_streaming.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +import json +import struct + +import numpy as np +import pytest + +from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch +from lerobot.datasets.mp4 import ( + _box, + _co64, + _dinf, + _hdlr, + _mdhd, + _mvhd, + _stco, + _stsc_one_sample_per_chunk, + _stss, + _stsz, + _stts, + _tkhd, + _vmhd, + parse_mp4_index, + synthesize_mp4, +) + + +def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes: + ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41") + sizes = np.array([10, 10, 10], dtype=np.int64) + durations = np.array([1000, 1000, 1000], dtype=np.int64) + stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8 + offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets) + stbl = _box( + b"stbl", + _box(b"stsd", stsd_body) + + _stts(durations) + + _stsc_one_sample_per_chunk(len(sizes)) + + _stsz(sizes) + + offsets + + _stss(np.array([1], dtype=np.int64)), + ) + minf = _box(b"minf", _vmhd() + _dinf() + stbl) + mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf) + trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia) + moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak) + mdat_payload_start = 10_000 + free_size = mdat_payload_start - 8 - len(ftyp) - len(moov) + assert free_size >= 8 + free = _box(b"free", b"\0" * (free_size - 8)) + return ftyp + moov + free + _box(b"mdat", b"x" * 128) + + +def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks(): + mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025])) + + sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0) + + assert sample_slice.byte_offset == 10_000 + assert sample_slice.byte_length == 60 + assert sample_slice.sample_lo == 0 + assert sample_slice.sample_hi == 2 + + +def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets(): + mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025])) + sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0) + + mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length) + mini_index = parse_mp4_index("mini.mp4", mini) + + expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset + np.testing.assert_array_equal(mini_index.sample_offsets, expected) + np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10])) + + +def test_parser_accepts_co64_chunk_offsets(): + mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True)) + + np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025])) + + +def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch): + class FakeDist: + def read_text(self, name): + assert name == "direct_url.json" + return json.dumps( + { + "url": "https://github.com/huggingface/huggingface_hub.git", + "vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"}, + } + ) + + monkeypatch.setattr( + "lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist() + ) + + assert_hf_hub_range_cache_branch() + + +def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch): + class FakeDist: + def read_text(self, name): + assert name == "direct_url.json" + return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"}) + + monkeypatch.setattr( + "lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist() + ) + + with pytest.raises(AssertionError): + assert_hf_hub_range_cache_branch()