From 34d0495d0325db5090055035ed1967b3613b031a Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 17 Jun 2026 20:19:54 +0200 Subject: [PATCH] Retry transient native HTTP range failures --- .../datasets/episode_video_streaming.py | 43 ++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/src/lerobot/datasets/episode_video_streaming.py b/src/lerobot/datasets/episode_video_streaming.py index f40a62849..aeda64f68 100644 --- a/src/lerobot/datasets/episode_video_streaming.py +++ b/src/lerobot/datasets/episode_video_streaming.py @@ -106,10 +106,27 @@ class NativeHTTPRangeFetcher: _GLOBAL_SIZES: dict[tuple[str, str], int] = {} _GLOBAL_LOCK = threading.Lock() - def __init__(self, data_root: str | Path, *, max_connections: int = 32, timeout: float = 60.0): + _RETRYABLE_EXCEPTIONS = ( + httpx.ConnectError, + httpx.ConnectTimeout, + httpx.ReadError, + httpx.ReadTimeout, + httpx.RemoteProtocolError, + httpx.PoolTimeout, + ) + + def __init__( + self, + data_root: str | Path, + *, + max_connections: int = 32, + timeout: float = 60.0, + max_retries: int = 4, + ): self.data_root = str(data_root).rstrip("/") if not self.data_root.startswith("hf://"): raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots") + self.max_retries = max_retries self.api = HfApi() self.fs: HfFileSystem | None = None self._bucket_id: str | None = None @@ -133,6 +150,20 @@ class NativeHTTPRangeFetcher: self._sizes: dict[str, int] = {} self._lock = threading.Lock() + def _request(self, method: str, url: str, **kwargs) -> httpx.Response: + last_exc: Exception | None = None + for attempt in range(self.max_retries + 1): + try: + return self.client.request(method, url, **kwargs) + except self._RETRYABLE_EXCEPTIONS as exc: + last_exc = exc + if attempt >= self.max_retries: + break + time.sleep(min(0.5 * 2**attempt, 5.0)) + if last_exc is None: + raise RuntimeError("HTTP request failed without an exception") + raise last_exc + def _cache_key(self, relative_path: str) -> tuple[str, str]: return self.data_root, relative_path @@ -192,7 +223,7 @@ class NativeHTTPRangeFetcher: return resolved source = self._source_url(relative_path) - response = self.client.head(source, headers=self.api._build_hf_headers(), follow_redirects=False) + response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False) try: hf_raise_for_status(response) location = response.headers.get("Location") @@ -224,8 +255,8 @@ class NativeHTTPRangeFetcher: resolved = self._resolve_url(relative_path) source = self._source_url(relative_path) - response = self.client.head( - resolved, headers=self._headers_for(resolved, source), follow_redirects=True + response = self._request( + "HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True ) try: hf_raise_for_status(response) @@ -243,13 +274,13 @@ class NativeHTTPRangeFetcher: source = self._source_url(relative_path) headers = self._headers_for(resolved, source) headers["Range"] = f"bytes={offset}-{offset + length - 1}" - response = self.client.get(resolved, headers=headers) + response = self._request("GET", resolved, headers=headers) if response.status_code == 403: response.close() resolved = self._resolve_url(relative_path, refresh=True) headers = self._headers_for(resolved, source) headers["Range"] = f"bytes={offset}-{offset + length - 1}" - response = self.client.get(resolved, headers=headers) + response = self._request("GET", resolved, headers=headers) try: hf_raise_for_status(response) return response.content