diff --git a/scripts/bench_episode_byte_cache.py b/scripts/bench_episode_byte_cache.py index 4b399532e..611f24f00 100644 --- a/scripts/bench_episode_byte_cache.py +++ b/scripts/bench_episode_byte_cache.py @@ -17,6 +17,7 @@ import threading import time from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from pathlib import Path import fsspec @@ -31,6 +32,7 @@ from lerobot.datasets.episode_video_streaming import ( EpisodeVideoManifest, NativeHTTPRangeFetcher, assert_hf_hub_range_cache_branch, + make_range_fetcher, ) from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec @@ -48,7 +50,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT) parser.add_argument( "--strategy", - choices=("both", "indexed", "remote-decoder", "native-http"), + choices=("both", "full", "indexed", "remote-decoder", "native-http", "random-frames"), default="both", help=argparse.SUPPRESS, ) @@ -60,6 +62,24 @@ def parse_args() -> argparse.Namespace: help="Limit manifest construction to the first N episodes for local smoke tests.", ) parser.add_argument("--pool-size", type=int, default=16) + parser.add_argument( + "--frame-pool-size", + type=int, + default=4096, + help="Number of random frame/camera targets for --strategy random-frames.", + ) + parser.add_argument( + "--coalesce-gap-kb", + type=int, + default=256, + help="Merge random-frame byte windows separated by at most this many KiB.", + ) + parser.add_argument( + "--random-frame-backend", + choices=("fsspec", "native-http"), + default="native-http", + help="Range backend for --strategy random-frames.", + ) parser.add_argument("--workers", type=int, default=8) parser.add_argument( "--include-decode", @@ -121,6 +141,132 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int: return total +@dataclass(frozen=True) +class FrameByteWindow: + file_id: int + file_path: str + byte_offset: int + byte_length: int + useful_bytes: int + sample_lo: int + sample_hi: int + target_sample: int + + +@dataclass(frozen=True) +class CoalescedByteRange: + file_id: int + file_path: str + byte_offset: int + byte_length: int + windows: int + useful_bytes: int + + +def _previous_sync_sample(sync_samples: np.ndarray, target_sample: int) -> int: + prev = sync_samples[sync_samples <= target_sample] + if len(prev): + return int(prev[-1]) + if len(sync_samples): + return int(sync_samples[0]) + return target_sample + + +def _frame_window_for_sample( + manifest: EpisodeVideoManifest, episode_index: int, camera_key: str, ts: float +) -> FrameByteWindow: + span = manifest.lookup(episode_index, camera_key) + file_record = manifest.file_lookup(span.file_id) + mp4 = file_record.mp4 + if len(mp4.sample_pts) == 0: + raise ValueError(f"{file_record.file_path} has no indexed samples") + target = int(np.searchsorted(mp4.sample_pts, ts, side="left")) + target = min(max(target, 0), len(mp4.sample_pts) - 1) + lo = _previous_sync_sample(mp4.sync_samples, target) + hi = max(target, lo) + offsets = mp4.sample_offsets[lo : hi + 1] + sizes = mp4.sample_sizes[lo : hi + 1] + byte_offset = int(offsets.min()) + byte_end = int((offsets + sizes).max()) + return FrameByteWindow( + file_id=span.file_id, + file_path=file_record.file_path, + byte_offset=byte_offset, + byte_length=byte_end - byte_offset, + useful_bytes=int(sizes.sum()), + sample_lo=lo, + sample_hi=hi, + target_sample=target, + ) + + +def _sample_frame_windows( + manifest: EpisodeVideoManifest, + *, + benchmark_episode_count: int, + frame_pool_size: int, + seed: int, +) -> list[FrameByteWindow]: + rng = random.Random(seed) + windows = [] + for _ in range(frame_pool_size): + ep = rng.randrange(benchmark_episode_count) + camera_key = rng.choice(manifest.video_keys) + span = manifest.lookup(ep, camera_key) + ts = rng.uniform(span.first_pts, max(span.last_pts, span.first_pts)) + windows.append(_frame_window_for_sample(manifest, ep, camera_key, ts)) + return windows + + +def _coalesce_windows(windows: Sequence[FrameByteWindow], gap_bytes: int) -> list[CoalescedByteRange]: + by_file: dict[int, list[FrameByteWindow]] = {} + for window in windows: + by_file.setdefault(window.file_id, []).append(window) + + ranges = [] + for file_id, file_windows in by_file.items(): + ordered = sorted(file_windows, key=lambda w: w.byte_offset) + current_start = ordered[0].byte_offset + current_end = ordered[0].byte_offset + ordered[0].byte_length + current_path = ordered[0].file_path + current_windows = 1 + current_useful_bytes = ordered[0].useful_bytes + for window in ordered[1:]: + start = window.byte_offset + end = window.byte_offset + window.byte_length + if start <= current_end + gap_bytes: + current_end = max(current_end, end) + current_windows += 1 + current_useful_bytes += window.useful_bytes + continue + ranges.append( + CoalescedByteRange( + file_id=file_id, + file_path=current_path, + byte_offset=current_start, + byte_length=current_end - current_start, + windows=current_windows, + useful_bytes=current_useful_bytes, + ) + ) + current_start = start + current_end = end + current_path = window.file_path + current_windows = 1 + current_useful_bytes = window.useful_bytes + ranges.append( + CoalescedByteRange( + file_id=file_id, + file_path=current_path, + byte_offset=current_start, + byte_length=current_end - current_start, + windows=current_windows, + useful_bytes=current_useful_bytes, + ) + ) + return ranges + + def _decode_all( cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int ) -> float: @@ -625,6 +771,106 @@ def run_indexed_strategy( ) +def run_random_frame_strategy( + meta: LeRobotDatasetMetadata, + data_root: str, + args: argparse.Namespace, + *, + range_backend: str = "native-http", + label: str = "random-frames", + sidecar_path: str | None = None, +) -> None: + if args.frame_pool_size <= 0: + raise ValueError(f"frame-pool-size must be > 0, got {args.frame_pool_size}") + _log(f"starting_strategy: {label}") + manifest_start = time.perf_counter() + dataset_episode_count = int(meta.total_episodes) + manifest_episode_count = args.manifest_episodes or dataset_episode_count + manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes) + manifest = EpisodeVideoManifest.build( + meta, + data_root, + episode_indices=range(manifest_episode_count), + range_backend=range_backend, + workers=args.workers, + max_probe_bytes=args.max_probe_mb * 1024 * 1024, + sidecar_path=sidecar_path, + ) + manifest_s = time.perf_counter() - manifest_start + _log(f"{label}: manifest_build_s={manifest_s:.2f}") + + benchmark_episode_count = min(dataset_episode_count, args.num_episodes) + window_start = time.perf_counter() + windows = _sample_frame_windows( + manifest, + benchmark_episode_count=benchmark_episode_count, + frame_pool_size=args.frame_pool_size, + seed=args.seed, + ) + window_s = time.perf_counter() - window_start + raw_bytes = sum(window.byte_length for window in windows) + useful_bytes = sum(window.useful_bytes for window in windows) + avg_decode_samples = sum(window.sample_hi - window.sample_lo + 1 for window in windows) / len(windows) + + coalesce_start = time.perf_counter() + coalesced = _coalesce_windows(windows, args.coalesce_gap_kb * 1024) + coalesce_s = time.perf_counter() - coalesce_start + coalesced_bytes = sum(item.byte_length for item in coalesced) + + _log( + f"{label}: fetching {len(coalesced)} coalesced ranges for {len(windows)} random frame targets " + f"({coalesced_bytes / 1024**2:.1f} MiB)" + ) + fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=args.workers) + + def read_range(item: CoalescedByteRange) -> int: + payload = fetcher.read_range(item.file_path, item.byte_offset, item.byte_length) + if len(payload) != item.byte_length: + raise OSError(f"Short read for {item.file_path}: expected {item.byte_length}, got {len(payload)}") + return len(payload) + + fetch_start = time.perf_counter() + try: + with ThreadPoolExecutor(max_workers=args.workers) as pool: + fetched_bytes = sum(pool.map(read_range, coalesced)) + finally: + fetcher.close() + fetch_s = time.perf_counter() - fetch_start + + print(f"manifest_build_s: {manifest_s:.2f}") + print(f"strategy: {label}") + print(f"range_backend: {range_backend}") + print(f"mp4_sidecar: {sidecar_path or 'none'}") + print(f"data_root: {data_root}") + print(f"dataset_episodes: {dataset_episode_count}") + print(f"benchmark_episodes: {benchmark_episode_count}") + print(f"frame_targets: {len(windows)}") + print(f"cameras: {manifest.video_keys}") + print(f"coalesce_gap_kb: {args.coalesce_gap_kb}") + print() + print( + "| Track | fetch MB/s | frame targets/s | wall s | raw MiB | coalesced MiB | " + "ranges | avg KiB/range | avg KiB/frame | notes |" + ) + print("|---|---:|---:|---:|---:|---:|---:|---:|---:|---|") + print( + f"| RANDOM FRAME WINDOWS | {fetched_bytes / fetch_s / 1024**2:.1f} | " + f"{len(windows) / fetch_s:.1f} | {fetch_s:.2f} | {raw_bytes / 1024**2:.1f} | " + f"{coalesced_bytes / 1024**2:.1f} | {len(coalesced)} | " + f"{coalesced_bytes / max(len(coalesced), 1) / 1024:.1f} | " + f"{coalesced_bytes / len(windows) / 1024:.1f} | " + f"{args.workers} workers, fetch-only, avg decode window {avg_decode_samples:.2f} samples |" + ) + print() + print("| Local Stage | wall s |") + print("|---|---:|") + print(f"| compute frame windows from sidecar | {window_s:.3f} |") + print(f"| coalesce byte windows | {coalesce_s:.3f} |") + print(f"| raw byte windows | {len(windows)} |") + print(f"| coalesced byte ranges | {len(coalesced)} |") + print(f"| useful sample MiB | {useful_bytes / 1024**2:.1f} |") + + def run_remote_strategy( meta: LeRobotDatasetMetadata, data_root: str, @@ -660,6 +906,8 @@ def run_remote_strategy( def main() -> None: args = parse_args() + if args.strategy == "full": + args.strategy = "both" data_root = args.data_root if data_root.startswith("hf://") and not args.no_hub_branch_assert: assert_hf_hub_range_cache_branch() @@ -697,6 +945,15 @@ def main() -> None: label="indexed-sidecar", sidecar_path=str(sidecar_path), ) + print() + run_random_frame_strategy( + meta, + data_root, + args, + range_backend=args.random_frame_backend, + label=f"random-frames-{args.random_frame_backend}-sidecar", + sidecar_path=str(sidecar_path), + ) return if sidecar_path is not None and args.strategy == "indexed": run_indexed_strategy( @@ -720,6 +977,16 @@ def main() -> None: sidecar_path=str(sidecar_path), ) return + if sidecar_path is not None and args.strategy == "random-frames": + run_random_frame_strategy( + meta, + data_root, + args, + range_backend=args.random_frame_backend, + label=f"random-frames-{args.random_frame_backend}-sidecar", + sidecar_path=str(sidecar_path), + ) + return if args.strategy == "both": expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}") @@ -759,6 +1026,17 @@ def main() -> None: label="indexed-native-http", sidecar_path=None, ) + if args.strategy == "both": + print() + if args.strategy in ("both", "random-frames"): + run_random_frame_strategy( + meta, + data_root, + args, + range_backend=args.random_frame_backend, + label=f"random-frames-{args.random_frame_backend}", + sidecar_path=None, + ) if __name__ == "__main__":