feat(streaming): defer video decode, episode-pool shuffle, and remote-IO retries

- streaming_dataset: defer torchcodec decode until a sample leaves the shuffle
  buffer (buffer now holds ~KB tabular rows, not MB of pixels) and add an opt-in
  episode-pool shuffle (episode_pool_size) with exact in-episode delta lookups;
  expose decode/fetch timing_stats.
- video_utils: retry transient hf:///fsspec/httpx transport errors during
  streaming decode (LEROBOT_REMOTE_IO_MAX_RETRIES).
- dataset_tools: write multiple ~32MB row groups with a page index to bound
  per-shard streaming memory.
- benchmarks/slurm: streaming benchmark + matrix submitter updates.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-06-11 10:08:28 +00:00
parent 42d4788e4a
commit 2ab71231cd
5 changed files with 472 additions and 63 deletions
+137 -20
View File
@@ -36,7 +36,9 @@ is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarm
import argparse
import csv
import json
import os
import statistics
import threading
import time
from pathlib import Path
@@ -47,6 +49,60 @@ from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
from lerobot.utils.constants import ACTION
def _tree_rss_bytes() -> int:
"""Sum RSS of this process and all its descendants via /proc (Linux only; 0 elsewhere).
DataLoader workers are separate processes, so the parent's own RSS misses most of the pipeline's
memory. Walking the process tree captures the real footprint (parquet buffers + decoders + shuffle).
"""
try:
children: dict[int, list[int]] = {}
for entry in os.listdir("/proc"):
if not entry.isdigit():
continue
try:
with open(f"/proc/{entry}/stat") as f:
ppid = int(f.read().split(") ", 1)[1].split()[1])
children.setdefault(ppid, []).append(int(entry))
except (OSError, ValueError, IndexError):
pass
total, stack = 0, [os.getpid()]
while stack:
cur = stack.pop()
try:
with open(f"/proc/{cur}/statm") as f:
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
except (OSError, ValueError, IndexError):
pass
stack.extend(children.get(cur, []))
return total
except OSError:
return 0
class PeakRSSSampler:
"""Background thread tracking peak process-tree RSS for the duration of the `with` block."""
def __init__(self, interval_s: float = 0.5):
self.interval_s = interval_s
self.peak_bytes = 0
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
def _run(self) -> None:
while not self._stop.is_set():
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
self._stop.wait(self.interval_s)
def __enter__(self) -> "PeakRSSSampler":
self._thread.start()
return self
def __exit__(self, *exc) -> None:
self._stop.set()
self._thread.join(timeout=2)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--repo_id", type=str, required=True)
@@ -62,8 +118,30 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument(
"--prefetch_factor",
type=int,
default=2,
help="DataLoader batches prefetched per worker. Higher hides IO/decode latency but raises RAM "
"(prefetch_factor x num_workers x batch_size decoded frames held in flight). Ignored if num_workers=0.",
)
parser.add_argument("--buffer_size", type=int, default=2000)
parser.add_argument(
"--max_num_shards",
type=int,
default=16,
help="Cap on concurrently-open stream shards. Each open shard holds ~one parquet row group in "
"RAM; reading from an hf:// bucket buffers ~5x more per shard than hf:// datasets, so lower this "
"(e.g. to num_workers) for bucket sources to avoid OOM. All data is still covered via re-sharding.",
)
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
parser.add_argument(
"--episode_pool_size",
type=int,
default=None,
help="A3 shuffle: keep this many full episodes live and sample frames uniformly across them "
"(mixing radius = this many episodes). Unset = default per-shard reservoir shuffle.",
)
parser.add_argument(
"--video_decode_device",
type=str,
@@ -87,8 +165,10 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str
data_files_root=args.data_files_root,
delta_timestamps=delta_timestamps,
buffer_size=args.buffer_size,
max_num_shards=args.max_num_shards,
video_decoder_cache_size=args.video_decoder_cache_size,
video_decode_device=args.video_decode_device,
episode_pool_size=args.episode_pool_size,
tolerance_s=1e-3,
)
@@ -116,37 +196,43 @@ def main() -> None:
# tensors errors). Pin only when decode is on CPU and we copy to a CUDA device.
pin_memory=device.type == "cuda" and not gpu_decode,
drop_last=True,
prefetch_factor=2 if args.num_workers > 0 else None,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
# CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method.
multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None,
)
sample_latencies_ms: list[float] = []
episodes_per_batch: list[int] = [] # shuffle-randomness proxy: distinct episodes within a batch
frames = 0
first_batch_latency_s = None
steady_start = None # wall-clock start of the post-warmup measurement window
t_start = time.perf_counter()
t_prev = t_start
for i, batch in enumerate(loader):
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
for value in batch.values():
if torch.is_tensor(value):
value.to(device, non_blocking=device.type == "cuda")
now = time.perf_counter()
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
with PeakRSSSampler() as rss:
for i, batch in enumerate(loader):
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
for value in batch.values():
if torch.is_tensor(value):
value.to(device, non_blocking=device.type == "cuda")
now = time.perf_counter()
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
if i == args.warmup_batches:
# Start the steady window here; the slow first batch and the prefetch queue it filled are
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
frames += args.batch_size
t_prev = now
if i + 1 >= args.num_batches:
break
if i == args.warmup_batches:
# Start the steady window here; the slow first batch and the prefetch queue it filled are
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
frames += args.batch_size
ep = batch.get("episode_index")
if torch.is_tensor(ep):
episodes_per_batch.append(int(torch.unique(ep).numel()))
t_prev = now
if i + 1 >= args.num_batches:
break
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
now = time.perf_counter()
elapsed = now - t_start
@@ -154,6 +240,16 @@ def main() -> None:
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
cache_stats = dataset.video_decoder_cache_stats()
timing = dataset.timing_stats() # cumulative decode/fetch seconds summed across workers
# Image (camera frame) resolution as decoded, e.g. [C, H, W]. Read from the dataset feature contract.
image_shape = (
list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
)
# Decode/fetch overlap in wall-clock (workers run in parallel), so normalize against the total worker
# budget (num_workers x wallclock) to express each stage as a fraction of available worker time.
worker_budget_s = max(args.num_workers, 1) * elapsed
decode_pct = round(100 * timing["decode_s_total"] / worker_budget_s, 1) if worker_budget_s else None
fetch_pct = round(100 * timing["fetch_s_total"] / worker_budget_s, 1) if worker_budget_s else None
# A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode
# error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is
@@ -172,11 +268,22 @@ def main() -> None:
"mode": args.mode,
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"prefetch_factor": args.prefetch_factor if args.num_workers > 0 else None,
"buffer_size": args.buffer_size,
"episode_pool_size": args.episode_pool_size,
"episodes_per_batch_mean": round(statistics.mean(episodes_per_batch), 1)
if episodes_per_batch
else None,
# Fraction of a batch that is distinct episodes; ~1.0 ≈ map-style uniform, low ≈ correlated.
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
if episodes_per_batch
else None,
"num_cameras": len(meta.video_keys),
"image_shape": image_shape,
"fps": meta.fps,
"device": str(device),
"video_decode_device": args.video_decode_device,
"peak_rss_gb": peak_rss_gb,
"frames_measured": frames,
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4),
"frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
@@ -186,13 +293,23 @@ def main() -> None:
else None,
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
"total_time_s": round(elapsed, 2),
"steady_time_s": round(steady_elapsed_s, 2),
"wallclock_s": round(elapsed, 2),
"decode_s_total": timing["decode_s_total"],
"fetch_s_total": timing["fetch_s_total"],
"decode_pct_worker_time": decode_pct,
"fetch_pct_worker_time": fetch_pct,
"video_decoder_cache": cache_stats,
}
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
tag = f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}_{args.video_decode_device}"
pool_tag = f"_ep{args.episode_pool_size}" if args.episode_pool_size else ""
tag = (
f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}"
f"_pf{args.prefetch_factor}{pool_tag}_{args.video_decode_device}"
)
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()}
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
+10 -3
View File
@@ -34,9 +34,14 @@ GPUS=${GPUS:-1}
SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement)
CPU_WORKERS=${CPU_WORKERS:-8}
GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session
CPU_BUFFER=${CPU_BUFFER:-4000}
CPU_BUFFER=${CPU_BUFFER:-2000} # shuffle buffer dominates worker RAM (buffer_size x num_workers decoded frames)
GPU_BUFFER=${GPU_BUFFER:-1000} # smaller buffer bounds on-GPU frame memory
# Cap concurrently-open stream shards. Each open shard holds ~one parquet row group in RAM, and reading
# from an hf:// bucket buffers ~5x more per shard than hf:// datasets (~1.2GB vs ~0.26GB). So for bucket
# sources default to num_workers (1 shard/worker); hub keeps 16. Override globally with MAX_SHARDS.
MAX_SHARDS=${MAX_SHARDS:-}
BATCH_SIZE=${BATCH_SIZE:-64}
PREFETCH=${PREFETCH:-2} # DataLoader batches prefetched per worker (higher = more throughput + RAM)
RUN=${RUN:-python}
# CONDA_ENV=<name> runs each job via `conda run -n <name>` (no activation needed inside the dash --wrap;
# --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11)
@@ -69,6 +74,7 @@ for SOURCE in $SOURCES; do
for MODE in $MODES; do
for DECODE in $DECODES; do
if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi
if [ -n "$MAX_SHARDS" ]; then S=$MAX_SHARDS; elif [ "$SOURCE" = hub ]; then S=16; else S=$W; fi
# Run strictly after the previous job so only one job touches the network at a time.
DEPFLAG=""
if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi
@@ -83,7 +89,8 @@ for SOURCE in $SOURCES; do
$RUN benchmarks/streaming/benchmark_streaming.py \
--repo_id $REPO_ID $ROOTFLAG \
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
--batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \
--batch_size $BATCH_SIZE --num_workers $W --prefetch_factor $PREFETCH \
--buffer_size $B --max_num_shards $S \
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
@@ -96,5 +103,5 @@ done
echo
echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))."
echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)"
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_<decode>.{json,csv}"
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_pf<prefetch>_<decode>.{json,csv}"
echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR"
+11 -2
View File
@@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
ep_dataset = embed_images(ep_dataset)
table = ep_dataset.with_format("arrow")[:]
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
writer.write_table(table)
# Emit several row groups with a page index instead of one giant row group. A single row group forces
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
# groups bounds per-shard memory while keeping groups large enough to scan
# efficiently; the page index lets readers skip to the pages they need.
target_row_group_bytes = 32 * 1024 * 1024
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
writer = pq.ParquetWriter(
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
)
writer.write_table(table, row_group_size=row_group_size)
writer.close()
+196 -27
View File
@@ -16,6 +16,7 @@
import logging
import math
import os
import time
from collections import deque
from collections.abc import Callable, Generator, Iterable, Iterator
from pathlib import Path
@@ -263,6 +264,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
video_decoder_cache_size: int | None = None,
data_files_root: str | None = None,
video_decode_device: str = "cpu",
episode_pool_size: int | None = None,
):
"""Initialize a StreamingLeRobotDataset.
@@ -326,12 +328,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.video_decoder_cache_size = video_decoder_cache_size
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
self.video_decode_device = video_decode_device
# A3 shuffle: when set, iterate by keeping this many full episodes live in memory and sampling
# frames uniformly across them (mixing radius = episode_pool_size episodes), instead of the
# default per-shard reservoir. Tabular deltas become exact in-episode index lookups (no
# Backtrackable). Trades video-decode locality for much stronger shuffle.
self.episode_pool_size = episode_pool_size
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
# Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into
# one place the main process can read after iteration (see video_decoder_cache_stats()).
self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_()
# Shared [hits, misses, evictions, decode_ns, fetch_ns] tensor so DataLoader workers aggregate
# decoder-cache stats and component timings into one place the main process can read after
# iteration (see video_decoder_cache_stats() / timing_stats()).
self._cache_counters = torch.zeros(5, dtype=torch.int64).share_memory_()
# Resume state captured by load_state_dict() and consumed at the next __iter__.
self._resume_state: dict | None = None
@@ -494,6 +502,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
shard.load_state_dict(resume["shards"][str(idx)])
self._shards[idx] = shard
# A3 episode-pool shuffle (opt-in): sample frames uniformly across many fully-loaded episodes.
if self.episode_pool_size:
shard_iters = {
idx: iter(self._shards[idx]) for idx in shard_indices if idx not in self._exhausted
}
yield from self._iter_episode_pool(shard_iters, rng)
return
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
idx_to_backtrack_dataset = {
@@ -506,6 +522,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# the logic is to add 2 levels of randomness:
# (1) sample one shard at random from the ones available, and
# (2) sample one frame from the shard sampled at (1)
# Buffer entries are (partial, video_spec): undecoded tabular rows. Video is decoded by
# _attach_video only when a sample leaves the buffer, keeping peak memory ~prefetch-bounded.
frames_buffer = []
while available_shards := list(idx_to_backtrack_dataset.keys()):
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
@@ -515,7 +533,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
for frame in self.make_frame(backtrack_dataset):
if len(frames_buffer) == self.buffer_size:
i = next(buffer_indices_generator) # samples a element from the buffer
yield frames_buffer[i]
yield self._attach_video(*frames_buffer[i]) # decode just-in-time on the way out
frames_buffer[i] = frame
else:
frames_buffer.append(frame)
@@ -527,9 +545,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
self._exhausted.add(shard_key)
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames (decoding each).
rng.shuffle(frames_buffer)
yield from frames_buffer
for partial, video_spec in frames_buffer:
yield self._attach_video(partial, video_spec)
def state_dict(self) -> dict:
"""Capture resume state: per-shard HF stream position, exhausted shards, and RNG state.
@@ -557,7 +576,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
approximate; the ``hit_rate`` ratio is preserved.
"""
hits, misses, evictions = (int(x) for x in self._cache_counters.tolist())
hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist())
total = hits + misses
return {
"hits": hits,
@@ -566,6 +585,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"hit_rate": round(hits / total, 4) if total else 0.0,
}
def timing_stats(self) -> dict[str, float]:
"""Cumulative seconds spent in video decode and parquet/sample fetch, summed across DataLoader
workers via the shared counter tensor. These overlap in wall-clock (workers run in parallel), so
compare them to ``num_workers x wallclock`` not to wallclock directly to get time fractions.
"""
decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist())
return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)}
def _get_window_steps(
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
) -> tuple[int, int]:
@@ -640,8 +667,17 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return padding_mask
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
"""Makes a frame starting from a dataset iterator"""
"""Build a frame's tabular content and defer the video decode.
Yields a ``(partial, video_spec)`` pair: ``partial`` holds all non-video fields (tabular
features, tabular delta windows + padding, task); ``video_spec`` carries what
:meth:`_attach_video` needs to decode the camera frames just-in-time at yield time. Deferring
the decode keeps the shuffle reservoir holding ~KB tabular rows instead of multi-MB decoded
images, which collapses peak memory.
"""
_t0 = time.perf_counter_ns()
item = next(dataset_iterator)
self._cache_counters[4] += time.perf_counter_ns() - _t0 # parquet/sample fetch time
item = item_to_torch(item)
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
@@ -673,29 +709,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
updates.append(query_result)
updates.append(padding)
# Load video frames, when needed
# Defer the (memory-heavy) video decode: capture only what _attach_video needs to decode the
# camera frames at yield time, so the shuffle buffer holds ~KB tabular rows, not MB of pixels.
video_spec = None
if len(self.meta.video_keys) > 0:
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
# Some timestamps might not result available considering the episode's boundaries
# Some timestamps might not be available considering the episode's boundaries
query_timestamps = self._get_query_timestamps(
current_ts, self.delta_indices, episode_boundaries_ts
)
video_frames = self._query_videos(query_timestamps, ep_idx)
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
video_frames[cam] = self.image_transforms(video_frames[cam])
updates.append(video_frames)
if self.delta_indices is not None:
# We always return the same number of frames. Unavailable frames are padded.
padding_mask = self._get_video_frame_padding_mask(
video_frames, query_timestamps, original_timestamps
)
updates.append(padding_mask)
video_spec = (query_timestamps, original_timestamps, ep_idx)
result = item.copy()
for update in updates:
@@ -703,7 +726,151 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
yield result
yield result, video_spec
def _attach_video(self, result: dict, video_spec: tuple | None) -> dict:
"""Decode the camera frames for a buffered sample and merge them in (counterpart to make_frame).
This is where torchcodec decode actually runs on one sample at a time as it leaves the shuffle
buffer so peak memory is bounded by the prefetch queue rather than ``buffer_size`` decoded frames.
"""
if video_spec is None:
return result
query_timestamps, original_timestamps, ep_idx = video_spec
video_frames = self._query_videos(query_timestamps, ep_idx)
if self.image_transforms is not None:
for cam in self.meta.camera_keys:
video_frames[cam] = self.image_transforms(video_frames[cam])
result.update(video_frames)
if self.delta_indices is not None:
# We always return the same number of frames. Unavailable frames are padded.
padding_mask = self._get_video_frame_padding_mask(
video_frames, query_timestamps, original_timestamps
)
result.update(padding_mask)
return result
@staticmethod
def _ep_id(raw_item: dict) -> int:
"""Episode index of a raw (pre-torch) HF stream row, coerced to a plain int."""
return int(np.asarray(raw_item["episode_index"]).reshape(-1)[0])
def _read_one_episode(self, sid: int, shard_iters: dict, carry: dict) -> list[dict] | None:
"""Read one full episode (contiguous rows) from a shard iterator, or None if exhausted.
Episodes are contiguous in the stream, so we read until ``episode_index`` changes and stash the
first row of the next episode in ``carry`` to start the following read.
"""
it = shard_iters[sid]
first = carry[sid]
carry[sid] = None
if first is None:
first = next(it, None)
if first is None:
return None
ep = self._ep_id(first)
rows = [first]
for row in it:
if self._ep_id(row) != ep:
carry[sid] = row # belongs to the next episode; start there next time
break
rows.append(row)
return rows
def _make_frame_from_episode(self, ep_rows: list[dict], p: int) -> tuple[dict, tuple | None]:
"""Build ``(partial, video_spec)`` for frame ``p`` of a fully-loaded episode (A3).
All temporal neighbors live in ``ep_rows``, so tabular delta windows are exact index lookups
with correct episode-boundary padding no Backtrackable, no lookahead pre-read. Video is still
decoded just-in-time by :meth:`_attach_video`.
"""
item = ep_rows[p]
ep_idx = item["episode_index"]
current_ts = float(item["timestamp"])
length = len(ep_rows)
updates = []
if self.delta_indices is not None:
query_result, padding = {}, {}
for key, deltas in self.delta_indices.items():
if key in self.meta.video_keys:
continue # visual frames are decoded separately
frames, is_pad = [], []
for d in deltas:
q = p + d
clamped = min(max(q, 0), length - 1) # out-of-episode neighbors pad to the boundary
frames.append(ep_rows[clamped][key])
is_pad.append(q != clamped)
query_result[key] = torch.stack(frames)
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
updates.append(query_result)
updates.append(padding)
video_spec = None
if len(self.meta.video_keys) > 0:
episode_boundaries_ts = {
key: (
0.0,
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"]
- self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
)
for key in self.meta.video_keys
}
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
query_timestamps = self._get_query_timestamps(
current_ts, self.delta_indices, episode_boundaries_ts
)
video_spec = (query_timestamps, original_timestamps, ep_idx)
result = item.copy()
for update in updates:
result.update(update)
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
return result, video_spec
def _iter_episode_pool(self, shard_iters: dict, rng: np.random.Generator) -> Iterator[dict]:
"""A3 shuffle: keep ``episode_pool_size`` full episodes live and sample frames uniformly across
them. Each episode costs ~one sequential read (IO-cheap); the mixing radius is the pool size.
``tickets`` holds one (slot, frame_pos) entry per live, not-yet-emitted frame; swap-remove gives
O(1) uniform sampling without replacement. When an episode drains it is evicted and a fresh one
is read in, keeping the pool full.
"""
carry = {sid: None for sid in shard_iters}
live = set(shard_iters)
pool: dict[int, dict] = {} # slot -> {"rows": [...], "remaining": int}
tickets: list[tuple[int, int]] = []
next_slot = 0
def load_episode() -> bool:
nonlocal next_slot
while live:
sid = int(rng.choice(tuple(live)))
rows = self._read_one_episode(sid, shard_iters, carry)
if rows is None:
live.discard(sid)
continue
ep_rows = [item_to_torch(r) for r in rows]
pool[next_slot] = {"rows": ep_rows, "remaining": len(ep_rows)}
tickets.extend((next_slot, p) for p in range(len(ep_rows)))
next_slot += 1
return True
return False
while len(pool) < self.episode_pool_size and load_episode():
pass
while tickets:
i = int(rng.integers(len(tickets)))
slot, p = tickets[i]
tickets[i] = tickets[-1] # swap-remove: O(1) sampling without replacement
tickets.pop()
partial, video_spec = self._make_frame_from_episode(pool[slot]["rows"], p)
yield self._attach_video(partial, video_spec)
pool[slot]["remaining"] -= 1
if pool[slot]["remaining"] == 0:
del pool[slot] # free the episode's frames
load_episode() # refill to keep the pool (and mixing radius) full
def _get_query_timestamps(
self,
@@ -745,6 +912,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
else:
root = self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
_t0 = time.perf_counter_ns()
frames = decode_video_frames_torchcodec(
video_path,
shifted_query_ts,
@@ -752,6 +920,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
decoder_cache=self.video_decoder_cache,
return_uint8=self._return_uint8,
)
self._cache_counters[3] += time.perf_counter_ns() - _t0 # video decode time
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
+118 -11
View File
@@ -22,6 +22,7 @@ import queue
import shutil
import tempfile
import threading
import time
import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, field
@@ -47,6 +48,92 @@ from lerobot.utils.import_utils import get_safe_default_video_backend
logger = logging.getLogger(__name__)
DEFAULT_REMOTE_IO_MAX_RETRIES = 5
"""Retry budget for transient hf:// / fsspec / httpx transport errors during streaming video decode.
Streaming a dataset from an HF bucket/CDN issues many small range requests and occasionally hits a
transient transport failure (timeout, dropped connection, 408/5xx). The right response is to rebuild
the connection and retry rather than crash the DataLoader worker. Override via
``LEROBOT_REMOTE_IO_MAX_RETRIES``; set to ``0`` to disable retries (fail fast).
"""
# Transient transport failures from the hf:// -> fsspec -> httpx stack. We match on text because the
# concrete exception types live in optional deps (httpx, huggingface_hub) and vary across versions.
# "client has been closed" is the important one: once a shared httpx client is closed by a single
# failed read, every subsequent read in that worker fails until the fsspec instance cache is cleared.
_RETRYABLE_TRANSPORT_FRAGMENTS = (
"client has been closed",
"server disconnected",
"remoteprotocolerror",
"unexpected_eof",
"eof occurred in violation of protocol",
"connection reset",
"connection aborted",
"connection broken",
"incompleteread",
"read operation timed out",
"timed out",
"request time-out",
"408",
"502",
"503",
"504",
)
def _remote_io_max_retries() -> int:
raw = os.environ.get("LEROBOT_REMOTE_IO_MAX_RETRIES")
if raw is None:
return DEFAULT_REMOTE_IO_MAX_RETRIES
try:
return max(0, int(raw))
except ValueError as e:
raise ValueError(f"LEROBOT_REMOTE_IO_MAX_RETRIES must be an integer; got {raw!r}") from e
def _is_retryable_transport_error(exc: BaseException) -> bool:
"""True if ``exc`` looks like a transient remote-IO failure worth retrying (vs a real bug)."""
text = f"{type(exc).__name__}: {exc}".lower()
return any(fragment in text for fragment in _RETRYABLE_TRANSPORT_FRAGMENTS)
def _recover_remote_io(decoder_cache: "VideoDecoderCache", video_path: str) -> None:
"""Drop the dead decoder for ``video_path`` and force a fresh fsspec client before a retry.
fsspec caches one filesystem instance per (protocol, args), and that instance owns the httpx
client a failed read may have closed. Clearing the instance cache makes the next ``fsspec.open``
build a new client, which is what breaks the "client has been closed" cascade.
"""
decoder_cache.invalidate(video_path)
with contextlib.suppress(Exception):
fsspec.AbstractFileSystem.clear_instance_cache()
def _retry_remote_io(operation, on_retry, max_retries: int, base_delay: float = 0.5, max_delay: float = 10.0):
"""Run ``operation()``, retrying transient transport errors after ``on_retry()`` + capped backoff.
Non-transport errors (decode / index / timestamp issues) propagate immediately so real bugs are
never masked by retries.
"""
attempt = 0
while True:
try:
return operation()
except Exception as e:
if attempt >= max_retries or not _is_retryable_transport_error(e):
raise
attempt += 1
logger.warning(
"Transient remote-IO error (%s: %s); rebuilding connection and retrying (%d/%d).",
type(e).__name__,
e,
attempt,
max_retries,
)
on_retry()
time.sleep(min(base_delay * 2 ** (attempt - 1), max_delay))
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
@@ -296,7 +383,11 @@ class VideoDecoderCache:
self.misses += 1
if self._counters is not None:
self._counters[1] += 1
file_handle = fsspec.open(video_path).__enter__()
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
# mostly-sequential reads torchcodec issues.
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device)
except Exception:
@@ -326,6 +417,18 @@ class VideoDecoderCache:
file_handle.close()
self._cache.clear()
def invalidate(self, video_path: str) -> None:
"""Drop and close the cached decoder for a path whose connection went bad.
After a transport error the cached ``fsspec`` handle (and the httpx client behind it) is dead;
removing the entry forces the next :meth:`get_decoder` to re-open a fresh handle.
"""
with self._lock:
entry = self._cache.pop(str(video_path), None)
if entry is not None:
with contextlib.suppress(Exception):
entry[1].close()
def size(self) -> int:
"""Return the number of cached decoders."""
with self._lock:
@@ -381,20 +484,24 @@ def decode_video_frames_torchcodec(
if decoder_cache is None:
decoder_cache = _default_decoder_cache
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
def _decode_frames():
# Both opening the decoder and reading frames go over the network for hf:// paths, so wrap the
# whole unit: a transient transport error retries by dropping the dead handle and rebuilding
# the connection (see _retry_remote_io / _recover_remote_io) instead of killing the worker.
decoder = decoder_cache.get_decoder(str(video_path))
average_fps = decoder.metadata.average_fps
frame_indices = [round(ts * average_fps) for ts in timestamps]
return decoder.get_frames_at(indices=frame_indices)
frames_batch = _retry_remote_io(
_decode_frames,
on_retry=lambda: _recover_remote_io(decoder_cache, str(video_path)),
max_retries=_remote_io_max_retries(),
)
loaded_ts = []
loaded_frames = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
loaded_frames.append(frame)
loaded_ts.append(pts.item())