mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
725 lines
26 KiB
Python
725 lines
26 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import random
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
from collections.abc import Sequence
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
|
|
import fsspec
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
import pyarrow.compute as pc
|
|
import pyarrow.parquet as pq
|
|
|
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
|
from lerobot.datasets.episode_video_streaming import (
|
|
EpisodeByteCache,
|
|
EpisodeVideoManifest,
|
|
NativeHTTPRangeFetcher,
|
|
assert_hf_hub_range_cache_branch,
|
|
)
|
|
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
|
|
|
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
|
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
|
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
|
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
|
|
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
|
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
|
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
|
parser.add_argument(
|
|
"--strategy",
|
|
choices=("both", "indexed", "remote-decoder", "native-http"),
|
|
default="both",
|
|
help=argparse.SUPPRESS,
|
|
)
|
|
parser.add_argument("--num-episodes", type=int, default=512)
|
|
parser.add_argument(
|
|
"--manifest-episodes",
|
|
type=int,
|
|
default=None,
|
|
help="Limit manifest construction to the first N episodes for local smoke tests.",
|
|
)
|
|
parser.add_argument("--pool-size", type=int, default=16)
|
|
parser.add_argument("--workers", type=int, default=8)
|
|
parser.add_argument("--decode-workers", type=int, default=1)
|
|
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
|
parser.add_argument("--frames-per-episode", type=int, default=16)
|
|
parser.add_argument("--max-probe-mb", type=int, default=64)
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
|
parser.add_argument(
|
|
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
|
)
|
|
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
|
return parser.parse_args()
|
|
|
|
|
|
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
|
rng = random.Random(seed)
|
|
upper = min(total, requested)
|
|
if pool_size > upper:
|
|
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
|
return rng.sample(range(upper), pool_size)
|
|
|
|
|
|
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
|
rng = random.Random(seed)
|
|
out: dict[tuple[int, str], list[float]] = {}
|
|
for ep in episodes:
|
|
for camera_key in manifest.video_keys:
|
|
span = manifest.lookup(ep, camera_key)
|
|
lo = span.first_pts
|
|
hi = max(span.last_pts, lo)
|
|
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
|
return out
|
|
|
|
|
|
def _timestamps_from_meta(
|
|
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
|
|
) -> dict[tuple[int, str], list[float]]:
|
|
rng = random.Random(seed)
|
|
out: dict[tuple[int, str], list[float]] = {}
|
|
for ep in episodes:
|
|
row = meta.episodes[ep]
|
|
for camera_key in meta.video_keys:
|
|
lo = float(row[f"videos/{camera_key}/from_timestamp"])
|
|
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
|
|
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
|
return out
|
|
|
|
|
|
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
|
total = 0
|
|
for ep in episodes:
|
|
for camera_key in manifest.video_keys:
|
|
total += manifest.lookup(ep, camera_key).mdat_length
|
|
return total
|
|
|
|
|
|
def _decode_all(
|
|
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
|
) -> float:
|
|
start = time.perf_counter()
|
|
items = list(timestamps.items())
|
|
if decode_workers <= 1:
|
|
for (ep, camera_key), ts in items:
|
|
cache.get_frames(ep, camera_key, ts)
|
|
else:
|
|
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
|
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
|
|
for future in futures:
|
|
future.result()
|
|
return time.perf_counter() - start
|
|
|
|
|
|
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
|
|
start = time.perf_counter()
|
|
for ep in episodes:
|
|
cache.submit_prefetch(ep)
|
|
for ep in episodes:
|
|
cache.ensure_ready(ep)
|
|
return time.perf_counter() - start
|
|
|
|
|
|
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
|
|
if elapsed_s <= 0:
|
|
return float("inf")
|
|
return len(episodes) * frames_per_episode / elapsed_s
|
|
|
|
|
|
def _log(message: str) -> None:
|
|
print(message, flush=True)
|
|
|
|
|
|
def _root_join(data_root: str, relative_path: str) -> str:
|
|
if data_root.startswith("hf://"):
|
|
return f"{data_root.rstrip('/')}/{relative_path}"
|
|
return str(Path(data_root) / relative_path)
|
|
|
|
|
|
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
|
|
local = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz"
|
|
if _valid_sidecar(local):
|
|
return local
|
|
if local.exists():
|
|
print(f"mp4_sidecar_invalid_local: {local}")
|
|
local.unlink()
|
|
full_local = SIDECAR_CACHE_DIR / "molmoact2-full.npz"
|
|
if _valid_sidecar(full_local):
|
|
return full_local
|
|
remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz")
|
|
protocol = "hf" if data_root.startswith("hf://") else "file"
|
|
fs = fsspec.filesystem(protocol)
|
|
if not fs.exists(remote):
|
|
return None
|
|
local.parent.mkdir(parents=True, exist_ok=True)
|
|
print(f"downloading_mp4_sidecar: {remote} -> {local}")
|
|
if data_root.startswith("hf://"):
|
|
_download_sidecar_native_http(
|
|
data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz", local
|
|
)
|
|
else:
|
|
fs.get(remote, str(local))
|
|
return local
|
|
|
|
|
|
def _valid_sidecar(path: Path) -> bool:
|
|
if not path.exists():
|
|
return False
|
|
try:
|
|
with np.load(path, allow_pickle=False) as data:
|
|
return "manifest_json" in data
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
|
|
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
|
|
tmp = local.with_suffix(local.suffix + ".tmp")
|
|
try:
|
|
size = fetcher.info_size(relative_path)
|
|
chunk_size = 16 * 1024 * 1024
|
|
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
|
|
with tmp.open("wb") as out_file:
|
|
out_file.truncate(size)
|
|
|
|
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
|
|
offset, length = offset_length
|
|
return offset, fetcher.read_range(relative_path, offset, length)
|
|
|
|
start = time.perf_counter()
|
|
done = 0
|
|
with ThreadPoolExecutor(max_workers=8) as pool:
|
|
futures = [pool.submit(read_chunk, item) for item in ranges]
|
|
with tmp.open("r+b") as rw_file:
|
|
for future in futures:
|
|
offset, data = future.result()
|
|
rw_file.seek(offset)
|
|
rw_file.write(data)
|
|
done += len(data)
|
|
elapsed = max(time.perf_counter() - start, 1e-9)
|
|
print(
|
|
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
|
|
f"({done / elapsed / 1024**2:.1f} MiB/s)",
|
|
flush=True,
|
|
)
|
|
tmp.replace(local)
|
|
finally:
|
|
fetcher.close()
|
|
|
|
|
|
class EpisodeParquetReader:
|
|
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
|
|
self.meta = meta
|
|
self.data_root = data_root
|
|
protocol = "hf" if data_root.startswith("hf://") else "file"
|
|
self.fs = fsspec.filesystem(protocol)
|
|
self._episode_row_groups = self._build_episode_row_groups()
|
|
self._table_cache: dict[str, pa.Table] = {}
|
|
self._cache_lock = threading.Lock()
|
|
|
|
def read_episode(self, episode_index: int) -> None:
|
|
relative_path = str(self.meta.get_data_file_path(episode_index))
|
|
table = self._read_table(relative_path)
|
|
table.filter(pc.equal(table["episode_index"], episode_index))
|
|
|
|
def _read_table(self, relative_path: str) -> pa.Table:
|
|
with self._cache_lock:
|
|
table = self._table_cache.get(relative_path)
|
|
if table is not None:
|
|
return table
|
|
with self.fs.open(
|
|
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
|
|
) as f:
|
|
table = pq.ParquetFile(f).read()
|
|
with self._cache_lock:
|
|
return self._table_cache.setdefault(relative_path, table)
|
|
|
|
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
|
|
return pool.submit(self.read_episode, episode_index)
|
|
|
|
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
|
|
start = time.perf_counter()
|
|
if workers <= 1:
|
|
for ep in episodes:
|
|
self.read_episode(ep)
|
|
else:
|
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
|
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
|
|
for future in futures:
|
|
future.result()
|
|
return time.perf_counter() - start
|
|
|
|
def _build_episode_row_groups(self) -> dict[int, int]:
|
|
counts: dict[tuple[int, int], int] = {}
|
|
row_groups = {}
|
|
for ep_idx in range(int(self.meta.total_episodes)):
|
|
ep = self.meta.episodes[ep_idx]
|
|
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
|
|
row_groups[ep_idx] = counts.get(key, 0)
|
|
counts[key] = row_groups[ep_idx] + 1
|
|
return row_groups
|
|
|
|
|
|
def run_sequential(
|
|
manifest: EpisodeVideoManifest,
|
|
data_root: str,
|
|
episodes: Sequence[int],
|
|
byte_budget: int,
|
|
parquet_reader: EpisodeParquetReader,
|
|
range_backend: str,
|
|
) -> dict[str, float]:
|
|
with EpisodeByteCache(
|
|
manifest,
|
|
data_root,
|
|
byte_budget=byte_budget,
|
|
workers=1,
|
|
range_backend=range_backend,
|
|
open_decoders=False,
|
|
) as cache:
|
|
parquet_s = parquet_reader.read_episodes(episodes, workers=1)
|
|
elapsed = _fill_cache(cache, episodes)
|
|
byte_count = _bytes_for(manifest, episodes)
|
|
episode_mb = byte_count / len(episodes) / 1024**2
|
|
return {
|
|
"fetch_s": elapsed,
|
|
"fetch_mbps": byte_count / elapsed / 1024**2,
|
|
"fetch_episodes_s": len(episodes) / elapsed,
|
|
"episode_mb": episode_mb,
|
|
"parquet_s": parquet_s,
|
|
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
|
|
}
|
|
|
|
|
|
def run_parallel(
|
|
manifest: EpisodeVideoManifest,
|
|
data_root: str,
|
|
episodes: Sequence[int],
|
|
timestamps: dict[tuple[int, str], list[float]],
|
|
byte_budget: int,
|
|
workers: int,
|
|
decode_workers: int,
|
|
frames_per_episode: int,
|
|
parquet_reader: EpisodeParquetReader,
|
|
range_backend: str,
|
|
) -> dict[str, float]:
|
|
with EpisodeByteCache(
|
|
manifest,
|
|
data_root,
|
|
byte_budget=byte_budget,
|
|
workers=workers,
|
|
range_backend=range_backend,
|
|
open_decoders=False,
|
|
) as cache:
|
|
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
|
|
fetch_s = _fill_cache(cache, episodes)
|
|
decoder_start = time.perf_counter()
|
|
for ep in episodes:
|
|
for camera_key in manifest.video_keys:
|
|
cache.get_decoder(ep, camera_key)
|
|
decoder_s = time.perf_counter() - decoder_start
|
|
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
|
|
byte_count = _bytes_for(manifest, episodes)
|
|
return {
|
|
"fetch_s": fetch_s,
|
|
"fetch_mbps": byte_count / fetch_s / 1024**2,
|
|
"fetch_episodes_s": len(episodes) / fetch_s,
|
|
"parquet_s": parquet_s,
|
|
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
|
|
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
|
|
}
|
|
|
|
|
|
def run_overlapped(
|
|
manifest: EpisodeVideoManifest,
|
|
data_root: str,
|
|
episodes: Sequence[int],
|
|
timestamps: dict[tuple[int, str], list[float]],
|
|
byte_budget: int,
|
|
workers: int,
|
|
decode_workers: int,
|
|
frames_per_episode: int,
|
|
prefetch_ahead: int,
|
|
parquet_reader: EpisodeParquetReader,
|
|
range_backend: str,
|
|
) -> dict[str, float]:
|
|
with EpisodeByteCache(
|
|
manifest,
|
|
data_root,
|
|
byte_budget=byte_budget,
|
|
workers=workers,
|
|
range_backend=range_backend,
|
|
open_decoders=True,
|
|
) as cache:
|
|
start = time.perf_counter()
|
|
video_wait_decode_s = 0.0
|
|
parquet_wait_s = 0.0
|
|
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
|
|
parquet_futures = {
|
|
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
|
|
}
|
|
for ep in episodes[:prefetch_ahead]:
|
|
cache.submit_prefetch(ep)
|
|
try:
|
|
for idx, ep in enumerate(episodes):
|
|
next_idx = idx + prefetch_ahead
|
|
if next_idx < len(episodes):
|
|
next_ep = episodes[next_idx]
|
|
cache.submit_prefetch(next_ep)
|
|
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
|
|
|
|
parquet_start = time.perf_counter()
|
|
parquet_futures.pop(ep).result()
|
|
parquet_wait_s += time.perf_counter() - parquet_start
|
|
|
|
video_start = time.perf_counter()
|
|
cache.ensure_ready(ep)
|
|
if decode_workers <= 1:
|
|
for camera_key in manifest.video_keys:
|
|
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
|
|
else:
|
|
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
|
futures = [
|
|
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
|
|
for camera_key in manifest.video_keys
|
|
]
|
|
for future in futures:
|
|
future.result()
|
|
video_wait_decode_s += time.perf_counter() - video_start
|
|
finally:
|
|
parquet_pool.shutdown(wait=True)
|
|
elapsed = time.perf_counter() - start
|
|
return {
|
|
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
|
|
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
|
|
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
|
|
"wall_s": elapsed,
|
|
"video_wait_decode_s": video_wait_decode_s,
|
|
"parquet_wait_s": parquet_wait_s,
|
|
}
|
|
|
|
|
|
_remote_decoder_local = threading.local()
|
|
|
|
|
|
def _remote_decoder_cache() -> VideoDecoderCache:
|
|
cache = getattr(_remote_decoder_local, "cache", None)
|
|
if cache is None:
|
|
cache = VideoDecoderCache(max_size=None)
|
|
_remote_decoder_local.cache = cache
|
|
return cache
|
|
|
|
|
|
def _decode_remote_source(
|
|
meta: LeRobotDatasetMetadata,
|
|
data_root: str,
|
|
episode_index: int,
|
|
camera_key: str,
|
|
timestamps: list[float],
|
|
):
|
|
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
|
|
return decode_video_frames_torchcodec(
|
|
video_path,
|
|
timestamps,
|
|
tolerance_s=1.0 / float(meta.fps),
|
|
decoder_cache=_remote_decoder_cache(),
|
|
return_uint8=True,
|
|
)
|
|
|
|
|
|
def run_remote_decoder(
|
|
meta: LeRobotDatasetMetadata,
|
|
data_root: str,
|
|
episodes: Sequence[int],
|
|
timestamps: dict[tuple[int, str], list[float]],
|
|
*,
|
|
frames_per_episode: int,
|
|
decode_workers: int,
|
|
parquet_reader: EpisodeParquetReader,
|
|
) -> dict[str, float]:
|
|
items = [
|
|
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
|
|
]
|
|
|
|
start = time.perf_counter()
|
|
for ep, camera_key, ts in items:
|
|
if camera_key == meta.video_keys[0]:
|
|
parquet_reader.read_episode(ep)
|
|
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
|
sequential_s = time.perf_counter() - start
|
|
|
|
start = time.perf_counter()
|
|
if decode_workers <= 1:
|
|
for ep, camera_key, ts in items:
|
|
if camera_key == meta.video_keys[0]:
|
|
parquet_reader.read_episode(ep)
|
|
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
|
else:
|
|
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
|
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
|
|
futures = [
|
|
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
|
|
for ep, camera_key, ts in items
|
|
]
|
|
for future in parquet_futures:
|
|
future.result()
|
|
for future in futures:
|
|
future.result()
|
|
parallel_s = time.perf_counter() - start
|
|
|
|
return {
|
|
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
|
|
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
|
|
}
|
|
|
|
|
|
def run_indexed_strategy(
|
|
meta: LeRobotDatasetMetadata,
|
|
data_root: str,
|
|
args: argparse.Namespace,
|
|
parquet_reader: EpisodeParquetReader,
|
|
*,
|
|
range_backend: str = "fsspec",
|
|
label: str = "indexed",
|
|
sidecar_path: str | None = None,
|
|
) -> None:
|
|
_log(f"starting_strategy: {label}")
|
|
manifest_start = time.perf_counter()
|
|
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
|
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
|
manifest = EpisodeVideoManifest.build(
|
|
meta,
|
|
data_root,
|
|
episode_indices=range(manifest_episode_count),
|
|
range_backend=range_backend,
|
|
workers=args.workers,
|
|
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
|
sidecar_path=sidecar_path,
|
|
)
|
|
manifest_s = time.perf_counter() - manifest_start
|
|
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
|
|
|
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
|
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
|
byte_budget = int(args.byte_budget_gb * 1024**3)
|
|
byte_count = _bytes_for(manifest, episodes)
|
|
_log(
|
|
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
|
|
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
|
|
)
|
|
|
|
_log(f"{label}: running sequential video fetch")
|
|
sequential = run_sequential(manifest, data_root, episodes, byte_budget, parquet_reader, range_backend)
|
|
_log(f"{label}: running parallel video fetch + decode-only")
|
|
parallel = run_parallel(
|
|
manifest,
|
|
data_root,
|
|
episodes,
|
|
timestamps,
|
|
byte_budget,
|
|
args.workers,
|
|
args.decode_workers,
|
|
args.frames_per_episode,
|
|
parquet_reader,
|
|
range_backend,
|
|
)
|
|
_log(f"{label}: running overlapped end-to-end")
|
|
overlapped = run_overlapped(
|
|
manifest,
|
|
data_root,
|
|
episodes,
|
|
timestamps,
|
|
byte_budget,
|
|
args.workers,
|
|
args.decode_workers,
|
|
args.frames_per_episode,
|
|
args.prefetch_ahead,
|
|
parquet_reader,
|
|
range_backend,
|
|
)
|
|
|
|
print(f"manifest_build_s: {manifest_s:.2f}")
|
|
print(f"strategy: {label}")
|
|
print(f"range_backend: {range_backend}")
|
|
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
|
print(f"data_root: {data_root}")
|
|
print(f"episodes: {episodes}")
|
|
print(f"cameras: {manifest.video_keys}")
|
|
print()
|
|
print("| Track | fetch MB/s | fetch eps/s | samples/s | avg MB/miss | notes |")
|
|
print("|---|---:|---:|---:|---:|---|")
|
|
print(
|
|
f"| SEQUENTIAL | {sequential['fetch_mbps']:.1f} | {sequential['fetch_episodes_s']:.2f} | - | "
|
|
f"{sequential['avg_mb_miss']:.1f} | 1 worker video fetch, parquet {sequential['parquet_s']:.2f}s |"
|
|
)
|
|
print(
|
|
f"| PARALLEL | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
|
|
f"{parallel['decode_samples_s']:.1f} | "
|
|
f"{sequential['avg_mb_miss']:.1f} | decode-only, decoder open "
|
|
f"{parallel['decoder_ms_miss']:.1f} ms/miss, parquet {parallel['parquet_s']:.2f}s |"
|
|
)
|
|
print(
|
|
f"| OVERLAPPED | - | - | {overlapped['samples_s']:.1f} | {sequential['avg_mb_miss']:.1f} | "
|
|
f"end-to-end; video {overlapped['video_samples_s']:.1f} samples/s "
|
|
f"({overlapped['video_wait_decode_s']:.2f}s), parquet {overlapped['parquet_samples_s']:.1f} "
|
|
f"samples/s ({overlapped['parquet_wait_s']:.2f}s) |"
|
|
)
|
|
|
|
|
|
def run_remote_strategy(
|
|
meta: LeRobotDatasetMetadata,
|
|
data_root: str,
|
|
args: argparse.Namespace,
|
|
parquet_reader: EpisodeParquetReader,
|
|
) -> None:
|
|
_log("starting_strategy: remote-decoder")
|
|
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
|
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
|
_log("remote-decoder: running direct source MP4 decoder")
|
|
result = run_remote_decoder(
|
|
meta,
|
|
data_root,
|
|
episodes,
|
|
timestamps,
|
|
frames_per_episode=args.frames_per_episode,
|
|
decode_workers=args.decode_workers,
|
|
parquet_reader=parquet_reader,
|
|
)
|
|
print("strategy: remote-decoder")
|
|
print(f"data_root: {data_root}")
|
|
print(f"episodes: {episodes}")
|
|
print(f"cameras: {list(meta.video_keys)}")
|
|
print()
|
|
print("| Track | samples/s | notes |")
|
|
print("|---|---:|---|")
|
|
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
|
|
print(
|
|
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
|
|
f"direct source MP4 decoder, {args.decode_workers} workers |"
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
data_root = args.data_root
|
|
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
|
assert_hf_hub_range_cache_branch()
|
|
|
|
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
|
meta.ensure_readable()
|
|
parquet_reader = EpisodeParquetReader(meta, data_root)
|
|
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
|
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
|
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
|
|
|
|
if sidecar_path is not None:
|
|
print(f"using_mp4_sidecar: {sidecar_path}")
|
|
|
|
if sidecar_path is not None and args.strategy == "both":
|
|
run_remote_strategy(meta, data_root, args, parquet_reader)
|
|
print()
|
|
run_indexed_strategy(
|
|
meta,
|
|
data_root,
|
|
args,
|
|
parquet_reader,
|
|
range_backend="native-http",
|
|
label="indexed-native-http-sidecar",
|
|
sidecar_path=str(sidecar_path),
|
|
)
|
|
print()
|
|
run_indexed_strategy(
|
|
meta,
|
|
data_root,
|
|
args,
|
|
parquet_reader,
|
|
range_backend="fsspec",
|
|
label="indexed-sidecar",
|
|
sidecar_path=str(sidecar_path),
|
|
)
|
|
return
|
|
if sidecar_path is not None and args.strategy == "indexed":
|
|
run_indexed_strategy(
|
|
meta,
|
|
data_root,
|
|
args,
|
|
parquet_reader,
|
|
range_backend="fsspec",
|
|
label="indexed-sidecar",
|
|
sidecar_path=str(sidecar_path),
|
|
)
|
|
return
|
|
if sidecar_path is not None and args.strategy == "native-http":
|
|
run_indexed_strategy(
|
|
meta,
|
|
data_root,
|
|
args,
|
|
parquet_reader,
|
|
range_backend="native-http",
|
|
label="indexed-native-http-sidecar",
|
|
sidecar_path=str(sidecar_path),
|
|
)
|
|
return
|
|
if args.strategy == "both":
|
|
expected_sidecar = SIDECAR_CACHE_DIR / f"molmoact2-{manifest_episode_count}.npz"
|
|
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/molmoact2-{manifest_episode_count}.npz")
|
|
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
|
print(f"mp4_sidecar_missing_remote: {expected_remote}")
|
|
print(
|
|
"build_mp4_sidecar: "
|
|
f"uv run --no-sync python scripts/build_mp4_sidecar.py --episodes {manifest_episode_count} "
|
|
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
|
)
|
|
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
|
print()
|
|
|
|
if args.strategy in ("both", "indexed"):
|
|
run_indexed_strategy(
|
|
meta,
|
|
data_root,
|
|
args,
|
|
parquet_reader,
|
|
range_backend="fsspec",
|
|
label="indexed",
|
|
sidecar_path=None,
|
|
)
|
|
if args.strategy == "both":
|
|
print()
|
|
if args.strategy in ("both", "remote-decoder"):
|
|
run_remote_strategy(meta, data_root, args, parquet_reader)
|
|
if args.strategy == "both":
|
|
print()
|
|
if args.strategy in ("both", "native-http"):
|
|
run_indexed_strategy(
|
|
meta,
|
|
data_root,
|
|
args,
|
|
parquet_reader,
|
|
range_backend="native-http",
|
|
label="indexed-native-http",
|
|
sidecar_path=None,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|