mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-12 14:09:51 +00:00
feat(streaming): defer video decode, episode-pool shuffle, and remote-IO retries
- streaming_dataset: defer torchcodec decode until a sample leaves the shuffle buffer (buffer now holds ~KB tabular rows, not MB of pixels) and add an opt-in episode-pool shuffle (episode_pool_size) with exact in-episode delta lookups; expose decode/fetch timing_stats. - video_utils: retry transient hf:///fsspec/httpx transport errors during streaming decode (LEROBOT_REMOTE_IO_MAX_RETRIES). - dataset_tools: write multiple ~32MB row groups with a page index to bound per-shard streaming memory. - benchmarks/slurm: streaming benchmark + matrix submitter updates. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -36,7 +36,9 @@ is whatever ``--repo_id``/``--root`` point at. See the README for bucket prewarm
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -47,6 +49,60 @@ from lerobot.datasets import LeRobotDatasetMetadata, StreamingLeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def _tree_rss_bytes() -> int:
|
||||
"""Sum RSS of this process and all its descendants via /proc (Linux only; 0 elsewhere).
|
||||
|
||||
DataLoader workers are separate processes, so the parent's own RSS misses most of the pipeline's
|
||||
memory. Walking the process tree captures the real footprint (parquet buffers + decoders + shuffle).
|
||||
"""
|
||||
try:
|
||||
children: dict[int, list[int]] = {}
|
||||
for entry in os.listdir("/proc"):
|
||||
if not entry.isdigit():
|
||||
continue
|
||||
try:
|
||||
with open(f"/proc/{entry}/stat") as f:
|
||||
ppid = int(f.read().split(") ", 1)[1].split()[1])
|
||||
children.setdefault(ppid, []).append(int(entry))
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
total, stack = 0, [os.getpid()]
|
||||
while stack:
|
||||
cur = stack.pop()
|
||||
try:
|
||||
with open(f"/proc/{cur}/statm") as f:
|
||||
total += int(f.read().split()[1]) * os.sysconf("SC_PAGE_SIZE")
|
||||
except (OSError, ValueError, IndexError):
|
||||
pass
|
||||
stack.extend(children.get(cur, []))
|
||||
return total
|
||||
except OSError:
|
||||
return 0
|
||||
|
||||
|
||||
class PeakRSSSampler:
|
||||
"""Background thread tracking peak process-tree RSS for the duration of the `with` block."""
|
||||
|
||||
def __init__(self, interval_s: float = 0.5):
|
||||
self.interval_s = interval_s
|
||||
self.peak_bytes = 0
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self.peak_bytes = max(self.peak_bytes, _tree_rss_bytes())
|
||||
self._stop.wait(self.interval_s)
|
||||
|
||||
def __enter__(self) -> "PeakRSSSampler":
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._stop.set()
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--repo_id", type=str, required=True)
|
||||
@@ -62,8 +118,30 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--prefetch_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help="DataLoader batches prefetched per worker. Higher hides IO/decode latency but raises RAM "
|
||||
"(prefetch_factor x num_workers x batch_size decoded frames held in flight). Ignored if num_workers=0.",
|
||||
)
|
||||
parser.add_argument("--buffer_size", type=int, default=2000)
|
||||
parser.add_argument(
|
||||
"--max_num_shards",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Cap on concurrently-open stream shards. Each open shard holds ~one parquet row group in "
|
||||
"RAM; reading from an hf:// bucket buffers ~5x more per shard than hf:// datasets, so lower this "
|
||||
"(e.g. to num_workers) for bucket sources to avoid OOM. All data is still covered via re-sharding.",
|
||||
)
|
||||
parser.add_argument("--video_decoder_cache_size", type=int, default=None)
|
||||
parser.add_argument(
|
||||
"--episode_pool_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="A3 shuffle: keep this many full episodes live and sample frames uniformly across them "
|
||||
"(mixing radius = this many episodes). Unset = default per-shard reservoir shuffle.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_decode_device",
|
||||
type=str,
|
||||
@@ -87,8 +165,10 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str
|
||||
data_files_root=args.data_files_root,
|
||||
delta_timestamps=delta_timestamps,
|
||||
buffer_size=args.buffer_size,
|
||||
max_num_shards=args.max_num_shards,
|
||||
video_decoder_cache_size=args.video_decoder_cache_size,
|
||||
video_decode_device=args.video_decode_device,
|
||||
episode_pool_size=args.episode_pool_size,
|
||||
tolerance_s=1e-3,
|
||||
)
|
||||
|
||||
@@ -116,37 +196,43 @@ def main() -> None:
|
||||
# tensors errors). Pin only when decode is on CPU and we copy to a CUDA device.
|
||||
pin_memory=device.type == "cuda" and not gpu_decode,
|
||||
drop_last=True,
|
||||
prefetch_factor=2 if args.num_workers > 0 else None,
|
||||
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
|
||||
# CUDA cannot initialize in forked workers; NVDEC decode in workers needs the spawn start method.
|
||||
multiprocessing_context="spawn" if gpu_decode and args.num_workers > 0 else None,
|
||||
)
|
||||
|
||||
sample_latencies_ms: list[float] = []
|
||||
episodes_per_batch: list[int] = [] # shuffle-randomness proxy: distinct episodes within a batch
|
||||
frames = 0
|
||||
first_batch_latency_s = None
|
||||
steady_start = None # wall-clock start of the post-warmup measurement window
|
||||
|
||||
t_start = time.perf_counter()
|
||||
t_prev = t_start
|
||||
for i, batch in enumerate(loader):
|
||||
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
with PeakRSSSampler() as rss:
|
||||
for i, batch in enumerate(loader):
|
||||
# Dummy consume: move tensors to the device, mimicking what a real trainer would do.
|
||||
for value in batch.values():
|
||||
if torch.is_tensor(value):
|
||||
value.to(device, non_blocking=device.type == "cuda")
|
||||
now = time.perf_counter()
|
||||
if first_batch_latency_s is None:
|
||||
first_batch_latency_s = now - t_start
|
||||
|
||||
if i == args.warmup_batches:
|
||||
# Start the steady window here; the slow first batch and the prefetch queue it filled are
|
||||
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
frames += args.batch_size
|
||||
t_prev = now
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
if i == args.warmup_batches:
|
||||
# Start the steady window here; the slow first batch and the prefetch queue it filled are
|
||||
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
|
||||
steady_start = now
|
||||
elif i > args.warmup_batches:
|
||||
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
|
||||
frames += args.batch_size
|
||||
ep = batch.get("episode_index")
|
||||
if torch.is_tensor(ep):
|
||||
episodes_per_batch.append(int(torch.unique(ep).numel()))
|
||||
t_prev = now
|
||||
if i + 1 >= args.num_batches:
|
||||
break
|
||||
peak_rss_gb = round(rss.peak_bytes / 1e9, 2) if rss.peak_bytes else None
|
||||
|
||||
now = time.perf_counter()
|
||||
elapsed = now - t_start
|
||||
@@ -154,6 +240,16 @@ def main() -> None:
|
||||
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
|
||||
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
|
||||
cache_stats = dataset.video_decoder_cache_stats()
|
||||
timing = dataset.timing_stats() # cumulative decode/fetch seconds summed across workers
|
||||
# Image (camera frame) resolution as decoded, e.g. [C, H, W]. Read from the dataset feature contract.
|
||||
image_shape = (
|
||||
list(meta.features[meta.video_keys[0]]["shape"]) if meta.video_keys else None
|
||||
)
|
||||
# Decode/fetch overlap in wall-clock (workers run in parallel), so normalize against the total worker
|
||||
# budget (num_workers x wallclock) to express each stage as a fraction of available worker time.
|
||||
worker_budget_s = max(args.num_workers, 1) * elapsed
|
||||
decode_pct = round(100 * timing["decode_s_total"] / worker_budget_s, 1) if worker_budget_s else None
|
||||
fetch_pct = round(100 * timing["fetch_s_total"] / worker_budget_s, 1) if worker_budget_s else None
|
||||
|
||||
# A 0-frame run is a failure, not a 0-throughput result: the pipeline produced no batches (decode
|
||||
# error swallowed in workers, all batches dropped by drop_last, etc.). Exit non-zero so the job is
|
||||
@@ -172,11 +268,22 @@ def main() -> None:
|
||||
"mode": args.mode,
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
"prefetch_factor": args.prefetch_factor if args.num_workers > 0 else None,
|
||||
"buffer_size": args.buffer_size,
|
||||
"episode_pool_size": args.episode_pool_size,
|
||||
"episodes_per_batch_mean": round(statistics.mean(episodes_per_batch), 1)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
# Fraction of a batch that is distinct episodes; ~1.0 ≈ map-style uniform, low ≈ correlated.
|
||||
"shuffle_randomness_frac": round(statistics.mean(episodes_per_batch) / args.batch_size, 3)
|
||||
if episodes_per_batch
|
||||
else None,
|
||||
"num_cameras": len(meta.video_keys),
|
||||
"image_shape": image_shape,
|
||||
"fps": meta.fps,
|
||||
"device": str(device),
|
||||
"video_decode_device": args.video_decode_device,
|
||||
"peak_rss_gb": peak_rss_gb,
|
||||
"frames_measured": frames,
|
||||
"first_batch_latency_s": round(first_batch_latency_s or float("nan"), 4),
|
||||
"frames_per_s_node": round(frames / steady_elapsed_s, 2) if steady_elapsed_s else 0.0,
|
||||
@@ -186,13 +293,23 @@ def main() -> None:
|
||||
else None,
|
||||
"p95_sample_latency_ms": round(percentile(sample_latencies_ms, 95), 3),
|
||||
"p99_sample_latency_ms": round(percentile(sample_latencies_ms, 99), 3),
|
||||
"total_time_s": round(elapsed, 2),
|
||||
"steady_time_s": round(steady_elapsed_s, 2),
|
||||
"wallclock_s": round(elapsed, 2),
|
||||
"decode_s_total": timing["decode_s_total"],
|
||||
"fetch_s_total": timing["fetch_s_total"],
|
||||
"decode_pct_worker_time": decode_pct,
|
||||
"fetch_pct_worker_time": fetch_pct,
|
||||
"video_decoder_cache": cache_stats,
|
||||
}
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}_{args.video_decode_device}"
|
||||
pool_tag = f"_ep{args.episode_pool_size}" if args.episode_pool_size else ""
|
||||
tag = (
|
||||
f"{args.source}_{args.mode}_bs{args.batch_size}_w{args.num_workers}"
|
||||
f"_pf{args.prefetch_factor}{pool_tag}_{args.video_decode_device}"
|
||||
)
|
||||
(out_dir / f"{tag}.json").write_text(json.dumps(results, indent=2))
|
||||
flat = {k: (json.dumps(v) if isinstance(v, dict) else v) for k, v in results.items()}
|
||||
with open(out_dir / f"{tag}.csv", "w", newline="") as f:
|
||||
|
||||
@@ -34,9 +34,14 @@ GPUS=${GPUS:-1}
|
||||
SERIAL=${SERIAL:-1} # 1 = run one job at a time (correct for bandwidth measurement)
|
||||
CPU_WORKERS=${CPU_WORKERS:-8}
|
||||
GPU_WORKERS=${GPU_WORKERS:-2} # low on purpose: each cuda worker holds a CUDA context + NVDEC session
|
||||
CPU_BUFFER=${CPU_BUFFER:-4000}
|
||||
CPU_BUFFER=${CPU_BUFFER:-2000} # shuffle buffer dominates worker RAM (buffer_size x num_workers decoded frames)
|
||||
GPU_BUFFER=${GPU_BUFFER:-1000} # smaller buffer bounds on-GPU frame memory
|
||||
# Cap concurrently-open stream shards. Each open shard holds ~one parquet row group in RAM, and reading
|
||||
# from an hf:// bucket buffers ~5x more per shard than hf:// datasets (~1.2GB vs ~0.26GB). So for bucket
|
||||
# sources default to num_workers (1 shard/worker); hub keeps 16. Override globally with MAX_SHARDS.
|
||||
MAX_SHARDS=${MAX_SHARDS:-}
|
||||
BATCH_SIZE=${BATCH_SIZE:-64}
|
||||
PREFETCH=${PREFETCH:-2} # DataLoader batches prefetched per worker (higher = more throughput + RAM)
|
||||
RUN=${RUN:-python}
|
||||
# CONDA_ENV=<name> runs each job via `conda run -n <name>` (no activation needed inside the dash --wrap;
|
||||
# --no-capture-output streams logs live). Set this to a conda env that has a MODERN torchcodec (>=0.11)
|
||||
@@ -69,6 +74,7 @@ for SOURCE in $SOURCES; do
|
||||
for MODE in $MODES; do
|
||||
for DECODE in $DECODES; do
|
||||
if [ "$DECODE" = cpu ]; then W=$CPU_WORKERS; B=$CPU_BUFFER; else W=$GPU_WORKERS; B=$GPU_BUFFER; fi
|
||||
if [ -n "$MAX_SHARDS" ]; then S=$MAX_SHARDS; elif [ "$SOURCE" = hub ]; then S=16; else S=$W; fi
|
||||
# Run strictly after the previous job so only one job touches the network at a time.
|
||||
DEPFLAG=""
|
||||
if [ "$SERIAL" = 1 ] && [ -n "$prev_jid" ]; then DEPFLAG="--dependency=afterany:$prev_jid"; fi
|
||||
@@ -83,7 +89,8 @@ for SOURCE in $SOURCES; do
|
||||
$RUN benchmarks/streaming/benchmark_streaming.py \
|
||||
--repo_id $REPO_ID $ROOTFLAG \
|
||||
--mode $MODE --source $SOURCE --video_decode_device $DECODE \
|
||||
--batch_size $BATCH_SIZE --num_workers $W --buffer_size $B \
|
||||
--batch_size $BATCH_SIZE --num_workers $W --prefetch_factor $PREFETCH \
|
||||
--buffer_size $B --max_num_shards $S \
|
||||
--num_batches $NUM_BATCHES --out_dir $OUT_DIR")
|
||||
jid=${jid%%;*} # strip ';cluster' suffix on federated setups
|
||||
echo "submitted job $jid bench_${SOURCE}_${MODE}_${DECODE}${DEPFLAG:+ (after $prev_jid)}"
|
||||
@@ -96,5 +103,5 @@ done
|
||||
echo
|
||||
echo "Submitted $n jobs ($([ "$SERIAL" = 1 ] && echo 'serial chain — one runs at a time' || echo 'parallel'))."
|
||||
echo "Watch: squeue -u \$USER (later jobs show reason '(Dependency)' until their turn)"
|
||||
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_<decode>.{json,csv}"
|
||||
echo "Results: $OUT_DIR/<source>_<mode>_bs${BATCH_SIZE}_w<workers>_pf<prefetch>_<decode>.{json,csv}"
|
||||
echo "Summarize when done: $RUN benchmarks/streaming/summarize_results.py $OUT_DIR"
|
||||
|
||||
@@ -945,8 +945,17 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
writer.write_table(table)
|
||||
# Emit several row groups with a page index instead of one giant row group. A single row group forces
|
||||
# streaming readers to materialize the whole file's columns per open shard; with random-access streaming
|
||||
# (shuffle + delta windows) across many workers x shards that dominates RAM. Targeting ~32MB-uncompressed
|
||||
# groups bounds per-shard memory while keeping groups large enough to scan
|
||||
# efficiently; the page index lets readers skip to the pages they need.
|
||||
target_row_group_bytes = 32 * 1024 * 1024
|
||||
row_group_size = max(1, min(table.num_rows, table.num_rows * target_row_group_bytes // max(table.nbytes, 1)))
|
||||
writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True, write_page_index=True
|
||||
)
|
||||
writer.write_table(table, row_group_size=row_group_size)
|
||||
writer.close()
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator
|
||||
from pathlib import Path
|
||||
@@ -263,6 +264,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
video_decode_device: str = "cpu",
|
||||
episode_pool_size: int | None = None,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -326,12 +328,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
|
||||
self.video_decode_device = video_decode_device
|
||||
# A3 shuffle: when set, iterate by keeping this many full episodes live in memory and sampling
|
||||
# frames uniformly across them (mixing radius = episode_pool_size episodes), instead of the
|
||||
# default per-shard reservoir. Tabular deltas become exact in-episode index lookups (no
|
||||
# Backtrackable). Trades video-decode locality for much stronger shuffle.
|
||||
self.episode_pool_size = episode_pool_size
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
# Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into
|
||||
# one place the main process can read after iteration (see video_decoder_cache_stats()).
|
||||
self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_()
|
||||
# Shared [hits, misses, evictions, decode_ns, fetch_ns] tensor so DataLoader workers aggregate
|
||||
# decoder-cache stats and component timings into one place the main process can read after
|
||||
# iteration (see video_decoder_cache_stats() / timing_stats()).
|
||||
self._cache_counters = torch.zeros(5, dtype=torch.int64).share_memory_()
|
||||
# Resume state captured by load_state_dict() and consumed at the next __iter__.
|
||||
self._resume_state: dict | None = None
|
||||
|
||||
@@ -494,6 +502,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
shard.load_state_dict(resume["shards"][str(idx)])
|
||||
self._shards[idx] = shard
|
||||
|
||||
# A3 episode-pool shuffle (opt-in): sample frames uniformly across many fully-loaded episodes.
|
||||
if self.episode_pool_size:
|
||||
shard_iters = {
|
||||
idx: iter(self._shards[idx]) for idx in shard_indices if idx not in self._exhausted
|
||||
}
|
||||
yield from self._iter_episode_pool(shard_iters, rng)
|
||||
return
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
@@ -506,6 +522,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# the logic is to add 2 levels of randomness:
|
||||
# (1) sample one shard at random from the ones available, and
|
||||
# (2) sample one frame from the shard sampled at (1)
|
||||
# Buffer entries are (partial, video_spec): undecoded tabular rows. Video is decoded by
|
||||
# _attach_video only when a sample leaves the buffer, keeping peak memory ~prefetch-bounded.
|
||||
frames_buffer = []
|
||||
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
||||
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
||||
@@ -515,7 +533,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
for frame in self.make_frame(backtrack_dataset):
|
||||
if len(frames_buffer) == self.buffer_size:
|
||||
i = next(buffer_indices_generator) # samples a element from the buffer
|
||||
yield frames_buffer[i]
|
||||
yield self._attach_video(*frames_buffer[i]) # decode just-in-time on the way out
|
||||
frames_buffer[i] = frame
|
||||
else:
|
||||
frames_buffer.append(frame)
|
||||
@@ -527,9 +545,10 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
self._exhausted.add(shard_key)
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames (decoding each).
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
for partial, video_spec in frames_buffer:
|
||||
yield self._attach_video(partial, video_spec)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Capture resume state: per-shard HF stream position, exhausted shards, and RNG state.
|
||||
@@ -557,7 +576,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
|
||||
approximate; the ``hit_rate`` ratio is preserved.
|
||||
"""
|
||||
hits, misses, evictions = (int(x) for x in self._cache_counters.tolist())
|
||||
hits, misses, evictions = (int(x) for x in self._cache_counters[:3].tolist())
|
||||
total = hits + misses
|
||||
return {
|
||||
"hits": hits,
|
||||
@@ -566,6 +585,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"hit_rate": round(hits / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
def timing_stats(self) -> dict[str, float]:
|
||||
"""Cumulative seconds spent in video decode and parquet/sample fetch, summed across DataLoader
|
||||
workers via the shared counter tensor. These overlap in wall-clock (workers run in parallel), so
|
||||
compare them to ``num_workers x wallclock`` — not to wallclock directly — to get time fractions.
|
||||
"""
|
||||
decode_ns, fetch_ns = (int(x) for x in self._cache_counters[3:5].tolist())
|
||||
return {"decode_s_total": round(decode_ns / 1e9, 2), "fetch_s_total": round(fetch_ns / 1e9, 2)}
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
@@ -640,8 +667,17 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
return padding_mask
|
||||
|
||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
"""Build a frame's tabular content and defer the video decode.
|
||||
|
||||
Yields a ``(partial, video_spec)`` pair: ``partial`` holds all non-video fields (tabular
|
||||
features, tabular delta windows + padding, task); ``video_spec`` carries what
|
||||
:meth:`_attach_video` needs to decode the camera frames just-in-time at yield time. Deferring
|
||||
the decode keeps the shuffle reservoir holding ~KB tabular rows instead of multi-MB decoded
|
||||
images, which collapses peak memory.
|
||||
"""
|
||||
_t0 = time.perf_counter_ns()
|
||||
item = next(dataset_iterator)
|
||||
self._cache_counters[4] += time.perf_counter_ns() - _t0 # parquet/sample fetch time
|
||||
item = item_to_torch(item)
|
||||
|
||||
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
||||
@@ -673,29 +709,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
# Load video frames, when needed
|
||||
# Defer the (memory-heavy) video decode: capture only what _attach_video needs to decode the
|
||||
# camera frames at yield time, so the shuffle buffer holds ~KB tabular rows, not MB of pixels.
|
||||
video_spec = None
|
||||
if len(self.meta.video_keys) > 0:
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
|
||||
# Some timestamps might not result available considering the episode's boundaries
|
||||
# Some timestamps might not be available considering the episode's boundaries
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
updates.append(video_frames)
|
||||
|
||||
if self.delta_indices is not None:
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padding_mask = self._get_video_frame_padding_mask(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
updates.append(padding_mask)
|
||||
video_spec = (query_timestamps, original_timestamps, ep_idx)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
@@ -703,7 +726,151 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield result
|
||||
yield result, video_spec
|
||||
|
||||
def _attach_video(self, result: dict, video_spec: tuple | None) -> dict:
|
||||
"""Decode the camera frames for a buffered sample and merge them in (counterpart to make_frame).
|
||||
|
||||
This is where torchcodec decode actually runs — on one sample at a time as it leaves the shuffle
|
||||
buffer — so peak memory is bounded by the prefetch queue rather than ``buffer_size`` decoded frames.
|
||||
"""
|
||||
if video_spec is None:
|
||||
return result
|
||||
query_timestamps, original_timestamps, ep_idx = video_spec
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.meta.camera_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
result.update(video_frames)
|
||||
if self.delta_indices is not None:
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padding_mask = self._get_video_frame_padding_mask(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
result.update(padding_mask)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _ep_id(raw_item: dict) -> int:
|
||||
"""Episode index of a raw (pre-torch) HF stream row, coerced to a plain int."""
|
||||
return int(np.asarray(raw_item["episode_index"]).reshape(-1)[0])
|
||||
|
||||
def _read_one_episode(self, sid: int, shard_iters: dict, carry: dict) -> list[dict] | None:
|
||||
"""Read one full episode (contiguous rows) from a shard iterator, or None if exhausted.
|
||||
|
||||
Episodes are contiguous in the stream, so we read until ``episode_index`` changes and stash the
|
||||
first row of the next episode in ``carry`` to start the following read.
|
||||
"""
|
||||
it = shard_iters[sid]
|
||||
first = carry[sid]
|
||||
carry[sid] = None
|
||||
if first is None:
|
||||
first = next(it, None)
|
||||
if first is None:
|
||||
return None
|
||||
ep = self._ep_id(first)
|
||||
rows = [first]
|
||||
for row in it:
|
||||
if self._ep_id(row) != ep:
|
||||
carry[sid] = row # belongs to the next episode; start there next time
|
||||
break
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def _make_frame_from_episode(self, ep_rows: list[dict], p: int) -> tuple[dict, tuple | None]:
|
||||
"""Build ``(partial, video_spec)`` for frame ``p`` of a fully-loaded episode (A3).
|
||||
|
||||
All temporal neighbors live in ``ep_rows``, so tabular delta windows are exact index lookups
|
||||
with correct episode-boundary padding — no Backtrackable, no lookahead pre-read. Video is still
|
||||
decoded just-in-time by :meth:`_attach_video`.
|
||||
"""
|
||||
item = ep_rows[p]
|
||||
ep_idx = item["episode_index"]
|
||||
current_ts = float(item["timestamp"])
|
||||
length = len(ep_rows)
|
||||
|
||||
updates = []
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = {}, {}
|
||||
for key, deltas in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
frames, is_pad = [], []
|
||||
for d in deltas:
|
||||
q = p + d
|
||||
clamped = min(max(q, 0), length - 1) # out-of-episode neighbors pad to the boundary
|
||||
frames.append(ep_rows[clamped][key])
|
||||
is_pad.append(q != clamped)
|
||||
query_result[key] = torch.stack(frames)
|
||||
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
video_spec = None
|
||||
if len(self.meta.video_keys) > 0:
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
0.0,
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"]
|
||||
- self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_spec = (query_timestamps, original_timestamps, ep_idx)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
return result, video_spec
|
||||
|
||||
def _iter_episode_pool(self, shard_iters: dict, rng: np.random.Generator) -> Iterator[dict]:
|
||||
"""A3 shuffle: keep ``episode_pool_size`` full episodes live and sample frames uniformly across
|
||||
them. Each episode costs ~one sequential read (IO-cheap); the mixing radius is the pool size.
|
||||
|
||||
``tickets`` holds one (slot, frame_pos) entry per live, not-yet-emitted frame; swap-remove gives
|
||||
O(1) uniform sampling without replacement. When an episode drains it is evicted and a fresh one
|
||||
is read in, keeping the pool full.
|
||||
"""
|
||||
carry = {sid: None for sid in shard_iters}
|
||||
live = set(shard_iters)
|
||||
pool: dict[int, dict] = {} # slot -> {"rows": [...], "remaining": int}
|
||||
tickets: list[tuple[int, int]] = []
|
||||
next_slot = 0
|
||||
|
||||
def load_episode() -> bool:
|
||||
nonlocal next_slot
|
||||
while live:
|
||||
sid = int(rng.choice(tuple(live)))
|
||||
rows = self._read_one_episode(sid, shard_iters, carry)
|
||||
if rows is None:
|
||||
live.discard(sid)
|
||||
continue
|
||||
ep_rows = [item_to_torch(r) for r in rows]
|
||||
pool[next_slot] = {"rows": ep_rows, "remaining": len(ep_rows)}
|
||||
tickets.extend((next_slot, p) for p in range(len(ep_rows)))
|
||||
next_slot += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
while len(pool) < self.episode_pool_size and load_episode():
|
||||
pass
|
||||
|
||||
while tickets:
|
||||
i = int(rng.integers(len(tickets)))
|
||||
slot, p = tickets[i]
|
||||
tickets[i] = tickets[-1] # swap-remove: O(1) sampling without replacement
|
||||
tickets.pop()
|
||||
partial, video_spec = self._make_frame_from_episode(pool[slot]["rows"], p)
|
||||
yield self._attach_video(partial, video_spec)
|
||||
pool[slot]["remaining"] -= 1
|
||||
if pool[slot]["remaining"] == 0:
|
||||
del pool[slot] # free the episode's frames
|
||||
load_episode() # refill to keep the pool (and mixing radius) full
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
@@ -745,6 +912,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
else:
|
||||
root = self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
_t0 = time.perf_counter_ns()
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
@@ -752,6 +920,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
self._cache_counters[3] += time.perf_counter_ns() - _t0 # video decode time
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import asdict, dataclass, field
|
||||
@@ -47,6 +48,92 @@ from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_REMOTE_IO_MAX_RETRIES = 5
|
||||
"""Retry budget for transient hf:// / fsspec / httpx transport errors during streaming video decode.
|
||||
|
||||
Streaming a dataset from an HF bucket/CDN issues many small range requests and occasionally hits a
|
||||
transient transport failure (timeout, dropped connection, 408/5xx). The right response is to rebuild
|
||||
the connection and retry rather than crash the DataLoader worker. Override via
|
||||
``LEROBOT_REMOTE_IO_MAX_RETRIES``; set to ``0`` to disable retries (fail fast).
|
||||
"""
|
||||
|
||||
# Transient transport failures from the hf:// -> fsspec -> httpx stack. We match on text because the
|
||||
# concrete exception types live in optional deps (httpx, huggingface_hub) and vary across versions.
|
||||
# "client has been closed" is the important one: once a shared httpx client is closed by a single
|
||||
# failed read, every subsequent read in that worker fails until the fsspec instance cache is cleared.
|
||||
_RETRYABLE_TRANSPORT_FRAGMENTS = (
|
||||
"client has been closed",
|
||||
"server disconnected",
|
||||
"remoteprotocolerror",
|
||||
"unexpected_eof",
|
||||
"eof occurred in violation of protocol",
|
||||
"connection reset",
|
||||
"connection aborted",
|
||||
"connection broken",
|
||||
"incompleteread",
|
||||
"read operation timed out",
|
||||
"timed out",
|
||||
"request time-out",
|
||||
"408",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
)
|
||||
|
||||
|
||||
def _remote_io_max_retries() -> int:
|
||||
raw = os.environ.get("LEROBOT_REMOTE_IO_MAX_RETRIES")
|
||||
if raw is None:
|
||||
return DEFAULT_REMOTE_IO_MAX_RETRIES
|
||||
try:
|
||||
return max(0, int(raw))
|
||||
except ValueError as e:
|
||||
raise ValueError(f"LEROBOT_REMOTE_IO_MAX_RETRIES must be an integer; got {raw!r}") from e
|
||||
|
||||
|
||||
def _is_retryable_transport_error(exc: BaseException) -> bool:
|
||||
"""True if ``exc`` looks like a transient remote-IO failure worth retrying (vs a real bug)."""
|
||||
text = f"{type(exc).__name__}: {exc}".lower()
|
||||
return any(fragment in text for fragment in _RETRYABLE_TRANSPORT_FRAGMENTS)
|
||||
|
||||
|
||||
def _recover_remote_io(decoder_cache: "VideoDecoderCache", video_path: str) -> None:
|
||||
"""Drop the dead decoder for ``video_path`` and force a fresh fsspec client before a retry.
|
||||
|
||||
fsspec caches one filesystem instance per (protocol, args), and that instance owns the httpx
|
||||
client a failed read may have closed. Clearing the instance cache makes the next ``fsspec.open``
|
||||
build a new client, which is what breaks the "client has been closed" cascade.
|
||||
"""
|
||||
decoder_cache.invalidate(video_path)
|
||||
with contextlib.suppress(Exception):
|
||||
fsspec.AbstractFileSystem.clear_instance_cache()
|
||||
|
||||
|
||||
def _retry_remote_io(operation, on_retry, max_retries: int, base_delay: float = 0.5, max_delay: float = 10.0):
|
||||
"""Run ``operation()``, retrying transient transport errors after ``on_retry()`` + capped backoff.
|
||||
|
||||
Non-transport errors (decode / index / timestamp issues) propagate immediately so real bugs are
|
||||
never masked by retries.
|
||||
"""
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return operation()
|
||||
except Exception as e:
|
||||
if attempt >= max_retries or not _is_retryable_transport_error(e):
|
||||
raise
|
||||
attempt += 1
|
||||
logger.warning(
|
||||
"Transient remote-IO error (%s: %s); rebuilding connection and retrying (%d/%d).",
|
||||
type(e).__name__,
|
||||
e,
|
||||
attempt,
|
||||
max_retries,
|
||||
)
|
||||
on_retry()
|
||||
time.sleep(min(base_delay * 2 ** (attempt - 1), max_delay))
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
@@ -296,7 +383,11 @@ class VideoDecoderCache:
|
||||
self.misses += 1
|
||||
if self._counters is not None:
|
||||
self._counters[1] += 1
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
# Bound per-handle buffering: with many decoders kept open at once (one per camera per active
|
||||
# shard, across all workers), the default fsspec read cache balloons RAM on remote backends
|
||||
# like hf:// buckets. A small readahead cache caps each handle's footprint without hurting the
|
||||
# mostly-sequential reads torchcodec issues.
|
||||
file_handle = fsspec.open(video_path, cache_type="readahead", block_size=2**20).__enter__()
|
||||
try:
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate", device=self.device)
|
||||
except Exception:
|
||||
@@ -326,6 +417,18 @@ class VideoDecoderCache:
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def invalidate(self, video_path: str) -> None:
|
||||
"""Drop and close the cached decoder for a path whose connection went bad.
|
||||
|
||||
After a transport error the cached ``fsspec`` handle (and the httpx client behind it) is dead;
|
||||
removing the entry forces the next :meth:`get_decoder` to re-open a fresh handle.
|
||||
"""
|
||||
with self._lock:
|
||||
entry = self._cache.pop(str(video_path), None)
|
||||
if entry is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
entry[1].close()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the number of cached decoders."""
|
||||
with self._lock:
|
||||
@@ -381,20 +484,24 @@ def decode_video_frames_torchcodec(
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
def _decode_frames():
|
||||
# Both opening the decoder and reading frames go over the network for hf:// paths, so wrap the
|
||||
# whole unit: a transient transport error retries by dropping the dead handle and rebuilding
|
||||
# the connection (see _retry_remote_io / _recover_remote_io) instead of killing the worker.
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
average_fps = decoder.metadata.average_fps
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
return decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
frames_batch = _retry_remote_io(
|
||||
_decode_frames,
|
||||
on_retry=lambda: _recover_remote_io(decoder_cache, str(video_path)),
|
||||
max_retries=_remote_io_max_retries(),
|
||||
)
|
||||
|
||||
loaded_ts = []
|
||||
loaded_frames = []
|
||||
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
|
||||
Reference in New Issue
Block a user