feat(streaming): wallclock benchmark throughput, cross-worker cache stats, bucket source

- benchmark: frames_per_s_node now measures sustained wall-clock throughput over the
  post-warmup window. The previous metric summed inter-batch gaps, which collapse to ~0
  under async prefetch (consumer drains a pre-filled queue) and overstated throughput ~100x.
- VideoDecoderCache gains an optional shared [hits, misses, evictions] counter tensor;
  StreamingLeRobotDataset.video_decoder_cache_stats() aggregates it across DataLoader
  workers (lock-free, approximate; hit_rate preserved). Fixes empty cache stats with workers.
- StreamingLeRobotDataset.data_files_root: read bulk data/ + videos/ from an fsspec root
  (e.g. hf://buckets/<owner>/<name>) while metadata still loads from repo_id. Enables
  bucket / prewarmed-bucket benchmark sources without copying metadata. Exposed as
  benchmark --data_files_root.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-09 15:25:44 +02:00
parent 77af66a29c
commit f7c8a526e8
3 changed files with 85 additions and 18 deletions
+21 -6
View File
@@ -51,6 +51,13 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--repo_id", type=str, required=True)
parser.add_argument("--root", type=str, default=None, help="Local/prewarmed root (else stream from Hub).")
parser.add_argument(
"--data_files_root",
type=str,
default=None,
help="fsspec root for bulk data/videos, e.g. hf://buckets/<owner>/<name>. Metadata still loads "
"from --repo_id on the Hub. Use for bucket / warmed_bucket sources.",
)
parser.add_argument("--mode", choices=["single", "sarm"], default="single")
parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.")
parser.add_argument("--batch_size", type=int, default=64)
@@ -70,6 +77,7 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str
return StreamingLeRobotDataset(
args.repo_id,
root=args.root,
data_files_root=args.data_files_root,
delta_timestamps=delta_timestamps,
buffer_size=args.buffer_size,
video_decoder_cache_size=args.video_decoder_cache_size,
@@ -103,6 +111,7 @@ def main() -> None:
sample_latencies_ms: list[float] = []
frames = 0
first_batch_latency_s = None
steady_start = None # wall-clock start of the post-warmup measurement window
t_start = time.perf_counter()
t_prev = t_start
@@ -115,17 +124,23 @@ def main() -> None:
if first_batch_latency_s is None:
first_batch_latency_s = now - t_start
if i >= args.warmup_batches:
per_sample_ms = (now - t_prev) / args.batch_size * 1000.0
sample_latencies_ms.append(per_sample_ms)
if i == args.warmup_batches:
# Start the steady window here; the slow first batch and the prefetch queue it filled are
# excluded so throughput reflects sustained production, not draining a pre-filled queue.
steady_start = now
elif i > args.warmup_batches:
sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0)
frames += args.batch_size
t_prev = now
if i + 1 >= args.num_batches:
break
elapsed = time.perf_counter() - t_start
steady_elapsed_s = sum(sample_latencies_ms) / 1000.0
cache_stats = dataset.video_decoder_cache.stats() if dataset.video_decoder_cache is not None else {}
now = time.perf_counter()
elapsed = now - t_start
# Wall-clock throughput over the steady window. NOT sum(inter-batch gaps): under async prefetch those
# gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x.
steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed
cache_stats = dataset.video_decoder_cache_stats()
results = {
"repo_id": args.repo_id,
+53 -11
View File
@@ -261,6 +261,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
rank: int | None = None,
world_size: int | None = None,
video_decoder_cache_size: int | None = None,
data_files_root: str | None = None,
):
"""Initialize a StreamingLeRobotDataset.
@@ -290,6 +291,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain.
When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working
set of live decoders never thrashes. See :class:`VideoDecoderCache`.
data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/``
trees (e.g. an HF storage bucket ``hf://buckets/<owner>/<name>``). When set, parquet and
video frames are read from there while metadata still loads from ``repo_id`` on the Hub.
Resolves through fsspec exactly like ``hf://``; use it to benchmark bucket / prewarmed-bucket
sources without copying the (small) metadata.
"""
super().__init__()
self.repo_id = repo_id
@@ -312,9 +318,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
self.video_decoder_cache_size = video_decoder_cache_size
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
# Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into
# one place the main process can read after iteration (see video_decoder_cache_stats()).
self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_()
# Resume state captured by load_state_dict() and consumed at the next __iter__.
self._resume_state: dict | None = None
@@ -338,13 +348,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.delta_timestamps = delta_timestamps
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
self.hf_dataset: datasets.IterableDataset = load_dataset(
self.repo_id if not self.streaming_from_local else str(self.root),
split="train",
streaming=self.streaming,
data_files="data/*/*.parquet",
revision=self.revision,
)
if self.data_files_root is not None:
# Bulk data lives in an fsspec root (e.g. an HF storage bucket); metadata stays on the Hub.
self.hf_dataset: datasets.IterableDataset = load_dataset(
"parquet",
split="train",
streaming=self.streaming,
data_files=f"{self.data_files_root}/data/*/*.parquet",
)
else:
self.hf_dataset = load_dataset(
self.repo_id if not self.streaming_from_local else str(self.root),
split="train",
streaming=self.streaming,
data_files="data/*/*.parquet",
revision=self.revision,
)
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
@@ -406,11 +425,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
margin so the round-robin never evicts a still-live decoder.
"""
if self.video_decoder_cache_size is not None:
return VideoDecoderCache(max_size=self.video_decoder_cache_size)
return VideoDecoderCache(max_size=self.video_decoder_cache_size, counters=self._cache_counters)
num_cameras = len(self.meta.video_keys)
if num_cameras == 0:
return VideoDecoderCache()
return VideoDecoderCache(max_size=(num_active_shards + 1) * num_cameras)
return VideoDecoderCache(counters=self._cache_counters)
return VideoDecoderCache(
max_size=(num_active_shards + 1) * num_cameras, counters=self._cache_counters
)
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
# The current sequential iteration is a bottleneck. A producer-consumer pattern
@@ -507,6 +528,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``."""
self._resume_state = state_dict
def video_decoder_cache_stats(self) -> dict[str, int | float]:
"""Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor.
Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums
hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as
approximate; the ``hit_rate`` ratio is preserved.
"""
hits, misses, evictions = (int(x) for x in self._cache_counters.tolist())
total = hits + misses
return {
"hits": hits,
"misses": misses,
"evictions": evictions,
"hit_rate": round(hits / total, 4) if total else 0.0,
}
def _get_window_steps(
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
) -> tuple[int, int]:
@@ -679,7 +716,12 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
if self.data_files_root is not None:
root = self.data_files_root
elif self.streaming and not self.streaming_from_local:
root = self.meta.url_root
else:
root = self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
frames = decode_video_frames_torchcodec(
video_path,
+11 -1
View File
@@ -242,7 +242,7 @@ class VideoDecoderCache:
_SENTINEL: ClassVar[object] = object()
def __init__(self, max_size: int | None | object = _SENTINEL):
def __init__(self, max_size: int | None | object = _SENTINEL, counters: "torch.Tensor | None" = None):
if max_size is VideoDecoderCache._SENTINEL:
max_size = _default_max_cache_size()
if max_size is not None and max_size <= 0:
@@ -254,6 +254,10 @@ class VideoDecoderCache:
self.hits = 0
self.misses = 0
self.evictions = 0
# Optional shared [hits, misses, evictions] tensor so DataLoader workers aggregate into one place
# (the per-worker `self.*` ints are invisible to the main process). Lock-free across processes, so
# treat the aggregate as approximate; the hit-rate ratio is preserved.
self._counters = counters
def __contains__(self, video_path: object) -> bool:
with self._lock:
@@ -276,9 +280,13 @@ class VideoDecoderCache:
if entry is not None:
self._cache.move_to_end(video_path)
self.hits += 1
if self._counters is not None:
self._counters[0] += 1
return entry[0]
self.misses += 1
if self._counters is not None:
self._counters[1] += 1
file_handle = fsspec.open(video_path).__enter__()
try:
decoder = VideoDecoder(file_handle, seek_mode="approximate")
@@ -294,6 +302,8 @@ class VideoDecoderCache:
while len(self._cache) > self.max_size:
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
self.evictions += 1
if self._counters is not None:
self._counters[2] += 1
with contextlib.suppress(Exception):
evicted_handle.close()