Merge remote episode-pool work into the full pool rewrite

The remote commit (2ab71231c) added an opt-in episode pool, deferred
decode in the legacy buffer path, decode/fetch timing instrumentation,
remote-IO retries (video_utils), and 32MB row-group writing
(dataset_tools). The pool rewrite on this side makes the episode pool
the only iteration path (with prefetch-on-admit, per-consumer seeding,
worker-exact fast-forward resume), so streaming_dataset.py resolves to
the rewrite with the remote instrumentation ported into it:

- 5-slot shared counters + timing_stats() (decode_s_total/fetch_s_total)
- fetch timed around episode admission, decode timed around emission
- benchmark/slurm keep the remote updates, with episode_pool_size as the
  knob (buffer_size deprecated and ignored)

video_utils retries and dataset_tools row groups are taken unchanged.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-11 15:17:04 +02:00
5 changed files with 296 additions and 45 deletions
+135 -22
View File
@@ -36,7 +36,9 @@ is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarm
import argparse
import csv
import json
import os
import statistics
import threading
import time
from pathlib import Path
@@ -47,6 +49,60 @@ from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
def _tree_rss_bytes() -> int:
"""Sum RSS of this process and all its descendants via /proc (Linux only; 0 elsewhere).
DataLoader workers are separate processes, so the parent's own RSS misses most of the pipeline's
memory. Walking the process tree captures the real footprint (parquet buffers + decoders + shuffle).
"""
try:
children: dict[int, list[int]] = {}
for entry in os.listdir("/proc"):
if not entry.isdigit():
continue
try:
with open(f"/proc/{entry}/stat") as f:
ppid = int(f.read().split(") ", 1)[1].split()[1])
children.setdefault(ppid, []).append(int(entry))
except (OSError, ValueError, IndexError):
pass
total, stack = 0, [os.getpid()]
while stack:
cur = stack.pop()
try:
with open(f"/proc/{cur}/statm") as f:
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
except (OSError, ValueError, IndexError):
pass
stack.extend(children.get(cur, []))
return total
except OSError:
return 0
class PeakRSSSampler:
"""Background thread tracking peak process-tree RSS for the duration of the `with` block."""
def __init__(self, interval_s: float = 0.5):
self.interval_s = interval_s
self.peak_bytes = 0
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
def _run(self) -> None:
while not self._stop.is_set():
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
self._stop.wait(self.interval_s)
def __enter__(self) -> "PeakRSSSampler":
self._thread.start()
return self
def __exit__(self, *exc) -> None:
self._stop.set()
self._thread.join(timeout=2)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--repo_id", type=str, required=True)
@@ -62,8 +118,29 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--episode_pool_size", type=int, default=64)
parser.add_argument(
"--prefetch_factor",
type=int,
default=2,
help="DataLoader batches prefetched per worker. Higher hides IO/decode latency but raises RAM "
"(prefetch_factor x num_workers x batch_size decoded frames held in flight). Ignored if num_workers=0.",
)
parser.add_argument("--buffer_size", type=int, default=None, help="Deprecated; ignored.")
parser.add_argument(
"--max_num_shards",
type=int,
default=16,
help="Cap on concurrently-open stream shards. Each open shard holds ~one parquet row group in "
"RAM; reading from an hf:// bucket buffers ~5x more per shard than hf:// datasets, so lower this "
"(e.g. to num_workers) for bucket sources to avoid OOM. All data is still covered via re-sharding.",
)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument(
"--episode_pool_size",
type=int,
default=64,
help="Whole episodes each consumer keeps open to shuffle across (the randomness knob).",
)
parser.add_argument(
"--video_decode_device",
type=str,
@@ -86,9 +163,10 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str
root=args.root,
data_files_root=args.data_files_root,
delta_timestamps=delta_timestamps,
episode_pool_size=args.episode_pool_size,
max_num_shards=args.max_num_shards,
video_decoder_cache_size=args.video_decoder_cache_size,
video_decode_device=args.video_decode_device,
episode_pool_size=args.episode_pool_size,
tolerance_s=1e-3,
)
@@ -116,37 +194,43 @@ def main() -> None:
# tensors errors). Pin only when decode is on CPU and we copy to a CUDA device.
pin_memory=device.type == "cuda" and not gpu_decode,
drop_last=True,
prefetch_factor=2 if args.num_workers > 0 else None,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
# CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method.
multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None,
)
sample_latencies_ms: list[float] = []
episodes_per_batch: list[int] = [] # shuffle-randomness proxy: distinct episodes within a batch
frames = 0
first_batch_latency_s = None
steady_start = None # wall-clock start of the post-warmup measurement window
t_start = time.perf_counter()
t_prev = t_start
for i, batch in enumerate(loader):
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
for value in batch.values():
if torch.is_tensor(value):
value.to(device, non_blocking=device.type == "cuda")
now = time.perf_counter()
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
with PeakRSSSampler() as rss:
for i, batch in enumerate(loader):
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
for value in batch.values():
if torch.is_tensor(value):
value.to(device, non_blocking=device.type == "cuda")
now = time.perf_counter()
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
if i == args.warmup_batches:
# Start the steady window here; the slow first batch and the prefetch queue it filled are
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
frames += args.batch_size
t_prev = now
if i + 1 >= args.num_batches:
break
if i == args.warmup_batches:
# Start the steady window here; the slow first batch and the prefetch queue it filled are
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
frames += args.batch_size
ep = batch.get("episode_index")
if torch.is_tensor(ep):
episodes_per_batch.append(int(torch.unique(ep).numel()))
t_prev = now
if i + 1 >= args.num_batches:
break
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
now = time.perf_counter()
elapsed = now - t_start
@@ -154,6 +238,14 @@ def main() -> None:
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
cache_stats = dataset.video_decoder_cache_stats()
timing = dataset.timing_stats() # cumulative decode/fetch seconds summed across workers
# Image (camera frame) resolution as decoded, e.g. [C, H, W]. Read from the dataset feature contract.
image_shape = list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
# Decode/fetch overlap in wall-clock (workers run in parallel), so normalize against the total worker
# budget (num_workers x wallclock) to express each stage as a fraction of available worker time.
worker_budget_s = max(args.num_workers, 1) * elapsed
decode_pct = round(100 * timing["decode_s_total"] / worker_budget_s, 1) if worker_budget_s else None
fetch_pct = round(100 * timing["fetch_s_total"] / worker_budget_s, 1) if worker_budget_s else None
# A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode
# error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is
@@ -172,11 +264,22 @@ def main() -> None:
"mode": args.mode,
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"prefetch_factor": args.prefetch_factor if args.num_workers > 0 else None,
"buffer_size": args.buffer_size,
"episode_pool_size": args.episode_pool_size,
"episodes_per_batch_mean": round(statistics.mean(episodes_per_batch), 1)
if episodes_per_batch
else None,
# Fraction of a batch that is distinct episodes; ~1.0 ≈ map-style uniform, low ≈ correlated.
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
if episodes_per_batch
else None,
"num_cameras": len(meta.video_keys),
"image_shape": image_shape,
"fps": meta.fps,
"device": str(device),
"video_decode_device": args.video_decode_device,
"peak_rss_gb": peak_rss_gb,
"frames_measured": frames,
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4),
"frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
@@ -186,13 +289,23 @@ def main() -> None:
else None,
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
"total_time_s": round(elapsed, 2),
"steady_time_s": round(steady_elapsed_s, 2),
"wallclock_s": round(elapsed, 2),
"decode_s_total": timing["decode_s_total"],
"fetch_s_total": timing["fetch_s_total"],
"decode_pct_worker_time": decode_pct,
"fetch_pct_worker_time": fetch_pct,
"video_decoder_cache": cache_stats,
}
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
tag = f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}_{args.video_decode_device}"
pool_tag = f"_ep{args.episode_pool_size}" if args.episode_pool_size else ""
tag = (
f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}"
f"_pf{args.prefetch_factor}{pool_tag}_{args.video_decode_device}"
)
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()}
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
+11 -4
View File
@@ -34,9 +34,14 @@ GPUS=${GPUS:-1}
SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement)
CPU_WORKERS=${CPU_WORKERS:-8}
GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session
CPU_BUFFER=${CPU_BUFFER:-4000}
GPU_BUFFER=${GPU_BUFFER:-1000} # smaller buffer bounds on-GPU frame memory
CPU_BUFFER=${CPU_BUFFER:-64} # episode pool size (whole episodes per consumer; tabular-only RAM)
GPU_BUFFER=${GPU_BUFFER:-32} # smaller episode pool bounds in-flight decoded frames
# Cap concurrently-open stream shards. Each open shard holds ~one parquet row group in RAM, and reading
# from an hf:// bucket buffers ~5x more per shard than hf:// datasets (~1.2GB vs ~0.26GB). So for bucket
# sources default to num_workers (1 shard/worker); hub keeps 16. Override globally with MAX_SHARDS.
MAX_SHARDS=${MAX_SHARDS:-}
BATCH_SIZE=${BATCH_SIZE:-64}
PREFETCH=${PREFETCH:-2} # DataLoader batches prefetched per worker (higher = more throughput + RAM)
RUN=${RUN:-python}
# CONDA_ENV=<name> runs each job via `conda run -n <name>` (no activation needed inside the dash --wrap;
# --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11)
@@ -69,6 +74,7 @@ for SOURCE in $SOURCES; do
for MODE in $MODES; do
for DECODE in $DECODES; do
if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi
if [ -n "$MAX_SHARDS" ]; then S=$MAX_SHARDS; elif [ "$SOURCE" = hub ]; then S=16; else S=$W; fi
# Run strictly after the previous job so only one job touches the network at a time.
DEPFLAG=""
if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi
@@ -83,7 +89,8 @@ for SOURCE in $SOURCES; do
$RUN benchmarks/streaming/benchmark_streaming.py \
--repo_id $REPO_ID $ROOTFLAG \
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
--batch_size $BATCH_SIZE --num_workers $W --episode_pool_size $B \
--batch_size $BATCH_SIZE --num_workers $W --prefetch_factor $PREFETCH \
--episode_pool_size $B --max_num_shards $S \
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
@@ -96,5 +103,5 @@ done
echo
echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))."
echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)"
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_<decode>.{json,csv}"
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_pf<prefetch>_<decode>.{json,csv}"
echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR"
+11 -2
View File
@@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
ep_dataset = embed_images(ep_dataset)
table = ep_dataset.with_format("arrow")[:]
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
writer.write_table(table)
# Emit several row groups with a page index instead of one giant row group. A single row group forces
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
# groups bounds per-shard memory while keeping groups large enough to scan
# efficiently; the page index lets readers skip to the pages they need.
target_row_group_bytes = 32 * 1024 * 1024
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
)
writer.write_table(table, row_group_size=row_group_size)
writer.close()
+21 -6
View File
@@ -16,6 +16,7 @@
import logging
import os
import shutil
import time
from collections.abc import Callable, Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
@@ -175,7 +176,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
revision: str | None = None,
force_cache_sync: bool = False,
streaming: bool = True,
episode_pool_size: int = 64,
episode_pool_size: int | None = 64,
buffer_size: int | None = None,
max_num_shards: int | None = None,
seed: int = 42,
@@ -254,7 +255,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"StreamingLeRobotDataset: `buffer_size` is deprecated and ignored; "
"use `episode_pool_size` (whole episodes, not frames)."
)
self.episode_pool_size = max(1, episode_pool_size)
self.episode_pool_size = max(1, episode_pool_size) if episode_pool_size else 64
self.max_num_shards = max_num_shards
self._return_uint8 = return_uint8
@@ -268,9 +269,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
self._prefetcher: _VideoPrefetcher | None = None
# Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into
# one place the main process can read after iteration (see video_decoder_cache_stats()).
self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_()
# Shared [hits, misses, evictions, decode_ns, fetch_ns] tensor so DataLoader workers aggregate
# decoder-cache stats and component timings into one place the main process can read after
# iteration (see video_decoder_cache_stats() / timing_stats()).
self._cache_counters = torch.zeros(5, dtype=torch.int64).share_memory_()
# Deterministic fast-forward resume (see load_state_dict): per-consumer epoch counter and
# number of samples still to skip.
self._epoch = 0
@@ -479,11 +481,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
admitted = 0
while len(pool) < self.episode_pool_size and streams:
stream = streams[next_stream % len(streams)]
fetch_start = time.perf_counter_ns()
try:
ep_idx, rows = next(stream)
except StopIteration:
streams.remove(stream)
continue
finally:
self._cache_counters[4] += time.perf_counter_ns() - fetch_start
next_stream += 1
episode = self._admit_episode(ep_idx, rows, prefetcher)
pool.append(episode)
@@ -567,7 +572,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
approximate; the ``hit_rate`` ratio is preserved.
"""
hits, misses, evictions = (int(x) for x in self._cache_counters.tolist())
hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist())
total = hits + misses
return {
"hits": hits,
@@ -576,6 +581,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"hit_rate": round(hits / total, 4) if total else 0.0,
}
def timing_stats(self) -> dict[str, float]:
"""Cumulative seconds spent in video decode and episode (tabular) fetch, summed across
DataLoader workers via the shared counter tensor. These overlap in wall-clock (workers run
in parallel), so compare them to ``num_workers x wallclock`` for time fractions.
"""
decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist())
return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)}
def _make_pool_sample(self, episode: _PooledEpisode, frame_pos: int) -> dict:
"""Assemble a full training sample for one pooled frame (tabular slices + video decode)."""
rows = episode.rows
@@ -603,7 +616,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
query_timestamps = self._get_query_timestamps(
current_ts, self.delta_indices, episode_boundaries_ts
)
decode_start = time.perf_counter_ns()
video_frames = self._query_videos(query_timestamps, ep_idx)
self._cache_counters[3] += time.perf_counter_ns() - decode_start
if self.image_transforms is not None:
for cam in self.meta.camera_keys:
+118 -11
View File
@@ -22,6 +22,7 @@ import queue
import shutil
import tempfile
import threading
import time
import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
@@ -47,6 +48,92 @@ from lerobot.utils.import_utils import get_safe_default_video_backend
logger = logging.getLogger(__name__)
DEFAULT_REMOTE_IO_MAX_RETRIES = 5
"""Retry budget for transient hf:// / fsspec / httpx transport errors during streaming video decode.
Streaming a dataset from an HF bucket/CDN issues many small range requests and occasionally hits a
transient transport failure (timeout, dropped connection, 408/5xx). The right response is to rebuild
the connection and retry rather than crash the DataLoader worker. Override via
``LEROBOT_REMOTE_IO_MAX_RETRIES``; set to ``0`` to disable retries (fail fast).
"""
# Transient transport failures from the hf:// -> fsspec -> httpx stack. We match on text because the
# concrete exception types live in optional deps (httpx, huggingface_hub) and vary across versions.
# "client has been closed" is the important one: once a shared httpx client is closed by a single
# failed read, every subsequent read in that worker fails until the fsspec instance cache is cleared.
_RETRYABLE_TRANSPORT_FRAGMENTS = (
"client has been closed",
"server disconnected",
"remoteprotocolerror",
"unexpected_eof",
"eof occurred in violation of protocol",
"connection reset",
"connection aborted",
"connection broken",
"incompleteread",
"read operation timed out",
"timed out",
"request time-out",
"408",
"502",
"503",
"504",
)
def _remote_io_max_retries() -> int:
raw = os.environ.get("LEROBOT_REMOTE_IO_MAX_RETRIES")
if raw is None:
return DEFAULT_REMOTE_IO_MAX_RETRIES
try:
return max(0, int(raw))
except ValueError as e:
raise ValueError(f"LEROBOT_REMOTE_IO_MAX_RETRIES must be an integer; got {raw!r}") from e
def _is_retryable_transport_error(exc: BaseException) -> bool:
"""True if ``exc`` looks like a transient remote-IO failure worth retrying (vs a real bug)."""
text = f"{type(exc).__name__}: {exc}".lower()
return any(fragment in text for fragment in _RETRYABLE_TRANSPORT_FRAGMENTS)
def _recover_remote_io(decoder_cache: "VideoDecoderCache", video_path: str) -> None:
"""Drop the dead decoder for ``video_path`` and force a fresh fsspec client before a retry.
fsspec caches one filesystem instance per (protocol, args), and that instance owns the httpx
client a failed read may have closed. Clearing the instance cache makes the next ``fsspec.open``
build a new client, which is what breaks the "client has been closed" cascade.
"""
decoder_cache.invalidate(video_path)
with contextlib.suppress(Exception):
fsspec.AbstractFileSystem.clear_instance_cache()
def _retry_remote_io(operation, on_retry, max_retries: int, base_delay: float = 0.5, max_delay: float = 10.0):
"""Run ``operation()``, retrying transient transport errors after ``on_retry()`` + capped backoff.
Non-transport errors (decode / index / timestamp issues) propagate immediately so real bugs are
never masked by retries.
"""
attempt = 0
while True:
try:
return operation()
except Exception as e:
if attempt >= max_retries or not _is_retryable_transport_error(e):
raise
attempt += 1
logger.warning(
"Transient remote-IO error (%s: %s); rebuilding connection and retrying (%d/%d).",
type(e).__name__,
e,
attempt,
max_retries,
)
on_retry()
time.sleep(min(base_delay * 2 ** (attempt - 1), max_delay))
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
@@ -296,7 +383,11 @@ class VideoDecoderCache:
self.misses += 1
if self._counters is not None:
self._counters[1] += 1
file_handle = fsspec.open(video_path).__enter__()
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
# mostly-sequential reads torchcodec issues.
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device)
except Exception:
@@ -326,6 +417,18 @@ class VideoDecoderCache:
file_handle.close()
self._cache.clear()
def invalidate(self, video_path: str) -> None:
"""Drop and close the cached decoder for a path whose connection went bad.
After a transport error the cached ``fsspec`` handle (and the httpx client behind it) is dead;
removing the entry forces the next :meth:`get_decoder` to re-open a fresh handle.
"""
with self._lock:
entry = self._cache.pop(str(video_path), None)
if entry is not None:
with contextlib.suppress(Exception):
entry[1].close()
def size(self) -> int:
"""Return the number of cached decoders."""
with self._lock:
@@ -381,20 +484,24 @@ def decode_video_frames_torchcodec(
if decoder_cache is None:
decoder_cache = _default_decoder_cache
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
def _decode_frames():
# Both opening the decoder and reading frames go over the network for hf:// paths, so wrap the
# whole unit: a transient transport error retries by dropping the dead handle and rebuilding
# the connection (see _retry_remote_io / _recover_remote_io) instead of killing the worker.
decoder = decoder_cache.get_decoder(str(video_path))
average_fps = decoder.metadata.average_fps
frame_indices = [round(ts * average_fps) for ts in timestamps]
return decoder.get_frames_at(indices=frame_indices)
frames_batch = _retry_remote_io(
_decode_frames,
on_retry=lambda: _recover_remote_io(decoder_cache, str(video_path)),
max_retries=_remote_io_max_retries(),
)
loaded_ts = []
loaded_frames = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
loaded_frames.append(frame)
loaded_ts.append(pts.item())