perf(streaming): parallel per-camera range fetch in the episode byte cache

The cluster benchmark showed fetch-bound throughput: resident decode
1824 samples/s vs stream keep-up 693 (target 1000), with fetch at ~465
MiB/s aggregate (~233/rank, i.e. ~4 effective HTTPS streams). Fixes:

- One prefetch future per (episode, camera) instead of per episode:
  cameras no longer fetch back-to-back on a single thread, so the worker
  pool converts directly into concurrent range GETs.
- Default fetch workers 4 -> 16, exposed as video_fetch_workers on
  StreamingLeRobotDataset for sweeps.
- RangeFetcher uses fs.cat_file (one ranged GET per fetch, no
  open/seek/read layering) and resolves any fsspec URL, so S3-compatible
  stores (e.g. Backblaze B2 via s3://) work identically to hf://.

Also fixed in passing: a latent deadlock on the payload-cache hit path
(_get_or_build_decoder re-acquired the non-reentrant lock; unhit so far
because payload hits are rare), and episode_byte_cache no longer imports
private torchcodec symbols at module import time (they vary across
torchcodec versions and broke the module on macOS wheels).

New unit tests (decoder layer stubbed): cameras fetch in parallel
(wall-clock bound), error propagation through ensure_ready/get_decoder,
cache-hit deadlock regression, cat_file range correctness. Local Hub
microbench shows +46% aggregate at 16 vs 4 workers on a residential
link that saturates at ~15 MiB/s; the real before/after needs the
cluster benchmark where per-stream throughput, not the link, binds.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-07-02 15:41:01 +02:00
parent 7b6f4f2b11
commit bc876949ff
3 changed files with 207 additions and 47 deletions
+47 -42
View File
@@ -14,7 +14,6 @@ import fsspec
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
from .mp4_episode_slice import SparseMp4Reader
from .torchcodec_utils import open_video_decoder
logger = logging.getLogger(__name__)
@@ -61,23 +60,24 @@ class CacheStats:
@dataclass
class _EpisodeEntry:
decoders: dict[str, Any] = field(default_factory=dict)
ready: threading.Event = field(default_factory=threading.Event)
futures: dict[str, Future] = field(default_factory=dict)
error: Exception | None = None
class RangeFetcher:
"""Sequential byte-range GETs via fsspec."""
"""Byte-range GETs via fsspec, one request per range (no open/seek/read layering)."""
def __init__(self, path: str):
self.path = path
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
# Resolve any fsspec URL (hf://, s3://, gs://, plain local paths, ...), so S3-compatible
# stores (e.g. Backblaze B2 via s3://) work identically to the Hub.
self._fs, self.path = fsspec.core.url_to_fs(path)
def fetch(self, lo: int, hi: int) -> bytes:
if hi < lo:
return b""
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
f.seek(lo)
return f.read(hi - lo + 1)
# cat_file issues a single ranged GET (end-exclusive); fs.open would add a metadata
# round-trip and buffered-read layering per fetch.
return self._fs.cat_file(self.path, start=lo, end=hi + 1)
class EpisodeByteCache:
@@ -91,7 +91,7 @@ class EpisodeByteCache:
max_bytes: int,
*,
data_root: str,
max_prefetch_workers: int = 4,
max_prefetch_workers: int = 16,
):
if max_bytes <= 0:
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
@@ -106,7 +106,6 @@ class EpisodeByteCache:
self._episodes: dict[int, _EpisodeEntry] = {}
self._stats = CacheStats()
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
self._futures: dict[int, Future] = {}
@property
def stats(self) -> CacheStats:
@@ -114,48 +113,49 @@ class EpisodeByteCache:
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
def submit_prefetch(self, ep_idx: int) -> None:
# One future per (episode, camera): an episode's cameras fetch in parallel instead of
# back-to-back on one thread, so the worker pool converts directly into concurrent
# range GETs (the fetch throughput lever).
with self._lock:
if ep_idx in self._episodes or ep_idx in self._futures:
if ep_idx in self._episodes:
return
entry = _EpisodeEntry()
self._episodes[ep_idx] = entry
self._stats.prefetch_submitted += 1
fut = self._executor.submit(self._prefetch_episode, ep_idx)
self._futures[ep_idx] = fut
for cam in self.byte_index.video_keys:
entry.futures[cam] = self._executor.submit(self._prefetch_camera, ep_idx, cam, entry)
def _prefetch_camera(self, ep_idx: int, cam: str, entry: _EpisodeEntry) -> None:
try:
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
except Exception as exc:
entry.error = exc
def ensure_ready(self, ep_idx: int) -> None:
with self._lock:
fut = self._futures.pop(ep_idx, None)
if fut is not None:
with self._lock:
self._stats.prefetch_waits += 1
fut.result()
entry = self._episodes.get(ep_idx)
if entry is None:
raise KeyError(f"episode {ep_idx} not prefetched")
pending = [f for f in entry.futures.values() if not f.done()]
if pending:
with self._lock:
self._stats.prefetch_waits += 1
for fut in entry.futures.values():
fut.result()
if entry.error is not None:
raise entry.error
entry.ready.wait()
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
entry = self._episodes[ep_idx]
fut = entry.futures.get(video_key)
if fut is not None:
fut.result()
if entry.error is not None:
raise entry.error
entry.ready.wait()
return entry.decoders[video_key]
def close(self) -> None:
self._executor.shutdown(wait=False, cancel_futures=True)
def _prefetch_episode(self, ep_idx: int) -> None:
entry = _EpisodeEntry()
self._episodes[ep_idx] = entry
try:
for cam in self.byte_index.video_keys:
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
except Exception as exc:
entry.error = exc
finally:
entry.ready.set()
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
key = (ep_idx, cam)
with self._lock:
@@ -163,12 +163,15 @@ class EpisodeByteCache:
if cached is not None:
self._cache.move_to_end(key)
self._stats.hits += 1
payload, _ = cached
t0 = time.perf_counter()
dec = self._decoder_from_payload(payload, ep_idx, cam)
with self._lock:
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
return dec
if cached is not None:
# Build the decoder outside the lock: self._lock is non-reentrant, and decoding
# while holding it would also serialize every other fetch thread.
payload, _ = cached
t0 = time.perf_counter()
dec = self._decoder_from_payload(payload, ep_idx, cam)
with self._lock:
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
return dec
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
@@ -240,9 +243,11 @@ class EpisodeByteCache:
self._stats.bytes_fetched += len(header)
return header
def _decoder_from_payload(
self, payload: SparseMp4Reader, ep_idx: int, cam: str
) -> Any:
def _decoder_from_payload(self, payload: SparseMp4Reader, ep_idx: int, cam: str) -> Any:
# Lazy import: torchcodec_utils touches private torchcodec symbols that vary across
# torchcodec versions; importing this module must not require them.
from .torchcodec_utils import open_video_decoder
payload.seek(0)
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
return open_video_decoder(payload, frame_mappings=mappings)
@@ -252,7 +257,7 @@ class EpisodeByteCache:
end = float(dec.metadata.end_stream_seconds)
duration = max(0.01, end - begin)
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
dec.get_frames_played_at([ts]).data
_ = dec.get_frames_played_at([ts]).data
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
payload.seek(0)
+10 -5
View File
@@ -129,6 +129,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
byte_index_build_in_memory: bool | None = None,
byte_index_workers: int = 8,
byte_index_max_episodes: int | None = None,
video_fetch_workers: int = 16,
):
"""Initialize a StreamingLeRobotDataset.
@@ -188,6 +189,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
if the sidecar parquet is missing on disk.
byte_index_workers (int, optional): Parallel moov-index workers for in-memory builds.
byte_index_max_episodes (int | None, optional): Cap episodes indexed (debug/smoke tests).
video_fetch_workers (int, optional): Concurrent byte-range fetch threads per consumer
feeding the episode byte cache. Each episode's cameras fetch in parallel, so this
converts directly into concurrent range GETs the fetch-throughput knob.
"""
super().__init__()
self.repo_id = repo_id
@@ -230,6 +234,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.byte_index_build_in_memory = byte_index_build_in_memory
self.byte_index_workers = byte_index_workers
self.byte_index_max_episodes = byte_index_max_episodes
self.video_fetch_workers = video_fetch_workers
self._episode_byte_cache = None
self._byte_index = None
self._data_root = None
@@ -256,7 +261,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
data_root = self._resolve_data_root()
index_dir = self.byte_index_path or (self.meta.root / "meta" / "byte_index")
sidecar_exists = (index_dir / "files.parquet").exists() and (index_dir / "episodes.parquet").exists()
sidecar_exists = (index_dir / "files.parquet").exists() and (
index_dir / "episodes.parquet"
).exists()
build_in_memory = (
self.byte_index_build_in_memory
if self.byte_index_build_in_memory is not None
@@ -529,10 +536,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128))
def _use_episode_byte_cache(self) -> bool:
return (
self.video_byte_cache_gb not in (None, 0)
and self.data_files_root is not None
)
return self.video_byte_cache_gb not in (None, 0) and self.data_files_root is not None
def _make_episode_byte_cache(self):
from .episode_byte_cache import EpisodeByteCache
@@ -544,6 +548,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self._byte_index,
max_bytes,
data_root=self._data_root,
max_prefetch_workers=self.video_fetch_workers,
)
def _submit_episode_prefetch(self, episode_batch: dict[str, list[list]]) -> dict[str, list[list]]:
+150
View File
@@ -0,0 +1,150 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for EpisodeByteCache fetch concurrency (decoder layer stubbed out)."""
import threading
import time
from dataclasses import dataclass
import pytest
from lerobot.datasets.episode_byte_cache import EpisodeByteCache, RangeFetcher
@dataclass
class _FakeLookup:
file_id: int
mdat_offset: int
mdat_length: int
@dataclass
class _FakeFileInfo:
file_path: str
file_size: int
header_length: int
class _FakeByteIndex:
"""Two cameras, one file per (episode, cam); fetch delay is injectable."""
def __init__(self, tmp_path, num_episodes=4, mdat_len=64):
self.video_keys = ["cam0", "cam1"]
self._files = {}
self._lookups = {}
file_id = 0
for ep in range(num_episodes):
for cam in self.video_keys:
path = tmp_path / f"ep{ep}_{cam}.bin"
payload = bytes([ep]) * 16 + bytes(range(64)) * ((mdat_len + 63) // 64)
path.write_bytes(payload)
self._files[file_id] = _FakeFileInfo(str(path), len(payload), header_length=16)
self._lookups[(ep, cam)] = _FakeLookup(file_id, mdat_offset=16, mdat_length=mdat_len)
file_id += 1
def lookup(self, ep_idx, cam):
return self._lookups[(ep_idx, cam)]
def file_lookup(self, file_id):
return self._files[file_id]
def custom_frame_mappings(self, ep_idx, cam):
return None
class _SlowFetchCache(EpisodeByteCache):
"""Stub decode/validation; add a per-fetch delay to observe fetch parallelism."""
fetch_delay_s = 0.15
def _fetch_manifest_slice(self, ep_idx, cam):
time.sleep(self.fetch_delay_s)
return f"payload-{ep_idx}-{cam}", 32, f"decoder-{ep_idx}-{cam}"
def _decoder_from_payload(self, payload, ep_idx, cam):
return f"decoder-{ep_idx}-{cam}"
def _make_cache(tmp_path, **kwargs):
index = _FakeByteIndex(tmp_path)
return _SlowFetchCache(index, max_bytes=10_000_000, data_root=str(tmp_path), **kwargs)
def test_cameras_fetch_in_parallel(tmp_path):
"""An episode's cameras must not fetch back-to-back on one thread."""
cache = _make_cache(tmp_path, max_prefetch_workers=8)
start = time.perf_counter()
for ep in range(4):
cache.submit_prefetch(ep)
for ep in range(4):
cache.ensure_ready(ep)
elapsed = time.perf_counter() - start
# 4 episodes x 2 cams x 0.15s = 1.2s sequential; 8 workers -> one wave ~0.15s.
assert elapsed < 0.6, f"fetches serialized: {elapsed:.2f}s for 8 fetches on 8 workers"
assert cache.get_decoder(2, "cam1") == "decoder-2-cam1"
cache.close()
def test_prefetch_error_propagates(tmp_path):
cache = _make_cache(tmp_path, max_prefetch_workers=2)
def boom(ep_idx, cam):
raise RuntimeError("fetch failed")
cache._fetch_manifest_slice = boom
cache.submit_prefetch(0)
with pytest.raises(RuntimeError, match="fetch failed"):
cache.ensure_ready(0)
with pytest.raises(RuntimeError, match="fetch failed"):
cache.get_decoder(0, "cam0")
cache.close()
def test_payload_cache_hit_does_not_deadlock(tmp_path):
"""Regression: the hit path used to re-acquire the non-reentrant lock (deadlock)."""
cache = _make_cache(tmp_path, max_prefetch_workers=2)
cache._cache[(0, "cam0")] = ("payload-0-cam0", 32)
cache._bytes_used = 32
result = {}
def hit():
result["dec"] = cache._get_or_build_decoder(0, "cam0")
thread = threading.Thread(target=hit, daemon=True)
thread.start()
thread.join(timeout=5)
assert not thread.is_alive(), "cache-hit path deadlocked"
assert result["dec"] == "decoder-0-cam0"
assert cache.stats.hits == 1
cache.close()
def test_range_fetcher_cat_file_correctness(tmp_path):
payload = bytes(range(256)) * 4
path = tmp_path / "blob.bin"
path.write_bytes(payload)
fetcher = RangeFetcher(str(path))
assert fetcher.fetch(0, 15) == payload[0:16]
assert fetcher.fetch(100, 355) == payload[100:356]
assert fetcher.fetch(len(payload) - 4, len(payload) - 1) == payload[-4:]
assert fetcher.fetch(10, 9) == b""
def test_ensure_ready_unknown_episode_raises(tmp_path):
cache = _make_cache(tmp_path)
with pytest.raises(KeyError):
cache.ensure_ready(99)
cache.close()