diff --git a/benchmarks/streaming/benchmark_streaming.py b/benchmarks/streaming/benchmark_streaming.py index 187b1770e..e0f981088 100644 --- a/benchmarks/streaming/benchmark_streaming.py +++ b/benchmarks/streaming/benchmark_streaming.py @@ -51,6 +51,13 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--repo_id", type=str, required=True) parser.add_argument("--root", type=str, default=None, help="Local/prewarmed root (else stream from Hub).") + parser.add_argument( + "--data_files_root", + type=str, + default=None, + help="fsspec root for bulk data/videos, e.g. hf://buckets//. Metadata still loads " + "from --repo_id on the Hub. Use for bucket / warmed_bucket sources.", + ) parser.add_argument("--mode", choices=["single", "sarm"], default="single") parser.add_argument("--source", type=str, default="hub", help="Label only: hub | bucket | warmed_bucket.") parser.add_argument("--batch_size", type=int, default=64) @@ -70,6 +77,7 @@ def build_dataset(args: argparse.Namespace, meta: LeRobotDatasetMetadata) -> Str return StreamingLeRobotDataset( args.repo_id, root=args.root, + data_files_root=args.data_files_root, delta_timestamps=delta_timestamps, buffer_size=args.buffer_size, video_decoder_cache_size=args.video_decoder_cache_size, @@ -103,6 +111,7 @@ def main() -> None: sample_latencies_ms: list[float] = [] frames = 0 first_batch_latency_s = None + steady_start = None # wall-clock start of the post-warmup measurement window t_start = time.perf_counter() t_prev = t_start @@ -115,17 +124,23 @@ def main() -> None: if first_batch_latency_s is None: first_batch_latency_s = now - t_start - if i >= args.warmup_batches: - per_sample_ms = (now - t_prev) / args.batch_size * 1000.0 - sample_latencies_ms.append(per_sample_ms) + if i == args.warmup_batches: + # Start the steady window here; the slow first batch and the prefetch queue it filled are + # excluded so throughput reflects sustained production, not draining a pre-filled queue. + steady_start = now + elif i > args.warmup_batches: + sample_latencies_ms.append((now - t_prev) / args.batch_size * 1000.0) frames += args.batch_size t_prev = now if i + 1 >= args.num_batches: break - elapsed = time.perf_counter() - t_start - steady_elapsed_s = sum(sample_latencies_ms) / 1000.0 - cache_stats = dataset.video_decoder_cache.stats() if dataset.video_decoder_cache is not None else {} + now = time.perf_counter() + elapsed = now - t_start + # Wall-clock throughput over the steady window. NOT sum(inter-batch gaps): under async prefetch those + # gaps collapse to ~0 (the consumer drains a pre-filled queue) and overstate throughput by ~100x. + steady_elapsed_s = (now - steady_start) if steady_start is not None else elapsed + cache_stats = dataset.video_decoder_cache_stats() results = { "repo_id": args.repo_id, diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 1746c9a4c..9a368a957 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -261,6 +261,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): rank: int | None = None, world_size: int | None = None, video_decoder_cache_size: int | None = None, + data_files_root: str | None = None, ): """Initialize a StreamingLeRobotDataset. @@ -290,6 +291,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): video_decoder_cache_size (int | None, optional): Max number of open video decoders to retain. When omitted, it defaults to ``(concurrent active shards + 1) × num_cameras`` so the working set of live decoders never thrashes. See :class:`VideoDecoderCache`. + data_files_root (str | None, optional): fsspec root holding the bulk ``data/`` and ``videos/`` + trees (e.g. an HF storage bucket ``hf://buckets//``). When set, parquet and + video frames are read from there while metadata still loads from ``repo_id`` on the Hub. + Resolves through fsspec exactly like ``hf://``; use it to benchmark bucket / prewarmed-bucket + sources without copying the (small) metadata. """ super().__init__() self.repo_id = repo_id @@ -312,9 +318,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.rank, self.world_size = self._resolve_distributed(rank, world_size) self.video_decoder_cache_size = video_decoder_cache_size + self.data_files_root = data_files_root.rstrip("/") if data_files_root else None # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None + # Shared [hits, misses, evictions] tensor so DataLoader workers aggregate decoder-cache stats into + # one place the main process can read after iteration (see video_decoder_cache_stats()). + self._cache_counters = torch.zeros(3, dtype=torch.int64).share_memory_() # Resume state captured by load_state_dict() and consumed at the next __iter__. self._resume_state: dict | None = None @@ -338,13 +348,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.delta_timestamps = delta_timestamps self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - self.hf_dataset: datasets.IterableDataset = load_dataset( - self.repo_id if not self.streaming_from_local else str(self.root), - split="train", - streaming=self.streaming, - data_files="data/*/*.parquet", - revision=self.revision, - ) + if self.data_files_root is not None: + # Bulk data lives in an fsspec root (e.g. an HF storage bucket); metadata stays on the Hub. + self.hf_dataset: datasets.IterableDataset = load_dataset( + "parquet", + split="train", + streaming=self.streaming, + data_files=f"{self.data_files_root}/data/*/*.parquet", + ) + else: + self.hf_dataset = load_dataset( + self.repo_id if not self.streaming_from_local else str(self.root), + split="train", + streaming=self.streaming, + data_files="data/*/*.parquet", + revision=self.revision, + ) self.num_shards = min(self.hf_dataset.num_shards, max_num_shards) @@ -406,11 +425,13 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): margin so the round-robin never evicts a still-live decoder. """ if self.video_decoder_cache_size is not None: - return VideoDecoderCache(max_size=self.video_decoder_cache_size) + return VideoDecoderCache(max_size=self.video_decoder_cache_size, counters=self._cache_counters) num_cameras = len(self.meta.video_keys) if num_cameras == 0: - return VideoDecoderCache() - return VideoDecoderCache(max_size=(num_active_shards + 1) * num_cameras) + return VideoDecoderCache(counters=self._cache_counters) + return VideoDecoderCache( + max_size=(num_active_shards + 1) * num_cameras, counters=self._cache_counters + ) # TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading. # The current sequential iteration is a bottleneck. A producer-consumer pattern @@ -507,6 +528,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): """Stage resume state captured by :meth:`state_dict`; applied at the next ``__iter__``.""" self._resume_state = state_dict + def video_decoder_cache_stats(self) -> dict[str, int | float]: + """Decoder-cache reuse aggregated across DataLoader workers via the shared counter tensor. + + Unlike ``self.video_decoder_cache.stats()`` (which only reflects the main process), this sums + hits/misses/evictions over every worker. Counts are lock-free across processes, so treat them as + approximate; the ``hit_rate`` ratio is preserved. + """ + hits, misses, evictions = (int(x) for x in self._cache_counters.tolist()) + total = hits + misses + return { + "hits": hits, + "misses": misses, + "evictions": evictions, + "hit_rate": round(hits / total, 4) if total else 0.0, + } + def _get_window_steps( self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False ) -> tuple[int, int]: @@ -679,7 +716,12 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset. from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] - root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root + if self.data_files_root is not None: + root = self.data_files_root + elif self.streaming and not self.streaming_from_local: + root = self.meta.url_root + else: + root = self.root video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" frames = decode_video_frames_torchcodec( video_path, diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index a56015d4f..1ed1b909c 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -242,7 +242,7 @@ class VideoDecoderCache: _SENTINEL: ClassVar[object] = object() - def __init__(self, max_size: int | None | object = _SENTINEL): + def __init__(self, max_size: int | None | object = _SENTINEL, counters: "torch.Tensor | None" = None): if max_size is VideoDecoderCache._SENTINEL: max_size = _default_max_cache_size() if max_size is not None and max_size <= 0: @@ -254,6 +254,10 @@ class VideoDecoderCache: self.hits = 0 self.misses = 0 self.evictions = 0 + # Optional shared [hits, misses, evictions] tensor so DataLoader workers aggregate into one place + # (the per-worker `self.*` ints are invisible to the main process). Lock-free across processes, so + # treat the aggregate as approximate; the hit-rate ratio is preserved. + self._counters = counters def __contains__(self, video_path: object) -> bool: with self._lock: @@ -276,9 +280,13 @@ class VideoDecoderCache: if entry is not None: self._cache.move_to_end(video_path) self.hits += 1 + if self._counters is not None: + self._counters[0] += 1 return entry[0] self.misses += 1 + if self._counters is not None: + self._counters[1] += 1 file_handle = fsspec.open(video_path).__enter__() try: decoder = VideoDecoder(file_handle, seek_mode="approximate") @@ -294,6 +302,8 @@ class VideoDecoderCache: while len(self._cache) > self.max_size: _evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False) self.evictions += 1 + if self._counters is not None: + self._counters[2] += 1 with contextlib.suppress(Exception): evicted_handle.close()