From 2ab71231cdab2a4eef9f8f563e919501a76782af Mon Sep 17 00:00:00 2001 From: pepijn Date: Thu, 11 Jun 2026 10:08:28 +0000 Subject: [PATCH] feat(streaming): defer video decode, episode-pool shuffle, and remote-IO retries - streaming_dataset: defer torchcodec decode until a sample leaves the shuffle buffer (buffer now holds ~KB tabular rows, not MB of pixels) and add an opt-in episode-pool shuffle (episode_pool_size) with exact in-episode delta lookups; expose decode/fetch timing_stats. - video_utils: retry transient hf:///fsspec/httpx transport errors during streaming decode (LEROBOT_REMOTE_IO_MAX_RETRIES). - dataset_tools: write multiple ~32MB row groups with a page index to bound per-shard streaming memory. - benchmarks/slurm: streaming benchmark + matrix submitter updates. Co-authored-by: Cursor --- benchmarks/streaming/benchmark_streaming.py | 157 ++++++++++++-- slurm/run_streaming_matrix.sh | 13 +- src/lerobot/datasets/dataset_tools.py | 13 +- src/lerobot/datasets/streaming_dataset.py | 223 +++++++++++++++++--- src/lerobot/datasets/video_utils.py | 129 ++++++++++- 5 files changed, 472 insertions(+), 63 deletions(-) diff --git a/benchmarks/streaming/benchmark_streaming.py b/benchmarks/streaming/benchmark_streaming.py index f36f9b0e1..8643f670d 100644 --- a/benchmarks/streaming/benchmark_streaming.py +++ b/benchmarks/streaming/benchmark_streaming.py @@ -36,7 +36,9 @@ is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarm import argparse import csv import json +import os import statistics +import threading import time from pathlib import Path @@ -47,6 +49,60 @@ from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset from lerobot.utils.constants import ACTION +def _tree_rss_bytes() -> int: + """Sum RSS of this process and all its descendants via /proc (Linux only; 0 elsewhere). + + DataLoader workers are separate processes, so the parent's own RSS misses most of the pipeline's + memory. Walking the process tree captures the real footprint (parquet buffers + decoders + shuffle). + """ + try: + children: dict[int, list[int]] = {} + for entry in os.listdir("/proc"): + if not entry.isdigit(): + continue + try: + with open(f"/proc/{entry}/stat") as f: + ppid = int(f.read().split(") ", 1)[1].split()[1]) + children.setdefault(ppid, []).append(int(entry)) + except (OSError, ValueError, IndexError): + pass + total, stack = 0, [os.getpid()] + while stack: + cur = stack.pop() + try: + with open(f"/proc/{cur}/statm") as f: + total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE") + except (OSError, ValueError, IndexError): + pass + stack.extend(children.get(cur, [])) + return total + except OSError: + return 0 + + +class PeakRSSSampler: + """Background thread tracking peak process-tree RSS for the duration of the `with` block.""" + + def __init__(self, interval_s: float = 0.5): + self.interval_s = interval_s + self.peak_bytes = 0 + self._stop = threading.Event() + self._thread = threading.Thread(target=self._run, daemon=True) + + def _run(self) -> None: + while not self._stop.is_set(): + self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes()) + self._stop.wait(self.interval_s) + + def __enter__(self) -> "PeakRSSSampler": + self._thread.start() + return self + + def __exit__(self, *exc) -> None: + self._stop.set() + self._thread.join(timeout=2) + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--repo_id", type=str, required=True) @@ -62,8 +118,30 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument( + "--prefetch_factor", + type=int, + default=2, + help="DataLoader batches prefetched per worker. Higher hides IO/decode latency but raises RAM " + "(prefetch_factor x num_workers x batch_size decoded frames held in flight). Ignored if num_workers=0.", + ) parser.add_argument("--buffer_size", type=int, default=2000) + parser.add_argument( + "--max_num_shards", + type=int, + default=16, + help="Cap on concurrently-open stream shards. Each open shard holds ~one parquet row group in " + "RAM; reading from an hf:// bucket buffers ~5x more per shard than hf:// datasets, so lower this " + "(e.g. to num_workers) for bucket sources to avoid OOM. All data is still covered via re-sharding.", + ) parser.add_argument("--video_decoder_cache_size", type=int, default=None) + parser.add_argument( + "--episode_pool_size", + type=int, + default=None, + help="A3 shuffle: keep this many full episodes live and sample frames uniformly across them " + "(mixing radius = this many episodes). Unset = default per-shard reservoir shuffle.", + ) parser.add_argument( "--video_decode_device", type=str, @@ -87,8 +165,10 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str data_files_root=args.data_files_root, delta_timestamps=delta_timestamps, buffer_size=args.buffer_size, + max_num_shards=args.max_num_shards, video_decoder_cache_size=args.video_decoder_cache_size, video_decode_device=args.video_decode_device, + episode_pool_size=args.episode_pool_size, tolerance_s=1e-3, ) @@ -116,37 +196,43 @@ def main() -> None: # tensors errors). Pin only when decode is on CPU and we copy to a CUDA device. pin_memory=device.type == "cuda" and not gpu_decode, drop_last=True, - prefetch_factor=2 if args.num_workers > 0 else None, + prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None, # CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method. multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None, ) sample_latencies_ms: list[float] = [] + episodes_per_batch: list[int] = [] # shuffle-randomness proxy: distinct episodes within a batch frames = 0 first_batch_latency_s = None steady_start = None # wall-clock start of the post-warmup measurement window t_start = time.perf_counter() t_prev = t_start - for i, batch in enumerate(loader): - # Dummy consume: move tensors to the device, mimicking what a real trainer would do. - for value in batch.values(): - if torch.is_tensor(value): - value.to(device, non_blocking=device.type == "cuda") - now = time.perf_counter() - if first_batch_latency_s is None: - first_batch_latency_s = now - t_start + with PeakRSSSampler() as rss: + for i, batch in enumerate(loader): + # Dummy consume: move tensors to the device, mimicking what a real trainer would do. + for value in batch.values(): + if torch.is_tensor(value): + value.to(device, non_blocking=device.type == "cuda") + now = time.perf_counter() + if first_batch_latency_s is None: + first_batch_latency_s = now - t_start - if i == args.warmup_batches: - # Start the steady window here; the slow first batch and the prefetch queue it filled are - # excluded so throughput reflects sustained production, not draining a pre-filled queue. - steady_start = now - elif i > args.warmup_batches: - sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0) - frames += args.batch_size - t_prev = now - if i + 1 >= args.num_batches: - break + if i == args.warmup_batches: + # Start the steady window here; the slow first batch and the prefetch queue it filled are + # excluded so throughput reflects sustained production, not draining a pre-filled queue. + steady_start = now + elif i > args.warmup_batches: + sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0) + frames += args.batch_size + ep = batch.get("episode_index") + if torch.is_tensor(ep): + episodes_per_batch.append(int(torch.unique(ep).numel())) + t_prev = now + if i + 1 >= args.num_batches: + break + peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None now = time.perf_counter() elapsed = now - t_start @@ -154,6 +240,16 @@ def main() -> None: # gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x. steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed cache_stats = dataset.video_decoder_cache_stats() + timing = dataset.timing_stats() # cumulative decode/fetch seconds summed across workers + # Image (camera frame) resolution as decoded, e.g. [C, H, W]. Read from the dataset feature contract. + image_shape = ( + list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None + ) + # Decode/fetch overlap in wall-clock (workers run in parallel), so normalize against the total worker + # budget (num_workers x wallclock) to express each stage as a fraction of available worker time. + worker_budget_s = max(args.num_workers, 1) * elapsed + decode_pct = round(100 * timing["decode_s_total"] / worker_budget_s, 1) if worker_budget_s else None + fetch_pct = round(100 * timing["fetch_s_total"] / worker_budget_s, 1) if worker_budget_s else None # A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode # error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is @@ -172,11 +268,22 @@ def main() -> None: "mode": args.mode, "batch_size": args.batch_size, "num_workers": args.num_workers, + "prefetch_factor": args.prefetch_factor if args.num_workers > 0 else None, "buffer_size": args.buffer_size, + "episode_pool_size": args.episode_pool_size, + "episodes_per_batch_mean": round(statistics.mean(episodes_per_batch), 1) + if episodes_per_batch + else None, + # Fraction of a batch that is distinct episodes; ~1.0 ≈ map-style uniform, low ≈ correlated. + "shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3) + if episodes_per_batch + else None, "num_cameras": len(meta.video_keys), + "image_shape": image_shape, "fps": meta.fps, "device": str(device), "video_decode_device": args.video_decode_device, + "peak_rss_gb": peak_rss_gb, "frames_measured": frames, "first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4), "frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0, @@ -186,13 +293,23 @@ def main() -> None: else None, "p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3), "p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3), + "total_time_s": round(elapsed, 2), + "steady_time_s": round(steady_elapsed_s, 2), "wallclock_s": round(elapsed, 2), + "decode_s_total": timing["decode_s_total"], + "fetch_s_total": timing["fetch_s_total"], + "decode_pct_worker_time": decode_pct, + "fetch_pct_worker_time": fetch_pct, "video_decoder_cache": cache_stats, } out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) - tag = f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}_{args.video_decode_device}" + pool_tag = f"_ep{args.episode_pool_size}" if args.episode_pool_size else "" + tag = ( + f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}" + f"_pf{args.prefetch_factor}{pool_tag}_{args.video_decode_device}" + ) (out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2)) flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()} with open(out_dir / f"{tag}.csv", "w", newline="") as f: diff --git a/slurm/run_streaming_matrix.sh b/slurm/run_streaming_matrix.sh index a33e181fc..98319eeed 100755 --- a/slurm/run_streaming_matrix.sh +++ b/slurm/run_streaming_matrix.sh @@ -34,9 +34,14 @@ GPUS=${GPUS:-1} SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement) CPU_WORKERS=${CPU_WORKERS:-8} GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session -CPU_BUFFER=${CPU_BUFFER:-4000} +CPU_BUFFER=${CPU_BUFFER:-2000} # shuffle buffer dominates worker RAM (buffer_size x num_workers decoded frames) GPU_BUFFER=${GPU_BUFFER:-1000} # smaller buffer bounds on-GPU frame memory +# Cap concurrently-open stream shards. Each open shard holds ~one parquet row group in RAM, and reading +# from an hf:// bucket buffers ~5x more per shard than hf:// datasets (~1.2GB vs ~0.26GB). So for bucket +# sources default to num_workers (1 shard/worker); hub keeps 16. Override globally with MAX_SHARDS. +MAX_SHARDS=${MAX_SHARDS:-} BATCH_SIZE=${BATCH_SIZE:-64} +PREFETCH=${PREFETCH:-2} # DataLoader batches prefetched per worker (higher = more throughput + RAM) RUN=${RUN:-python} # CONDA_ENV= runs each job via `conda run -n ` (no activation needed inside the dash --wrap; # --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11) @@ -69,6 +74,7 @@ for SOURCE in $SOURCES; do for MODE in $MODES; do for DECODE in $DECODES; do if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi + if [ -n "$MAX_SHARDS" ]; then S=$MAX_SHARDS; elif [ "$SOURCE" = hub ]; then S=16; else S=$W; fi # Run strictly after the previous job so only one job touches the network at a time. DEPFLAG="" if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi @@ -83,7 +89,8 @@ for SOURCE in $SOURCES; do $RUN benchmarks/streaming/benchmark_streaming.py \ --repo_id $REPO_ID $ROOTFLAG \ --mode $MODE --source $SOURCE --video_decode_device $DECODE \ - --batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \ + --batch_size $BATCH_SIZE --num_workers $W --prefetch_factor $PREFETCH \ + --buffer_size $B --max_num_shards $S \ --num_batches $NUM_BATCHES --out_dir $OUT_DIR") jid=${jid%%;*} # strip ';cluster' suffix on federated setups echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}" @@ -96,5 +103,5 @@ done echo echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))." echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)" -echo "Results: $OUT_DIR/__bs${BATCH_SIZE}_w_.{json,csv}" +echo "Results: $OUT_DIR/__bs${BATCH_SIZE}_w_pf_.{json,csv}" echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR" diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index adbb841c4..d3d3c1716 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) - ep_dataset = embed_images(ep_dataset) table = ep_dataset.with_format("arrow")[:] - writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True) - writer.write_table(table) + # Emit several row groups with a page index instead of one giant row group. A single row group forces + # streaming readers to materialize the whole file's columns per open shard; with random-access streaming + # (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed + # groups bounds per-shard memory while keeping groups large enough to scan + # efficiently; the page index lets readers skip to the pages they need. + target_row_group_bytes = 32 * 1024 * 1024 + row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1))) + writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True + ) + writer.write_table(table, row_group_size=row_group_size) writer.close() diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 7cf61a6ed..14a2a427b 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -16,6 +16,7 @@ import logging import math import os +import time from collections import deque from collections.abc import Callable, Generator, Iterable, Iterator from pathlib import Path @@ -263,6 +264,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): video_decoder_cache_size: int | None = None, data_files_root: str | None = None, video_decode_device: str = "cpu", + episode_pool_size: int | None = None, ): """Initialize a StreamingLeRobotDataset. @@ -326,12 +328,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.video_decoder_cache_size = video_decoder_cache_size self.data_files_root = data_files_root.rstrip("/") if data_files_root else None self.video_decode_device = video_decode_device + # A3 shuffle: when set, iterate by keeping this many full episodes live in memory and sampling + # frames uniformly across them (mixing radius = episode_pool_size episodes), instead of the + # default per-shard reservoir. Tabular deltas become exact in-episode index lookups (no + # Backtrackable). Trades video-decode locality for much stronger shuffle. + self.episode_pool_size = episode_pool_size # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None - # Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into - # one place the main process can read after iteration (see video_decoder_cache_stats()). - self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_() + # Shared [hits, misses, evictions, decode_ns, fetch_ns] tensor so DataLoader workers aggregate + # decoder-cache stats and component timings into one place the main process can read after + # iteration (see video_decoder_cache_stats() / timing_stats()). + self._cache_counters = torch.zeros(5, dtype=torch.int64).share_memory_() # Resume state captured by load_state_dict() and consumed at the next __iter__. self._resume_state: dict | None = None @@ -494,6 +502,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): shard.load_state_dict(resume["shards"][str(idx)]) self._shards[idx] = shard + # A3 episode-pool shuffle (opt-in): sample frames uniformly across many fully-loaded episodes. + if self.episode_pool_size: + shard_iters = { + idx: iter(self._shards[idx]) for idx in shard_indices if idx not in self._exhausted + } + yield from self._iter_episode_pool(shard_iters, rng) + return + buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size) idx_to_backtrack_dataset = { @@ -506,6 +522,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # the logic is to add 2 levels of randomness: # (1) sample one shard at random from the ones available, and # (2) sample one frame from the shard sampled at (1) + # Buffer entries are (partial, video_spec): undecoded tabular rows. Video is decoded by + # _attach_video only when a sample leaves the buffer, keeping peak memory ~prefetch-bounded. frames_buffer = [] while available_shards := list(idx_to_backtrack_dataset.keys()): shard_key = next(self._infinite_generator_over_elements(rng, available_shards)) @@ -515,7 +533,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): for frame in self.make_frame(backtrack_dataset): if len(frames_buffer) == self.buffer_size: i = next(buffer_indices_generator) # samples a element from the buffer - yield frames_buffer[i] + yield self._attach_video(*frames_buffer[i]) # decode just-in-time on the way out frames_buffer[i] = frame else: frames_buffer.append(frame) @@ -527,9 +545,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard self._exhausted.add(shard_key) - # Once shards are all exhausted, shuffle the buffer and yield the remaining frames + # Once shards are all exhausted, shuffle the buffer and yield the remaining frames (decoding each). rng.shuffle(frames_buffer) - yield from frames_buffer + for partial, video_spec in frames_buffer: + yield self._attach_video(partial, video_spec) def state_dict(self) -> dict: """Capture resume state: per-shard HF stream position, exhausted shards, and RNG state. @@ -557,7 +576,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as approximate; the ``hit_rate`` ratio is preserved. """ - hits, misses, evictions = (int(x) for x in self._cache_counters.tolist()) + hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist()) total = hits + misses return { "hits": hits, @@ -566,6 +585,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): "hit_rate": round(hits / total, 4) if total else 0.0, } + def timing_stats(self) -> dict[str, float]: + """Cumulative seconds spent in video decode and parquet/sample fetch, summed across DataLoader + workers via the shared counter tensor. These overlap in wall-clock (workers run in parallel), so + compare them to ``num_workers x wallclock`` — not to wallclock directly — to get time fractions. + """ + decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist()) + return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)} + def _get_window_steps( self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False ) -> tuple[int, int]: @@ -640,8 +667,17 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return padding_mask def make_frame(self, dataset_iterator: Backtrackable) -> Generator: - """Makes a frame starting from a dataset iterator""" + """Build a frame's tabular content and defer the video decode. + + Yields a ``(partial, video_spec)`` pair: ``partial`` holds all non-video fields (tabular + features, tabular delta windows + padding, task); ``video_spec`` carries what + :meth:`_attach_video` needs to decode the camera frames just-in-time at yield time. Deferring + the decode keeps the shuffle reservoir holding ~KB tabular rows instead of multi-MB decoded + images, which collapses peak memory. + """ + _t0 = time.perf_counter_ns() item = next(dataset_iterator) + self._cache_counters[4] += time.perf_counter_ns() - _t0 # parquet/sample fetch time item = item_to_torch(item) updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features) @@ -673,29 +709,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): updates.append(query_result) updates.append(padding) - # Load video frames, when needed + # Defer the (memory-heavy) video decode: capture only what _attach_video needs to decode the + # camera frames at yield time, so the shuffle buffer holds ~KB tabular rows, not MB of pixels. + video_spec = None if len(self.meta.video_keys) > 0: original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices) - - # Some timestamps might not result available considering the episode's boundaries + # Some timestamps might not be available considering the episode's boundaries query_timestamps = self._get_query_timestamps( current_ts, self.delta_indices, episode_boundaries_ts ) - video_frames = self._query_videos(query_timestamps, ep_idx) - - if self.image_transforms is not None: - image_keys = self.meta.camera_keys - for cam in image_keys: - video_frames[cam] = self.image_transforms(video_frames[cam]) - - updates.append(video_frames) - - if self.delta_indices is not None: - # We always return the same number of frames. Unavailable frames are padded. - padding_mask = self._get_video_frame_padding_mask( - video_frames, query_timestamps, original_timestamps - ) - updates.append(padding_mask) + video_spec = (query_timestamps, original_timestamps, ep_idx) result = item.copy() for update in updates: @@ -703,7 +726,151 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): result["task"] = self.meta.tasks.iloc[item["task_index"]].name - yield result + yield result, video_spec + + def _attach_video(self, result: dict, video_spec: tuple | None) -> dict: + """Decode the camera frames for a buffered sample and merge them in (counterpart to make_frame). + + This is where torchcodec decode actually runs — on one sample at a time as it leaves the shuffle + buffer — so peak memory is bounded by the prefetch queue rather than ``buffer_size`` decoded frames. + """ + if video_spec is None: + return result + query_timestamps, original_timestamps, ep_idx = video_spec + video_frames = self._query_videos(query_timestamps, ep_idx) + if self.image_transforms is not None: + for cam in self.meta.camera_keys: + video_frames[cam] = self.image_transforms(video_frames[cam]) + result.update(video_frames) + if self.delta_indices is not None: + # We always return the same number of frames. Unavailable frames are padded. + padding_mask = self._get_video_frame_padding_mask( + video_frames, query_timestamps, original_timestamps + ) + result.update(padding_mask) + return result + + @staticmethod + def _ep_id(raw_item: dict) -> int: + """Episode index of a raw (pre-torch) HF stream row, coerced to a plain int.""" + return int(np.asarray(raw_item["episode_index"]).reshape(-1)[0]) + + def _read_one_episode(self, sid: int, shard_iters: dict, carry: dict) -> list[dict] | None: + """Read one full episode (contiguous rows) from a shard iterator, or None if exhausted. + + Episodes are contiguous in the stream, so we read until ``episode_index`` changes and stash the + first row of the next episode in ``carry`` to start the following read. + """ + it = shard_iters[sid] + first = carry[sid] + carry[sid] = None + if first is None: + first = next(it, None) + if first is None: + return None + ep = self._ep_id(first) + rows = [first] + for row in it: + if self._ep_id(row) != ep: + carry[sid] = row # belongs to the next episode; start there next time + break + rows.append(row) + return rows + + def _make_frame_from_episode(self, ep_rows: list[dict], p: int) -> tuple[dict, tuple | None]: + """Build ``(partial, video_spec)`` for frame ``p`` of a fully-loaded episode (A3). + + All temporal neighbors live in ``ep_rows``, so tabular delta windows are exact index lookups + with correct episode-boundary padding — no Backtrackable, no lookahead pre-read. Video is still + decoded just-in-time by :meth:`_attach_video`. + """ + item = ep_rows[p] + ep_idx = item["episode_index"] + current_ts = float(item["timestamp"]) + length = len(ep_rows) + + updates = [] + if self.delta_indices is not None: + query_result, padding = {}, {} + for key, deltas in self.delta_indices.items(): + if key in self.meta.video_keys: + continue # visual frames are decoded separately + frames, is_pad = [], [] + for d in deltas: + q = p + d + clamped = min(max(q, 0), length - 1) # out-of-episode neighbors pad to the boundary + frames.append(ep_rows[clamped][key]) + is_pad.append(q != clamped) + query_result[key] = torch.stack(frames) + padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad) + updates.append(query_result) + updates.append(padding) + + video_spec = None + if len(self.meta.video_keys) > 0: + episode_boundaries_ts = { + key: ( + 0.0, + self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"] + - self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"], + ) + for key in self.meta.video_keys + } + original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices) + query_timestamps = self._get_query_timestamps( + current_ts, self.delta_indices, episode_boundaries_ts + ) + video_spec = (query_timestamps, original_timestamps, ep_idx) + + result = item.copy() + for update in updates: + result.update(update) + result["task"] = self.meta.tasks.iloc[item["task_index"]].name + return result, video_spec + + def _iter_episode_pool(self, shard_iters: dict, rng: np.random.Generator) -> Iterator[dict]: + """A3 shuffle: keep ``episode_pool_size`` full episodes live and sample frames uniformly across + them. Each episode costs ~one sequential read (IO-cheap); the mixing radius is the pool size. + + ``tickets`` holds one (slot, frame_pos) entry per live, not-yet-emitted frame; swap-remove gives + O(1) uniform sampling without replacement. When an episode drains it is evicted and a fresh one + is read in, keeping the pool full. + """ + carry = {sid: None for sid in shard_iters} + live = set(shard_iters) + pool: dict[int, dict] = {} # slot -> {"rows": [...], "remaining": int} + tickets: list[tuple[int, int]] = [] + next_slot = 0 + + def load_episode() -> bool: + nonlocal next_slot + while live: + sid = int(rng.choice(tuple(live))) + rows = self._read_one_episode(sid, shard_iters, carry) + if rows is None: + live.discard(sid) + continue + ep_rows = [item_to_torch(r) for r in rows] + pool[next_slot] = {"rows": ep_rows, "remaining": len(ep_rows)} + tickets.extend((next_slot, p) for p in range(len(ep_rows))) + next_slot += 1 + return True + return False + + while len(pool) < self.episode_pool_size and load_episode(): + pass + + while tickets: + i = int(rng.integers(len(tickets))) + slot, p = tickets[i] + tickets[i] = tickets[-1] # swap-remove: O(1) sampling without replacement + tickets.pop() + partial, video_spec = self._make_frame_from_episode(pool[slot]["rows"], p) + yield self._attach_video(partial, video_spec) + pool[slot]["remaining"] -= 1 + if pool[slot]["remaining"] == 0: + del pool[slot] # free the episode's frames + load_episode() # refill to keep the pool (and mixing radius) full def _get_query_timestamps( self, @@ -745,6 +912,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): else: root = self.root video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" + _t0 = time.perf_counter_ns() frames = decode_video_frames_torchcodec( video_path, shifted_query_ts, @@ -752,6 +920,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): decoder_cache=self.video_decoder_cache, return_uint8=self._return_uint8, ) + self._cache_counters[3] += time.perf_counter_ns() - _t0 # video decode time item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index c0b8ebd1d..ee801e052 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -22,6 +22,7 @@ import queue import shutil import tempfile import threading +import time import warnings from collections import OrderedDict from dataclasses import asdict, dataclass, field @@ -47,6 +48,92 @@ from lerobot.utils.import_utils import get_safe_default_video_backend logger = logging.getLogger(__name__) +DEFAULT_REMOTE_IO_MAX_RETRIES = 5 +"""Retry budget for transient hf:// / fsspec / httpx transport errors during streaming video decode. + +Streaming a dataset from an HF bucket/CDN issues many small range requests and occasionally hits a +transient transport failure (timeout, dropped connection, 408/5xx). The right response is to rebuild +the connection and retry rather than crash the DataLoader worker. Override via +``LEROBOT_REMOTE_IO_MAX_RETRIES``; set to ``0`` to disable retries (fail fast). +""" + +# Transient transport failures from the hf:// -> fsspec -> httpx stack. We match on text because the +# concrete exception types live in optional deps (httpx, huggingface_hub) and vary across versions. +# "client has been closed" is the important one: once a shared httpx client is closed by a single +# failed read, every subsequent read in that worker fails until the fsspec instance cache is cleared. +_RETRYABLE_TRANSPORT_FRAGMENTS = ( + "client has been closed", + "server disconnected", + "remoteprotocolerror", + "unexpected_eof", + "eof occurred in violation of protocol", + "connection reset", + "connection aborted", + "connection broken", + "incompleteread", + "read operation timed out", + "timed out", + "request time-out", + "408", + "502", + "503", + "504", +) + + +def _remote_io_max_retries() -> int: + raw = os.environ.get("LEROBOT_REMOTE_IO_MAX_RETRIES") + if raw is None: + return DEFAULT_REMOTE_IO_MAX_RETRIES + try: + return max(0, int(raw)) + except ValueError as e: + raise ValueError(f"LEROBOT_REMOTE_IO_MAX_RETRIES must be an integer; got {raw!r}") from e + + +def _is_retryable_transport_error(exc: BaseException) -> bool: + """True if ``exc`` looks like a transient remote-IO failure worth retrying (vs a real bug).""" + text = f"{type(exc).__name__}: {exc}".lower() + return any(fragment in text for fragment in _RETRYABLE_TRANSPORT_FRAGMENTS) + + +def _recover_remote_io(decoder_cache: "VideoDecoderCache", video_path: str) -> None: + """Drop the dead decoder for ``video_path`` and force a fresh fsspec client before a retry. + + fsspec caches one filesystem instance per (protocol, args), and that instance owns the httpx + client a failed read may have closed. Clearing the instance cache makes the next ``fsspec.open`` + build a new client, which is what breaks the "client has been closed" cascade. + """ + decoder_cache.invalidate(video_path) + with contextlib.suppress(Exception): + fsspec.AbstractFileSystem.clear_instance_cache() + + +def _retry_remote_io(operation, on_retry, max_retries: int, base_delay: float = 0.5, max_delay: float = 10.0): + """Run ``operation()``, retrying transient transport errors after ``on_retry()`` + capped backoff. + + Non-transport errors (decode / index / timestamp issues) propagate immediately so real bugs are + never masked by retries. + """ + attempt = 0 + while True: + try: + return operation() + except Exception as e: + if attempt >= max_retries or not _is_retryable_transport_error(e): + raise + attempt += 1 + logger.warning( + "Transient remote-IO error (%s: %s); rebuilding connection and retrying (%d/%d).", + type(e).__name__, + e, + attempt, + max_retries, + ) + on_retry() + time.sleep(min(base_delay * 2 ** (attempt - 1), max_delay)) + + def decode_video_frames( video_path: Path | str, timestamps: list[float], @@ -296,7 +383,11 @@ class VideoDecoderCache: self.misses += 1 if self._counters is not None: self._counters[1] += 1 - file_handle = fsspec.open(video_path).__enter__() + # Bound per-handle buffering: with many decoders kept open at once (one per camera per active + # shard, across all workers), the default fsspec read cache balloons RAM on remote backends + # like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the + # mostly-sequential reads torchcodec issues. + file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__() try: decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device) except Exception: @@ -326,6 +417,18 @@ class VideoDecoderCache: file_handle.close() self._cache.clear() + def invalidate(self, video_path: str) -> None: + """Drop and close the cached decoder for a path whose connection went bad. + + After a transport error the cached ``fsspec`` handle (and the httpx client behind it) is dead; + removing the entry forces the next :meth:`get_decoder` to re-open a fresh handle. + """ + with self._lock: + entry = self._cache.pop(str(video_path), None) + if entry is not None: + with contextlib.suppress(Exception): + entry[1].close() + def size(self) -> int: """Return the number of cached decoders.""" with self._lock: @@ -381,20 +484,24 @@ def decode_video_frames_torchcodec( if decoder_cache is None: decoder_cache = _default_decoder_cache - # Use cached decoder instead of creating new one each time - decoder = decoder_cache.get_decoder(str(video_path)) + def _decode_frames(): + # Both opening the decoder and reading frames go over the network for hf:// paths, so wrap the + # whole unit: a transient transport error retries by dropping the dead handle and rebuilding + # the connection (see _retry_remote_io / _recover_remote_io) instead of killing the worker. + decoder = decoder_cache.get_decoder(str(video_path)) + average_fps = decoder.metadata.average_fps + frame_indices = [round(ts * average_fps) for ts in timestamps] + return decoder.get_frames_at(indices=frame_indices) + + frames_batch = _retry_remote_io( + _decode_frames, + on_retry=lambda: _recover_remote_io(decoder_cache, str(video_path)), + max_retries=_remote_io_max_retries(), + ) loaded_ts = [] loaded_frames = [] - # get metadata for frame information - metadata = decoder.metadata - average_fps = metadata.average_fps - # convert timestamps to frame indices - frame_indices = [round(ts * average_fps) for ts in timestamps] - # retrieve frames based on indices - frames_batch = decoder.get_frames_at(indices=frame_indices) - for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True): loaded_frames.append(frame) loaded_ts.append(pts.item())