mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Retry transient native HTTP range failures
This commit is contained in:
@@ -106,10 +106,27 @@ class NativeHTTPRangeFetcher:
|
||||
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
|
||||
_GLOBAL_LOCK = threading.Lock()
|
||||
|
||||
def __init__(self, data_root: str | Path, *, max_connections: int = 32, timeout: float = 60.0):
|
||||
_RETRYABLE_EXCEPTIONS = (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.PoolTimeout,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
max_connections: int = 32,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 4,
|
||||
):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
if not self.data_root.startswith("hf://"):
|
||||
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
|
||||
self.max_retries = max_retries
|
||||
self.api = HfApi()
|
||||
self.fs: HfFileSystem | None = None
|
||||
self._bucket_id: str | None = None
|
||||
@@ -133,6 +150,20 @@ class NativeHTTPRangeFetcher:
|
||||
self._sizes: dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self.client.request(method, url, **kwargs)
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
time.sleep(min(0.5 * 2**attempt, 5.0))
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _cache_key(self, relative_path: str) -> tuple[str, str]:
|
||||
return self.data_root, relative_path
|
||||
|
||||
@@ -192,7 +223,7 @@ class NativeHTTPRangeFetcher:
|
||||
return resolved
|
||||
|
||||
source = self._source_url(relative_path)
|
||||
response = self.client.head(source, headers=self.api._build_hf_headers(), follow_redirects=False)
|
||||
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
location = response.headers.get("Location")
|
||||
@@ -224,8 +255,8 @@ class NativeHTTPRangeFetcher:
|
||||
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
response = self.client.head(
|
||||
resolved, headers=self._headers_for(resolved, source), follow_redirects=True
|
||||
response = self._request(
|
||||
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
|
||||
)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
@@ -243,13 +274,13 @@ class NativeHTTPRangeFetcher:
|
||||
source = self._source_url(relative_path)
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
response = self.client.get(resolved, headers=headers)
|
||||
response = self._request("GET", resolved, headers=headers)
|
||||
if response.status_code == 403:
|
||||
response.close()
|
||||
resolved = self._resolve_url(relative_path, refresh=True)
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
response = self.client.get(resolved, headers=headers)
|
||||
response = self._request("GET", resolved, headers=headers)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
return response.content
|
||||
|
||||
Reference in New Issue
Block a user