mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Add episode video streaming byte cache
This commit is contained in:
@@ -0,0 +1,724 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import (
|
||||
EpisodeByteCache,
|
||||
EpisodeVideoManifest,
|
||||
NativeHTTPRangeFetcher,
|
||||
assert_hf_hub_range_cache_branch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=("both", "indexed", "remote-decoder", "native-http"),
|
||||
default="both",
|
||||
help=argparse.SUPPRESS,
|
||||
)
|
||||
parser.add_argument("--num-episodes", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--manifest-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit manifest construction to the first N episodes for local smoke tests.",
|
||||
)
|
||||
parser.add_argument("--pool-size", type=int, default=16)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
||||
parser.add_argument(
|
||||
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(total, requested)
|
||||
if pool_size > upper:
|
||||
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
||||
return rng.sample(range(upper), pool_size)
|
||||
|
||||
|
||||
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
span = manifest.lookup(ep, camera_key)
|
||||
lo = span.first_pts
|
||||
hi = max(span.last_pts, lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _timestamps_from_meta(
|
||||
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
|
||||
) -> dict[tuple[int, str], list[float]]:
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
row = meta.episodes[ep]
|
||||
for camera_key in meta.video_keys:
|
||||
lo = float(row[f"videos/{camera_key}/from_timestamp"])
|
||||
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
total = 0
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
total += manifest.lookup(ep, camera_key).mdat_length
|
||||
return total
|
||||
|
||||
|
||||
def _decode_all(
|
||||
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
||||
) -> float:
|
||||
start = time.perf_counter()
|
||||
items = list(timestamps.items())
|
||||
if decode_workers <= 1:
|
||||
for (ep, camera_key), ts in items:
|
||||
cache.get_frames(ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
|
||||
start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
cache.submit_prefetch(ep)
|
||||
for ep in episodes:
|
||||
cache.ensure_ready(ep)
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
|
||||
if elapsed_s <= 0:
|
||||
return float("inf")
|
||||
return len(episodes) * frames_per_episode / elapsed_s
|
||||
|
||||
|
||||
def _log(message: str) -> None:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def _root_join(data_root: str, relative_path: str) -> str:
|
||||
if data_root.startswith("hf://"):
|
||||
return f"{data_root.rstrip('/')}/{relative_path}"
|
||||
return str(Path(data_root) / relative_path)
|
||||
|
||||
|
||||
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
|
||||
local = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz"
|
||||
if _valid_sidecar(local):
|
||||
return local
|
||||
if local.exists():
|
||||
print(f"mp4_sidecar_invalid_local: {local}")
|
||||
local.unlink()
|
||||
full_local = SIDECAR_CACHE_DIR / "molmoact2-full.npz"
|
||||
if _valid_sidecar(full_local):
|
||||
return full_local
|
||||
remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz")
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
fs = fsspec.filesystem(protocol)
|
||||
if not fs.exists(remote):
|
||||
return None
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"downloading_mp4_sidecar: {remote} -> {local}")
|
||||
if data_root.startswith("hf://"):
|
||||
_download_sidecar_native_http(
|
||||
data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz", local
|
||||
)
|
||||
else:
|
||||
fs.get(remote, str(local))
|
||||
return local
|
||||
|
||||
|
||||
def _valid_sidecar(path: Path) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
with np.load(path, allow_pickle=False) as data:
|
||||
return "manifest_json" in data
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
|
||||
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
|
||||
tmp = local.with_suffix(local.suffix + ".tmp")
|
||||
try:
|
||||
size = fetcher.info_size(relative_path)
|
||||
chunk_size = 16 * 1024 * 1024
|
||||
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
|
||||
with tmp.open("wb") as out_file:
|
||||
out_file.truncate(size)
|
||||
|
||||
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
|
||||
offset, length = offset_length
|
||||
return offset, fetcher.read_range(relative_path, offset, length)
|
||||
|
||||
start = time.perf_counter()
|
||||
done = 0
|
||||
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||
futures = [pool.submit(read_chunk, item) for item in ranges]
|
||||
with tmp.open("r+b") as rw_file:
|
||||
for future in futures:
|
||||
offset, data = future.result()
|
||||
rw_file.seek(offset)
|
||||
rw_file.write(data)
|
||||
done += len(data)
|
||||
elapsed = max(time.perf_counter() - start, 1e-9)
|
||||
print(
|
||||
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
|
||||
f"({done / elapsed / 1024**2:.1f} MiB/s)",
|
||||
flush=True,
|
||||
)
|
||||
tmp.replace(local)
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
|
||||
class EpisodeParquetReader:
|
||||
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
|
||||
self.meta = meta
|
||||
self.data_root = data_root
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self._episode_row_groups = self._build_episode_row_groups()
|
||||
self._table_cache: dict[str, pa.Table] = {}
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def read_episode(self, episode_index: int) -> None:
|
||||
relative_path = str(self.meta.get_data_file_path(episode_index))
|
||||
table = self._read_table(relative_path)
|
||||
table.filter(pc.equal(table["episode_index"], episode_index))
|
||||
|
||||
def _read_table(self, relative_path: str) -> pa.Table:
|
||||
with self._cache_lock:
|
||||
table = self._table_cache.get(relative_path)
|
||||
if table is not None:
|
||||
return table
|
||||
with self.fs.open(
|
||||
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
|
||||
) as f:
|
||||
table = pq.ParquetFile(f).read()
|
||||
with self._cache_lock:
|
||||
return self._table_cache.setdefault(relative_path, table)
|
||||
|
||||
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
|
||||
return pool.submit(self.read_episode, episode_index)
|
||||
|
||||
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
|
||||
start = time.perf_counter()
|
||||
if workers <= 1:
|
||||
for ep in episodes:
|
||||
self.read_episode(ep)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
def _build_episode_row_groups(self) -> dict[int, int]:
|
||||
counts: dict[tuple[int, int], int] = {}
|
||||
row_groups = {}
|
||||
for ep_idx in range(int(self.meta.total_episodes)):
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
|
||||
row_groups[ep_idx] = counts.get(key, 0)
|
||||
counts[key] = row_groups[ep_idx] + 1
|
||||
return row_groups
|
||||
|
||||
|
||||
def run_sequential(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
byte_budget: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=1,
|
||||
range_backend=range_backend,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
parquet_s = parquet_reader.read_episodes(episodes, workers=1)
|
||||
elapsed = _fill_cache(cache, episodes)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
return {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / elapsed,
|
||||
"episode_mb": episode_mb,
|
||||
"parquet_s": parquet_s,
|
||||
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
|
||||
}
|
||||
|
||||
|
||||
def run_parallel(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
|
||||
fetch_s = _fill_cache(cache, episodes)
|
||||
decoder_start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_decoder(ep, camera_key)
|
||||
decoder_s = time.perf_counter() - decoder_start
|
||||
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
return {
|
||||
"fetch_s": fetch_s,
|
||||
"fetch_mbps": byte_count / fetch_s / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / fetch_s,
|
||||
"parquet_s": parquet_s,
|
||||
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
|
||||
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def run_overlapped(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
prefetch_ahead: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=True,
|
||||
) as cache:
|
||||
start = time.perf_counter()
|
||||
video_wait_decode_s = 0.0
|
||||
parquet_wait_s = 0.0
|
||||
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
|
||||
parquet_futures = {
|
||||
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
|
||||
}
|
||||
for ep in episodes[:prefetch_ahead]:
|
||||
cache.submit_prefetch(ep)
|
||||
try:
|
||||
for idx, ep in enumerate(episodes):
|
||||
next_idx = idx + prefetch_ahead
|
||||
if next_idx < len(episodes):
|
||||
next_ep = episodes[next_idx]
|
||||
cache.submit_prefetch(next_ep)
|
||||
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
|
||||
|
||||
parquet_start = time.perf_counter()
|
||||
parquet_futures.pop(ep).result()
|
||||
parquet_wait_s += time.perf_counter() - parquet_start
|
||||
|
||||
video_start = time.perf_counter()
|
||||
cache.ensure_ready(ep)
|
||||
if decode_workers <= 1:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
|
||||
for camera_key in manifest.video_keys
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
video_wait_decode_s += time.perf_counter() - video_start
|
||||
finally:
|
||||
parquet_pool.shutdown(wait=True)
|
||||
elapsed = time.perf_counter() - start
|
||||
return {
|
||||
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
|
||||
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
|
||||
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
|
||||
"wall_s": elapsed,
|
||||
"video_wait_decode_s": video_wait_decode_s,
|
||||
"parquet_wait_s": parquet_wait_s,
|
||||
}
|
||||
|
||||
|
||||
_remote_decoder_local = threading.local()
|
||||
|
||||
|
||||
def _remote_decoder_cache() -> VideoDecoderCache:
|
||||
cache = getattr(_remote_decoder_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = VideoDecoderCache(max_size=None)
|
||||
_remote_decoder_local.cache = cache
|
||||
return cache
|
||||
|
||||
|
||||
def _decode_remote_source(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
timestamps: list[float],
|
||||
):
|
||||
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
|
||||
return decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
timestamps,
|
||||
tolerance_s=1.0 / float(meta.fps),
|
||||
decoder_cache=_remote_decoder_cache(),
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
|
||||
def run_remote_decoder(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
*,
|
||||
frames_per_episode: int,
|
||||
decode_workers: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> dict[str, float]:
|
||||
items = [
|
||||
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
sequential_s = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
if decode_workers <= 1:
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
|
||||
futures = [
|
||||
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
|
||||
for ep, camera_key, ts in items
|
||||
]
|
||||
for future in parquet_futures:
|
||||
future.result()
|
||||
for future in futures:
|
||||
future.result()
|
||||
parallel_s = time.perf_counter() - start
|
||||
|
||||
return {
|
||||
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
|
||||
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def run_indexed_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
*,
|
||||
range_backend: str = "fsspec",
|
||||
label: str = "indexed",
|
||||
sidecar_path: str | None = None,
|
||||
) -> None:
|
||||
_log(f"starting_strategy: {label}")
|
||||
manifest_start = time.perf_counter()
|
||||
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||
manifest = EpisodeVideoManifest.build(
|
||||
meta,
|
||||
data_root,
|
||||
episode_indices=range(manifest_episode_count),
|
||||
range_backend=range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
sidecar_path=sidecar_path,
|
||||
)
|
||||
manifest_s = time.perf_counter() - manifest_start
|
||||
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
||||
byte_budget = int(args.byte_budget_gb * 1024**3)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
_log(
|
||||
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
|
||||
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
|
||||
)
|
||||
|
||||
_log(f"{label}: running sequential video fetch")
|
||||
sequential = run_sequential(manifest, data_root, episodes, byte_budget, parquet_reader, range_backend)
|
||||
_log(f"{label}: running parallel video fetch + decode-only")
|
||||
parallel = run_parallel(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
_log(f"{label}: running overlapped end-to-end")
|
||||
overlapped = run_overlapped(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
args.prefetch_ahead,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
|
||||
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||
print(f"strategy: {label}")
|
||||
print(f"range_backend: {range_backend}")
|
||||
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
print()
|
||||
print("| Track | fetch MB/s | fetch eps/s | samples/s | avg MB/miss | notes |")
|
||||
print("|---|---:|---:|---:|---:|---|")
|
||||
print(
|
||||
f"| SEQUENTIAL | {sequential['fetch_mbps']:.1f} | {sequential['fetch_episodes_s']:.2f} | - | "
|
||||
f"{sequential['avg_mb_miss']:.1f} | 1 worker video fetch, parquet {sequential['parquet_s']:.2f}s |"
|
||||
)
|
||||
print(
|
||||
f"| PARALLEL | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
|
||||
f"{parallel['decode_samples_s']:.1f} | "
|
||||
f"{sequential['avg_mb_miss']:.1f} | decode-only, decoder open "
|
||||
f"{parallel['decoder_ms_miss']:.1f} ms/miss, parquet {parallel['parquet_s']:.2f}s |"
|
||||
)
|
||||
print(
|
||||
f"| OVERLAPPED | - | - | {overlapped['samples_s']:.1f} | {sequential['avg_mb_miss']:.1f} | "
|
||||
f"end-to-end; video {overlapped['video_samples_s']:.1f} samples/s "
|
||||
f"({overlapped['video_wait_decode_s']:.2f}s), parquet {overlapped['parquet_samples_s']:.1f} "
|
||||
f"samples/s ({overlapped['parquet_wait_s']:.2f}s) |"
|
||||
)
|
||||
|
||||
|
||||
def run_remote_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> None:
|
||||
_log("starting_strategy: remote-decoder")
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log("remote-decoder: running direct source MP4 decoder")
|
||||
result = run_remote_decoder(
|
||||
meta,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
frames_per_episode=args.frames_per_episode,
|
||||
decode_workers=args.decode_workers,
|
||||
parquet_reader=parquet_reader,
|
||||
)
|
||||
print("strategy: remote-decoder")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"episodes: {episodes}")
|
||||
print(f"cameras: {list(meta.video_keys)}")
|
||||
print()
|
||||
print("| Track | samples/s | notes |")
|
||||
print("|---|---:|---|")
|
||||
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
|
||||
print(
|
||||
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
|
||||
f"direct source MP4 decoder, {args.decode_workers} workers |"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
data_root = args.data_root
|
||||
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
parquet_reader = EpisodeParquetReader(meta, data_root)
|
||||
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
|
||||
|
||||
if sidecar_path is not None:
|
||||
print(f"using_mp4_sidecar: {sidecar_path}")
|
||||
|
||||
if sidecar_path is not None and args.strategy == "both":
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
print()
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
print()
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "indexed":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "native-http":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http-sidecar",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy == "both":
|
||||
expected_sidecar = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz"
|
||||
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz")
|
||||
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||
print(f"mp4_sidecar_missing_remote: {expected_remote}")
|
||||
print(
|
||||
"build_mp4_sidecar: "
|
||||
f"uv run --no-sync python scripts/build_mp4_sidecar.py --episodes {manifest_episode_count} "
|
||||
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||
)
|
||||
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||
print()
|
||||
|
||||
if args.strategy in ("both", "indexed"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed",
|
||||
sidecar_path=None,
|
||||
)
|
||||
if args.strategy == "both":
|
||||
print()
|
||||
if args.strategy in ("both", "remote-decoder"):
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
if args.strategy == "both":
|
||||
print()
|
||||
if args.strategy in ("both", "native-http"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http",
|
||||
sidecar_path=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--episodes", type=int, default=None)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument(
|
||||
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def push_sidecar(local_path: str, data_root: str, episode_count: int) -> list[str]:
|
||||
if not data_root.startswith("hf://"):
|
||||
return []
|
||||
|
||||
local = Path(local_path)
|
||||
fs = fsspec.filesystem("hf")
|
||||
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
|
||||
remote_paths = [f"{remote_dir}/{local.name}"]
|
||||
alias = f"{remote_dir}/molmoact2-{episode_count}.npz"
|
||||
if alias not in remote_paths:
|
||||
remote_paths.append(alias)
|
||||
|
||||
for remote in remote_paths:
|
||||
fs.put(str(local), remote)
|
||||
return remote_paths
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
total = (
|
||||
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
|
||||
)
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
EpisodeVideoManifest.write_file_sidecar(
|
||||
args.output,
|
||||
rel_paths,
|
||||
args.data_root,
|
||||
range_backend=args.range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f"wrote {args.output}")
|
||||
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
|
||||
if args.no_push:
|
||||
print("push_skipped: --no-push")
|
||||
else:
|
||||
pushed = push_sidecar(args.output, args.data_root, total)
|
||||
for remote in pushed:
|
||||
print(f"pushed {remote}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user