Retry transient native HTTP range failures

This commit is contained in:
Pepijn
2026-06-17 20:19:54 +02:00
parent 834c282631
commit 34d0495d03
@@ -106,10 +106,27 @@ class NativeHTTPRangeFetcher:
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
_GLOBAL_LOCK = threading.Lock()
def __init__(self, data_root: str | Path, *, max_connections: int = 32, timeout: float = 60.0):
_RETRYABLE_EXCEPTIONS = (
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
httpx.PoolTimeout,
)
def __init__(
self,
data_root: str | Path,
*,
max_connections: int = 32,
timeout: float = 60.0,
max_retries: int = 4,
):
self.data_root = str(data_root).rstrip("/")
if not self.data_root.startswith("hf://"):
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
self.max_retries = max_retries
self.api = HfApi()
self.fs: HfFileSystem | None = None
self._bucket_id: str | None = None
@@ -133,6 +150,20 @@ class NativeHTTPRangeFetcher:
self._sizes: dict[str, int] = {}
self._lock = threading.Lock()
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
return self.client.request(method, url, **kwargs)
except self._RETRYABLE_EXCEPTIONS as exc:
last_exc = exc
if attempt >= self.max_retries:
break
time.sleep(min(0.5 * 2**attempt, 5.0))
if last_exc is None:
raise RuntimeError("HTTP request failed without an exception")
raise last_exc
def _cache_key(self, relative_path: str) -> tuple[str, str]:
return self.data_root, relative_path
@@ -192,7 +223,7 @@ class NativeHTTPRangeFetcher:
return resolved
source = self._source_url(relative_path)
response = self.client.head(source, headers=self.api._build_hf_headers(), follow_redirects=False)
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
try:
hf_raise_for_status(response)
location = response.headers.get("Location")
@@ -224,8 +255,8 @@ class NativeHTTPRangeFetcher:
resolved = self._resolve_url(relative_path)
source = self._source_url(relative_path)
response = self.client.head(
resolved, headers=self._headers_for(resolved, source), follow_redirects=True
response = self._request(
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
)
try:
hf_raise_for_status(response)
@@ -243,13 +274,13 @@ class NativeHTTPRangeFetcher:
source = self._source_url(relative_path)
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
response = self.client.get(resolved, headers=headers)
response = self._request("GET", resolved, headers=headers)
if response.status_code == 403:
response.close()
resolved = self._resolve_url(relative_path, refresh=True)
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
response = self.client.get(resolved, headers=headers)
response = self._request("GET", resolved, headers=headers)
try:
hf_raise_for_status(response)
return response.content