mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Report range read timing breakdown
This commit is contained in:
@@ -351,7 +351,7 @@ def run_fetch_pool(
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
job_count = max(timings["jobs"], 1.0)
|
||||
return {
|
||||
result = {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / elapsed,
|
||||
@@ -363,6 +363,8 @@ def run_fetch_pool(
|
||||
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
|
||||
"store_ms": timings["store_s"] * 1000 / job_count,
|
||||
}
|
||||
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||
return result
|
||||
|
||||
|
||||
def run_parallel(
|
||||
@@ -547,6 +549,30 @@ def run_remote_decoder(
|
||||
}
|
||||
|
||||
|
||||
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
|
||||
range_jobs = fetch_pool.get("range_jobs", 0.0)
|
||||
if range_jobs <= 0:
|
||||
return
|
||||
|
||||
print()
|
||||
print("| Range Read Stage | avg ms/range |")
|
||||
print("|---|---:|")
|
||||
for key, label in (
|
||||
("range_open_s", "fsspec handle open/lookup"),
|
||||
("range_seek_s", "fsspec seek"),
|
||||
("range_read_s", "fsspec read"),
|
||||
("range_resolve_s", "http URL resolve"),
|
||||
("range_header_s", "http response headers"),
|
||||
("range_first_byte_s", "http first body byte"),
|
||||
("range_body_s", "http body drain"),
|
||||
):
|
||||
value = fetch_pool.get(key)
|
||||
if value is not None:
|
||||
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
|
||||
print(f"| range reads | {range_jobs:.0f} |")
|
||||
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
|
||||
|
||||
|
||||
def run_indexed_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
@@ -618,6 +644,7 @@ def run_indexed_strategy(
|
||||
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
|
||||
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
|
||||
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
|
||||
_print_range_timing_summary(fetch_pool)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
|
||||
if args.include_decode:
|
||||
|
||||
@@ -61,6 +61,14 @@ class ThreadLocalRangeFetcher:
|
||||
self.block_size = block_size
|
||||
self.cache_type = cache_type
|
||||
self._local = threading.local()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_open_s": 0.0,
|
||||
"range_seek_s": 0.0,
|
||||
"range_read_s": 0.0,
|
||||
}
|
||||
|
||||
def _url(self, relative_path: str) -> str:
|
||||
if self.data_root.startswith("hf://"):
|
||||
@@ -84,9 +92,32 @@ class ThreadLocalRangeFetcher:
|
||||
return int(self.fs.info(self._url(relative_path))["size"])
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
open_start = time.perf_counter()
|
||||
handle = self._handle(relative_path)
|
||||
open_s = time.perf_counter() - open_start
|
||||
seek_start = time.perf_counter()
|
||||
handle.seek(offset)
|
||||
return handle.read(length)
|
||||
seek_s = time.perf_counter() - seek_start
|
||||
read_start = time.perf_counter()
|
||||
data = handle.read(length)
|
||||
read_s = time.perf_counter() - read_start
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(data)),
|
||||
range_open_s=open_s,
|
||||
range_seek_s=seek_s,
|
||||
range_read_s=read_s,
|
||||
)
|
||||
return data
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] += value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
handles = getattr(self._local, "handles", None)
|
||||
@@ -149,6 +180,15 @@ class NativeHTTPRangeFetcher:
|
||||
self._source_urls: dict[str, str] = {}
|
||||
self._sizes: dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_resolve_s": 0.0,
|
||||
"range_header_s": 0.0,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
}
|
||||
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
last_exc: Exception | None = None
|
||||
@@ -270,22 +310,91 @@ class NativeHTTPRangeFetcher:
|
||||
response.close()
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
resolve_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
resolve_s = time.perf_counter() - resolve_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
response = self._request("GET", resolved, headers=headers)
|
||||
if response.status_code == 403:
|
||||
response.close()
|
||||
payload, status_code, timings = self._read_range_response(resolved, headers)
|
||||
if status_code == 403:
|
||||
refresh_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path, refresh=True)
|
||||
resolve_s += time.perf_counter() - refresh_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
response = self._request("GET", resolved, headers=headers)
|
||||
try:
|
||||
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
|
||||
for key, value in retry_timings.items():
|
||||
timings[key] += value
|
||||
if status_code == 403:
|
||||
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(payload)),
|
||||
range_resolve_s=resolve_s,
|
||||
**timings,
|
||||
)
|
||||
return payload
|
||||
|
||||
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self._read_range_response_once(url, headers)
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
time.sleep(min(0.5 * 2**attempt, 5.0))
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP range request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _read_range_response_once(
|
||||
self, url: str, headers: dict[str, str]
|
||||
) -> tuple[bytes, int, dict[str, float]]:
|
||||
header_start = time.perf_counter()
|
||||
with self.client.stream("GET", url, headers=headers) as response:
|
||||
header_s = time.perf_counter() - header_start
|
||||
if response.status_code == 403:
|
||||
return (
|
||||
b"",
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
},
|
||||
)
|
||||
hf_raise_for_status(response)
|
||||
return response.content
|
||||
finally:
|
||||
response.close()
|
||||
chunks = []
|
||||
first_byte_s = 0.0
|
||||
first_chunk = True
|
||||
body_start = time.perf_counter()
|
||||
for chunk in response.iter_bytes():
|
||||
if first_chunk:
|
||||
first_byte_s = time.perf_counter() - body_start
|
||||
first_chunk = False
|
||||
chunks.append(chunk)
|
||||
body_s = time.perf_counter() - body_start
|
||||
return (
|
||||
b"".join(chunks),
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": first_byte_s,
|
||||
"range_body_s": body_s,
|
||||
},
|
||||
)
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] += value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
@@ -606,7 +715,11 @@ class EpisodeByteCache:
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._lock:
|
||||
return dict(self._timing_totals)
|
||||
summary = dict(self._timing_totals)
|
||||
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
|
||||
if fetcher_summary is not None:
|
||||
summary.update(fetcher_summary())
|
||||
return summary
|
||||
|
||||
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
|
||||
key = (episode_index, camera_key)
|
||||
|
||||
Reference in New Issue
Block a user