mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Tune native HTTP range diagnostics
This commit is contained in:
@@ -68,6 +68,24 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument("--pool-size", type=int, default=16)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--native-http-connections",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max HTTP connections for --range-backend native-http. Defaults to --workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-retries",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Retries per native HTTP range request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-timeout",
|
||||
type=float,
|
||||
default=120.0,
|
||||
help="Timeout in seconds for native HTTP requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-decode",
|
||||
action="store_true",
|
||||
@@ -343,6 +361,7 @@ def run_fetch_pool(
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
range_backend: str,
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
@@ -350,6 +369,9 @@ def run_fetch_pool(
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
native_http_connections=args.native_http_connections,
|
||||
native_http_timeout=args.native_http_timeout,
|
||||
native_http_retries=args.native_http_retries,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
elapsed = _fill_cache(cache, episodes)
|
||||
@@ -571,10 +593,15 @@ def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
|
||||
("range_header_s", "http response headers"),
|
||||
("range_first_byte_s", "http first body byte"),
|
||||
("range_body_s", "http body drain"),
|
||||
("range_retry_sleep_s", "http retry sleep"),
|
||||
):
|
||||
value = fetch_pool.get(key)
|
||||
if value is not None:
|
||||
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
|
||||
if "range_retry_attempts" in fetch_pool:
|
||||
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
|
||||
if fetch_pool.get("range_failed_requests"):
|
||||
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
|
||||
print(f"| range reads | {range_jobs:.0f} |")
|
||||
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
|
||||
|
||||
@@ -617,7 +644,7 @@ def run_indexed_strategy(
|
||||
)
|
||||
|
||||
_log(f"{label}: filling episode byte cache with {args.workers} workers")
|
||||
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend)
|
||||
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
|
||||
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
|
||||
|
||||
@@ -188,6 +188,9 @@ class NativeHTTPRangeFetcher:
|
||||
"range_header_s": 0.0,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
"range_retry_attempts": 0.0,
|
||||
"range_retry_sleep_s": 0.0,
|
||||
"range_failed_requests": 0.0,
|
||||
}
|
||||
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
@@ -338,14 +341,27 @@ class NativeHTTPRangeFetcher:
|
||||
|
||||
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
|
||||
last_exc: Exception | None = None
|
||||
retry_attempts = 0.0
|
||||
retry_sleep_s = 0.0
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self._read_range_response_once(url, headers)
|
||||
payload, status_code, timings = self._read_range_response_once(url, headers)
|
||||
timings["range_retry_attempts"] = retry_attempts
|
||||
timings["range_retry_sleep_s"] = retry_sleep_s
|
||||
return payload, status_code, timings
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
time.sleep(min(0.5 * 2**attempt, 5.0))
|
||||
retry_attempts += 1.0
|
||||
sleep_s = min(0.5 * 2**attempt, 5.0)
|
||||
retry_sleep_s += sleep_s
|
||||
time.sleep(sleep_s)
|
||||
self._record_timing(
|
||||
range_failed_requests=1.0,
|
||||
range_retry_attempts=retry_attempts,
|
||||
range_retry_sleep_s=retry_sleep_s,
|
||||
)
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP range request failed without an exception")
|
||||
raise last_exc
|
||||
@@ -400,11 +416,25 @@ class NativeHTTPRangeFetcher:
|
||||
self.client.close()
|
||||
|
||||
|
||||
def make_range_fetcher(data_root: str | Path, *, range_backend: str, workers: int):
|
||||
def make_range_fetcher(
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
):
|
||||
if range_backend == "fsspec":
|
||||
return ThreadLocalRangeFetcher(data_root)
|
||||
if range_backend == "native-http":
|
||||
return NativeHTTPRangeFetcher(data_root, max_connections=max(8, workers * 4))
|
||||
max_connections = native_http_connections or max(8, workers)
|
||||
return NativeHTTPRangeFetcher(
|
||||
data_root,
|
||||
max_connections=max_connections,
|
||||
timeout=native_http_timeout,
|
||||
max_retries=native_http_retries,
|
||||
)
|
||||
raise ValueError(f"Unknown range backend: {range_backend}")
|
||||
|
||||
|
||||
@@ -648,10 +678,20 @@ class EpisodeByteCache:
|
||||
byte_budget: int = 80 * 1024**3,
|
||||
workers: int = 8,
|
||||
range_backend: str = "fsspec",
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
open_decoders: bool = True,
|
||||
):
|
||||
self.manifest = manifest
|
||||
self.fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
|
||||
self.fetcher = make_range_fetcher(
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
native_http_connections=native_http_connections,
|
||||
native_http_timeout=native_http_timeout,
|
||||
native_http_retries=native_http_retries,
|
||||
)
|
||||
self.byte_budget = byte_budget
|
||||
self.open_decoders = open_decoders
|
||||
self._pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
Reference in New Issue
Block a user