mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 11:17:02 +00:00
Benchmark random sampling from episode pool
This commit is contained in:
+324
-315
@@ -18,8 +18,7 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
@@ -34,7 +33,6 @@ from lerobot.datasets.episode_video_streaming import (
|
||||
EpisodeVideoManifest,
|
||||
NativeHTTPRangeFetcher,
|
||||
assert_hf_hub_range_cache_branch,
|
||||
make_range_fetcher,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||
|
||||
@@ -52,7 +50,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=("both", "full", "indexed", "remote-decoder", "native-http", "gop-window"),
|
||||
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
|
||||
default="both",
|
||||
help=argparse.SUPPRESS,
|
||||
)
|
||||
@@ -105,23 +103,13 @@ def parse_args() -> argparse.Namespace:
|
||||
action="store_true",
|
||||
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-gop-window",
|
||||
action="store_true",
|
||||
help="Also benchmark random frame GOP/window byte-range fetches from the MP4 sidecar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gop-window-post-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Extra compressed samples after each target frame to include in GOP/window ranges.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gop-window-merge-gap-kb",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Merge GOP/window ranges from the same MP4 when the byte gap is at most this many KiB.",
|
||||
)
|
||||
parser.add_argument("--include-pool-sampling", action="store_true")
|
||||
parser.add_argument("--pool-random-samples", type=int, default=4096)
|
||||
parser.add_argument("--batch-size", type=int, default=512)
|
||||
parser.add_argument("--target-samples-s", type=float, default=500.0)
|
||||
parser.add_argument("--stream-samples", type=int, default=4096)
|
||||
parser.add_argument("--pool-samples-per-episode", type=int, default=160)
|
||||
parser.add_argument("--stream-prefetch-episodes", type=int, default=16)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||
@@ -177,118 +165,56 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
return total
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GopWindowRange:
|
||||
file_path: str
|
||||
offset: int
|
||||
length: int
|
||||
target_frames: int
|
||||
covered_samples: int
|
||||
def _episode_timestamp_window(manifest: EpisodeVideoManifest, episode_index: int) -> tuple[float, float]:
|
||||
spans = [manifest.lookup(episode_index, camera_key) for camera_key in manifest.video_keys]
|
||||
lo = max(span.first_pts for span in spans)
|
||||
hi = min(span.last_pts for span in spans)
|
||||
if hi >= lo:
|
||||
return lo, hi
|
||||
first = spans[0]
|
||||
return first.first_pts, max(first.last_pts, first.first_pts)
|
||||
|
||||
|
||||
def _sample_bounds_for_episode(manifest: EpisodeVideoManifest, episode_index: int, camera_key: str):
|
||||
span = manifest.lookup(episode_index, camera_key)
|
||||
mp4 = manifest.file_lookup(span.file_id).mp4
|
||||
sample_count = len(mp4.sample_pts)
|
||||
if sample_count == 0:
|
||||
raise ValueError(f"{mp4.file_path} contains no indexed samples")
|
||||
lo = int(np.searchsorted(mp4.sample_pts, span.first_pts, side="left"))
|
||||
hi = int(np.searchsorted(mp4.sample_pts, span.last_pts, side="right")) - 1
|
||||
lo = min(max(lo, 0), sample_count - 1)
|
||||
hi = min(max(hi, lo), sample_count - 1)
|
||||
return span, mp4, lo, hi
|
||||
|
||||
|
||||
def _byte_range_for_samples(mp4, sample_lo: int, sample_hi: int, *, file_size: int) -> tuple[int, int]:
|
||||
offsets = mp4.sample_offsets[sample_lo : sample_hi + 1]
|
||||
sizes = mp4.sample_sizes[sample_lo : sample_hi + 1]
|
||||
byte_lo = int(offsets.min())
|
||||
byte_hi = int((offsets + sizes).max())
|
||||
byte_hi = min(byte_hi, file_size)
|
||||
return byte_lo, byte_hi - byte_lo
|
||||
|
||||
|
||||
def _gop_window_for_target_sample(
|
||||
manifest: EpisodeVideoManifest,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
target_sample: int,
|
||||
*,
|
||||
post_frames: int,
|
||||
) -> GopWindowRange:
|
||||
span = manifest.lookup(episode_index, camera_key)
|
||||
file_record = manifest.file_lookup(span.file_id)
|
||||
mp4 = file_record.mp4
|
||||
sync = mp4.sync_samples[mp4.sync_samples <= target_sample]
|
||||
sample_lo = int(sync[-1]) if len(sync) else 0
|
||||
sample_hi = min(max(target_sample + post_frames, sample_lo), span.sample_hi, len(mp4.sample_pts) - 1)
|
||||
offset, length = _byte_range_for_samples(mp4, sample_lo, sample_hi, file_size=file_record.file_size)
|
||||
return GopWindowRange(
|
||||
file_path=file_record.file_path,
|
||||
offset=offset,
|
||||
length=length,
|
||||
target_frames=1,
|
||||
covered_samples=sample_hi - sample_lo + 1,
|
||||
)
|
||||
|
||||
|
||||
def _gop_window_ranges(
|
||||
manifest: EpisodeVideoManifest,
|
||||
episodes: Sequence[int],
|
||||
*,
|
||||
frames_per_episode: int,
|
||||
seed: int,
|
||||
post_frames: int,
|
||||
merge_gap_bytes: int,
|
||||
) -> tuple[list[GopWindowRange], int, int, int]:
|
||||
def _random_training_samples(
|
||||
manifest: EpisodeVideoManifest, episodes: Sequence[int], count: int, seed: int
|
||||
) -> list[tuple[int, float]]:
|
||||
rng = random.Random(seed)
|
||||
raw: list[GopWindowRange] = []
|
||||
compressed_target_bytes = 0
|
||||
covered_samples = 0
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
span, mp4, target_lo, target_hi = _sample_bounds_for_episode(manifest, ep, camera_key)
|
||||
for _ in range(frames_per_episode):
|
||||
ts = rng.uniform(span.first_pts, max(span.last_pts, span.first_pts))
|
||||
target = int(np.searchsorted(mp4.sample_pts, ts, side="left"))
|
||||
target = min(max(target, target_lo), target_hi)
|
||||
compressed_target_bytes += int(mp4.sample_sizes[target])
|
||||
window = _gop_window_for_target_sample(
|
||||
manifest,
|
||||
ep,
|
||||
camera_key,
|
||||
target,
|
||||
post_frames=post_frames,
|
||||
)
|
||||
covered_samples += window.covered_samples
|
||||
raw.append(window)
|
||||
|
||||
merged = _merge_gop_window_ranges(raw, merge_gap_bytes)
|
||||
return merged, len(raw), compressed_target_bytes, covered_samples
|
||||
out = []
|
||||
for _ in range(count):
|
||||
ep = rng.choice(episodes)
|
||||
lo, hi = _episode_timestamp_window(manifest, ep)
|
||||
out.append((ep, rng.uniform(lo, max(hi, lo))))
|
||||
return out
|
||||
|
||||
|
||||
def _merge_gop_window_ranges(ranges: Sequence[GopWindowRange], merge_gap_bytes: int) -> list[GopWindowRange]:
|
||||
if not ranges:
|
||||
return []
|
||||
ordered = sorted(ranges, key=lambda item: (item.file_path, item.offset, item.length))
|
||||
merged: list[GopWindowRange] = []
|
||||
current = ordered[0]
|
||||
for item in ordered[1:]:
|
||||
current_end = current.offset + current.length
|
||||
if item.file_path == current.file_path and item.offset <= current_end + merge_gap_bytes:
|
||||
new_end = max(current_end, item.offset + item.length)
|
||||
current = GopWindowRange(
|
||||
file_path=current.file_path,
|
||||
offset=current.offset,
|
||||
length=new_end - current.offset,
|
||||
target_frames=current.target_frames + item.target_frames,
|
||||
covered_samples=current.covered_samples + item.covered_samples,
|
||||
)
|
||||
else:
|
||||
merged.append(current)
|
||||
current = item
|
||||
merged.append(current)
|
||||
return merged
|
||||
def _sampling_randomness(samples: Sequence[tuple[int, float]], *, batch_size: int) -> dict[str, float]:
|
||||
if not samples:
|
||||
return {
|
||||
"sample_count": 0.0,
|
||||
"unique_episodes": 0.0,
|
||||
"unique_episode_fraction": 0.0,
|
||||
"mean_samples_per_used_episode": 0.0,
|
||||
"max_samples_per_episode": 0.0,
|
||||
"mean_unique_episodes_per_batch": 0.0,
|
||||
"min_unique_episodes_per_batch": 0.0,
|
||||
}
|
||||
counts: dict[int, int] = {}
|
||||
for ep, _ts in samples:
|
||||
counts[ep] = counts.get(ep, 0) + 1
|
||||
batch_uniques = [
|
||||
len({ep for ep, _ts in samples[idx : idx + batch_size]})
|
||||
for idx in range(0, len(samples), batch_size)
|
||||
if samples[idx : idx + batch_size]
|
||||
]
|
||||
return {
|
||||
"sample_count": float(len(samples)),
|
||||
"unique_episodes": float(len(counts)),
|
||||
"unique_episode_fraction": len(counts) / len(samples),
|
||||
"mean_samples_per_used_episode": len(samples) / len(counts),
|
||||
"max_samples_per_episode": float(max(counts.values())),
|
||||
"mean_unique_episodes_per_batch": float(np.mean(batch_uniques)),
|
||||
"min_unique_episodes_per_batch": float(min(batch_uniques)),
|
||||
}
|
||||
|
||||
|
||||
def _decode_all(
|
||||
@@ -307,6 +233,181 @@ def _decode_all(
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _decoder_locks(
|
||||
manifest: EpisodeVideoManifest, episodes: Sequence[int]
|
||||
) -> dict[tuple[int, str], threading.Lock]:
|
||||
return {(ep, camera_key): threading.Lock() for ep in episodes for camera_key in manifest.video_keys}
|
||||
|
||||
|
||||
def _open_resident_decoders(
|
||||
cache: EpisodeByteCache, episodes: Sequence[int], *, decode_workers: int
|
||||
) -> tuple[float, int]:
|
||||
items = [(ep, camera_key) for ep in episodes for camera_key in cache.manifest.video_keys]
|
||||
start = time.perf_counter()
|
||||
if decode_workers <= 1:
|
||||
for ep, camera_key in items:
|
||||
cache.get_decoder(ep, camera_key)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [pool.submit(cache.get_decoder, ep, camera_key) for ep, camera_key in items]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start, len(items)
|
||||
|
||||
|
||||
def _decode_training_sample(
|
||||
cache: EpisodeByteCache,
|
||||
episode_index: int,
|
||||
timestamp: float,
|
||||
locks: dict[tuple[int, str], threading.Lock],
|
||||
) -> None:
|
||||
for camera_key in cache.manifest.video_keys:
|
||||
with locks[(episode_index, camera_key)]:
|
||||
cache.get_frames(episode_index, camera_key, [timestamp])
|
||||
|
||||
|
||||
def run_pool_random_decode(
|
||||
cache: EpisodeByteCache,
|
||||
episodes: Sequence[int],
|
||||
*,
|
||||
sample_count: int,
|
||||
batch_size: int,
|
||||
decode_workers: int,
|
||||
seed: int,
|
||||
) -> dict[str, float]:
|
||||
samples = _random_training_samples(cache.manifest, episodes, sample_count, seed)
|
||||
touched_episodes = sorted({ep for ep, _ts in samples})
|
||||
decoder_open_s, decoder_count = _open_resident_decoders(
|
||||
cache, touched_episodes, decode_workers=decode_workers
|
||||
)
|
||||
locks = _decoder_locks(cache.manifest, touched_episodes)
|
||||
|
||||
start = time.perf_counter()
|
||||
if decode_workers <= 1:
|
||||
for ep, ts in samples:
|
||||
_decode_training_sample(cache, ep, ts, locks)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [pool.submit(_decode_training_sample, cache, ep, ts, locks) for ep, ts in samples]
|
||||
for future in futures:
|
||||
future.result()
|
||||
decode_s = time.perf_counter() - start
|
||||
|
||||
randomness = _sampling_randomness(samples, batch_size=batch_size)
|
||||
camera_frames = sample_count * len(cache.manifest.video_keys)
|
||||
result = {
|
||||
"decoder_open_s": decoder_open_s,
|
||||
"decoder_count": float(decoder_count),
|
||||
"decoder_open_ms": decoder_open_s * 1000 / max(decoder_count, 1),
|
||||
"decode_s": decode_s,
|
||||
"training_samples_s": sample_count / decode_s if decode_s > 0 else float("inf"),
|
||||
"camera_frames_s": camera_frames / decode_s if decode_s > 0 else float("inf"),
|
||||
"decode_ms_sample": decode_s * 1000 / max(sample_count, 1),
|
||||
"decode_ms_camera_frame": decode_s * 1000 / max(camera_frames, 1),
|
||||
}
|
||||
result.update(randomness)
|
||||
return result
|
||||
|
||||
|
||||
def run_pool_stream_simulation(
|
||||
cache: EpisodeByteCache,
|
||||
resident_episodes: Sequence[int],
|
||||
*,
|
||||
dataset_episode_count: int,
|
||||
num_episodes: int,
|
||||
sample_count: int,
|
||||
target_samples_s: float,
|
||||
samples_per_episode: int,
|
||||
prefetch_episodes: int,
|
||||
batch_size: int,
|
||||
decode_workers: int,
|
||||
seed: int,
|
||||
) -> dict[str, float]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(dataset_episode_count, num_episodes)
|
||||
resident = list(resident_episodes)
|
||||
resident_set = set(resident)
|
||||
candidates = [ep for ep in range(upper) if ep not in resident_set]
|
||||
rng.shuffle(candidates)
|
||||
replacements = iter(candidates)
|
||||
pending: list[int] = []
|
||||
|
||||
def schedule_one() -> bool:
|
||||
try:
|
||||
ep = next(replacements)
|
||||
except StopIteration:
|
||||
return False
|
||||
cache.submit_prefetch(ep)
|
||||
pending.append(ep)
|
||||
return True
|
||||
|
||||
for _ in range(prefetch_episodes):
|
||||
if not schedule_one():
|
||||
break
|
||||
|
||||
locks = _decoder_locks(cache.manifest, resident)
|
||||
sample_period = 1.0 / target_samples_s if target_samples_s > 0 else 0.0
|
||||
refill_wait_s = 0.0
|
||||
deadline_miss_s = 0.0
|
||||
replacement_count = 0
|
||||
decoded_samples: list[tuple[int, float]] = []
|
||||
start = time.perf_counter()
|
||||
next_deadline = start + sample_period
|
||||
|
||||
for idx in range(sample_count):
|
||||
if idx > 0 and samples_per_episode > 0 and idx % samples_per_episode == 0 and pending:
|
||||
new_ep = pending.pop(0)
|
||||
wait_start = time.perf_counter()
|
||||
cache.ensure_ready(new_ep)
|
||||
for camera_key in cache.manifest.video_keys:
|
||||
locks[(new_ep, camera_key)] = threading.Lock()
|
||||
cache.get_decoder(new_ep, camera_key)
|
||||
refill_wait_s += time.perf_counter() - wait_start
|
||||
old_ep = resident.pop(0)
|
||||
resident_set.discard(old_ep)
|
||||
resident.append(new_ep)
|
||||
resident_set.add(new_ep)
|
||||
replacement_count += 1
|
||||
schedule_one()
|
||||
|
||||
ep = rng.choice(resident)
|
||||
lo, hi = _episode_timestamp_window(cache.manifest, ep)
|
||||
ts = rng.uniform(lo, max(hi, lo))
|
||||
_decode_training_sample(cache, ep, ts, locks)
|
||||
decoded_samples.append((ep, ts))
|
||||
|
||||
if sample_period > 0:
|
||||
now = time.perf_counter()
|
||||
if now < next_deadline:
|
||||
time.sleep(next_deadline - now)
|
||||
else:
|
||||
deadline_miss_s += now - next_deadline
|
||||
next_deadline += sample_period
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
result = {
|
||||
"target_samples_s": target_samples_s,
|
||||
"actual_samples_s": sample_count / elapsed if elapsed > 0 else float("inf"),
|
||||
"stream_wall_s": elapsed,
|
||||
"refill_wait_s": refill_wait_s,
|
||||
"deadline_miss_s": deadline_miss_s,
|
||||
"replacements": float(replacement_count),
|
||||
"replacement_episodes_s": replacement_count / elapsed if elapsed > 0 else 0.0,
|
||||
"samples_per_episode": float(samples_per_episode),
|
||||
"prefetch_episodes": float(prefetch_episodes),
|
||||
"kept_up": 1.0
|
||||
if sample_count / elapsed >= target_samples_s * 0.98 and deadline_miss_s < elapsed * 0.02
|
||||
else 0.0,
|
||||
}
|
||||
result.update(
|
||||
{
|
||||
f"stream_{key}": value
|
||||
for key, value in _sampling_randomness(decoded_samples, batch_size=batch_size).items()
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _fill_cache(
|
||||
cache: EpisodeByteCache, episodes: Sequence[int], *, progress_interval: float = 10.0
|
||||
) -> float:
|
||||
@@ -522,6 +623,8 @@ def run_fetch_pool(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
dataset_episode_count: int,
|
||||
benchmark_episode_count: int,
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
range_backend: str,
|
||||
@@ -540,6 +643,34 @@ def run_fetch_pool(
|
||||
) as cache:
|
||||
elapsed = _fill_cache(cache, episodes, progress_interval=args.progress_interval)
|
||||
timings = cache.timing_summary()
|
||||
random_decode = None
|
||||
stream_sim = None
|
||||
if args.include_pool_sampling:
|
||||
_log("pool_sampling: warming resident decoders and decoding random samples")
|
||||
random_decode = run_pool_random_decode(
|
||||
cache,
|
||||
episodes,
|
||||
sample_count=args.pool_random_samples,
|
||||
batch_size=args.batch_size,
|
||||
decode_workers=args.decode_workers,
|
||||
seed=args.seed + 3,
|
||||
)
|
||||
_log(
|
||||
f"pool_stream: consuming {args.target_samples_s:.1f} samples/s while prefetching replacements"
|
||||
)
|
||||
stream_sim = run_pool_stream_simulation(
|
||||
cache,
|
||||
episodes,
|
||||
dataset_episode_count=dataset_episode_count,
|
||||
num_episodes=benchmark_episode_count,
|
||||
sample_count=args.stream_samples,
|
||||
target_samples_s=args.target_samples_s,
|
||||
samples_per_episode=args.pool_samples_per_episode,
|
||||
prefetch_episodes=args.stream_prefetch_episodes,
|
||||
batch_size=args.batch_size,
|
||||
decode_workers=args.decode_workers,
|
||||
seed=args.seed + 4,
|
||||
)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
job_count = max(timings["jobs"], 1.0)
|
||||
@@ -556,87 +687,10 @@ def run_fetch_pool(
|
||||
"store_ms": timings["store_s"] * 1000 / job_count,
|
||||
}
|
||||
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||
return result
|
||||
|
||||
|
||||
def run_gop_window_fetch(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
workers: int,
|
||||
range_backend: str,
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, float]:
|
||||
merge_gap_bytes = int(args.gop_window_merge_gap_kb * 1024)
|
||||
windows, raw_windows, compressed_target_bytes, covered_samples = _gop_window_ranges(
|
||||
manifest,
|
||||
episodes,
|
||||
frames_per_episode=args.frames_per_episode,
|
||||
seed=args.seed + 2,
|
||||
post_frames=args.gop_window_post_frames,
|
||||
merge_gap_bytes=merge_gap_bytes,
|
||||
)
|
||||
if not windows:
|
||||
raise ValueError("No GOP/window ranges were planned")
|
||||
|
||||
fetcher = make_range_fetcher(
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
native_http_connections=args.native_http_connections,
|
||||
native_http_timeout=args.native_http_timeout,
|
||||
native_http_retries=args.native_http_retries,
|
||||
)
|
||||
|
||||
def fetch_window(window: GopWindowRange) -> int:
|
||||
payload = fetcher.read_range(window.file_path, window.offset, window.length)
|
||||
if len(payload) != window.length:
|
||||
raise OSError(f"Short read for {window.file_path}: expected {window.length}, got {len(payload)}")
|
||||
return len(payload)
|
||||
|
||||
byte_count = sum(window.length for window in windows)
|
||||
start = time.perf_counter()
|
||||
done = 0
|
||||
done_ranges = 0
|
||||
last_progress = start
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(fetch_window, window) for window in windows]
|
||||
for future in as_completed(futures):
|
||||
done += future.result()
|
||||
done_ranges += 1
|
||||
now = time.perf_counter()
|
||||
if args.progress_interval > 0 and now - last_progress >= args.progress_interval:
|
||||
elapsed = max(now - start, 1e-9)
|
||||
_log(
|
||||
"gop_window_progress: "
|
||||
f"ranges_done={done_ranges}/{len(windows)} "
|
||||
f"fetched={done / 1024**3:.2f} GiB "
|
||||
f"fetch={done / elapsed / 1024**2:.1f} MiB/s "
|
||||
f"elapsed={_format_duration(elapsed)}"
|
||||
)
|
||||
last_progress = now
|
||||
finally:
|
||||
timings = fetcher.timing_summary() if hasattr(fetcher, "timing_summary") else {}
|
||||
fetcher.close()
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
result = {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"frame_windows_s": raw_windows / elapsed,
|
||||
"ranges_s": len(windows) / elapsed,
|
||||
"bytes": float(byte_count),
|
||||
"raw_windows": float(raw_windows),
|
||||
"merged_windows": float(len(windows)),
|
||||
"compressed_target_bytes": float(compressed_target_bytes),
|
||||
"covered_samples": float(covered_samples),
|
||||
"avg_mb_range": byte_count / len(windows) / 1024**2,
|
||||
"avg_kib_frame_window": byte_count / raw_windows / 1024,
|
||||
"avg_compressed_kib_target": compressed_target_bytes / raw_windows / 1024,
|
||||
"avg_covered_samples": covered_samples / raw_windows,
|
||||
}
|
||||
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||
if random_decode is not None:
|
||||
result.update({f"pool_decode_{key}": value for key, value in random_decode.items()})
|
||||
if stream_sim is not None:
|
||||
result.update({f"pool_stream_{key}": value for key, value in stream_sim.items()})
|
||||
return result
|
||||
|
||||
|
||||
@@ -988,7 +1042,17 @@ def run_indexed_strategy(
|
||||
)
|
||||
|
||||
_log(f"{label}: filling episode byte cache with {args.workers} workers")
|
||||
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
|
||||
fetch_pool = run_fetch_pool(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
dataset_episode_count,
|
||||
benchmark_episode_count,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
range_backend,
|
||||
args,
|
||||
)
|
||||
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
|
||||
@@ -1022,6 +1086,55 @@ def run_indexed_strategy(
|
||||
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
|
||||
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
|
||||
_print_range_timing_summary(fetch_pool)
|
||||
if args.include_pool_sampling:
|
||||
print()
|
||||
print("| Resident Pool Decode | value |")
|
||||
print("|---|---:|")
|
||||
print(f"| random training samples | {fetch_pool['pool_decode_sample_count']:.0f} |")
|
||||
print(f"| decoder opens | {fetch_pool['pool_decode_decoder_count']:.0f} |")
|
||||
print(f"| decoder open ms/episode-camera | {fetch_pool['pool_decode_decoder_open_ms']:.3f} |")
|
||||
print(f"| decode wall s | {fetch_pool['pool_decode_decode_s']:.3f} |")
|
||||
print(f"| training samples/s | {fetch_pool['pool_decode_training_samples_s']:.1f} |")
|
||||
print(f"| camera frames/s | {fetch_pool['pool_decode_camera_frames_s']:.1f} |")
|
||||
print(f"| decode ms/training sample | {fetch_pool['pool_decode_decode_ms_sample']:.3f} |")
|
||||
print(f"| decode ms/camera frame | {fetch_pool['pool_decode_decode_ms_camera_frame']:.3f} |")
|
||||
print()
|
||||
print("| Resident Pool Randomness | value |")
|
||||
print("|---|---:|")
|
||||
print(f"| pool episodes | {len(episodes)} |")
|
||||
print(f"| batch size | {args.batch_size} |")
|
||||
print(f"| unique episodes sampled | {fetch_pool['pool_decode_unique_episodes']:.0f} |")
|
||||
print(
|
||||
f"| mean unique episodes/batch | {fetch_pool['pool_decode_mean_unique_episodes_per_batch']:.1f} |"
|
||||
)
|
||||
print(
|
||||
f"| min unique episodes/batch | {fetch_pool['pool_decode_min_unique_episodes_per_batch']:.0f} |"
|
||||
)
|
||||
print(
|
||||
f"| mean samples/used episode | {fetch_pool['pool_decode_mean_samples_per_used_episode']:.2f} |"
|
||||
)
|
||||
print(f"| max samples/episode | {fetch_pool['pool_decode_max_samples_per_episode']:.0f} |")
|
||||
print()
|
||||
print("| Streaming Keep-Up Simulation | value |")
|
||||
print("|---|---:|")
|
||||
print(f"| target samples/s | {fetch_pool['pool_stream_target_samples_s']:.1f} |")
|
||||
print(f"| actual samples/s | {fetch_pool['pool_stream_actual_samples_s']:.1f} |")
|
||||
print(f"| kept up | {'yes' if fetch_pool['pool_stream_kept_up'] else 'no'} |")
|
||||
print(f"| stream wall s | {fetch_pool['pool_stream_stream_wall_s']:.3f} |")
|
||||
print(f"| refill wait s | {fetch_pool['pool_stream_refill_wait_s']:.3f} |")
|
||||
print(f"| deadline miss s | {fetch_pool['pool_stream_deadline_miss_s']:.3f} |")
|
||||
print(f"| replacement episodes | {fetch_pool['pool_stream_replacements']:.0f} |")
|
||||
print(f"| replacement episodes/s | {fetch_pool['pool_stream_replacement_episodes_s']:.2f} |")
|
||||
print(f"| samples per replacement episode | {fetch_pool['pool_stream_samples_per_episode']:.0f} |")
|
||||
print(f"| prefetch replacement episodes | {fetch_pool['pool_stream_prefetch_episodes']:.0f} |")
|
||||
print(
|
||||
f"| stream mean unique episodes/batch | "
|
||||
f"{fetch_pool['pool_stream_stream_mean_unique_episodes_per_batch']:.1f} |"
|
||||
)
|
||||
print(
|
||||
f"| stream min unique episodes/batch | "
|
||||
f"{fetch_pool['pool_stream_stream_min_unique_episodes_per_batch']:.0f} |"
|
||||
)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
|
||||
if args.include_decode:
|
||||
@@ -1070,80 +1183,6 @@ def run_indexed_strategy(
|
||||
)
|
||||
|
||||
|
||||
def run_gop_window_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
*,
|
||||
range_backend: str = "fsspec",
|
||||
sidecar_path: str | None = None,
|
||||
) -> None:
|
||||
_log("starting_strategy: gop-window")
|
||||
memory_start = _memory_snapshot()
|
||||
manifest_start = time.perf_counter()
|
||||
dataset_episode_count = int(meta.total_episodes)
|
||||
manifest_episode_count = args.manifest_episodes or dataset_episode_count
|
||||
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
|
||||
manifest = EpisodeVideoManifest.build(
|
||||
meta,
|
||||
data_root,
|
||||
episode_indices=range(manifest_episode_count),
|
||||
range_backend=range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
sidecar_path=sidecar_path,
|
||||
)
|
||||
manifest_s = time.perf_counter() - manifest_start
|
||||
_log(f"gop-window: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
|
||||
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
|
||||
full_episode_bytes = _bytes_for(manifest, episodes)
|
||||
result = run_gop_window_fetch(manifest, data_root, episodes, args.workers, range_backend, args)
|
||||
estimated_benchmark_s = benchmark_episode_count * args.frames_per_episode / result["frame_windows_s"]
|
||||
estimated_dataset_s = dataset_episode_count * args.frames_per_episode / result["frame_windows_s"]
|
||||
|
||||
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||
print("strategy: gop-window")
|
||||
print(f"range_backend: {range_backend}")
|
||||
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"dataset_episodes: {dataset_episode_count}")
|
||||
print(f"benchmark_episodes: {benchmark_episode_count}")
|
||||
print(f"pool_episodes: {len(episodes)}")
|
||||
print(f"frames_per_episode: {args.frames_per_episode}")
|
||||
print(f"gop_window_post_frames: {args.gop_window_post_frames}")
|
||||
print(f"gop_window_merge_gap_kb: {args.gop_window_merge_gap_kb}")
|
||||
print(f"sampled_episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
print()
|
||||
print(
|
||||
"| Track | fetch MB/s | frame windows/s | ranges/s | wall s | "
|
||||
"est benchmark | est full dataset | notes |"
|
||||
)
|
||||
print("|---|---:|---:|---:|---:|---:|---:|---|")
|
||||
print(
|
||||
f"| GOP/WINDOW FETCH | {result['fetch_mbps']:.1f} | {result['frame_windows_s']:.1f} | "
|
||||
f"{result['ranges_s']:.1f} | {result['fetch_s']:.2f} | "
|
||||
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
|
||||
f"{args.workers} workers, fetch-and-drop, no decoder open/frame decode |"
|
||||
)
|
||||
print()
|
||||
print("| GOP Window Shape | value |")
|
||||
print("|---|---:|")
|
||||
print(f"| target frame windows | {result['raw_windows']:.0f} |")
|
||||
print(f"| fetched byte ranges | {result['merged_windows']:.0f} |")
|
||||
print(f"| fetched GiB | {result['bytes'] / 1024**3:.2f} |")
|
||||
print(f"| full episode-pool GiB | {full_episode_bytes / 1024**3:.2f} |")
|
||||
print(f"| fetched/full episode bytes | {result['bytes'] / full_episode_bytes:.3f} |")
|
||||
print(f"| avg MiB/range | {result['avg_mb_range']:.3f} |")
|
||||
print(f"| avg KiB/frame window | {result['avg_kib_frame_window']:.1f} |")
|
||||
print(f"| avg compressed KiB/target frame | {result['avg_compressed_kib_target']:.1f} |")
|
||||
print(f"| avg compressed samples/window | {result['avg_covered_samples']:.1f} |")
|
||||
_print_range_timing_summary(result)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
|
||||
|
||||
def run_remote_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
@@ -1213,15 +1252,6 @@ def main() -> None:
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
if args.include_gop_window:
|
||||
print()
|
||||
run_gop_window_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
range_backend=args.range_backend,
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "indexed":
|
||||
run_indexed_strategy(
|
||||
@@ -1233,15 +1263,6 @@ def main() -> None:
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
if args.include_gop_window:
|
||||
print()
|
||||
run_gop_window_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
range_backend=args.range_backend,
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "native-http":
|
||||
run_indexed_strategy(
|
||||
@@ -1254,16 +1275,7 @@ def main() -> None:
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "gop-window":
|
||||
run_gop_window_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
range_backend=args.range_backend,
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy in ("both", "gop-window"):
|
||||
if args.strategy == "both":
|
||||
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
|
||||
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||
@@ -1273,9 +1285,6 @@ def main() -> None:
|
||||
"uv run --no-sync python scripts/build_mp4_sidecar.py "
|
||||
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||
)
|
||||
if args.strategy == "gop-window":
|
||||
print("gop_window_requires_mp4_sidecar: existing per-sample MP4 index sidecar is required")
|
||||
return
|
||||
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||
print()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user