Add distributed episode pool benchmark summaries

This commit is contained in:
Pepijn
2026-06-22 17:08:02 +02:00
parent 9202fcea96
commit f2b5c4a47b
2 changed files with 188 additions and 10 deletions
+123 -10
View File
@@ -11,9 +11,11 @@
from __future__ import annotations
import argparse
import json
import os
import random
import resource
import socket
import tempfile
import threading
import time
@@ -116,6 +118,9 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--max-probe-mb", type=int, default=64)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--byte-budget-gb", type=float, default=80)
parser.add_argument("--distributed-shard-count", type=int, default=1)
parser.add_argument("--distributed-shard-index", type=int, default=0)
parser.add_argument("--summary-json", default=None)
parser.add_argument(
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
)
@@ -123,12 +128,46 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
def _episode_shard(
total: int,
requested: int,
seed: int,
*,
shard_count: int = 1,
shard_index: int = 0,
) -> list[int]:
rng = random.Random(seed)
upper = min(total, requested)
if pool_size > upper:
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
return rng.sample(range(upper), pool_size)
if shard_count < 1:
raise ValueError(f"distributed-shard-count must be >= 1, got {shard_count}")
if shard_index < 0 or shard_index >= shard_count:
raise ValueError(f"distributed-shard-index must be in [0, {shard_count}), got {shard_index}")
permutation = list(range(upper))
rng.shuffle(permutation)
return permutation[shard_index::shard_count]
def _episode_pool(
total: int,
requested: int,
pool_size: int,
seed: int,
*,
shard_count: int = 1,
shard_index: int = 0,
) -> list[int]:
shard = _episode_shard(
total,
requested,
seed,
shard_count=shard_count,
shard_index=shard_index,
)
if pool_size > len(shard):
raise ValueError(
f"pool-size={pool_size} exceeds shard episodes={len(shard)} for shard {shard_index}/{shard_count}"
)
return shard[:pool_size]
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
@@ -308,16 +347,27 @@ def run_pool_stream_simulation(
target_samples_s: float,
samples_per_episode: int,
prefetch_episodes: int,
shard_count: int,
shard_index: int,
shard_seed: int,
batch_size: int,
decode_workers: int,
seed: int,
) -> dict[str, float]:
rng = random.Random(seed)
upper = min(dataset_episode_count, num_episodes)
resident = list(resident_episodes)
resident_set = set(resident)
candidates = [ep for ep in range(upper) if ep not in resident_set]
rng.shuffle(candidates)
candidates = [
ep
for ep in _episode_shard(
dataset_episode_count,
num_episodes,
shard_seed,
shard_count=shard_count,
shard_index=shard_index,
)
if ep not in resident_set
]
replacements = iter(candidates)
pending: list[int] = []
@@ -480,6 +530,17 @@ def _print_memory_summary(start: dict[str, float | None], end: dict[str, float |
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
def _write_summary_json(path: str | None, payload: dict) -> None:
if path is None:
return
out = Path(path).expanduser()
out.parent.mkdir(parents=True, exist_ok=True)
tmp = out.with_suffix(out.suffix + ".tmp")
tmp.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
tmp.replace(out)
print(f"summary_json: {out}")
def _root_join(data_root: str, relative_path: str) -> str:
if data_root.startswith("hf://"):
return f"{data_root.rstrip('/')}/{relative_path}"
@@ -655,6 +716,9 @@ def run_fetch_pool(
target_samples_s=args.target_samples_s,
samples_per_episode=args.pool_samples_per_episode,
prefetch_episodes=args.stream_prefetch_episodes,
shard_count=args.distributed_shard_count,
shard_index=args.distributed_shard_index,
shard_seed=args.seed,
batch_size=args.batch_size,
decode_workers=args.decode_workers,
seed=args.seed + 4,
@@ -1021,7 +1085,14 @@ def run_indexed_strategy(
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
episodes = _episode_pool(
dataset_episode_count,
args.num_episodes,
args.pool_size,
args.seed,
shard_count=args.distributed_shard_count,
shard_index=args.distributed_shard_index,
)
byte_budget = int(args.byte_budget_gb * 1024**3)
byte_count = _bytes_for(manifest, episodes)
_log(
@@ -1051,6 +1122,8 @@ def run_indexed_strategy(
print(f"data_root: {data_root}")
print(f"dataset_episodes: {dataset_episode_count}")
print(f"benchmark_episodes: {benchmark_episode_count}")
print(f"distributed_shard_count: {args.distributed_shard_count}")
print(f"distributed_shard_index: {args.distributed_shard_index}")
print(f"pool_episodes: {len(episodes)}")
print(f"sampled_episodes: {episodes}")
print(f"cameras: {manifest.video_keys}")
@@ -1123,7 +1196,40 @@ def run_indexed_strategy(
f"| stream min unique episodes/batch | "
f"{fetch_pool['pool_stream_stream_min_unique_episodes_per_batch']:.0f} |"
)
_print_memory_summary(memory_start, _memory_snapshot())
memory_end = _memory_snapshot()
_print_memory_summary(memory_start, memory_end)
summary = {
"hostname": socket.gethostname(),
"strategy": label,
"range_backend": range_backend,
"data_root": data_root,
"mp4_sidecar": sidecar_path,
"dataset_episodes": dataset_episode_count,
"benchmark_episodes": benchmark_episode_count,
"distributed_shard_count": args.distributed_shard_count,
"distributed_shard_index": args.distributed_shard_index,
"pool_episodes": len(episodes),
"sampled_episodes": episodes,
"workers": args.workers,
"decode_workers": args.decode_workers,
"manifest_build_s": manifest_s,
"fetch_bytes": byte_count,
"fetch_gib": byte_count / 1024**3,
"fetch_s": fetch_pool["fetch_s"],
"fetch_mib_s": fetch_pool["fetch_mbps"],
"fetch_episodes_s": fetch_pool["fetch_episodes_s"],
"avg_mb_camera": fetch_pool["avg_mb_miss"],
"range_reads": fetch_pool.get("range_jobs", 0.0),
"range_hffs_get_exceptions": fetch_pool.get("range_hffs_get_exception_attempts", 0.0),
"range_hffs_get_retries": fetch_pool.get("range_hffs_get_retries", 0.0),
"rss_start_mib": memory_start["rss_mib"],
"rss_end_mib": memory_end["rss_mib"],
"peak_rss_mib": memory_end["peak_rss_mib"],
}
for key, value in fetch_pool.items():
if key.startswith("pool_decode_") or key.startswith("pool_stream_"):
summary[key] = value
_write_summary_json(args.summary_json, summary)
if args.include_decode:
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
@@ -1178,7 +1284,14 @@ def run_remote_strategy(
parquet_reader: EpisodeParquetReader,
) -> None:
_log("starting_strategy: remote-decoder")
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
episodes = _episode_pool(
int(meta.total_episodes),
args.num_episodes,
args.pool_size,
args.seed,
shard_count=args.distributed_shard_count,
shard_index=args.distributed_shard_index,
)
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
_log("remote-decoder: running direct source MP4 decoder")
result = run_remote_decoder(
+65
View File
@@ -0,0 +1,65 @@
#!/usr/bin/env python
from __future__ import annotations
import argparse
import json
from pathlib import Path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Summarize distributed episode pool benchmark JSON files.")
parser.add_argument("summaries", nargs="+", help="Rank summary JSON files.")
return parser.parse_args()
def _load(path: str) -> dict:
return json.loads(Path(path).read_text())
def _fmt(value: float) -> str:
return f"{value:.1f}"
def main() -> None:
args = parse_args()
rows = [_load(path) for path in args.summaries]
rows.sort(key=lambda row: int(row.get("distributed_shard_index", 0)))
total_bytes = sum(float(row.get("fetch_bytes", 0.0)) for row in rows)
max_fetch_s = max(float(row.get("fetch_s", 0.0)) for row in rows)
aggregate_mib_s = total_bytes / max_fetch_s / 1024**2 if max_fetch_s > 0 else float("inf")
summed_rank_mib_s = sum(float(row.get("fetch_mib_s", 0.0)) for row in rows)
total_decode_samples_s = sum(float(row.get("pool_decode_training_samples_s", 0.0)) for row in rows)
total_stream_samples_s = sum(float(row.get("pool_stream_actual_samples_s", 0.0)) for row in rows)
kept_up = all(bool(row.get("pool_stream_kept_up", 0.0)) for row in rows)
print("| Aggregate | value |")
print("|---|---:|")
print(f"| ranks | {len(rows)} |")
print(f"| total fetched GiB | {total_bytes / 1024**3:.2f} |")
print(f"| aggregate fetch MiB/s | {_fmt(aggregate_mib_s)} |")
print(f"| summed rank fetch MiB/s | {_fmt(summed_rank_mib_s)} |")
if total_decode_samples_s:
print(f"| aggregate resident decode samples/s | {_fmt(total_decode_samples_s)} |")
if total_stream_samples_s:
print(f"| aggregate stream samples/s | {_fmt(total_stream_samples_s)} |")
print(f"| all ranks kept up | {'yes' if kept_up else 'no'} |")
print()
print("| Rank | host | fetch MiB/s | fetch s | GiB | decode samples/s | stream samples/s | kept up |")
print("|---:|---|---:|---:|---:|---:|---:|---|")
for row in rows:
rank = int(row.get("distributed_shard_index", 0))
print(
f"| {rank} | {row.get('hostname', '')} | "
f"{_fmt(float(row.get('fetch_mib_s', 0.0)))} | "
f"{_fmt(float(row.get('fetch_s', 0.0)))} | "
f"{float(row.get('fetch_gib', 0.0)):.2f} | "
f"{_fmt(float(row.get('pool_decode_training_samples_s', 0.0)))} | "
f"{_fmt(float(row.get('pool_stream_actual_samples_s', 0.0)))} | "
f"{'yes' if row.get('pool_stream_kept_up', 0.0) else 'no'} |"
)
if __name__ == "__main__":
main()