mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Add episode video streaming byte cache
This commit is contained in:
@@ -421,6 +421,7 @@ exclude_dirs = [
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
[tool.typos]
|
||||
default.extend-words = { trak = "trak" }
|
||||
default.extend-ignore-re = [
|
||||
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
||||
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
|
||||
|
||||
@@ -0,0 +1,724 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import (
|
||||
EpisodeByteCache,
|
||||
EpisodeVideoManifest,
|
||||
NativeHTTPRangeFetcher,
|
||||
assert_hf_hub_range_cache_branch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=("both", "indexed", "remote-decoder", "native-http"),
|
||||
default="both",
|
||||
help=argparse.SUPPRESS,
|
||||
)
|
||||
parser.add_argument("--num-episodes", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--manifest-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit manifest construction to the first N episodes for local smoke tests.",
|
||||
)
|
||||
parser.add_argument("--pool-size", type=int, default=16)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
||||
parser.add_argument(
|
||||
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(total, requested)
|
||||
if pool_size > upper:
|
||||
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
||||
return rng.sample(range(upper), pool_size)
|
||||
|
||||
|
||||
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
span = manifest.lookup(ep, camera_key)
|
||||
lo = span.first_pts
|
||||
hi = max(span.last_pts, lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _timestamps_from_meta(
|
||||
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
|
||||
) -> dict[tuple[int, str], list[float]]:
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
row = meta.episodes[ep]
|
||||
for camera_key in meta.video_keys:
|
||||
lo = float(row[f"videos/{camera_key}/from_timestamp"])
|
||||
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
total = 0
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
total += manifest.lookup(ep, camera_key).mdat_length
|
||||
return total
|
||||
|
||||
|
||||
def _decode_all(
|
||||
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
||||
) -> float:
|
||||
start = time.perf_counter()
|
||||
items = list(timestamps.items())
|
||||
if decode_workers <= 1:
|
||||
for (ep, camera_key), ts in items:
|
||||
cache.get_frames(ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
|
||||
start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
cache.submit_prefetch(ep)
|
||||
for ep in episodes:
|
||||
cache.ensure_ready(ep)
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
|
||||
if elapsed_s <= 0:
|
||||
return float("inf")
|
||||
return len(episodes) * frames_per_episode / elapsed_s
|
||||
|
||||
|
||||
def _log(message: str) -> None:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def _root_join(data_root: str, relative_path: str) -> str:
|
||||
if data_root.startswith("hf://"):
|
||||
return f"{data_root.rstrip('/')}/{relative_path}"
|
||||
return str(Path(data_root) / relative_path)
|
||||
|
||||
|
||||
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
|
||||
local = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz"
|
||||
if _valid_sidecar(local):
|
||||
return local
|
||||
if local.exists():
|
||||
print(f"mp4_sidecar_invalid_local: {local}")
|
||||
local.unlink()
|
||||
full_local = SIDECAR_CACHE_DIR / "molmoact2-full.npz"
|
||||
if _valid_sidecar(full_local):
|
||||
return full_local
|
||||
remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz")
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
fs = fsspec.filesystem(protocol)
|
||||
if not fs.exists(remote):
|
||||
return None
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"downloading_mp4_sidecar: {remote} -> {local}")
|
||||
if data_root.startswith("hf://"):
|
||||
_download_sidecar_native_http(
|
||||
data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz", local
|
||||
)
|
||||
else:
|
||||
fs.get(remote, str(local))
|
||||
return local
|
||||
|
||||
|
||||
def _valid_sidecar(path: Path) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
with np.load(path, allow_pickle=False) as data:
|
||||
return "manifest_json" in data
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
|
||||
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
|
||||
tmp = local.with_suffix(local.suffix + ".tmp")
|
||||
try:
|
||||
size = fetcher.info_size(relative_path)
|
||||
chunk_size = 16 * 1024 * 1024
|
||||
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
|
||||
with tmp.open("wb") as out_file:
|
||||
out_file.truncate(size)
|
||||
|
||||
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
|
||||
offset, length = offset_length
|
||||
return offset, fetcher.read_range(relative_path, offset, length)
|
||||
|
||||
start = time.perf_counter()
|
||||
done = 0
|
||||
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||
futures = [pool.submit(read_chunk, item) for item in ranges]
|
||||
with tmp.open("r+b") as rw_file:
|
||||
for future in futures:
|
||||
offset, data = future.result()
|
||||
rw_file.seek(offset)
|
||||
rw_file.write(data)
|
||||
done += len(data)
|
||||
elapsed = max(time.perf_counter() - start, 1e-9)
|
||||
print(
|
||||
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
|
||||
f"({done / elapsed / 1024**2:.1f} MiB/s)",
|
||||
flush=True,
|
||||
)
|
||||
tmp.replace(local)
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
|
||||
class EpisodeParquetReader:
|
||||
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
|
||||
self.meta = meta
|
||||
self.data_root = data_root
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self._episode_row_groups = self._build_episode_row_groups()
|
||||
self._table_cache: dict[str, pa.Table] = {}
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def read_episode(self, episode_index: int) -> None:
|
||||
relative_path = str(self.meta.get_data_file_path(episode_index))
|
||||
table = self._read_table(relative_path)
|
||||
table.filter(pc.equal(table["episode_index"], episode_index))
|
||||
|
||||
def _read_table(self, relative_path: str) -> pa.Table:
|
||||
with self._cache_lock:
|
||||
table = self._table_cache.get(relative_path)
|
||||
if table is not None:
|
||||
return table
|
||||
with self.fs.open(
|
||||
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
|
||||
) as f:
|
||||
table = pq.ParquetFile(f).read()
|
||||
with self._cache_lock:
|
||||
return self._table_cache.setdefault(relative_path, table)
|
||||
|
||||
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
|
||||
return pool.submit(self.read_episode, episode_index)
|
||||
|
||||
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
|
||||
start = time.perf_counter()
|
||||
if workers <= 1:
|
||||
for ep in episodes:
|
||||
self.read_episode(ep)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
def _build_episode_row_groups(self) -> dict[int, int]:
|
||||
counts: dict[tuple[int, int], int] = {}
|
||||
row_groups = {}
|
||||
for ep_idx in range(int(self.meta.total_episodes)):
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
|
||||
row_groups[ep_idx] = counts.get(key, 0)
|
||||
counts[key] = row_groups[ep_idx] + 1
|
||||
return row_groups
|
||||
|
||||
|
||||
def run_sequential(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
byte_budget: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=1,
|
||||
range_backend=range_backend,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
parquet_s = parquet_reader.read_episodes(episodes, workers=1)
|
||||
elapsed = _fill_cache(cache, episodes)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
return {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / elapsed,
|
||||
"episode_mb": episode_mb,
|
||||
"parquet_s": parquet_s,
|
||||
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
|
||||
}
|
||||
|
||||
|
||||
def run_parallel(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
|
||||
fetch_s = _fill_cache(cache, episodes)
|
||||
decoder_start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_decoder(ep, camera_key)
|
||||
decoder_s = time.perf_counter() - decoder_start
|
||||
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
return {
|
||||
"fetch_s": fetch_s,
|
||||
"fetch_mbps": byte_count / fetch_s / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / fetch_s,
|
||||
"parquet_s": parquet_s,
|
||||
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
|
||||
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def run_overlapped(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
prefetch_ahead: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=True,
|
||||
) as cache:
|
||||
start = time.perf_counter()
|
||||
video_wait_decode_s = 0.0
|
||||
parquet_wait_s = 0.0
|
||||
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
|
||||
parquet_futures = {
|
||||
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
|
||||
}
|
||||
for ep in episodes[:prefetch_ahead]:
|
||||
cache.submit_prefetch(ep)
|
||||
try:
|
||||
for idx, ep in enumerate(episodes):
|
||||
next_idx = idx + prefetch_ahead
|
||||
if next_idx < len(episodes):
|
||||
next_ep = episodes[next_idx]
|
||||
cache.submit_prefetch(next_ep)
|
||||
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
|
||||
|
||||
parquet_start = time.perf_counter()
|
||||
parquet_futures.pop(ep).result()
|
||||
parquet_wait_s += time.perf_counter() - parquet_start
|
||||
|
||||
video_start = time.perf_counter()
|
||||
cache.ensure_ready(ep)
|
||||
if decode_workers <= 1:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
|
||||
for camera_key in manifest.video_keys
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
video_wait_decode_s += time.perf_counter() - video_start
|
||||
finally:
|
||||
parquet_pool.shutdown(wait=True)
|
||||
elapsed = time.perf_counter() - start
|
||||
return {
|
||||
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
|
||||
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
|
||||
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
|
||||
"wall_s": elapsed,
|
||||
"video_wait_decode_s": video_wait_decode_s,
|
||||
"parquet_wait_s": parquet_wait_s,
|
||||
}
|
||||
|
||||
|
||||
_remote_decoder_local = threading.local()
|
||||
|
||||
|
||||
def _remote_decoder_cache() -> VideoDecoderCache:
|
||||
cache = getattr(_remote_decoder_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = VideoDecoderCache(max_size=None)
|
||||
_remote_decoder_local.cache = cache
|
||||
return cache
|
||||
|
||||
|
||||
def _decode_remote_source(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
timestamps: list[float],
|
||||
):
|
||||
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
|
||||
return decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
timestamps,
|
||||
tolerance_s=1.0 / float(meta.fps),
|
||||
decoder_cache=_remote_decoder_cache(),
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
|
||||
def run_remote_decoder(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
*,
|
||||
frames_per_episode: int,
|
||||
decode_workers: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> dict[str, float]:
|
||||
items = [
|
||||
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
sequential_s = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
if decode_workers <= 1:
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
|
||||
futures = [
|
||||
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
|
||||
for ep, camera_key, ts in items
|
||||
]
|
||||
for future in parquet_futures:
|
||||
future.result()
|
||||
for future in futures:
|
||||
future.result()
|
||||
parallel_s = time.perf_counter() - start
|
||||
|
||||
return {
|
||||
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
|
||||
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def run_indexed_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
*,
|
||||
range_backend: str = "fsspec",
|
||||
label: str = "indexed",
|
||||
sidecar_path: str | None = None,
|
||||
) -> None:
|
||||
_log(f"starting_strategy: {label}")
|
||||
manifest_start = time.perf_counter()
|
||||
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||
manifest = EpisodeVideoManifest.build(
|
||||
meta,
|
||||
data_root,
|
||||
episode_indices=range(manifest_episode_count),
|
||||
range_backend=range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
sidecar_path=sidecar_path,
|
||||
)
|
||||
manifest_s = time.perf_counter() - manifest_start
|
||||
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
||||
byte_budget = int(args.byte_budget_gb * 1024**3)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
_log(
|
||||
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
|
||||
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
|
||||
)
|
||||
|
||||
_log(f"{label}: running sequential video fetch")
|
||||
sequential = run_sequential(manifest, data_root, episodes, byte_budget, parquet_reader, range_backend)
|
||||
_log(f"{label}: running parallel video fetch + decode-only")
|
||||
parallel = run_parallel(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
_log(f"{label}: running overlapped end-to-end")
|
||||
overlapped = run_overlapped(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
args.prefetch_ahead,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
|
||||
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||
print(f"strategy: {label}")
|
||||
print(f"range_backend: {range_backend}")
|
||||
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
print()
|
||||
print("| Track | fetch MB/s | fetch eps/s | samples/s | avg MB/miss | notes |")
|
||||
print("|---|---:|---:|---:|---:|---|")
|
||||
print(
|
||||
f"| SEQUENTIAL | {sequential['fetch_mbps']:.1f} | {sequential['fetch_episodes_s']:.2f} | - | "
|
||||
f"{sequential['avg_mb_miss']:.1f} | 1 worker video fetch, parquet {sequential['parquet_s']:.2f}s |"
|
||||
)
|
||||
print(
|
||||
f"| PARALLEL | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
|
||||
f"{parallel['decode_samples_s']:.1f} | "
|
||||
f"{sequential['avg_mb_miss']:.1f} | decode-only, decoder open "
|
||||
f"{parallel['decoder_ms_miss']:.1f} ms/miss, parquet {parallel['parquet_s']:.2f}s |"
|
||||
)
|
||||
print(
|
||||
f"| OVERLAPPED | - | - | {overlapped['samples_s']:.1f} | {sequential['avg_mb_miss']:.1f} | "
|
||||
f"end-to-end; video {overlapped['video_samples_s']:.1f} samples/s "
|
||||
f"({overlapped['video_wait_decode_s']:.2f}s), parquet {overlapped['parquet_samples_s']:.1f} "
|
||||
f"samples/s ({overlapped['parquet_wait_s']:.2f}s) |"
|
||||
)
|
||||
|
||||
|
||||
def run_remote_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> None:
|
||||
_log("starting_strategy: remote-decoder")
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log("remote-decoder: running direct source MP4 decoder")
|
||||
result = run_remote_decoder(
|
||||
meta,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
frames_per_episode=args.frames_per_episode,
|
||||
decode_workers=args.decode_workers,
|
||||
parquet_reader=parquet_reader,
|
||||
)
|
||||
print("strategy: remote-decoder")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"episodes: {episodes}")
|
||||
print(f"cameras: {list(meta.video_keys)}")
|
||||
print()
|
||||
print("| Track | samples/s | notes |")
|
||||
print("|---|---:|---|")
|
||||
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
|
||||
print(
|
||||
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
|
||||
f"direct source MP4 decoder, {args.decode_workers} workers |"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
data_root = args.data_root
|
||||
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
parquet_reader = EpisodeParquetReader(meta, data_root)
|
||||
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
|
||||
|
||||
if sidecar_path is not None:
|
||||
print(f"using_mp4_sidecar: {sidecar_path}")
|
||||
|
||||
if sidecar_path is not None and args.strategy == "both":
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
print()
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
print()
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "indexed":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "native-http":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy == "both":
|
||||
expected_sidecar = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz"
|
||||
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz")
|
||||
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||
print(f"mp4_sidecar_missing_remote: {expected_remote}")
|
||||
print(
|
||||
"build_mp4_sidecar: "
|
||||
f"uv run --no-sync python scripts/build_mp4_sidecar.py --episodes {manifest_episode_count} "
|
||||
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||
)
|
||||
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||
print()
|
||||
|
||||
if args.strategy in ("both", "indexed"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed",
|
||||
sidecar_path=None,
|
||||
)
|
||||
if args.strategy == "both":
|
||||
print()
|
||||
if args.strategy in ("both", "remote-decoder"):
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
if args.strategy == "both":
|
||||
print()
|
||||
if args.strategy in ("both", "native-http"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http",
|
||||
sidecar_path=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--episodes", type=int, default=None)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument(
|
||||
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def push_sidecar(local_path: str, data_root: str, episode_count: int) -> list[str]:
|
||||
if not data_root.startswith("hf://"):
|
||||
return []
|
||||
|
||||
local = Path(local_path)
|
||||
fs = fsspec.filesystem("hf")
|
||||
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
|
||||
remote_paths = [f"{remote_dir}/{local.name}"]
|
||||
alias = f"{remote_dir}/molmoact2-{episode_count}.npz"
|
||||
if alias not in remote_paths:
|
||||
remote_paths.append(alias)
|
||||
|
||||
for remote in remote_paths:
|
||||
fs.put(str(local), remote)
|
||||
return remote_paths
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
total = (
|
||||
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
|
||||
)
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
EpisodeVideoManifest.write_file_sidecar(
|
||||
args.output,
|
||||
rel_paths,
|
||||
args.data_root,
|
||||
range_backend=args.range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f"wrote {args.output}")
|
||||
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
|
||||
if args.no_push:
|
||||
print("push_skipped: --no-push")
|
||||
else:
|
||||
pushed = push_sidecar(args.output, args.data_root, total)
|
||||
for remote in pushed:
|
||||
print(f"pushed {remote}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,668 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import quote, urljoin, urlparse
|
||||
|
||||
import fsspec
|
||||
import httpx
|
||||
import numpy as np
|
||||
from huggingface_hub import HfApi, HfFileSystem, constants
|
||||
from huggingface_hub.utils import hf_raise_for_status
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EpisodeVideoSpan:
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
first_pts: float
|
||||
last_pts: float
|
||||
frame_count: int
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoFileRecord:
|
||||
file_path: str
|
||||
file_size: int
|
||||
mp4: Mp4Index
|
||||
|
||||
|
||||
class ThreadLocalRangeFetcher:
|
||||
"""Range reader that gives each worker thread independent file handles."""
|
||||
|
||||
def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
protocol = "hf" if self.data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self.block_size = block_size
|
||||
self.cache_type = cache_type
|
||||
self._local = threading.local()
|
||||
|
||||
def _url(self, relative_path: str) -> str:
|
||||
if self.data_root.startswith("hf://"):
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
return str(Path(self.data_root) / relative_path)
|
||||
|
||||
def _handle(self, relative_path: str):
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
handles = {}
|
||||
self._local.handles = handles
|
||||
handle = handles.get(relative_path)
|
||||
if handle is None or getattr(handle, "closed", False):
|
||||
handle = self.fs.open(
|
||||
self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type
|
||||
)
|
||||
handles[relative_path] = handle
|
||||
return handle
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
return int(self.fs.info(self._url(relative_path))["size"])
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
handle = self._handle(relative_path)
|
||||
handle.seek(offset)
|
||||
return handle.read(length)
|
||||
|
||||
def close(self) -> None:
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
return
|
||||
for handle in handles.values():
|
||||
with contextlib.suppress(Exception):
|
||||
handle.close()
|
||||
handles.clear()
|
||||
|
||||
|
||||
class NativeHTTPRangeFetcher:
|
||||
"""Direct pooled HTTP range reader for hf:// paths."""
|
||||
|
||||
_GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
|
||||
_GLOBAL_LOCK = threading.Lock()
|
||||
|
||||
def __init__(self, data_root: str | Path, *, max_connections: int = 32, timeout: float = 60.0):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
if not self.data_root.startswith("hf://"):
|
||||
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
|
||||
self.api = HfApi()
|
||||
self.fs: HfFileSystem | None = None
|
||||
self._bucket_id: str | None = None
|
||||
self._bucket_prefix = ""
|
||||
if self.data_root.startswith("hf://buckets/"):
|
||||
bucket_root = self.data_root.removeprefix("hf://buckets/")
|
||||
parts = bucket_root.split("/", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid bucket root: {self.data_root}")
|
||||
self._bucket_id = f"{parts[0]}/{parts[1]}"
|
||||
self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else ""
|
||||
else:
|
||||
self.fs = HfFileSystem()
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections),
|
||||
follow_redirects=False,
|
||||
)
|
||||
self._resolved_urls: dict[str, str] = {}
|
||||
self._source_urls: dict[str, str] = {}
|
||||
self._sizes: dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _cache_key(self, relative_path: str) -> tuple[str, str]:
|
||||
return self.data_root, relative_path
|
||||
|
||||
def _path(self, relative_path: str) -> str:
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
|
||||
def _bucket_path(self, relative_path: str) -> str:
|
||||
if self._bucket_prefix:
|
||||
return f"{self._bucket_prefix}/{relative_path}"
|
||||
return relative_path
|
||||
|
||||
def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]:
|
||||
headers = self.api._build_hf_headers()
|
||||
if urlparse(request_url).netloc != urlparse(source_url).netloc:
|
||||
headers.pop("authorization", None)
|
||||
headers.pop("Authorization", None)
|
||||
return headers
|
||||
|
||||
def _source_url(self, relative_path: str) -> str:
|
||||
with self._lock:
|
||||
source = self._source_urls.get(relative_path)
|
||||
if source is not None:
|
||||
return source
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
source = self._GLOBAL_SOURCE_URLS.get(key)
|
||||
if source is None:
|
||||
if self._bucket_id is not None:
|
||||
source = (
|
||||
f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/"
|
||||
f"{quote(self._bucket_path(relative_path))}"
|
||||
)
|
||||
else:
|
||||
if self.fs is None:
|
||||
raise RuntimeError("HfFileSystem fallback was not initialized")
|
||||
source = self.fs.url(self._path(relative_path))
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SOURCE_URLS[key] = source
|
||||
with self._lock:
|
||||
self._source_urls[relative_path] = source
|
||||
return source
|
||||
|
||||
def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str:
|
||||
with self._lock:
|
||||
if not refresh and relative_path in self._resolved_urls:
|
||||
return self._resolved_urls[relative_path]
|
||||
key = self._cache_key(relative_path)
|
||||
if not refresh:
|
||||
with self._GLOBAL_LOCK:
|
||||
resolved = self._GLOBAL_RESOLVED_URLS.get(key)
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if resolved is not None:
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if size is not None:
|
||||
self._sizes[relative_path] = size
|
||||
return resolved
|
||||
|
||||
source = self._source_url(relative_path)
|
||||
response = self.client.head(source, headers=self.api._build_hf_headers(), follow_redirects=False)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
location = response.headers.get("Location")
|
||||
resolved = urljoin(source, location) if location else source
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._sizes[relative_path] = int(response.headers["Content-Length"])
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_RESOLVED_URLS[key] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"])
|
||||
return resolved
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
with self._lock:
|
||||
size = self._sizes.get(relative_path)
|
||||
if size is not None:
|
||||
return size
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if size is not None:
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
return size
|
||||
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
response = self.client.head(
|
||||
resolved, headers=self._headers_for(resolved, source), follow_redirects=True
|
||||
)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
size = int(response.headers["Content-Length"])
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SIZES[key] = size
|
||||
return size
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
response = self.client.get(resolved, headers=headers)
|
||||
if response.status_code == 403:
|
||||
response.close()
|
||||
resolved = self._resolve_url(relative_path, refresh=True)
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
response = self.client.get(resolved, headers=headers)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
return response.content
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
|
||||
def make_range_fetcher(data_root: str | Path, *, range_backend: str, workers: int):
|
||||
if range_backend == "fsspec":
|
||||
return ThreadLocalRangeFetcher(data_root)
|
||||
if range_backend == "native-http":
|
||||
return NativeHTTPRangeFetcher(data_root, max_connections=max(8, workers * 4))
|
||||
raise ValueError(f"Unknown range backend: {range_backend}")
|
||||
|
||||
|
||||
class EpisodeVideoManifest:
|
||||
_FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {}
|
||||
_FILE_SIDECAR_CACHE_LOCK = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
video_keys: list[str],
|
||||
files: list[VideoFileRecord],
|
||||
spans: dict[str, np.ndarray],
|
||||
):
|
||||
self.video_keys = list(video_keys)
|
||||
self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)}
|
||||
self.files = files
|
||||
self.spans = spans
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
episode_indices: list[int] | range | None = None,
|
||||
range_backend: str = "fsspec",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
sidecar_path: str | Path | None = None,
|
||||
) -> EpisodeVideoManifest:
|
||||
meta.ensure_readable()
|
||||
video_keys = list(meta.video_keys)
|
||||
if episode_indices is None:
|
||||
episode_indices = range(int(meta.total_episodes))
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys}
|
||||
)
|
||||
path_to_id = {path: idx for idx, path in enumerate(rel_paths)}
|
||||
if sidecar_path is None:
|
||||
files = cls._build_file_records(
|
||||
rel_paths,
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
else:
|
||||
records = cls.load_file_sidecar(sidecar_path)
|
||||
missing = [path for path in rel_paths if path not in records]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}"
|
||||
)
|
||||
files = [records[path] for path in rel_paths]
|
||||
|
||||
total = int(meta.total_episodes)
|
||||
num_cameras = len(video_keys)
|
||||
spans: dict[str, np.ndarray] = {
|
||||
"file_id": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"mdat_offset": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"mdat_length": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"first_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"last_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"frame_count": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_lo": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_hi": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"source_start_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
}
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep = meta.episodes[ep_idx]
|
||||
for cam_idx, key in enumerate(video_keys):
|
||||
rel_path = str(meta.get_video_file_path(ep_idx, key))
|
||||
file_id = path_to_id[rel_path]
|
||||
mp4 = files[file_id].mp4
|
||||
from_ts = float(ep[f"videos/{key}/from_timestamp"])
|
||||
to_ts = float(ep[f"videos/{key}/to_timestamp"])
|
||||
sample_slice = mp4.sample_slice(
|
||||
from_ts,
|
||||
to_ts,
|
||||
keyframe_pad_s=keyframe_pad_s,
|
||||
keyframe_pad_fraction=keyframe_pad_fraction,
|
||||
file_size=files[file_id].file_size,
|
||||
)
|
||||
spans["file_id"][ep_idx, cam_idx] = file_id
|
||||
spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset
|
||||
spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length
|
||||
spans["first_pts"][ep_idx, cam_idx] = from_ts
|
||||
spans["last_pts"][ep_idx, cam_idx] = to_ts
|
||||
spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1
|
||||
spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo
|
||||
spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi
|
||||
spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts
|
||||
|
||||
return cls(video_keys=video_keys, files=files, spans=spans)
|
||||
|
||||
@staticmethod
|
||||
def _build_file_records(
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
header_probe_bytes: int,
|
||||
max_probe_bytes: int,
|
||||
) -> list[VideoFileRecord]:
|
||||
fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
|
||||
|
||||
def build_file(path: str) -> VideoFileRecord:
|
||||
file_size = fetcher.info_size(path)
|
||||
mp4 = fetch_mp4_index(
|
||||
path,
|
||||
fetcher.read_range,
|
||||
file_size=file_size,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
return VideoFileRecord(path, file_size, mp4)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
return list(pool.map(build_file, rel_paths))
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
@classmethod
|
||||
def write_file_sidecar(
|
||||
cls,
|
||||
sidecar_path: str | Path,
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str = "native-http",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> None:
|
||||
records = cls._build_file_records(
|
||||
sorted(set(rel_paths)),
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
cls.save_file_sidecar(sidecar_path, records)
|
||||
|
||||
@staticmethod
|
||||
def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None:
|
||||
sidecar_path = Path(sidecar_path)
|
||||
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"version": 1,
|
||||
"files": [
|
||||
{"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()}
|
||||
for record in records
|
||||
],
|
||||
}
|
||||
arrays = {}
|
||||
for file_idx, record in enumerate(records):
|
||||
arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts
|
||||
arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations
|
||||
arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes
|
||||
arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets
|
||||
arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples
|
||||
np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays)
|
||||
|
||||
@staticmethod
|
||||
def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]:
|
||||
cache_key = str(Path(sidecar_path).expanduser())
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
with np.load(sidecar_path, allow_pickle=False) as data:
|
||||
payload = json.loads(bytes(data["manifest_json"]).decode("utf-8"))
|
||||
records = {}
|
||||
for file_idx, item in enumerate(payload["files"]):
|
||||
arrays = {
|
||||
name: data[f"{file_idx}/{name}"]
|
||||
for name in [
|
||||
"sample_pts",
|
||||
"sample_durations",
|
||||
"sample_sizes",
|
||||
"sample_offsets",
|
||||
"sync_samples",
|
||||
]
|
||||
}
|
||||
mp4 = Mp4Index.from_dict(item["mp4"], arrays)
|
||||
records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4)
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records
|
||||
return records
|
||||
|
||||
def camera_id(self, camera_key: str) -> int:
|
||||
return self._camera_to_id[camera_key]
|
||||
|
||||
def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan:
|
||||
cam = self.camera_id(camera_key)
|
||||
return EpisodeVideoSpan(
|
||||
file_id=int(self.spans["file_id"][episode_index, cam]),
|
||||
mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]),
|
||||
mdat_length=int(self.spans["mdat_length"][episode_index, cam]),
|
||||
first_pts=float(self.spans["first_pts"][episode_index, cam]),
|
||||
last_pts=float(self.spans["last_pts"][episode_index, cam]),
|
||||
frame_count=int(self.spans["frame_count"][episode_index, cam]),
|
||||
sample_lo=int(self.spans["sample_lo"][episode_index, cam]),
|
||||
sample_hi=int(self.spans["sample_hi"][episode_index, cam]),
|
||||
source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]),
|
||||
)
|
||||
|
||||
def file_lookup(self, file_id: int) -> VideoFileRecord:
|
||||
return self.files[file_id]
|
||||
|
||||
def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index:
|
||||
return self.files[self.lookup(episode_index, camera_key).file_id].mp4
|
||||
|
||||
def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice:
|
||||
span = self.lookup(episode_index, camera_key)
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=span.sample_lo,
|
||||
sample_hi=span.sample_hi,
|
||||
byte_offset=span.mdat_offset,
|
||||
byte_length=span.mdat_length,
|
||||
source_start_pts=span.source_start_pts,
|
||||
)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
def __init__(
|
||||
self,
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
byte_budget: int = 80 * 1024**3,
|
||||
workers: int = 8,
|
||||
range_backend: str = "fsspec",
|
||||
open_decoders: bool = True,
|
||||
):
|
||||
self.manifest = manifest
|
||||
self.fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
|
||||
self.byte_budget = byte_budget
|
||||
self.open_decoders = open_decoders
|
||||
self._pool = ThreadPoolExecutor(max_workers=workers)
|
||||
self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict()
|
||||
self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {}
|
||||
self._bytes = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def close(self) -> None:
|
||||
self._pool.shutdown(wait=True)
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._futures.clear()
|
||||
self._bytes = 0
|
||||
self.fetcher.close()
|
||||
|
||||
def __enter__(self) -> EpisodeByteCache:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc) -> None:
|
||||
self.close()
|
||||
|
||||
def submit_prefetch(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self._submit(episode_index, camera_key)
|
||||
|
||||
def ensure_ready(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self.get_bytes(episode_index, camera_key)
|
||||
|
||||
def get_bytes(self, episode_index: int, camera_key: str) -> bytes:
|
||||
return self._get_entry(episode_index, camera_key)["bytes"]
|
||||
|
||||
def get_decoder(self, episode_index: int, camera_key: str):
|
||||
entry = self._get_entry(episode_index, camera_key)
|
||||
decoder = entry.get("decoder")
|
||||
if decoder is None:
|
||||
decoder = open_video_decoder(io.BytesIO(entry["bytes"]))
|
||||
entry["decoder"] = decoder
|
||||
return decoder
|
||||
|
||||
def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]):
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
local_ts = [ts - span.source_start_pts for ts in timestamps]
|
||||
decoder = self.get_decoder(episode_index, camera_key)
|
||||
if hasattr(decoder, "get_frames_played_at"):
|
||||
return decoder.get_frames_played_at(local_ts).data
|
||||
metadata = decoder.metadata
|
||||
fps = getattr(metadata, "average_fps", None)
|
||||
if fps is None:
|
||||
duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9)
|
||||
fps = metadata.num_frames / duration
|
||||
return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data
|
||||
|
||||
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
future: Future[dict[str, Any]] = Future()
|
||||
future.set_result(self._cache[key])
|
||||
return future
|
||||
future = self._futures.get(key)
|
||||
if future is None:
|
||||
future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key)
|
||||
self._futures[key] = future
|
||||
return future
|
||||
|
||||
def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
entry = self._cache.get(key)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return entry
|
||||
future = self._submit(episode_index, camera_key)
|
||||
entry = future.result()
|
||||
with self._lock:
|
||||
self._futures.pop(key, None)
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return existing
|
||||
self._cache[key] = entry
|
||||
self._bytes += len(entry["bytes"])
|
||||
self._evict_locked()
|
||||
return entry
|
||||
|
||||
def _evict_locked(self) -> None:
|
||||
while self._bytes > self.byte_budget and self._cache:
|
||||
_key, entry = self._cache.popitem(last=False)
|
||||
self._bytes -= len(entry["bytes"])
|
||||
|
||||
def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
file_record = self.manifest.file_lookup(span.file_id)
|
||||
payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length)
|
||||
if len(payload) != span.mdat_length:
|
||||
raise OSError(
|
||||
f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}"
|
||||
)
|
||||
mp4_bytes = synthesize_mp4(
|
||||
file_record.mp4, self.manifest.sample_slice(episode_index, camera_key), payload
|
||||
)
|
||||
entry: dict[str, Any] = {"bytes": mp4_bytes, "decoder": None}
|
||||
if self.open_decoders:
|
||||
entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes))
|
||||
return entry
|
||||
|
||||
|
||||
def open_video_decoder(file_like_or_bytesio, frame_mappings=None):
|
||||
if frame_mappings is not None:
|
||||
raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.")
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
return VideoDecoder(file_like_or_bytesio, seek_mode="approximate")
|
||||
|
||||
|
||||
def assert_hf_hub_range_cache_branch() -> None:
|
||||
"""Fail unless huggingface_hub was installed from the required range-cache branch."""
|
||||
|
||||
try:
|
||||
dist = metadata.distribution("huggingface_hub")
|
||||
except metadata.PackageNotFoundError as exc:
|
||||
raise AssertionError("huggingface_hub is not installed") from exc
|
||||
|
||||
candidates = []
|
||||
direct_url = dist.read_text("direct_url.json")
|
||||
if direct_url:
|
||||
candidates.append(direct_url)
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
parsed = json.loads(direct_url)
|
||||
candidates.append(str(parsed.get("url", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", "")))
|
||||
|
||||
text = "\n".join(candidates)
|
||||
if "feat/hffs-cache-cdn-range-reads" not in text:
|
||||
raise AssertionError(
|
||||
"huggingface_hub must be installed from "
|
||||
"git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageTimer:
|
||||
fetch_ms: float = 0.0
|
||||
decode_ms: float = 0.0
|
||||
bytes_read: int = 0
|
||||
misses: int = 0
|
||||
|
||||
def record_fetch(self, start: float, byte_count: int) -> None:
|
||||
self.fetch_ms += (time.perf_counter() - start) * 1000
|
||||
self.bytes_read += byte_count
|
||||
self.misses += 1
|
||||
@@ -0,0 +1,666 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Box:
|
||||
type: bytes
|
||||
start: int
|
||||
header_size: int
|
||||
end: int
|
||||
|
||||
@property
|
||||
def payload_start(self) -> int:
|
||||
return self.start + self.header_size
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4SampleSlice:
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
byte_offset: int
|
||||
byte_length: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4Index:
|
||||
file_path: str
|
||||
file_size: int
|
||||
ftyp: bytes
|
||||
moov_offset: int
|
||||
mdat_offset: int
|
||||
mdat_payload_offset: int
|
||||
mdat_payload_size: int
|
||||
faststart: bool
|
||||
codec: str
|
||||
timescale: int
|
||||
duration: int
|
||||
track_id: int
|
||||
width: int
|
||||
height: int
|
||||
stsd_body: bytes
|
||||
sample_pts: np.ndarray
|
||||
sample_durations: np.ndarray
|
||||
sample_sizes: np.ndarray
|
||||
sample_offsets: np.ndarray
|
||||
sync_samples: np.ndarray
|
||||
|
||||
def sample_slice(
|
||||
self,
|
||||
from_ts: float,
|
||||
to_ts: float,
|
||||
*,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
file_size: int | None = None,
|
||||
) -> Mp4SampleSlice:
|
||||
if to_ts < from_ts:
|
||||
raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}")
|
||||
if len(self.sample_pts) == 0:
|
||||
raise ValueError(f"{self.file_path} contains no indexed samples")
|
||||
|
||||
pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction)
|
||||
lo_ts = max(0.0, from_ts - pad)
|
||||
hi_ts = to_ts + pad
|
||||
lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left"))
|
||||
hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1
|
||||
lo = min(max(lo, 0), len(self.sample_pts) - 1)
|
||||
hi = min(max(hi, lo), len(self.sample_pts) - 1)
|
||||
|
||||
if len(self.sync_samples):
|
||||
prev_sync = self.sync_samples[self.sync_samples <= lo]
|
||||
if len(prev_sync):
|
||||
lo = int(prev_sync[-1])
|
||||
else:
|
||||
lo = int(self.sync_samples[0])
|
||||
if lo > hi:
|
||||
hi = lo
|
||||
|
||||
offsets = self.sample_offsets[lo : hi + 1]
|
||||
sizes = self.sample_sizes[lo : hi + 1]
|
||||
slice_lo = int(offsets.min())
|
||||
slice_hi = int((offsets + sizes).max())
|
||||
if file_size is not None:
|
||||
slice_hi = min(slice_hi, int(file_size))
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=lo,
|
||||
sample_hi=hi,
|
||||
byte_offset=slice_lo,
|
||||
byte_length=slice_hi - slice_lo,
|
||||
source_start_pts=float(self.sample_pts[lo]),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"file_size": self.file_size,
|
||||
"ftyp": self.ftyp.hex(),
|
||||
"moov_offset": self.moov_offset,
|
||||
"mdat_offset": self.mdat_offset,
|
||||
"mdat_payload_offset": self.mdat_payload_offset,
|
||||
"mdat_payload_size": self.mdat_payload_size,
|
||||
"faststart": self.faststart,
|
||||
"codec": self.codec,
|
||||
"timescale": self.timescale,
|
||||
"duration": self.duration,
|
||||
"track_id": self.track_id,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"stsd_body": self.stsd_body.hex(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index:
|
||||
return cls(
|
||||
file_path=data["file_path"],
|
||||
file_size=int(data["file_size"]),
|
||||
ftyp=bytes.fromhex(data["ftyp"]),
|
||||
moov_offset=int(data["moov_offset"]),
|
||||
mdat_offset=int(data["mdat_offset"]),
|
||||
mdat_payload_offset=int(data["mdat_payload_offset"]),
|
||||
mdat_payload_size=int(data["mdat_payload_size"]),
|
||||
faststart=bool(data["faststart"]),
|
||||
codec=data["codec"],
|
||||
timescale=int(data["timescale"]),
|
||||
duration=int(data["duration"]),
|
||||
track_id=int(data["track_id"]),
|
||||
width=int(data["width"]),
|
||||
height=int(data["height"]),
|
||||
stsd_body=bytes.fromhex(data["stsd_body"]),
|
||||
sample_pts=arrays["sample_pts"],
|
||||
sample_durations=arrays["sample_durations"],
|
||||
sample_sizes=arrays["sample_sizes"],
|
||||
sample_offsets=arrays["sample_offsets"],
|
||||
sync_samples=arrays["sync_samples"],
|
||||
)
|
||||
|
||||
|
||||
def fetch_mp4_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
*,
|
||||
file_size: int,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> Mp4Index:
|
||||
probe_size = min(header_probe_bytes, file_size)
|
||||
while True:
|
||||
data = read_range(path, 0, probe_size)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
has_mdat = any(box.type == b"mdat" for box in top)
|
||||
has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top)
|
||||
if has_mdat and has_moov:
|
||||
return parse_mp4_index(path, data, file_size=file_size)
|
||||
if probe_size >= min(max_probe_bytes, file_size):
|
||||
if has_mdat and not has_moov:
|
||||
tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes)
|
||||
if tail_index is not None:
|
||||
return tail_index
|
||||
missing = []
|
||||
if not has_mdat:
|
||||
missing.append("mdat")
|
||||
if not has_moov:
|
||||
missing.append("moov")
|
||||
raise ValueError(
|
||||
f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}"
|
||||
)
|
||||
probe_size = min(probe_size * 2, max_probe_bytes, file_size)
|
||||
|
||||
|
||||
def _fetch_tail_moov_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
prefix: bytes,
|
||||
top_boxes: list[Box],
|
||||
file_size: int,
|
||||
max_probe_bytes: int,
|
||||
) -> Mp4Index | None:
|
||||
mdat_box = _one(top_boxes, b"mdat")
|
||||
if mdat_box is None or mdat_box.end >= file_size:
|
||||
return None
|
||||
tail_offset = mdat_box.end
|
||||
tail_length = min(max_probe_bytes, file_size - tail_offset)
|
||||
tail = read_range(path, tail_offset, tail_length)
|
||||
tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True))
|
||||
moov_box = next(
|
||||
(box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None
|
||||
)
|
||||
if moov_box is None:
|
||||
return None
|
||||
ftyp_box = _one(top_boxes, b"ftyp", required=False)
|
||||
ftyp = (
|
||||
prefix[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
moov_start = moov_box.payload_start - tail_offset
|
||||
moov_end = moov_box.end - tail_offset
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=tail[moov_start:moov_end],
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index:
|
||||
if file_size is None:
|
||||
file_size = len(data)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
ftyp_box = _one(top, b"ftyp", required=False)
|
||||
moov_box = _one(top, b"moov")
|
||||
mdat_box = _one(top, b"mdat")
|
||||
if moov_box.end > len(data):
|
||||
raise ValueError(f"{path}: moov box is truncated")
|
||||
|
||||
moov = data[moov_box.payload_start : moov_box.end]
|
||||
ftyp = (
|
||||
data[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=moov,
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def _parse_mp4_index_from_layout(
|
||||
path: str,
|
||||
*,
|
||||
file_size: int,
|
||||
ftyp: bytes,
|
||||
moov_offset: int,
|
||||
moov: bytes,
|
||||
mdat_box: Box,
|
||||
) -> Mp4Index:
|
||||
mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"]))
|
||||
trak_box, trak_payload = _find_video_trak(moov)
|
||||
_ = trak_box
|
||||
tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"]))
|
||||
mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"]))
|
||||
stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"])
|
||||
|
||||
stsd = _find_child(stbl, b"stsd")
|
||||
stsd_body = stbl[stsd.payload_start : stsd.end]
|
||||
codec = _parse_stsd_codec(stsd_body)
|
||||
stts = _parse_stts(_payload(stbl, b"stts"))
|
||||
sample_sizes = _parse_stsz(_payload(stbl, b"stsz"))
|
||||
stsc = _parse_stsc(_payload(stbl, b"stsc"))
|
||||
chunk_offsets = _parse_chunk_offsets(stbl)
|
||||
sync_samples = _parse_stss(stbl, len(sample_sizes))
|
||||
|
||||
sample_durations = _expand_stts(stts, len(sample_sizes))
|
||||
sample_pts_units = np.empty(len(sample_durations), dtype=np.int64)
|
||||
if len(sample_durations):
|
||||
sample_pts_units[0] = 0
|
||||
if len(sample_durations) > 1:
|
||||
sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64)
|
||||
sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale)
|
||||
sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes)
|
||||
|
||||
return Mp4Index(
|
||||
file_path=path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_offset,
|
||||
mdat_offset=mdat_box.start,
|
||||
mdat_payload_offset=mdat_box.payload_start,
|
||||
mdat_payload_size=mdat_box.end - mdat_box.payload_start
|
||||
if mdat_box.end <= file_size
|
||||
else file_size - mdat_box.payload_start,
|
||||
faststart=moov_offset < mdat_box.start,
|
||||
codec=codec,
|
||||
timescale=mdhd_timescale,
|
||||
duration=mdhd_duration or mvhd_duration,
|
||||
track_id=tkhd["track_id"],
|
||||
width=tkhd["width"],
|
||||
height=tkhd["height"],
|
||||
stsd_body=stsd_body,
|
||||
sample_pts=sample_pts,
|
||||
sample_durations=sample_durations,
|
||||
sample_sizes=sample_sizes,
|
||||
sample_offsets=sample_offsets,
|
||||
sync_samples=sync_samples,
|
||||
)
|
||||
|
||||
|
||||
def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes:
|
||||
lo = sample_slice.sample_lo
|
||||
hi = sample_slice.sample_hi + 1
|
||||
if lo < 0 or hi > len(index.sample_sizes) or lo >= hi:
|
||||
raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}")
|
||||
|
||||
offsets = index.sample_offsets[lo:hi]
|
||||
sizes = index.sample_sizes[lo:hi]
|
||||
rel_offsets = offsets - sample_slice.byte_offset
|
||||
if int(rel_offsets.min()) != 0:
|
||||
raise ValueError("Sample slice must start at the minimum referenced sample offset")
|
||||
if int((rel_offsets + sizes).max()) > len(mdat_payload):
|
||||
raise ValueError("Sample slice does not cover all referenced samples")
|
||||
|
||||
durations = index.sample_durations[lo:hi]
|
||||
sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0)
|
||||
header_size = len(index.ftyp) + len(moov)
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8)
|
||||
return index.ftyp + moov + _box(b"mdat", mdat_payload)
|
||||
|
||||
|
||||
def iter_boxes(
|
||||
data: bytes,
|
||||
start: int,
|
||||
end: int,
|
||||
*,
|
||||
absolute_base: int = 0,
|
||||
allow_truncated: bool = False,
|
||||
) -> Iterable[Box]:
|
||||
pos = start
|
||||
while pos + 8 <= end:
|
||||
size = struct.unpack_from(">I", data, pos)[0]
|
||||
typ = data[pos + 4 : pos + 8]
|
||||
header_size = 8
|
||||
if size == 1:
|
||||
if pos + 16 > end:
|
||||
break
|
||||
size = struct.unpack_from(">Q", data, pos + 8)[0]
|
||||
header_size = 16
|
||||
elif size == 0:
|
||||
size = end - pos
|
||||
if size < header_size:
|
||||
break
|
||||
box_end = pos + size
|
||||
if box_end > end and not allow_truncated:
|
||||
break
|
||||
yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end)
|
||||
pos = box_end
|
||||
|
||||
|
||||
def _find_video_trak(moov: bytes) -> tuple[Box, bytes]:
|
||||
for trak in _children(moov, 0, len(moov)):
|
||||
if trak.type != b"trak":
|
||||
continue
|
||||
payload = moov[trak.payload_start : trak.end]
|
||||
hdlr = _find_descendant(payload, [b"mdia", b"hdlr"])
|
||||
if hdlr[8:12] == b"vide":
|
||||
return trak, payload
|
||||
raise ValueError("No video track found")
|
||||
|
||||
|
||||
def _find_descendant(data: bytes, path: list[bytes]) -> bytes:
|
||||
current = data
|
||||
for typ in path:
|
||||
box = _find_child(current, typ)
|
||||
current = current[box.payload_start : box.end]
|
||||
return current
|
||||
|
||||
|
||||
def _find_child(data: bytes, typ: bytes) -> Box:
|
||||
for box in _children(data, 0, len(data)):
|
||||
if box.type == typ:
|
||||
return box
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
|
||||
|
||||
def _children(data: bytes, start: int, end: int) -> Iterable[Box]:
|
||||
return iter_boxes(data, start, end, absolute_base=0)
|
||||
|
||||
|
||||
def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None:
|
||||
matches = [box for box in boxes if box.type == typ]
|
||||
if not matches and required:
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
return matches[0] if matches else None
|
||||
|
||||
|
||||
def _payload(parent: bytes, typ: bytes) -> bytes:
|
||||
box = _find_child(parent, typ)
|
||||
return parent[box.payload_start : box.end]
|
||||
|
||||
|
||||
def _parse_mvhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_mdhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_tkhd(payload: bytes) -> dict[str, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
track_id = struct.unpack_from(">I", payload, 20)[0]
|
||||
duration = struct.unpack_from(">Q", payload, 28)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 88)
|
||||
else:
|
||||
track_id = struct.unpack_from(">I", payload, 12)[0]
|
||||
duration = struct.unpack_from(">I", payload, 20)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 76)
|
||||
return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16}
|
||||
|
||||
|
||||
def _parse_stsd_codec(stsd_body: bytes) -> str:
|
||||
if len(stsd_body) < 16:
|
||||
return "unknown"
|
||||
return stsd_body[12:16].decode("latin1")
|
||||
|
||||
|
||||
def _parse_stts(payload: bytes) -> list[tuple[int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">II", payload, offset))
|
||||
offset += 8
|
||||
return out
|
||||
|
||||
|
||||
def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray:
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
pos = 0
|
||||
for count, delta in entries:
|
||||
values[pos : pos + count] = delta
|
||||
pos += count
|
||||
if pos != sample_count:
|
||||
raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}")
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsz(payload: bytes) -> np.ndarray:
|
||||
sample_size, sample_count = struct.unpack_from(">II", payload, 4)
|
||||
if sample_size:
|
||||
return np.full(sample_count, sample_size, dtype=np.int64)
|
||||
offset = 12
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
for idx in range(sample_count):
|
||||
values[idx] = struct.unpack_from(">I", payload, offset)[0]
|
||||
offset += 4
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">III", payload, offset))
|
||||
offset += 12
|
||||
return out
|
||||
|
||||
|
||||
def _parse_chunk_offsets(stbl: bytes) -> np.ndarray:
|
||||
with_stco = None
|
||||
with_co64 = None
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stco":
|
||||
with_stco = stbl[box.payload_start : box.end]
|
||||
elif box.type == b"co64":
|
||||
with_co64 = stbl[box.payload_start : box.end]
|
||||
if with_co64 is not None:
|
||||
count = struct.unpack_from(">I", with_co64, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
if with_stco is None:
|
||||
raise ValueError("Missing stco/co64 chunk offsets")
|
||||
count = struct.unpack_from(">I", with_stco, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
|
||||
|
||||
def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray:
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stss":
|
||||
payload = stbl[box.payload_start : box.end]
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)],
|
||||
dtype=np.int64,
|
||||
)
|
||||
return np.arange(sample_count, dtype=np.int64)
|
||||
|
||||
|
||||
def _sample_offsets(
|
||||
stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray
|
||||
) -> np.ndarray:
|
||||
if not stsc:
|
||||
raise ValueError("stsc is empty")
|
||||
offsets = np.empty(len(sample_sizes), dtype=np.int64)
|
||||
sample_idx = 0
|
||||
for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc):
|
||||
next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1
|
||||
for chunk_number in range(first_chunk, next_first):
|
||||
if chunk_number < 1 or chunk_number > len(chunk_offsets):
|
||||
raise ValueError("stsc references a chunk outside stco/co64")
|
||||
chunk_pos = int(chunk_offsets[chunk_number - 1])
|
||||
for _ in range(samples_per_chunk):
|
||||
if sample_idx >= len(sample_sizes):
|
||||
return offsets
|
||||
offsets[sample_idx] = chunk_pos
|
||||
chunk_pos += int(sample_sizes[sample_idx])
|
||||
sample_idx += 1
|
||||
if sample_idx != len(sample_sizes):
|
||||
raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}")
|
||||
return offsets
|
||||
|
||||
|
||||
def _make_moov(
|
||||
index: Mp4Index,
|
||||
durations: np.ndarray,
|
||||
sizes: np.ndarray,
|
||||
rel_offsets: np.ndarray,
|
||||
sync_samples: np.ndarray,
|
||||
*,
|
||||
mdat_data_offset: int,
|
||||
) -> bytes:
|
||||
duration = int(durations.sum())
|
||||
stco_values = [int(mdat_data_offset + value) for value in rel_offsets]
|
||||
if any(value > 0xFFFFFFFF for value in stco_values):
|
||||
offset_box = _co64(stco_values)
|
||||
else:
|
||||
offset_box = _stco(stco_values)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", index.stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offset_box
|
||||
+ (_stss(sync_samples) if len(sync_samples) else b""),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia)
|
||||
return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak)
|
||||
|
||||
|
||||
def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes:
|
||||
return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload)
|
||||
|
||||
|
||||
def _box(typ: bytes, payload: bytes) -> bytes:
|
||||
size = len(payload) + 8
|
||||
if size <= 0xFFFFFFFF:
|
||||
return struct.pack(">I4s", size, typ) + payload
|
||||
return struct.pack(">I4sQ", 1, typ, size + 8) + payload
|
||||
|
||||
|
||||
def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIII", 0, 0, timescale, duration)
|
||||
+ struct.pack(">IHH", 0x00010000, 0x0100, 0)
|
||||
+ b"\0" * 8
|
||||
+ matrix
|
||||
+ b"\0" * 24
|
||||
+ struct.pack(">I", next_track_id)
|
||||
)
|
||||
return _full_box(b"mvhd", 0, 0, payload)
|
||||
|
||||
|
||||
def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIIII", 0, 0, track_id, 0, duration)
|
||||
+ b"\0" * 8
|
||||
+ struct.pack(">hhhh", 0, 0, 0, 0)
|
||||
+ matrix
|
||||
+ struct.pack(">II", width << 16, height << 16)
|
||||
)
|
||||
return _full_box(b"tkhd", 0, 7, payload)
|
||||
|
||||
|
||||
def _mdhd(timescale: int, duration: int) -> bytes:
|
||||
return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0")
|
||||
|
||||
|
||||
def _hdlr() -> bytes:
|
||||
return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0")
|
||||
|
||||
|
||||
def _vmhd() -> bytes:
|
||||
return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0))
|
||||
|
||||
|
||||
def _dinf() -> bytes:
|
||||
url = _full_box(b"url ", 0, 1)
|
||||
dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url)
|
||||
return _box(b"dinf", dref)
|
||||
|
||||
|
||||
def _stts(durations: np.ndarray) -> bytes:
|
||||
runs = []
|
||||
for duration in durations.tolist():
|
||||
if runs and runs[-1][1] == int(duration):
|
||||
runs[-1][0] += 1
|
||||
else:
|
||||
runs.append([1, int(duration)])
|
||||
payload = struct.pack(">I", len(runs)) + b"".join(
|
||||
struct.pack(">II", count, delta) for count, delta in runs
|
||||
)
|
||||
return _full_box(b"stts", 0, 0, payload)
|
||||
|
||||
|
||||
def _stsc_one_sample_per_chunk(sample_count: int) -> bytes:
|
||||
return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1))
|
||||
|
||||
|
||||
def _stsz(sizes: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stsz",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()),
|
||||
)
|
||||
|
||||
|
||||
def _stco(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _co64(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _stss(values: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stss",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()),
|
||||
)
|
||||
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch
|
||||
from lerobot.datasets.mp4 import (
|
||||
_box,
|
||||
_co64,
|
||||
_dinf,
|
||||
_hdlr,
|
||||
_mdhd,
|
||||
_mvhd,
|
||||
_stco,
|
||||
_stsc_one_sample_per_chunk,
|
||||
_stss,
|
||||
_stsz,
|
||||
_stts,
|
||||
_tkhd,
|
||||
_vmhd,
|
||||
parse_mp4_index,
|
||||
synthesize_mp4,
|
||||
)
|
||||
|
||||
|
||||
def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes:
|
||||
ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
sizes = np.array([10, 10, 10], dtype=np.int64)
|
||||
durations = np.array([1000, 1000, 1000], dtype=np.int64)
|
||||
stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8
|
||||
offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offsets
|
||||
+ _stss(np.array([1], dtype=np.int64)),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia)
|
||||
moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak)
|
||||
mdat_payload_start = 10_000
|
||||
free_size = mdat_payload_start - 8 - len(ftyp) - len(moov)
|
||||
assert free_size >= 8
|
||||
free = _box(b"free", b"\0" * (free_size - 8))
|
||||
return ftyp + moov + free + _box(b"mdat", b"x" * 128)
|
||||
|
||||
|
||||
def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
assert sample_slice.byte_offset == 10_000
|
||||
assert sample_slice.byte_length == 60
|
||||
assert sample_slice.sample_lo == 0
|
||||
assert sample_slice.sample_hi == 2
|
||||
|
||||
|
||||
def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length)
|
||||
mini_index = parse_mp4_index("mini.mp4", mini)
|
||||
|
||||
expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset
|
||||
np.testing.assert_array_equal(mini_index.sample_offsets, expected)
|
||||
np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10]))
|
||||
|
||||
|
||||
def test_parser_accepts_co64_chunk_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True))
|
||||
|
||||
np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025]))
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps(
|
||||
{
|
||||
"url": "https://github.com/huggingface/huggingface_hub.git",
|
||||
"vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"},
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_hf_hub_range_cache_branch()
|
||||
Reference in New Issue
Block a user