mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 19:57:27 +00:00
Add GOP window range benchmark
This commit is contained in:
@@ -18,7 +18,8 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
@@ -33,6 +34,7 @@ from lerobot.datasets.episode_video_streaming import (
|
||||
EpisodeVideoManifest,
|
||||
NativeHTTPRangeFetcher,
|
||||
assert_hf_hub_range_cache_branch,
|
||||
make_range_fetcher,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||
|
||||
@@ -50,7 +52,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
|
||||
choices=("both", "full", "indexed", "remote-decoder", "native-http", "gop-window"),
|
||||
default="both",
|
||||
help=argparse.SUPPRESS,
|
||||
)
|
||||
@@ -103,6 +105,23 @@ def parse_args() -> argparse.Namespace:
|
||||
action="store_true",
|
||||
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-gop-window",
|
||||
action="store_true",
|
||||
help="Also benchmark random frame GOP/window byte-range fetches from the MP4 sidecar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gop-window-post-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Extra compressed samples after each target frame to include in GOP/window ranges.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gop-window-merge-gap-kb",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Merge GOP/window ranges from the same MP4 when the byte gap is at most this many KiB.",
|
||||
)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||
@@ -158,6 +177,120 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
return total
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GopWindowRange:
|
||||
file_path: str
|
||||
offset: int
|
||||
length: int
|
||||
target_frames: int
|
||||
covered_samples: int
|
||||
|
||||
|
||||
def _sample_bounds_for_episode(manifest: EpisodeVideoManifest, episode_index: int, camera_key: str):
|
||||
span = manifest.lookup(episode_index, camera_key)
|
||||
mp4 = manifest.file_lookup(span.file_id).mp4
|
||||
sample_count = len(mp4.sample_pts)
|
||||
if sample_count == 0:
|
||||
raise ValueError(f"{mp4.file_path} contains no indexed samples")
|
||||
lo = int(np.searchsorted(mp4.sample_pts, span.first_pts, side="left"))
|
||||
hi = int(np.searchsorted(mp4.sample_pts, span.last_pts, side="right")) - 1
|
||||
lo = min(max(lo, 0), sample_count - 1)
|
||||
hi = min(max(hi, lo), sample_count - 1)
|
||||
return span, mp4, lo, hi
|
||||
|
||||
|
||||
def _byte_range_for_samples(mp4, sample_lo: int, sample_hi: int, *, file_size: int) -> tuple[int, int]:
|
||||
offsets = mp4.sample_offsets[sample_lo : sample_hi + 1]
|
||||
sizes = mp4.sample_sizes[sample_lo : sample_hi + 1]
|
||||
byte_lo = int(offsets.min())
|
||||
byte_hi = int((offsets + sizes).max())
|
||||
byte_hi = min(byte_hi, file_size)
|
||||
return byte_lo, byte_hi - byte_lo
|
||||
|
||||
|
||||
def _gop_window_for_target_sample(
|
||||
manifest: EpisodeVideoManifest,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
target_sample: int,
|
||||
*,
|
||||
post_frames: int,
|
||||
) -> GopWindowRange:
|
||||
span = manifest.lookup(episode_index, camera_key)
|
||||
file_record = manifest.file_lookup(span.file_id)
|
||||
mp4 = file_record.mp4
|
||||
sync = mp4.sync_samples[mp4.sync_samples <= target_sample]
|
||||
sample_lo = int(sync[-1]) if len(sync) else 0
|
||||
sample_hi = min(max(target_sample + post_frames, sample_lo), span.sample_hi, len(mp4.sample_pts) - 1)
|
||||
offset, length = _byte_range_for_samples(mp4, sample_lo, sample_hi, file_size=file_record.file_size)
|
||||
return GopWindowRange(
|
||||
file_path=file_record.file_path,
|
||||
offset=offset,
|
||||
length=length,
|
||||
target_frames=1,
|
||||
covered_samples=sample_hi - sample_lo + 1,
|
||||
)
|
||||
|
||||
|
||||
def _gop_window_ranges(
|
||||
manifest: EpisodeVideoManifest,
|
||||
episodes: Sequence[int],
|
||||
*,
|
||||
frames_per_episode: int,
|
||||
seed: int,
|
||||
post_frames: int,
|
||||
merge_gap_bytes: int,
|
||||
) -> tuple[list[GopWindowRange], int, int, int]:
|
||||
rng = random.Random(seed)
|
||||
raw: list[GopWindowRange] = []
|
||||
compressed_target_bytes = 0
|
||||
covered_samples = 0
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
span, mp4, target_lo, target_hi = _sample_bounds_for_episode(manifest, ep, camera_key)
|
||||
for _ in range(frames_per_episode):
|
||||
ts = rng.uniform(span.first_pts, max(span.last_pts, span.first_pts))
|
||||
target = int(np.searchsorted(mp4.sample_pts, ts, side="left"))
|
||||
target = min(max(target, target_lo), target_hi)
|
||||
compressed_target_bytes += int(mp4.sample_sizes[target])
|
||||
window = _gop_window_for_target_sample(
|
||||
manifest,
|
||||
ep,
|
||||
camera_key,
|
||||
target,
|
||||
post_frames=post_frames,
|
||||
)
|
||||
covered_samples += window.covered_samples
|
||||
raw.append(window)
|
||||
|
||||
merged = _merge_gop_window_ranges(raw, merge_gap_bytes)
|
||||
return merged, len(raw), compressed_target_bytes, covered_samples
|
||||
|
||||
|
||||
def _merge_gop_window_ranges(ranges: Sequence[GopWindowRange], merge_gap_bytes: int) -> list[GopWindowRange]:
|
||||
if not ranges:
|
||||
return []
|
||||
ordered = sorted(ranges, key=lambda item: (item.file_path, item.offset, item.length))
|
||||
merged: list[GopWindowRange] = []
|
||||
current = ordered[0]
|
||||
for item in ordered[1:]:
|
||||
current_end = current.offset + current.length
|
||||
if item.file_path == current.file_path and item.offset <= current_end + merge_gap_bytes:
|
||||
new_end = max(current_end, item.offset + item.length)
|
||||
current = GopWindowRange(
|
||||
file_path=current.file_path,
|
||||
offset=current.offset,
|
||||
length=new_end - current.offset,
|
||||
target_frames=current.target_frames + item.target_frames,
|
||||
covered_samples=current.covered_samples + item.covered_samples,
|
||||
)
|
||||
else:
|
||||
merged.append(current)
|
||||
current = item
|
||||
merged.append(current)
|
||||
return merged
|
||||
|
||||
|
||||
def _decode_all(
|
||||
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
||||
) -> float:
|
||||
@@ -426,6 +559,87 @@ def run_fetch_pool(
|
||||
return result
|
||||
|
||||
|
||||
def run_gop_window_fetch(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
workers: int,
|
||||
range_backend: str,
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, float]:
|
||||
merge_gap_bytes = int(args.gop_window_merge_gap_kb * 1024)
|
||||
windows, raw_windows, compressed_target_bytes, covered_samples = _gop_window_ranges(
|
||||
manifest,
|
||||
episodes,
|
||||
frames_per_episode=args.frames_per_episode,
|
||||
seed=args.seed + 2,
|
||||
post_frames=args.gop_window_post_frames,
|
||||
merge_gap_bytes=merge_gap_bytes,
|
||||
)
|
||||
if not windows:
|
||||
raise ValueError("No GOP/window ranges were planned")
|
||||
|
||||
fetcher = make_range_fetcher(
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
native_http_connections=args.native_http_connections,
|
||||
native_http_timeout=args.native_http_timeout,
|
||||
native_http_retries=args.native_http_retries,
|
||||
)
|
||||
|
||||
def fetch_window(window: GopWindowRange) -> int:
|
||||
payload = fetcher.read_range(window.file_path, window.offset, window.length)
|
||||
if len(payload) != window.length:
|
||||
raise OSError(f"Short read for {window.file_path}: expected {window.length}, got {len(payload)}")
|
||||
return len(payload)
|
||||
|
||||
byte_count = sum(window.length for window in windows)
|
||||
start = time.perf_counter()
|
||||
done = 0
|
||||
done_ranges = 0
|
||||
last_progress = start
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(fetch_window, window) for window in windows]
|
||||
for future in as_completed(futures):
|
||||
done += future.result()
|
||||
done_ranges += 1
|
||||
now = time.perf_counter()
|
||||
if args.progress_interval > 0 and now - last_progress >= args.progress_interval:
|
||||
elapsed = max(now - start, 1e-9)
|
||||
_log(
|
||||
"gop_window_progress: "
|
||||
f"ranges_done={done_ranges}/{len(windows)} "
|
||||
f"fetched={done / 1024**3:.2f} GiB "
|
||||
f"fetch={done / elapsed / 1024**2:.1f} MiB/s "
|
||||
f"elapsed={_format_duration(elapsed)}"
|
||||
)
|
||||
last_progress = now
|
||||
finally:
|
||||
timings = fetcher.timing_summary() if hasattr(fetcher, "timing_summary") else {}
|
||||
fetcher.close()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
result = {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"frame_windows_s": raw_windows / elapsed,
|
||||
"ranges_s": len(windows) / elapsed,
|
||||
"bytes": float(byte_count),
|
||||
"raw_windows": float(raw_windows),
|
||||
"merged_windows": float(len(windows)),
|
||||
"compressed_target_bytes": float(compressed_target_bytes),
|
||||
"covered_samples": float(covered_samples),
|
||||
"avg_mb_range": byte_count / len(windows) / 1024**2,
|
||||
"avg_kib_frame_window": byte_count / raw_windows / 1024,
|
||||
"avg_compressed_kib_target": compressed_target_bytes / raw_windows / 1024,
|
||||
"avg_covered_samples": covered_samples / raw_windows,
|
||||
}
|
||||
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||
return result
|
||||
|
||||
|
||||
def run_parallel(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
@@ -856,6 +1070,80 @@ def run_indexed_strategy(
|
||||
)
|
||||
|
||||
|
||||
def run_gop_window_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
*,
|
||||
range_backend: str = "fsspec",
|
||||
sidecar_path: str | None = None,
|
||||
) -> None:
|
||||
_log("starting_strategy: gop-window")
|
||||
memory_start = _memory_snapshot()
|
||||
manifest_start = time.perf_counter()
|
||||
dataset_episode_count = int(meta.total_episodes)
|
||||
manifest_episode_count = args.manifest_episodes or dataset_episode_count
|
||||
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
|
||||
manifest = EpisodeVideoManifest.build(
|
||||
meta,
|
||||
data_root,
|
||||
episode_indices=range(manifest_episode_count),
|
||||
range_backend=range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
sidecar_path=sidecar_path,
|
||||
)
|
||||
manifest_s = time.perf_counter() - manifest_start
|
||||
_log(f"gop-window: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
|
||||
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
|
||||
full_episode_bytes = _bytes_for(manifest, episodes)
|
||||
result = run_gop_window_fetch(manifest, data_root, episodes, args.workers, range_backend, args)
|
||||
estimated_benchmark_s = benchmark_episode_count * args.frames_per_episode / result["frame_windows_s"]
|
||||
estimated_dataset_s = dataset_episode_count * args.frames_per_episode / result["frame_windows_s"]
|
||||
|
||||
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||
print("strategy: gop-window")
|
||||
print(f"range_backend: {range_backend}")
|
||||
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"dataset_episodes: {dataset_episode_count}")
|
||||
print(f"benchmark_episodes: {benchmark_episode_count}")
|
||||
print(f"pool_episodes: {len(episodes)}")
|
||||
print(f"frames_per_episode: {args.frames_per_episode}")
|
||||
print(f"gop_window_post_frames: {args.gop_window_post_frames}")
|
||||
print(f"gop_window_merge_gap_kb: {args.gop_window_merge_gap_kb}")
|
||||
print(f"sampled_episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
print()
|
||||
print(
|
||||
"| Track | fetch MB/s | frame windows/s | ranges/s | wall s | "
|
||||
"est benchmark | est full dataset | notes |"
|
||||
)
|
||||
print("|---|---:|---:|---:|---:|---:|---:|---|")
|
||||
print(
|
||||
f"| GOP/WINDOW FETCH | {result['fetch_mbps']:.1f} | {result['frame_windows_s']:.1f} | "
|
||||
f"{result['ranges_s']:.1f} | {result['fetch_s']:.2f} | "
|
||||
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
|
||||
f"{args.workers} workers, fetch-and-drop, no decoder open/frame decode |"
|
||||
)
|
||||
print()
|
||||
print("| GOP Window Shape | value |")
|
||||
print("|---|---:|")
|
||||
print(f"| target frame windows | {result['raw_windows']:.0f} |")
|
||||
print(f"| fetched byte ranges | {result['merged_windows']:.0f} |")
|
||||
print(f"| fetched GiB | {result['bytes'] / 1024**3:.2f} |")
|
||||
print(f"| full episode-pool GiB | {full_episode_bytes / 1024**3:.2f} |")
|
||||
print(f"| fetched/full episode bytes | {result['bytes'] / full_episode_bytes:.3f} |")
|
||||
print(f"| avg MiB/range | {result['avg_mb_range']:.3f} |")
|
||||
print(f"| avg KiB/frame window | {result['avg_kib_frame_window']:.1f} |")
|
||||
print(f"| avg compressed KiB/target frame | {result['avg_compressed_kib_target']:.1f} |")
|
||||
print(f"| avg compressed samples/window | {result['avg_covered_samples']:.1f} |")
|
||||
_print_range_timing_summary(result)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
|
||||
|
||||
def run_remote_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
@@ -925,6 +1213,15 @@ def main() -> None:
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
if args.include_gop_window:
|
||||
print()
|
||||
run_gop_window_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
range_backend=args.range_backend,
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "indexed":
|
||||
run_indexed_strategy(
|
||||
@@ -936,6 +1233,15 @@ def main() -> None:
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
if args.include_gop_window:
|
||||
print()
|
||||
run_gop_window_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
range_backend=args.range_backend,
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "native-http":
|
||||
run_indexed_strategy(
|
||||
@@ -948,7 +1254,16 @@ def main() -> None:
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy == "both":
|
||||
if sidecar_path is not None and args.strategy == "gop-window":
|
||||
run_gop_window_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
range_backend=args.range_backend,
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy in ("both", "gop-window"):
|
||||
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
|
||||
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||
@@ -958,6 +1273,9 @@ def main() -> None:
|
||||
"uv run --no-sync python scripts/build_mp4_sidecar.py "
|
||||
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||
)
|
||||
if args.strategy == "gop-window":
|
||||
print("gop_window_requires_mp4_sidecar: existing per-sample MP4 index sidecar is required")
|
||||
return
|
||||
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||
print()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user