diff --git a/scripts/bench_episode_byte_cache.py b/scripts/bench_episode_byte_cache.py index 453f38139..01b14a319 100644 --- a/scripts/bench_episode_byte_cache.py +++ b/scripts/bench_episode_byte_cache.py @@ -11,9 +11,11 @@ from __future__ import annotations import argparse +import json import os import random import resource +import socket import tempfile import threading import time @@ -116,6 +118,9 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--max-probe-mb", type=int, default=64) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--byte-budget-gb", type=float, default=80) + parser.add_argument("--distributed-shard-count", type=int, default=1) + parser.add_argument("--distributed-shard-index", type=int, default=0) + parser.add_argument("--summary-json", default=None) parser.add_argument( "--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory." ) @@ -123,12 +128,46 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]: +def _episode_shard( + total: int, + requested: int, + seed: int, + *, + shard_count: int = 1, + shard_index: int = 0, +) -> list[int]: rng = random.Random(seed) upper = min(total, requested) - if pool_size > upper: - raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}") - return rng.sample(range(upper), pool_size) + if shard_count < 1: + raise ValueError(f"distributed-shard-count must be >= 1, got {shard_count}") + if shard_index < 0 or shard_index >= shard_count: + raise ValueError(f"distributed-shard-index must be in [0, {shard_count}), got {shard_index}") + permutation = list(range(upper)) + rng.shuffle(permutation) + return permutation[shard_index::shard_count] + + +def _episode_pool( + total: int, + requested: int, + pool_size: int, + seed: int, + *, + shard_count: int = 1, + shard_index: int = 0, +) -> list[int]: + shard = _episode_shard( + total, + requested, + seed, + shard_count=shard_count, + shard_index=shard_index, + ) + if pool_size > len(shard): + raise ValueError( + f"pool-size={pool_size} exceeds shard episodes={len(shard)} for shard {shard_index}/{shard_count}" + ) + return shard[:pool_size] def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int): @@ -308,16 +347,27 @@ def run_pool_stream_simulation( target_samples_s: float, samples_per_episode: int, prefetch_episodes: int, + shard_count: int, + shard_index: int, + shard_seed: int, batch_size: int, decode_workers: int, seed: int, ) -> dict[str, float]: rng = random.Random(seed) - upper = min(dataset_episode_count, num_episodes) resident = list(resident_episodes) resident_set = set(resident) - candidates = [ep for ep in range(upper) if ep not in resident_set] - rng.shuffle(candidates) + candidates = [ + ep + for ep in _episode_shard( + dataset_episode_count, + num_episodes, + shard_seed, + shard_count=shard_count, + shard_index=shard_index, + ) + if ep not in resident_set + ] replacements = iter(candidates) pending: list[int] = [] @@ -480,6 +530,17 @@ def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | print(f"| peak rss | {end['peak_rss_mib']:.1f} |") +def _write_summary_json(path: str | None, payload: dict) -> None: + if path is None: + return + out = Path(path).expanduser() + out.parent.mkdir(parents=True, exist_ok=True) + tmp = out.with_suffix(out.suffix + ".tmp") + tmp.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n") + tmp.replace(out) + print(f"summary_json: {out}") + + def _root_join(data_root: str, relative_path: str) -> str: if data_root.startswith("hf://"): return f"{data_root.rstrip('/')}/{relative_path}" @@ -655,6 +716,9 @@ def run_fetch_pool( target_samples_s=args.target_samples_s, samples_per_episode=args.pool_samples_per_episode, prefetch_episodes=args.stream_prefetch_episodes, + shard_count=args.distributed_shard_count, + shard_index=args.distributed_shard_index, + shard_seed=args.seed, batch_size=args.batch_size, decode_workers=args.decode_workers, seed=args.seed + 4, @@ -1021,7 +1085,14 @@ def run_indexed_strategy( _log(f"{label}: manifest_build_s={manifest_s:.2f}") benchmark_episode_count = min(dataset_episode_count, args.num_episodes) - episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed) + episodes = _episode_pool( + dataset_episode_count, + args.num_episodes, + args.pool_size, + args.seed, + shard_count=args.distributed_shard_count, + shard_index=args.distributed_shard_index, + ) byte_budget = int(args.byte_budget_gb * 1024**3) byte_count = _bytes_for(manifest, episodes) _log( @@ -1051,6 +1122,8 @@ def run_indexed_strategy( print(f"data_root: {data_root}") print(f"dataset_episodes: {dataset_episode_count}") print(f"benchmark_episodes: {benchmark_episode_count}") + print(f"distributed_shard_count: {args.distributed_shard_count}") + print(f"distributed_shard_index: {args.distributed_shard_index}") print(f"pool_episodes: {len(episodes)}") print(f"sampled_episodes: {episodes}") print(f"cameras: {manifest.video_keys}") @@ -1123,7 +1196,40 @@ def run_indexed_strategy( f"| stream min unique episodes/batch | " f"{fetch_pool['pool_stream_stream_min_unique_episodes_per_batch']:.0f} |" ) - _print_memory_summary(memory_start, _memory_snapshot()) + memory_end = _memory_snapshot() + _print_memory_summary(memory_start, memory_end) + summary = { + "hostname": socket.gethostname(), + "strategy": label, + "range_backend": range_backend, + "data_root": data_root, + "mp4_sidecar": sidecar_path, + "dataset_episodes": dataset_episode_count, + "benchmark_episodes": benchmark_episode_count, + "distributed_shard_count": args.distributed_shard_count, + "distributed_shard_index": args.distributed_shard_index, + "pool_episodes": len(episodes), + "sampled_episodes": episodes, + "workers": args.workers, + "decode_workers": args.decode_workers, + "manifest_build_s": manifest_s, + "fetch_bytes": byte_count, + "fetch_gib": byte_count / 1024**3, + "fetch_s": fetch_pool["fetch_s"], + "fetch_mib_s": fetch_pool["fetch_mbps"], + "fetch_episodes_s": fetch_pool["fetch_episodes_s"], + "avg_mb_camera": fetch_pool["avg_mb_miss"], + "range_reads": fetch_pool.get("range_jobs", 0.0), + "range_hffs_get_exceptions": fetch_pool.get("range_hffs_get_exception_attempts", 0.0), + "range_hffs_get_retries": fetch_pool.get("range_hffs_get_retries", 0.0), + "rss_start_mib": memory_start["rss_mib"], + "rss_end_mib": memory_end["rss_mib"], + "peak_rss_mib": memory_end["peak_rss_mib"], + } + for key, value in fetch_pool.items(): + if key.startswith("pool_decode_") or key.startswith("pool_stream_"): + summary[key] = value + _write_summary_json(args.summary_json, summary) if args.include_decode: timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1) @@ -1178,7 +1284,14 @@ def run_remote_strategy( parquet_reader: EpisodeParquetReader, ) -> None: _log("starting_strategy: remote-decoder") - episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed) + episodes = _episode_pool( + int(meta.total_episodes), + args.num_episodes, + args.pool_size, + args.seed, + shard_count=args.distributed_shard_count, + shard_index=args.distributed_shard_index, + ) timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1) _log("remote-decoder: running direct source MP4 decoder") result = run_remote_decoder( diff --git a/scripts/summarize_episode_pool_bench.py b/scripts/summarize_episode_pool_bench.py new file mode 100644 index 000000000..e9308c92f --- /dev/null +++ b/scripts/summarize_episode_pool_bench.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Summarize distributed episode pool benchmark JSON files.") + parser.add_argument("summaries", nargs="+", help="Rank summary JSON files.") + return parser.parse_args() + + +def _load(path: str) -> dict: + return json.loads(Path(path).read_text()) + + +def _fmt(value: float) -> str: + return f"{value:.1f}" + + +def main() -> None: + args = parse_args() + rows = [_load(path) for path in args.summaries] + rows.sort(key=lambda row: int(row.get("distributed_shard_index", 0))) + total_bytes = sum(float(row.get("fetch_bytes", 0.0)) for row in rows) + max_fetch_s = max(float(row.get("fetch_s", 0.0)) for row in rows) + aggregate_mib_s = total_bytes / max_fetch_s / 1024**2 if max_fetch_s > 0 else float("inf") + summed_rank_mib_s = sum(float(row.get("fetch_mib_s", 0.0)) for row in rows) + total_decode_samples_s = sum(float(row.get("pool_decode_training_samples_s", 0.0)) for row in rows) + total_stream_samples_s = sum(float(row.get("pool_stream_actual_samples_s", 0.0)) for row in rows) + kept_up = all(bool(row.get("pool_stream_kept_up", 0.0)) for row in rows) + + print("| Aggregate | value |") + print("|---|---:|") + print(f"| ranks | {len(rows)} |") + print(f"| total fetched GiB | {total_bytes / 1024**3:.2f} |") + print(f"| aggregate fetch MiB/s | {_fmt(aggregate_mib_s)} |") + print(f"| summed rank fetch MiB/s | {_fmt(summed_rank_mib_s)} |") + if total_decode_samples_s: + print(f"| aggregate resident decode samples/s | {_fmt(total_decode_samples_s)} |") + if total_stream_samples_s: + print(f"| aggregate stream samples/s | {_fmt(total_stream_samples_s)} |") + print(f"| all ranks kept up | {'yes' if kept_up else 'no'} |") + + print() + print("| Rank | host | fetch MiB/s | fetch s | GiB | decode samples/s | stream samples/s | kept up |") + print("|---:|---|---:|---:|---:|---:|---:|---|") + for row in rows: + rank = int(row.get("distributed_shard_index", 0)) + print( + f"| {rank} | {row.get('hostname', '')} | " + f"{_fmt(float(row.get('fetch_mib_s', 0.0)))} | " + f"{_fmt(float(row.get('fetch_s', 0.0)))} | " + f"{float(row.get('fetch_gib', 0.0)):.2f} | " + f"{_fmt(float(row.get('pool_decode_training_samples_s', 0.0)))} | " + f"{_fmt(float(row.get('pool_stream_actual_samples_s', 0.0)))} | " + f"{'yes' if row.get('pool_stream_kept_up', 0.0) else 'no'} |" + ) + + +if __name__ == "__main__": + main()