diff --git a/scripts/bench_episode_byte_cache.py b/scripts/bench_episode_byte_cache.py index f8b178d83..eea8e2cbc 100644 --- a/scripts/bench_episode_byte_cache.py +++ b/scripts/bench_episode_byte_cache.py @@ -351,7 +351,7 @@ def run_fetch_pool( byte_count = _bytes_for(manifest, episodes) episode_mb = byte_count / len(episodes) / 1024**2 job_count = max(timings["jobs"], 1.0) - return { + result = { "fetch_s": elapsed, "fetch_mbps": byte_count / elapsed / 1024**2, "fetch_episodes_s": len(episodes) / elapsed, @@ -363,6 +363,8 @@ def run_fetch_pool( "synthesize_ms": timings["synthesize_s"] * 1000 / job_count, "store_ms": timings["store_s"] * 1000 / job_count, } + result.update({key: value for key, value in timings.items() if key.startswith("range_")}) + return result def run_parallel( @@ -547,6 +549,30 @@ def run_remote_decoder( } +def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None: + range_jobs = fetch_pool.get("range_jobs", 0.0) + if range_jobs <= 0: + return + + print() + print("| Range Read Stage | avg ms/range |") + print("|---|---:|") + for key, label in ( + ("range_open_s", "fsspec handle open/lookup"), + ("range_seek_s", "fsspec seek"), + ("range_read_s", "fsspec read"), + ("range_resolve_s", "http URL resolve"), + ("range_header_s", "http response headers"), + ("range_first_byte_s", "http first body byte"), + ("range_body_s", "http body drain"), + ): + value = fetch_pool.get(key) + if value is not None: + print(f"| {label} | {value * 1000 / range_jobs:.3f} |") + print(f"| range reads | {range_jobs:.0f} |") + print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |") + + def run_indexed_strategy( meta: LeRobotDatasetMetadata, data_root: str, @@ -618,6 +644,7 @@ def run_indexed_strategy( print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |") print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |") print(f"| camera jobs | {fetch_pool['jobs']:.0f} |") + _print_range_timing_summary(fetch_pool) _print_memory_summary(memory_start, _memory_snapshot()) if args.include_decode: diff --git a/src/lerobot/datasets/episode_video_streaming.py b/src/lerobot/datasets/episode_video_streaming.py index dff8aa596..245f7cabd 100644 --- a/src/lerobot/datasets/episode_video_streaming.py +++ b/src/lerobot/datasets/episode_video_streaming.py @@ -61,6 +61,14 @@ class ThreadLocalRangeFetcher: self.block_size = block_size self.cache_type = cache_type self._local = threading.local() + self._timing_lock = threading.Lock() + self._timing_totals = { + "range_jobs": 0.0, + "range_bytes": 0.0, + "range_open_s": 0.0, + "range_seek_s": 0.0, + "range_read_s": 0.0, + } def _url(self, relative_path: str) -> str: if self.data_root.startswith("hf://"): @@ -84,9 +92,32 @@ class ThreadLocalRangeFetcher: return int(self.fs.info(self._url(relative_path))["size"]) def read_range(self, relative_path: str, offset: int, length: int) -> bytes: + open_start = time.perf_counter() handle = self._handle(relative_path) + open_s = time.perf_counter() - open_start + seek_start = time.perf_counter() handle.seek(offset) - return handle.read(length) + seek_s = time.perf_counter() - seek_start + read_start = time.perf_counter() + data = handle.read(length) + read_s = time.perf_counter() - read_start + self._record_timing( + range_jobs=1.0, + range_bytes=float(len(data)), + range_open_s=open_s, + range_seek_s=seek_s, + range_read_s=read_s, + ) + return data + + def _record_timing(self, **kwargs: float) -> None: + with self._timing_lock: + for key, value in kwargs.items(): + self._timing_totals[key] += value + + def timing_summary(self) -> dict[str, float]: + with self._timing_lock: + return dict(self._timing_totals) def close(self) -> None: handles = getattr(self._local, "handles", None) @@ -149,6 +180,15 @@ class NativeHTTPRangeFetcher: self._source_urls: dict[str, str] = {} self._sizes: dict[str, int] = {} self._lock = threading.Lock() + self._timing_lock = threading.Lock() + self._timing_totals = { + "range_jobs": 0.0, + "range_bytes": 0.0, + "range_resolve_s": 0.0, + "range_header_s": 0.0, + "range_first_byte_s": 0.0, + "range_body_s": 0.0, + } def _request(self, method: str, url: str, **kwargs) -> httpx.Response: last_exc: Exception | None = None @@ -270,22 +310,91 @@ class NativeHTTPRangeFetcher: response.close() def read_range(self, relative_path: str, offset: int, length: int) -> bytes: + resolve_start = time.perf_counter() resolved = self._resolve_url(relative_path) source = self._source_url(relative_path) + resolve_s = time.perf_counter() - resolve_start headers = self._headers_for(resolved, source) headers["Range"] = f"bytes={offset}-{offset + length - 1}" - response = self._request("GET", resolved, headers=headers) - if response.status_code == 403: - response.close() + payload, status_code, timings = self._read_range_response(resolved, headers) + if status_code == 403: + refresh_start = time.perf_counter() resolved = self._resolve_url(relative_path, refresh=True) + resolve_s += time.perf_counter() - refresh_start headers = self._headers_for(resolved, source) headers["Range"] = f"bytes={offset}-{offset + length - 1}" - response = self._request("GET", resolved, headers=headers) - try: + payload, status_code, retry_timings = self._read_range_response(resolved, headers) + for key, value in retry_timings.items(): + timings[key] += value + if status_code == 403: + raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}") + self._record_timing( + range_jobs=1.0, + range_bytes=float(len(payload)), + range_resolve_s=resolve_s, + **timings, + ) + return payload + + def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]: + last_exc: Exception | None = None + for attempt in range(self.max_retries + 1): + try: + return self._read_range_response_once(url, headers) + except self._RETRYABLE_EXCEPTIONS as exc: + last_exc = exc + if attempt >= self.max_retries: + break + time.sleep(min(0.5 * 2**attempt, 5.0)) + if last_exc is None: + raise RuntimeError("HTTP range request failed without an exception") + raise last_exc + + def _read_range_response_once( + self, url: str, headers: dict[str, str] + ) -> tuple[bytes, int, dict[str, float]]: + header_start = time.perf_counter() + with self.client.stream("GET", url, headers=headers) as response: + header_s = time.perf_counter() - header_start + if response.status_code == 403: + return ( + b"", + response.status_code, + { + "range_header_s": header_s, + "range_first_byte_s": 0.0, + "range_body_s": 0.0, + }, + ) hf_raise_for_status(response) - return response.content - finally: - response.close() + chunks = [] + first_byte_s = 0.0 + first_chunk = True + body_start = time.perf_counter() + for chunk in response.iter_bytes(): + if first_chunk: + first_byte_s = time.perf_counter() - body_start + first_chunk = False + chunks.append(chunk) + body_s = time.perf_counter() - body_start + return ( + b"".join(chunks), + response.status_code, + { + "range_header_s": header_s, + "range_first_byte_s": first_byte_s, + "range_body_s": body_s, + }, + ) + + def _record_timing(self, **kwargs: float) -> None: + with self._timing_lock: + for key, value in kwargs.items(): + self._timing_totals[key] += value + + def timing_summary(self) -> dict[str, float]: + with self._timing_lock: + return dict(self._timing_totals) def close(self) -> None: self.client.close() @@ -606,7 +715,11 @@ class EpisodeByteCache: def timing_summary(self) -> dict[str, float]: with self._lock: - return dict(self._timing_totals) + summary = dict(self._timing_totals) + fetcher_summary = getattr(self.fetcher, "timing_summary", None) + if fetcher_summary is not None: + summary.update(fetcher_summary()) + return summary def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]: key = (episode_index, camera_key)