diff --git a/scripts/bench_episode_byte_cache.py b/scripts/bench_episode_byte_cache.py index 3b47656f6..ac817b4ae 100644 --- a/scripts/bench_episode_byte_cache.py +++ b/scripts/bench_episode_byte_cache.py @@ -68,6 +68,24 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--pool-size", type=int, default=16) parser.add_argument("--workers", type=int, default=8) + parser.add_argument( + "--native-http-connections", + type=int, + default=None, + help="Max HTTP connections for --range-backend native-http. Defaults to --workers.", + ) + parser.add_argument( + "--native-http-retries", + type=int, + default=8, + help="Retries per native HTTP range request.", + ) + parser.add_argument( + "--native-http-timeout", + type=float, + default=120.0, + help="Timeout in seconds for native HTTP requests.", + ) parser.add_argument( "--include-decode", action="store_true", @@ -343,6 +361,7 @@ def run_fetch_pool( byte_budget: int, workers: int, range_backend: str, + args: argparse.Namespace, ) -> dict[str, float]: with EpisodeByteCache( manifest, @@ -350,6 +369,9 @@ def run_fetch_pool( byte_budget=byte_budget, workers=workers, range_backend=range_backend, + native_http_connections=args.native_http_connections, + native_http_timeout=args.native_http_timeout, + native_http_retries=args.native_http_retries, open_decoders=False, ) as cache: elapsed = _fill_cache(cache, episodes) @@ -571,10 +593,15 @@ def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None: ("range_header_s", "http response headers"), ("range_first_byte_s", "http first body byte"), ("range_body_s", "http body drain"), + ("range_retry_sleep_s", "http retry sleep"), ): value = fetch_pool.get(key) if value is not None: print(f"| {label} | {value * 1000 / range_jobs:.3f} |") + if "range_retry_attempts" in fetch_pool: + print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |") + if fetch_pool.get("range_failed_requests"): + print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |") print(f"| range reads | {range_jobs:.0f} |") print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |") @@ -617,7 +644,7 @@ def run_indexed_strategy( ) _log(f"{label}: filling episode byte cache with {args.workers} workers") - fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend) + fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args) estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"] estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"] diff --git a/src/lerobot/datasets/episode_video_streaming.py b/src/lerobot/datasets/episode_video_streaming.py index 245f7cabd..21564896c 100644 --- a/src/lerobot/datasets/episode_video_streaming.py +++ b/src/lerobot/datasets/episode_video_streaming.py @@ -188,6 +188,9 @@ class NativeHTTPRangeFetcher: "range_header_s": 0.0, "range_first_byte_s": 0.0, "range_body_s": 0.0, + "range_retry_attempts": 0.0, + "range_retry_sleep_s": 0.0, + "range_failed_requests": 0.0, } def _request(self, method: str, url: str, **kwargs) -> httpx.Response: @@ -338,14 +341,27 @@ class NativeHTTPRangeFetcher: def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]: last_exc: Exception | None = None + retry_attempts = 0.0 + retry_sleep_s = 0.0 for attempt in range(self.max_retries + 1): try: - return self._read_range_response_once(url, headers) + payload, status_code, timings = self._read_range_response_once(url, headers) + timings["range_retry_attempts"] = retry_attempts + timings["range_retry_sleep_s"] = retry_sleep_s + return payload, status_code, timings except self._RETRYABLE_EXCEPTIONS as exc: last_exc = exc if attempt >= self.max_retries: break - time.sleep(min(0.5 * 2**attempt, 5.0)) + retry_attempts += 1.0 + sleep_s = min(0.5 * 2**attempt, 5.0) + retry_sleep_s += sleep_s + time.sleep(sleep_s) + self._record_timing( + range_failed_requests=1.0, + range_retry_attempts=retry_attempts, + range_retry_sleep_s=retry_sleep_s, + ) if last_exc is None: raise RuntimeError("HTTP range request failed without an exception") raise last_exc @@ -400,11 +416,25 @@ class NativeHTTPRangeFetcher: self.client.close() -def make_range_fetcher(data_root: str | Path, *, range_backend: str, workers: int): +def make_range_fetcher( + data_root: str | Path, + *, + range_backend: str, + workers: int, + native_http_connections: int | None = None, + native_http_timeout: float = 60.0, + native_http_retries: int = 4, +): if range_backend == "fsspec": return ThreadLocalRangeFetcher(data_root) if range_backend == "native-http": - return NativeHTTPRangeFetcher(data_root, max_connections=max(8, workers * 4)) + max_connections = native_http_connections or max(8, workers) + return NativeHTTPRangeFetcher( + data_root, + max_connections=max_connections, + timeout=native_http_timeout, + max_retries=native_http_retries, + ) raise ValueError(f"Unknown range backend: {range_backend}") @@ -648,10 +678,20 @@ class EpisodeByteCache: byte_budget: int = 80 * 1024**3, workers: int = 8, range_backend: str = "fsspec", + native_http_connections: int | None = None, + native_http_timeout: float = 60.0, + native_http_retries: int = 4, open_decoders: bool = True, ): self.manifest = manifest - self.fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers) + self.fetcher = make_range_fetcher( + data_root, + range_backend=range_backend, + workers=workers, + native_http_connections=native_http_connections, + native_http_timeout=native_http_timeout, + native_http_retries=native_http_retries, + ) self.byte_budget = byte_budget self.open_decoders = open_decoders self._pool = ThreadPoolExecutor(max_workers=workers)