mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-28 05:37:16 +00:00
Log episode cache fill progress
This commit is contained in:
@@ -86,6 +86,12 @@ def parse_args() -> argparse.Namespace:
|
||||
default=120.0,
|
||||
help="Timeout in seconds for native HTTP requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-interval",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="Print episode-pool fill progress every N seconds. Set 0 to disable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-decode",
|
||||
action="store_true",
|
||||
@@ -162,12 +168,31 @@ def _decode_all(
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
|
||||
def _fill_cache(
|
||||
cache: EpisodeByteCache, episodes: Sequence[int], *, progress_interval: float = 10.0
|
||||
) -> float:
|
||||
start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
cache.submit_prefetch(ep)
|
||||
for ep in episodes:
|
||||
last_progress = start
|
||||
for idx, ep in enumerate(episodes, start=1):
|
||||
cache.ensure_ready(ep)
|
||||
now = time.perf_counter()
|
||||
if progress_interval > 0 and now - last_progress >= progress_interval:
|
||||
timings = cache.timing_summary()
|
||||
byte_count = timings.get("range_bytes", 0.0)
|
||||
elapsed = max(now - start, 1e-9)
|
||||
jobs = timings.get("jobs", 0.0)
|
||||
total_jobs = len(episodes) * len(cache.manifest.video_keys)
|
||||
_log(
|
||||
"fill_progress: "
|
||||
f"episodes_ready={idx}/{len(episodes)} "
|
||||
f"camera_jobs={jobs:.0f}/{total_jobs} "
|
||||
f"fetched={byte_count / 1024**3:.2f} GiB "
|
||||
f"fetch={byte_count / elapsed / 1024**2:.1f} MiB/s "
|
||||
f"elapsed={_format_duration(elapsed)}"
|
||||
)
|
||||
last_progress = now
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
@@ -374,7 +399,7 @@ def run_fetch_pool(
|
||||
native_http_retries=args.native_http_retries,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
elapsed = _fill_cache(cache, episodes)
|
||||
elapsed = _fill_cache(cache, episodes, progress_interval=args.progress_interval)
|
||||
timings = cache.timing_summary()
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
@@ -602,6 +627,14 @@ def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
|
||||
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
|
||||
if fetch_pool.get("range_failed_requests"):
|
||||
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
|
||||
status_counts = {
|
||||
key.removeprefix("range_status_"): value
|
||||
for key, value in fetch_pool.items()
|
||||
if key.startswith("range_status_")
|
||||
}
|
||||
if status_counts:
|
||||
summary = ", ".join(f"{status}={count:.0f}" for status, count in sorted(status_counts.items()))
|
||||
print(f"| http status counts | {summary} |")
|
||||
print(f"| range reads | {range_jobs:.0f} |")
|
||||
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user