mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 19:27:08 +00:00
Add distributed episode pool benchmark summaries
This commit is contained in:
@@ -11,9 +11,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import resource
|
||||
import socket
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
@@ -116,6 +118,9 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
||||
parser.add_argument("--distributed-shard-count", type=int, default=1)
|
||||
parser.add_argument("--distributed-shard-index", type=int, default=0)
|
||||
parser.add_argument("--summary-json", default=None)
|
||||
parser.add_argument(
|
||||
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
||||
)
|
||||
@@ -123,12 +128,46 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
||||
def _episode_shard(
|
||||
total: int,
|
||||
requested: int,
|
||||
seed: int,
|
||||
*,
|
||||
shard_count: int = 1,
|
||||
shard_index: int = 0,
|
||||
) -> list[int]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(total, requested)
|
||||
if pool_size > upper:
|
||||
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
||||
return rng.sample(range(upper), pool_size)
|
||||
if shard_count < 1:
|
||||
raise ValueError(f"distributed-shard-count must be >= 1, got {shard_count}")
|
||||
if shard_index < 0 or shard_index >= shard_count:
|
||||
raise ValueError(f"distributed-shard-index must be in [0, {shard_count}), got {shard_index}")
|
||||
permutation = list(range(upper))
|
||||
rng.shuffle(permutation)
|
||||
return permutation[shard_index::shard_count]
|
||||
|
||||
|
||||
def _episode_pool(
|
||||
total: int,
|
||||
requested: int,
|
||||
pool_size: int,
|
||||
seed: int,
|
||||
*,
|
||||
shard_count: int = 1,
|
||||
shard_index: int = 0,
|
||||
) -> list[int]:
|
||||
shard = _episode_shard(
|
||||
total,
|
||||
requested,
|
||||
seed,
|
||||
shard_count=shard_count,
|
||||
shard_index=shard_index,
|
||||
)
|
||||
if pool_size > len(shard):
|
||||
raise ValueError(
|
||||
f"pool-size={pool_size} exceeds shard episodes={len(shard)} for shard {shard_index}/{shard_count}"
|
||||
)
|
||||
return shard[:pool_size]
|
||||
|
||||
|
||||
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
||||
@@ -308,16 +347,27 @@ def run_pool_stream_simulation(
|
||||
target_samples_s: float,
|
||||
samples_per_episode: int,
|
||||
prefetch_episodes: int,
|
||||
shard_count: int,
|
||||
shard_index: int,
|
||||
shard_seed: int,
|
||||
batch_size: int,
|
||||
decode_workers: int,
|
||||
seed: int,
|
||||
) -> dict[str, float]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(dataset_episode_count, num_episodes)
|
||||
resident = list(resident_episodes)
|
||||
resident_set = set(resident)
|
||||
candidates = [ep for ep in range(upper) if ep not in resident_set]
|
||||
rng.shuffle(candidates)
|
||||
candidates = [
|
||||
ep
|
||||
for ep in _episode_shard(
|
||||
dataset_episode_count,
|
||||
num_episodes,
|
||||
shard_seed,
|
||||
shard_count=shard_count,
|
||||
shard_index=shard_index,
|
||||
)
|
||||
if ep not in resident_set
|
||||
]
|
||||
replacements = iter(candidates)
|
||||
pending: list[int] = []
|
||||
|
||||
@@ -480,6 +530,17 @@ def _print_memory_summary(start: dict[str, float | None], end: dict[str, float |
|
||||
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
|
||||
|
||||
|
||||
def _write_summary_json(path: str | None, payload: dict) -> None:
|
||||
if path is None:
|
||||
return
|
||||
out = Path(path).expanduser()
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = out.with_suffix(out.suffix + ".tmp")
|
||||
tmp.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
|
||||
tmp.replace(out)
|
||||
print(f"summary_json: {out}")
|
||||
|
||||
|
||||
def _root_join(data_root: str, relative_path: str) -> str:
|
||||
if data_root.startswith("hf://"):
|
||||
return f"{data_root.rstrip('/')}/{relative_path}"
|
||||
@@ -655,6 +716,9 @@ def run_fetch_pool(
|
||||
target_samples_s=args.target_samples_s,
|
||||
samples_per_episode=args.pool_samples_per_episode,
|
||||
prefetch_episodes=args.stream_prefetch_episodes,
|
||||
shard_count=args.distributed_shard_count,
|
||||
shard_index=args.distributed_shard_index,
|
||||
shard_seed=args.seed,
|
||||
batch_size=args.batch_size,
|
||||
decode_workers=args.decode_workers,
|
||||
seed=args.seed + 4,
|
||||
@@ -1021,7 +1085,14 @@ def run_indexed_strategy(
|
||||
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
|
||||
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
|
||||
episodes = _episode_pool(
|
||||
dataset_episode_count,
|
||||
args.num_episodes,
|
||||
args.pool_size,
|
||||
args.seed,
|
||||
shard_count=args.distributed_shard_count,
|
||||
shard_index=args.distributed_shard_index,
|
||||
)
|
||||
byte_budget = int(args.byte_budget_gb * 1024**3)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
_log(
|
||||
@@ -1051,6 +1122,8 @@ def run_indexed_strategy(
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"dataset_episodes: {dataset_episode_count}")
|
||||
print(f"benchmark_episodes: {benchmark_episode_count}")
|
||||
print(f"distributed_shard_count: {args.distributed_shard_count}")
|
||||
print(f"distributed_shard_index: {args.distributed_shard_index}")
|
||||
print(f"pool_episodes: {len(episodes)}")
|
||||
print(f"sampled_episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
@@ -1123,7 +1196,40 @@ def run_indexed_strategy(
|
||||
f"| stream min unique episodes/batch | "
|
||||
f"{fetch_pool['pool_stream_stream_min_unique_episodes_per_batch']:.0f} |"
|
||||
)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
memory_end = _memory_snapshot()
|
||||
_print_memory_summary(memory_start, memory_end)
|
||||
summary = {
|
||||
"hostname": socket.gethostname(),
|
||||
"strategy": label,
|
||||
"range_backend": range_backend,
|
||||
"data_root": data_root,
|
||||
"mp4_sidecar": sidecar_path,
|
||||
"dataset_episodes": dataset_episode_count,
|
||||
"benchmark_episodes": benchmark_episode_count,
|
||||
"distributed_shard_count": args.distributed_shard_count,
|
||||
"distributed_shard_index": args.distributed_shard_index,
|
||||
"pool_episodes": len(episodes),
|
||||
"sampled_episodes": episodes,
|
||||
"workers": args.workers,
|
||||
"decode_workers": args.decode_workers,
|
||||
"manifest_build_s": manifest_s,
|
||||
"fetch_bytes": byte_count,
|
||||
"fetch_gib": byte_count / 1024**3,
|
||||
"fetch_s": fetch_pool["fetch_s"],
|
||||
"fetch_mib_s": fetch_pool["fetch_mbps"],
|
||||
"fetch_episodes_s": fetch_pool["fetch_episodes_s"],
|
||||
"avg_mb_camera": fetch_pool["avg_mb_miss"],
|
||||
"range_reads": fetch_pool.get("range_jobs", 0.0),
|
||||
"range_hffs_get_exceptions": fetch_pool.get("range_hffs_get_exception_attempts", 0.0),
|
||||
"range_hffs_get_retries": fetch_pool.get("range_hffs_get_retries", 0.0),
|
||||
"rss_start_mib": memory_start["rss_mib"],
|
||||
"rss_end_mib": memory_end["rss_mib"],
|
||||
"peak_rss_mib": memory_end["peak_rss_mib"],
|
||||
}
|
||||
for key, value in fetch_pool.items():
|
||||
if key.startswith("pool_decode_") or key.startswith("pool_stream_"):
|
||||
summary[key] = value
|
||||
_write_summary_json(args.summary_json, summary)
|
||||
|
||||
if args.include_decode:
|
||||
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
||||
@@ -1178,7 +1284,14 @@ def run_remote_strategy(
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> None:
|
||||
_log("starting_strategy: remote-decoder")
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
episodes = _episode_pool(
|
||||
int(meta.total_episodes),
|
||||
args.num_episodes,
|
||||
args.pool_size,
|
||||
args.seed,
|
||||
shard_count=args.distributed_shard_count,
|
||||
shard_index=args.distributed_shard_index,
|
||||
)
|
||||
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log("remote-decoder: running direct source MP4 decoder")
|
||||
result = run_remote_decoder(
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Summarize distributed episode pool benchmark JSON files.")
|
||||
parser.add_argument("summaries", nargs="+", help="Rank summary JSON files.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _load(path: str) -> dict:
|
||||
return json.loads(Path(path).read_text())
|
||||
|
||||
|
||||
def _fmt(value: float) -> str:
|
||||
return f"{value:.1f}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
rows = [_load(path) for path in args.summaries]
|
||||
rows.sort(key=lambda row: int(row.get("distributed_shard_index", 0)))
|
||||
total_bytes = sum(float(row.get("fetch_bytes", 0.0)) for row in rows)
|
||||
max_fetch_s = max(float(row.get("fetch_s", 0.0)) for row in rows)
|
||||
aggregate_mib_s = total_bytes / max_fetch_s / 1024**2 if max_fetch_s > 0 else float("inf")
|
||||
summed_rank_mib_s = sum(float(row.get("fetch_mib_s", 0.0)) for row in rows)
|
||||
total_decode_samples_s = sum(float(row.get("pool_decode_training_samples_s", 0.0)) for row in rows)
|
||||
total_stream_samples_s = sum(float(row.get("pool_stream_actual_samples_s", 0.0)) for row in rows)
|
||||
kept_up = all(bool(row.get("pool_stream_kept_up", 0.0)) for row in rows)
|
||||
|
||||
print("| Aggregate | value |")
|
||||
print("|---|---:|")
|
||||
print(f"| ranks | {len(rows)} |")
|
||||
print(f"| total fetched GiB | {total_bytes / 1024**3:.2f} |")
|
||||
print(f"| aggregate fetch MiB/s | {_fmt(aggregate_mib_s)} |")
|
||||
print(f"| summed rank fetch MiB/s | {_fmt(summed_rank_mib_s)} |")
|
||||
if total_decode_samples_s:
|
||||
print(f"| aggregate resident decode samples/s | {_fmt(total_decode_samples_s)} |")
|
||||
if total_stream_samples_s:
|
||||
print(f"| aggregate stream samples/s | {_fmt(total_stream_samples_s)} |")
|
||||
print(f"| all ranks kept up | {'yes' if kept_up else 'no'} |")
|
||||
|
||||
print()
|
||||
print("| Rank | host | fetch MiB/s | fetch s | GiB | decode samples/s | stream samples/s | kept up |")
|
||||
print("|---:|---|---:|---:|---:|---:|---:|---|")
|
||||
for row in rows:
|
||||
rank = int(row.get("distributed_shard_index", 0))
|
||||
print(
|
||||
f"| {rank} | {row.get('hostname', '')} | "
|
||||
f"{_fmt(float(row.get('fetch_mib_s', 0.0)))} | "
|
||||
f"{_fmt(float(row.get('fetch_s', 0.0)))} | "
|
||||
f"{float(row.get('fetch_gib', 0.0)):.2f} | "
|
||||
f"{_fmt(float(row.get('pool_decode_training_samples_s', 0.0)))} | "
|
||||
f"{_fmt(float(row.get('pool_stream_actual_samples_s', 0.0)))} | "
|
||||
f"{'yes' if row.get('pool_stream_kept_up', 0.0) else 'no'} |"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user