From bc876949ffb18378ef78f98cf520db1942eea4ab Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 2 Jul 2026 15:41:01 +0200 Subject: [PATCH] perf(streaming): parallel per-camera range fetch in the episode byte cache The cluster benchmark showed fetch-bound throughput: resident decode 1824 samples/s vs stream keep-up 693 (target 1000), with fetch at ~465 MiB/s aggregate (~233/rank, i.e. ~4 effective HTTPS streams). Fixes: - One prefetch future per (episode, camera) instead of per episode: cameras no longer fetch back-to-back on a single thread, so the worker pool converts directly into concurrent range GETs. - Default fetch workers 4 -> 16, exposed as video_fetch_workers on StreamingLeRobotDataset for sweeps. - RangeFetcher uses fs.cat_file (one ranged GET per fetch, no open/seek/read layering) and resolves any fsspec URL, so S3-compatible stores (e.g. Backblaze B2 via s3://) work identically to hf://. Also fixed in passing: a latent deadlock on the payload-cache hit path (_get_or_build_decoder re-acquired the non-reentrant lock; unhit so far because payload hits are rare), and episode_byte_cache no longer imports private torchcodec symbols at module import time (they vary across torchcodec versions and broke the module on macOS wheels). New unit tests (decoder layer stubbed): cameras fetch in parallel (wall-clock bound), error propagation through ensure_ready/get_decoder, cache-hit deadlock regression, cat_file range correctness. Local Hub microbench shows +46% aggregate at 16 vs 4 workers on a residential link that saturates at ~15 MiB/s; the real before/after needs the cluster benchmark where per-stream throughput, not the link, binds. Co-Authored-By: Claude Fable 5 --- src/lerobot/datasets/episode_byte_cache.py | 89 ++++++------ src/lerobot/datasets/streaming_dataset.py | 15 ++- tests/datasets/test_episode_byte_cache.py | 150 +++++++++++++++++++++ 3 files changed, 207 insertions(+), 47 deletions(-) create mode 100644 tests/datasets/test_episode_byte_cache.py diff --git a/src/lerobot/datasets/episode_byte_cache.py b/src/lerobot/datasets/episode_byte_cache.py index bf7aefd16..1268b74f6 100644 --- a/src/lerobot/datasets/episode_byte_cache.py +++ b/src/lerobot/datasets/episode_byte_cache.py @@ -14,7 +14,6 @@ import fsspec from .byte_index import EpisodeByteIndex, EpisodeSliceLookup from .mp4_episode_slice import SparseMp4Reader -from .torchcodec_utils import open_video_decoder logger = logging.getLogger(__name__) @@ -61,23 +60,24 @@ class CacheStats: @dataclass class _EpisodeEntry: decoders: dict[str, Any] = field(default_factory=dict) - ready: threading.Event = field(default_factory=threading.Event) + futures: dict[str, Future] = field(default_factory=dict) error: Exception | None = None class RangeFetcher: - """Sequential byte-range GETs via fsspec.""" + """Byte-range GETs via fsspec, one request per range (no open/seek/read layering).""" def __init__(self, path: str): - self.path = path - self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file") + # Resolve any fsspec URL (hf://, s3://, gs://, plain local paths, ...), so S3-compatible + # stores (e.g. Backblaze B2 via s3://) work identically to the Hub. + self._fs, self.path = fsspec.core.url_to_fs(path) def fetch(self, lo: int, hi: int) -> bytes: if hi < lo: return b"" - with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f: - f.seek(lo) - return f.read(hi - lo + 1) + # cat_file issues a single ranged GET (end-exclusive); fs.open would add a metadata + # round-trip and buffered-read layering per fetch. + return self._fs.cat_file(self.path, start=lo, end=hi + 1) class EpisodeByteCache: @@ -91,7 +91,7 @@ class EpisodeByteCache: max_bytes: int, *, data_root: str, - max_prefetch_workers: int = 4, + max_prefetch_workers: int = 16, ): if max_bytes <= 0: raise ValueError(f"max_bytes must be positive; got {max_bytes}") @@ -106,7 +106,6 @@ class EpisodeByteCache: self._episodes: dict[int, _EpisodeEntry] = {} self._stats = CacheStats() self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers) - self._futures: dict[int, Future] = {} @property def stats(self) -> CacheStats: @@ -114,48 +113,49 @@ class EpisodeByteCache: return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__}) def submit_prefetch(self, ep_idx: int) -> None: + # One future per (episode, camera): an episode's cameras fetch in parallel instead of + # back-to-back on one thread, so the worker pool converts directly into concurrent + # range GETs (the fetch throughput lever). with self._lock: - if ep_idx in self._episodes or ep_idx in self._futures: + if ep_idx in self._episodes: return + entry = _EpisodeEntry() + self._episodes[ep_idx] = entry self._stats.prefetch_submitted += 1 - fut = self._executor.submit(self._prefetch_episode, ep_idx) - self._futures[ep_idx] = fut + for cam in self.byte_index.video_keys: + entry.futures[cam] = self._executor.submit(self._prefetch_camera, ep_idx, cam, entry) + + def _prefetch_camera(self, ep_idx: int, cam: str, entry: _EpisodeEntry) -> None: + try: + entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam) + except Exception as exc: + entry.error = exc def ensure_ready(self, ep_idx: int) -> None: - with self._lock: - fut = self._futures.pop(ep_idx, None) - if fut is not None: - with self._lock: - self._stats.prefetch_waits += 1 - fut.result() entry = self._episodes.get(ep_idx) if entry is None: raise KeyError(f"episode {ep_idx} not prefetched") + pending = [f for f in entry.futures.values() if not f.done()] + if pending: + with self._lock: + self._stats.prefetch_waits += 1 + for fut in entry.futures.values(): + fut.result() if entry.error is not None: raise entry.error - entry.ready.wait() def get_decoder(self, ep_idx: int, video_key: str) -> Any: entry = self._episodes[ep_idx] + fut = entry.futures.get(video_key) + if fut is not None: + fut.result() if entry.error is not None: raise entry.error - entry.ready.wait() return entry.decoders[video_key] def close(self) -> None: self._executor.shutdown(wait=False, cancel_futures=True) - def _prefetch_episode(self, ep_idx: int) -> None: - entry = _EpisodeEntry() - self._episodes[ep_idx] = entry - try: - for cam in self.byte_index.video_keys: - entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam) - except Exception as exc: - entry.error = exc - finally: - entry.ready.set() - def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any: key = (ep_idx, cam) with self._lock: @@ -163,12 +163,15 @@ class EpisodeByteCache: if cached is not None: self._cache.move_to_end(key) self._stats.hits += 1 - payload, _ = cached - t0 = time.perf_counter() - dec = self._decoder_from_payload(payload, ep_idx, cam) - with self._lock: - self._stats.buffer_hit_decoder_s += time.perf_counter() - t0 - return dec + if cached is not None: + # Build the decoder outside the lock: self._lock is non-reentrant, and decoding + # while holding it would also serialize every other fetch thread. + payload, _ = cached + t0 = time.perf_counter() + dec = self._decoder_from_payload(payload, ep_idx, cam) + with self._lock: + self._stats.buffer_hit_decoder_s += time.perf_counter() - t0 + return dec payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam) @@ -240,9 +243,11 @@ class EpisodeByteCache: self._stats.bytes_fetched += len(header) return header - def _decoder_from_payload( - self, payload: SparseMp4Reader, ep_idx: int, cam: str - ) -> Any: + def _decoder_from_payload(self, payload: SparseMp4Reader, ep_idx: int, cam: str) -> Any: + # Lazy import: torchcodec_utils touches private torchcodec symbols that vary across + # torchcodec versions; importing this module must not require them. + from .torchcodec_utils import open_video_decoder + payload.seek(0) mappings = self.byte_index.custom_frame_mappings(ep_idx, cam) return open_video_decoder(payload, frame_mappings=mappings) @@ -252,7 +257,7 @@ class EpisodeByteCache: end = float(dec.metadata.end_stream_seconds) duration = max(0.01, end - begin) for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3): - dec.get_frames_played_at([ts]).data + _ = dec.get_frames_played_at([ts]).data def _rewind_payload(self, payload: SparseMp4Reader) -> None: payload.seek(0) diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 6718fde6f..4537e48d3 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -129,6 +129,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): byte_index_build_in_memory: bool | None = None, byte_index_workers: int = 8, byte_index_max_episodes: int | None = None, + video_fetch_workers: int = 16, ): """Initialize a StreamingLeRobotDataset. @@ -188,6 +189,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): if the sidecar parquet is missing on disk. byte_index_workers (int, optional): Parallel moov-index workers for in-memory builds. byte_index_max_episodes (int | None, optional): Cap episodes indexed (debug/smoke tests). + video_fetch_workers (int, optional): Concurrent byte-range fetch threads per consumer + feeding the episode byte cache. Each episode's cameras fetch in parallel, so this + converts directly into concurrent range GETs — the fetch-throughput knob. """ super().__init__() self.repo_id = repo_id @@ -230,6 +234,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.byte_index_build_in_memory = byte_index_build_in_memory self.byte_index_workers = byte_index_workers self.byte_index_max_episodes = byte_index_max_episodes + self.video_fetch_workers = video_fetch_workers self._episode_byte_cache = None self._byte_index = None self._data_root = None @@ -256,7 +261,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): data_root = self._resolve_data_root() index_dir = self.byte_index_path or (self.meta.root / "meta" / "byte_index") - sidecar_exists = (index_dir / "files.parquet").exists() and (index_dir / "episodes.parquet").exists() + sidecar_exists = (index_dir / "files.parquet").exists() and ( + index_dir / "episodes.parquet" + ).exists() build_in_memory = ( self.byte_index_build_in_memory if self.byte_index_build_in_memory is not None @@ -529,10 +536,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128)) def _use_episode_byte_cache(self) -> bool: - return ( - self.video_byte_cache_gb not in (None, 0) - and self.data_files_root is not None - ) + return self.video_byte_cache_gb not in (None, 0) and self.data_files_root is not None def _make_episode_byte_cache(self): from .episode_byte_cache import EpisodeByteCache @@ -544,6 +548,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self._byte_index, max_bytes, data_root=self._data_root, + max_prefetch_workers=self.video_fetch_workers, ) def _submit_episode_prefetch(self, episode_batch: dict[str, list[list]]) -> dict[str, list[list]]: diff --git a/tests/datasets/test_episode_byte_cache.py b/tests/datasets/test_episode_byte_cache.py new file mode 100644 index 000000000..12830b2a0 --- /dev/null +++ b/tests/datasets/test_episode_byte_cache.py @@ -0,0 +1,150 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for EpisodeByteCache fetch concurrency (decoder layer stubbed out).""" + +import threading +import time +from dataclasses import dataclass + +import pytest + +from lerobot.datasets.episode_byte_cache import EpisodeByteCache, RangeFetcher + + +@dataclass +class _FakeLookup: + file_id: int + mdat_offset: int + mdat_length: int + + +@dataclass +class _FakeFileInfo: + file_path: str + file_size: int + header_length: int + + +class _FakeByteIndex: + """Two cameras, one file per (episode, cam); fetch delay is injectable.""" + + def __init__(self, tmp_path, num_episodes=4, mdat_len=64): + self.video_keys = ["cam0", "cam1"] + self._files = {} + self._lookups = {} + file_id = 0 + for ep in range(num_episodes): + for cam in self.video_keys: + path = tmp_path / f"ep{ep}_{cam}.bin" + payload = bytes([ep]) * 16 + bytes(range(64)) * ((mdat_len + 63) // 64) + path.write_bytes(payload) + self._files[file_id] = _FakeFileInfo(str(path), len(payload), header_length=16) + self._lookups[(ep, cam)] = _FakeLookup(file_id, mdat_offset=16, mdat_length=mdat_len) + file_id += 1 + + def lookup(self, ep_idx, cam): + return self._lookups[(ep_idx, cam)] + + def file_lookup(self, file_id): + return self._files[file_id] + + def custom_frame_mappings(self, ep_idx, cam): + return None + + +class _SlowFetchCache(EpisodeByteCache): + """Stub decode/validation; add a per-fetch delay to observe fetch parallelism.""" + + fetch_delay_s = 0.15 + + def _fetch_manifest_slice(self, ep_idx, cam): + time.sleep(self.fetch_delay_s) + return f"payload-{ep_idx}-{cam}", 32, f"decoder-{ep_idx}-{cam}" + + def _decoder_from_payload(self, payload, ep_idx, cam): + return f"decoder-{ep_idx}-{cam}" + + +def _make_cache(tmp_path, **kwargs): + index = _FakeByteIndex(tmp_path) + return _SlowFetchCache(index, max_bytes=10_000_000, data_root=str(tmp_path), **kwargs) + + +def test_cameras_fetch_in_parallel(tmp_path): + """An episode's cameras must not fetch back-to-back on one thread.""" + cache = _make_cache(tmp_path, max_prefetch_workers=8) + start = time.perf_counter() + for ep in range(4): + cache.submit_prefetch(ep) + for ep in range(4): + cache.ensure_ready(ep) + elapsed = time.perf_counter() - start + # 4 episodes x 2 cams x 0.15s = 1.2s sequential; 8 workers -> one wave ~0.15s. + assert elapsed < 0.6, f"fetches serialized: {elapsed:.2f}s for 8 fetches on 8 workers" + assert cache.get_decoder(2, "cam1") == "decoder-2-cam1" + cache.close() + + +def test_prefetch_error_propagates(tmp_path): + cache = _make_cache(tmp_path, max_prefetch_workers=2) + + def boom(ep_idx, cam): + raise RuntimeError("fetch failed") + + cache._fetch_manifest_slice = boom + cache.submit_prefetch(0) + with pytest.raises(RuntimeError, match="fetch failed"): + cache.ensure_ready(0) + with pytest.raises(RuntimeError, match="fetch failed"): + cache.get_decoder(0, "cam0") + cache.close() + + +def test_payload_cache_hit_does_not_deadlock(tmp_path): + """Regression: the hit path used to re-acquire the non-reentrant lock (deadlock).""" + cache = _make_cache(tmp_path, max_prefetch_workers=2) + cache._cache[(0, "cam0")] = ("payload-0-cam0", 32) + cache._bytes_used = 32 + + result = {} + + def hit(): + result["dec"] = cache._get_or_build_decoder(0, "cam0") + + thread = threading.Thread(target=hit, daemon=True) + thread.start() + thread.join(timeout=5) + assert not thread.is_alive(), "cache-hit path deadlocked" + assert result["dec"] == "decoder-0-cam0" + assert cache.stats.hits == 1 + cache.close() + + +def test_range_fetcher_cat_file_correctness(tmp_path): + payload = bytes(range(256)) * 4 + path = tmp_path / "blob.bin" + path.write_bytes(payload) + fetcher = RangeFetcher(str(path)) + assert fetcher.fetch(0, 15) == payload[0:16] + assert fetcher.fetch(100, 355) == payload[100:356] + assert fetcher.fetch(len(payload) - 4, len(payload) - 1) == payload[-4:] + assert fetcher.fetch(10, 9) == b"" + + +def test_ensure_ready_unknown_episode_raises(tmp_path): + cache = _make_cache(tmp_path) + with pytest.raises(KeyError): + cache.ensure_ready(99) + cache.close()