Report range read timing breakdown

This commit is contained in:
Pepijn
2026-06-17 21:20:08 +02:00
parent cdfe192491
commit 04ab43b8d2
2 changed files with 151 additions and 11 deletions
+28 -1
View File
@@ -351,7 +351,7 @@ def run_fetch_pool(
byte_count = _bytes_for(manifest, episodes)
episode_mb = byte_count / len(episodes) / 1024**2
job_count = max(timings["jobs"], 1.0)
return {
result = {
"fetch_s": elapsed,
"fetch_mbps": byte_count / elapsed / 1024**2,
"fetch_episodes_s": len(episodes) / elapsed,
@@ -363,6 +363,8 @@ def run_fetch_pool(
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
"store_ms": timings["store_s"] * 1000 / job_count,
}
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
return result
def run_parallel(
@@ -547,6 +549,30 @@ def run_remote_decoder(
}
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
range_jobs = fetch_pool.get("range_jobs", 0.0)
if range_jobs <= 0:
return
print()
print("| Range Read Stage | avg ms/range |")
print("|---|---:|")
for key, label in (
("range_open_s", "fsspec handle open/lookup"),
("range_seek_s", "fsspec seek"),
("range_read_s", "fsspec read"),
("range_resolve_s", "http URL resolve"),
("range_header_s", "http response headers"),
("range_first_byte_s", "http first body byte"),
("range_body_s", "http body drain"),
):
value = fetch_pool.get(key)
if value is not None:
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
print(f"| range reads | {range_jobs:.0f} |")
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
def run_indexed_strategy(
meta: LeRobotDatasetMetadata,
data_root: str,
@@ -618,6 +644,7 @@ def run_indexed_strategy(
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
_print_range_timing_summary(fetch_pool)
_print_memory_summary(memory_start, _memory_snapshot())
if args.include_decode:
+123 -10
View File
@@ -61,6 +61,14 @@ class ThreadLocalRangeFetcher:
self.block_size = block_size
self.cache_type = cache_type
self._local = threading.local()
self._timing_lock = threading.Lock()
self._timing_totals = {
"range_jobs": 0.0,
"range_bytes": 0.0,
"range_open_s": 0.0,
"range_seek_s": 0.0,
"range_read_s": 0.0,
}
def _url(self, relative_path: str) -> str:
if self.data_root.startswith("hf://"):
@@ -84,9 +92,32 @@ class ThreadLocalRangeFetcher:
return int(self.fs.info(self._url(relative_path))["size"])
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
open_start = time.perf_counter()
handle = self._handle(relative_path)
open_s = time.perf_counter() - open_start
seek_start = time.perf_counter()
handle.seek(offset)
return handle.read(length)
seek_s = time.perf_counter() - seek_start
read_start = time.perf_counter()
data = handle.read(length)
read_s = time.perf_counter() - read_start
self._record_timing(
range_jobs=1.0,
range_bytes=float(len(data)),
range_open_s=open_s,
range_seek_s=seek_s,
range_read_s=read_s,
)
return data
def _record_timing(self, **kwargs: float) -> None:
with self._timing_lock:
for key, value in kwargs.items():
self._timing_totals[key] += value
def timing_summary(self) -> dict[str, float]:
with self._timing_lock:
return dict(self._timing_totals)
def close(self) -> None:
handles = getattr(self._local, "handles", None)
@@ -149,6 +180,15 @@ class NativeHTTPRangeFetcher:
self._source_urls: dict[str, str] = {}
self._sizes: dict[str, int] = {}
self._lock = threading.Lock()
self._timing_lock = threading.Lock()
self._timing_totals = {
"range_jobs": 0.0,
"range_bytes": 0.0,
"range_resolve_s": 0.0,
"range_header_s": 0.0,
"range_first_byte_s": 0.0,
"range_body_s": 0.0,
}
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
last_exc: Exception | None = None
@@ -270,22 +310,91 @@ class NativeHTTPRangeFetcher:
response.close()
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
resolve_start = time.perf_counter()
resolved = self._resolve_url(relative_path)
source = self._source_url(relative_path)
resolve_s = time.perf_counter() - resolve_start
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
response = self._request("GET", resolved, headers=headers)
if response.status_code == 403:
response.close()
payload, status_code, timings = self._read_range_response(resolved, headers)
if status_code == 403:
refresh_start = time.perf_counter()
resolved = self._resolve_url(relative_path, refresh=True)
resolve_s += time.perf_counter() - refresh_start
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
response = self._request("GET", resolved, headers=headers)
try:
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
for key, value in retry_timings.items():
timings[key] += value
if status_code == 403:
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
self._record_timing(
range_jobs=1.0,
range_bytes=float(len(payload)),
range_resolve_s=resolve_s,
**timings,
)
return payload
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
return self._read_range_response_once(url, headers)
except self._RETRYABLE_EXCEPTIONS as exc:
last_exc = exc
if attempt >= self.max_retries:
break
time.sleep(min(0.5 * 2**attempt, 5.0))
if last_exc is None:
raise RuntimeError("HTTP range request failed without an exception")
raise last_exc
def _read_range_response_once(
self, url: str, headers: dict[str, str]
) -> tuple[bytes, int, dict[str, float]]:
header_start = time.perf_counter()
with self.client.stream("GET", url, headers=headers) as response:
header_s = time.perf_counter() - header_start
if response.status_code == 403:
return (
b"",
response.status_code,
{
"range_header_s": header_s,
"range_first_byte_s": 0.0,
"range_body_s": 0.0,
},
)
hf_raise_for_status(response)
return response.content
finally:
response.close()
chunks = []
first_byte_s = 0.0
first_chunk = True
body_start = time.perf_counter()
for chunk in response.iter_bytes():
if first_chunk:
first_byte_s = time.perf_counter() - body_start
first_chunk = False
chunks.append(chunk)
body_s = time.perf_counter() - body_start
return (
b"".join(chunks),
response.status_code,
{
"range_header_s": header_s,
"range_first_byte_s": first_byte_s,
"range_body_s": body_s,
},
)
def _record_timing(self, **kwargs: float) -> None:
with self._timing_lock:
for key, value in kwargs.items():
self._timing_totals[key] += value
def timing_summary(self) -> dict[str, float]:
with self._timing_lock:
return dict(self._timing_totals)
def close(self) -> None:
self.client.close()
@@ -606,7 +715,11 @@ class EpisodeByteCache:
def timing_summary(self) -> dict[str, float]:
with self._lock:
return dict(self._timing_totals)
summary = dict(self._timing_totals)
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
if fetcher_summary is not None:
summary.update(fetcher_summary())
return summary
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
key = (episode_index, camera_key)