mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 08:07:03 +00:00
perf(streaming): parallel per-camera range fetch in the episode byte cache
The cluster benchmark showed fetch-bound throughput: resident decode 1824 samples/s vs stream keep-up 693 (target 1000), with fetch at ~465 MiB/s aggregate (~233/rank, i.e. ~4 effective HTTPS streams). Fixes: - One prefetch future per (episode, camera) instead of per episode: cameras no longer fetch back-to-back on a single thread, so the worker pool converts directly into concurrent range GETs. - Default fetch workers 4 -> 16, exposed as video_fetch_workers on StreamingLeRobotDataset for sweeps. - RangeFetcher uses fs.cat_file (one ranged GET per fetch, no open/seek/read layering) and resolves any fsspec URL, so S3-compatible stores (e.g. Backblaze B2 via s3://) work identically to hf://. Also fixed in passing: a latent deadlock on the payload-cache hit path (_get_or_build_decoder re-acquired the non-reentrant lock; unhit so far because payload hits are rare), and episode_byte_cache no longer imports private torchcodec symbols at module import time (they vary across torchcodec versions and broke the module on macOS wheels). New unit tests (decoder layer stubbed): cameras fetch in parallel (wall-clock bound), error propagation through ensure_ready/get_decoder, cache-hit deadlock regression, cat_file range correctness. Local Hub microbench shows +46% aggregate at 16 vs 4 workers on a residential link that saturates at ~15 MiB/s; the real before/after needs the cluster benchmark where per-stream throughput, not the link, binds. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -14,7 +14,6 @@ import fsspec
|
||||
|
||||
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
|
||||
from .mp4_episode_slice import SparseMp4Reader
|
||||
from .torchcodec_utils import open_video_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -61,23 +60,24 @@ class CacheStats:
|
||||
@dataclass
|
||||
class _EpisodeEntry:
|
||||
decoders: dict[str, Any] = field(default_factory=dict)
|
||||
ready: threading.Event = field(default_factory=threading.Event)
|
||||
futures: dict[str, Future] = field(default_factory=dict)
|
||||
error: Exception | None = None
|
||||
|
||||
|
||||
class RangeFetcher:
|
||||
"""Sequential byte-range GETs via fsspec."""
|
||||
"""Byte-range GETs via fsspec, one request per range (no open/seek/read layering)."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
# Resolve any fsspec URL (hf://, s3://, gs://, plain local paths, ...), so S3-compatible
|
||||
# stores (e.g. Backblaze B2 via s3://) work identically to the Hub.
|
||||
self._fs, self.path = fsspec.core.url_to_fs(path)
|
||||
|
||||
def fetch(self, lo: int, hi: int) -> bytes:
|
||||
if hi < lo:
|
||||
return b""
|
||||
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
|
||||
f.seek(lo)
|
||||
return f.read(hi - lo + 1)
|
||||
# cat_file issues a single ranged GET (end-exclusive); fs.open would add a metadata
|
||||
# round-trip and buffered-read layering per fetch.
|
||||
return self._fs.cat_file(self.path, start=lo, end=hi + 1)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
@@ -91,7 +91,7 @@ class EpisodeByteCache:
|
||||
max_bytes: int,
|
||||
*,
|
||||
data_root: str,
|
||||
max_prefetch_workers: int = 4,
|
||||
max_prefetch_workers: int = 16,
|
||||
):
|
||||
if max_bytes <= 0:
|
||||
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
|
||||
@@ -106,7 +106,6 @@ class EpisodeByteCache:
|
||||
self._episodes: dict[int, _EpisodeEntry] = {}
|
||||
self._stats = CacheStats()
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
|
||||
self._futures: dict[int, Future] = {}
|
||||
|
||||
@property
|
||||
def stats(self) -> CacheStats:
|
||||
@@ -114,48 +113,49 @@ class EpisodeByteCache:
|
||||
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
|
||||
|
||||
def submit_prefetch(self, ep_idx: int) -> None:
|
||||
# One future per (episode, camera): an episode's cameras fetch in parallel instead of
|
||||
# back-to-back on one thread, so the worker pool converts directly into concurrent
|
||||
# range GETs (the fetch throughput lever).
|
||||
with self._lock:
|
||||
if ep_idx in self._episodes or ep_idx in self._futures:
|
||||
if ep_idx in self._episodes:
|
||||
return
|
||||
entry = _EpisodeEntry()
|
||||
self._episodes[ep_idx] = entry
|
||||
self._stats.prefetch_submitted += 1
|
||||
fut = self._executor.submit(self._prefetch_episode, ep_idx)
|
||||
self._futures[ep_idx] = fut
|
||||
for cam in self.byte_index.video_keys:
|
||||
entry.futures[cam] = self._executor.submit(self._prefetch_camera, ep_idx, cam, entry)
|
||||
|
||||
def _prefetch_camera(self, ep_idx: int, cam: str, entry: _EpisodeEntry) -> None:
|
||||
try:
|
||||
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
|
||||
except Exception as exc:
|
||||
entry.error = exc
|
||||
|
||||
def ensure_ready(self, ep_idx: int) -> None:
|
||||
with self._lock:
|
||||
fut = self._futures.pop(ep_idx, None)
|
||||
if fut is not None:
|
||||
with self._lock:
|
||||
self._stats.prefetch_waits += 1
|
||||
fut.result()
|
||||
entry = self._episodes.get(ep_idx)
|
||||
if entry is None:
|
||||
raise KeyError(f"episode {ep_idx} not prefetched")
|
||||
pending = [f for f in entry.futures.values() if not f.done()]
|
||||
if pending:
|
||||
with self._lock:
|
||||
self._stats.prefetch_waits += 1
|
||||
for fut in entry.futures.values():
|
||||
fut.result()
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
|
||||
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
|
||||
entry = self._episodes[ep_idx]
|
||||
fut = entry.futures.get(video_key)
|
||||
if fut is not None:
|
||||
fut.result()
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
return entry.decoders[video_key]
|
||||
|
||||
def close(self) -> None:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
def _prefetch_episode(self, ep_idx: int) -> None:
|
||||
entry = _EpisodeEntry()
|
||||
self._episodes[ep_idx] = entry
|
||||
try:
|
||||
for cam in self.byte_index.video_keys:
|
||||
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
|
||||
except Exception as exc:
|
||||
entry.error = exc
|
||||
finally:
|
||||
entry.ready.set()
|
||||
|
||||
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
|
||||
key = (ep_idx, cam)
|
||||
with self._lock:
|
||||
@@ -163,12 +163,15 @@ class EpisodeByteCache:
|
||||
if cached is not None:
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.hits += 1
|
||||
payload, _ = cached
|
||||
t0 = time.perf_counter()
|
||||
dec = self._decoder_from_payload(payload, ep_idx, cam)
|
||||
with self._lock:
|
||||
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
|
||||
return dec
|
||||
if cached is not None:
|
||||
# Build the decoder outside the lock: self._lock is non-reentrant, and decoding
|
||||
# while holding it would also serialize every other fetch thread.
|
||||
payload, _ = cached
|
||||
t0 = time.perf_counter()
|
||||
dec = self._decoder_from_payload(payload, ep_idx, cam)
|
||||
with self._lock:
|
||||
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
|
||||
return dec
|
||||
|
||||
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
|
||||
|
||||
@@ -240,9 +243,11 @@ class EpisodeByteCache:
|
||||
self._stats.bytes_fetched += len(header)
|
||||
return header
|
||||
|
||||
def _decoder_from_payload(
|
||||
self, payload: SparseMp4Reader, ep_idx: int, cam: str
|
||||
) -> Any:
|
||||
def _decoder_from_payload(self, payload: SparseMp4Reader, ep_idx: int, cam: str) -> Any:
|
||||
# Lazy import: torchcodec_utils touches private torchcodec symbols that vary across
|
||||
# torchcodec versions; importing this module must not require them.
|
||||
from .torchcodec_utils import open_video_decoder
|
||||
|
||||
payload.seek(0)
|
||||
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
|
||||
return open_video_decoder(payload, frame_mappings=mappings)
|
||||
@@ -252,7 +257,7 @@ class EpisodeByteCache:
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
duration = max(0.01, end - begin)
|
||||
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
|
||||
dec.get_frames_played_at([ts]).data
|
||||
_ = dec.get_frames_played_at([ts]).data
|
||||
|
||||
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
|
||||
payload.seek(0)
|
||||
|
||||
@@ -129,6 +129,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
byte_index_build_in_memory: bool | None = None,
|
||||
byte_index_workers: int = 8,
|
||||
byte_index_max_episodes: int | None = None,
|
||||
video_fetch_workers: int = 16,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -188,6 +189,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
if the sidecar parquet is missing on disk.
|
||||
byte_index_workers (int, optional): Parallel moov-index workers for in-memory builds.
|
||||
byte_index_max_episodes (int | None, optional): Cap episodes indexed (debug/smoke tests).
|
||||
video_fetch_workers (int, optional): Concurrent byte-range fetch threads per consumer
|
||||
feeding the episode byte cache. Each episode's cameras fetch in parallel, so this
|
||||
converts directly into concurrent range GETs — the fetch-throughput knob.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -230,6 +234,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.byte_index_build_in_memory = byte_index_build_in_memory
|
||||
self.byte_index_workers = byte_index_workers
|
||||
self.byte_index_max_episodes = byte_index_max_episodes
|
||||
self.video_fetch_workers = video_fetch_workers
|
||||
self._episode_byte_cache = None
|
||||
self._byte_index = None
|
||||
self._data_root = None
|
||||
@@ -256,7 +261,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
data_root = self._resolve_data_root()
|
||||
index_dir = self.byte_index_path or (self.meta.root / "meta" / "byte_index")
|
||||
sidecar_exists = (index_dir / "files.parquet").exists() and (index_dir / "episodes.parquet").exists()
|
||||
sidecar_exists = (index_dir / "files.parquet").exists() and (
|
||||
index_dir / "episodes.parquet"
|
||||
).exists()
|
||||
build_in_memory = (
|
||||
self.byte_index_build_in_memory
|
||||
if self.byte_index_build_in_memory is not None
|
||||
@@ -529,10 +536,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128))
|
||||
|
||||
def _use_episode_byte_cache(self) -> bool:
|
||||
return (
|
||||
self.video_byte_cache_gb not in (None, 0)
|
||||
and self.data_files_root is not None
|
||||
)
|
||||
return self.video_byte_cache_gb not in (None, 0) and self.data_files_root is not None
|
||||
|
||||
def _make_episode_byte_cache(self):
|
||||
from .episode_byte_cache import EpisodeByteCache
|
||||
@@ -544,6 +548,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self._byte_index,
|
||||
max_bytes,
|
||||
data_root=self._data_root,
|
||||
max_prefetch_workers=self.video_fetch_workers,
|
||||
)
|
||||
|
||||
def _submit_episode_prefetch(self, episode_batch: dict[str, list[list]]) -> dict[str, list[list]]:
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for EpisodeByteCache fetch concurrency (decoder layer stubbed out)."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache, RangeFetcher
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeLookup:
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeFileInfo:
|
||||
file_path: str
|
||||
file_size: int
|
||||
header_length: int
|
||||
|
||||
|
||||
class _FakeByteIndex:
|
||||
"""Two cameras, one file per (episode, cam); fetch delay is injectable."""
|
||||
|
||||
def __init__(self, tmp_path, num_episodes=4, mdat_len=64):
|
||||
self.video_keys = ["cam0", "cam1"]
|
||||
self._files = {}
|
||||
self._lookups = {}
|
||||
file_id = 0
|
||||
for ep in range(num_episodes):
|
||||
for cam in self.video_keys:
|
||||
path = tmp_path / f"ep{ep}_{cam}.bin"
|
||||
payload = bytes([ep]) * 16 + bytes(range(64)) * ((mdat_len + 63) // 64)
|
||||
path.write_bytes(payload)
|
||||
self._files[file_id] = _FakeFileInfo(str(path), len(payload), header_length=16)
|
||||
self._lookups[(ep, cam)] = _FakeLookup(file_id, mdat_offset=16, mdat_length=mdat_len)
|
||||
file_id += 1
|
||||
|
||||
def lookup(self, ep_idx, cam):
|
||||
return self._lookups[(ep_idx, cam)]
|
||||
|
||||
def file_lookup(self, file_id):
|
||||
return self._files[file_id]
|
||||
|
||||
def custom_frame_mappings(self, ep_idx, cam):
|
||||
return None
|
||||
|
||||
|
||||
class _SlowFetchCache(EpisodeByteCache):
|
||||
"""Stub decode/validation; add a per-fetch delay to observe fetch parallelism."""
|
||||
|
||||
fetch_delay_s = 0.15
|
||||
|
||||
def _fetch_manifest_slice(self, ep_idx, cam):
|
||||
time.sleep(self.fetch_delay_s)
|
||||
return f"payload-{ep_idx}-{cam}", 32, f"decoder-{ep_idx}-{cam}"
|
||||
|
||||
def _decoder_from_payload(self, payload, ep_idx, cam):
|
||||
return f"decoder-{ep_idx}-{cam}"
|
||||
|
||||
|
||||
def _make_cache(tmp_path, **kwargs):
|
||||
index = _FakeByteIndex(tmp_path)
|
||||
return _SlowFetchCache(index, max_bytes=10_000_000, data_root=str(tmp_path), **kwargs)
|
||||
|
||||
|
||||
def test_cameras_fetch_in_parallel(tmp_path):
|
||||
"""An episode's cameras must not fetch back-to-back on one thread."""
|
||||
cache = _make_cache(tmp_path, max_prefetch_workers=8)
|
||||
start = time.perf_counter()
|
||||
for ep in range(4):
|
||||
cache.submit_prefetch(ep)
|
||||
for ep in range(4):
|
||||
cache.ensure_ready(ep)
|
||||
elapsed = time.perf_counter() - start
|
||||
# 4 episodes x 2 cams x 0.15s = 1.2s sequential; 8 workers -> one wave ~0.15s.
|
||||
assert elapsed < 0.6, f"fetches serialized: {elapsed:.2f}s for 8 fetches on 8 workers"
|
||||
assert cache.get_decoder(2, "cam1") == "decoder-2-cam1"
|
||||
cache.close()
|
||||
|
||||
|
||||
def test_prefetch_error_propagates(tmp_path):
|
||||
cache = _make_cache(tmp_path, max_prefetch_workers=2)
|
||||
|
||||
def boom(ep_idx, cam):
|
||||
raise RuntimeError("fetch failed")
|
||||
|
||||
cache._fetch_manifest_slice = boom
|
||||
cache.submit_prefetch(0)
|
||||
with pytest.raises(RuntimeError, match="fetch failed"):
|
||||
cache.ensure_ready(0)
|
||||
with pytest.raises(RuntimeError, match="fetch failed"):
|
||||
cache.get_decoder(0, "cam0")
|
||||
cache.close()
|
||||
|
||||
|
||||
def test_payload_cache_hit_does_not_deadlock(tmp_path):
|
||||
"""Regression: the hit path used to re-acquire the non-reentrant lock (deadlock)."""
|
||||
cache = _make_cache(tmp_path, max_prefetch_workers=2)
|
||||
cache._cache[(0, "cam0")] = ("payload-0-cam0", 32)
|
||||
cache._bytes_used = 32
|
||||
|
||||
result = {}
|
||||
|
||||
def hit():
|
||||
result["dec"] = cache._get_or_build_decoder(0, "cam0")
|
||||
|
||||
thread = threading.Thread(target=hit, daemon=True)
|
||||
thread.start()
|
||||
thread.join(timeout=5)
|
||||
assert not thread.is_alive(), "cache-hit path deadlocked"
|
||||
assert result["dec"] == "decoder-0-cam0"
|
||||
assert cache.stats.hits == 1
|
||||
cache.close()
|
||||
|
||||
|
||||
def test_range_fetcher_cat_file_correctness(tmp_path):
|
||||
payload = bytes(range(256)) * 4
|
||||
path = tmp_path / "blob.bin"
|
||||
path.write_bytes(payload)
|
||||
fetcher = RangeFetcher(str(path))
|
||||
assert fetcher.fetch(0, 15) == payload[0:16]
|
||||
assert fetcher.fetch(100, 355) == payload[100:356]
|
||||
assert fetcher.fetch(len(payload) - 4, len(payload) - 1) == payload[-4:]
|
||||
assert fetcher.fetch(10, 9) == b""
|
||||
|
||||
|
||||
def test_ensure_ready_unknown_episode_raises(tmp_path):
|
||||
cache = _make_cache(tmp_path)
|
||||
with pytest.raises(KeyError):
|
||||
cache.ensure_ready(99)
|
||||
cache.close()
|
||||
Reference in New Issue
Block a user