mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 08:37:10 +00:00
bc876949ff
The cluster benchmark showed fetch-bound throughput: resident decode 1824 samples/s vs stream keep-up 693 (target 1000), with fetch at ~465 MiB/s aggregate (~233/rank, i.e. ~4 effective HTTPS streams). Fixes: - One prefetch future per (episode, camera) instead of per episode: cameras no longer fetch back-to-back on a single thread, so the worker pool converts directly into concurrent range GETs. - Default fetch workers 4 -> 16, exposed as video_fetch_workers on StreamingLeRobotDataset for sweeps. - RangeFetcher uses fs.cat_file (one ranged GET per fetch, no open/seek/read layering) and resolves any fsspec URL, so S3-compatible stores (e.g. Backblaze B2 via s3://) work identically to hf://. Also fixed in passing: a latent deadlock on the payload-cache hit path (_get_or_build_decoder re-acquired the non-reentrant lock; unhit so far because payload hits are rare), and episode_byte_cache no longer imports private torchcodec symbols at module import time (they vary across torchcodec versions and broke the module on macOS wheels). New unit tests (decoder layer stubbed): cameras fetch in parallel (wall-clock bound), error propagation through ensure_ready/get_decoder, cache-hit deadlock regression, cat_file range correctness. Local Hub microbench shows +46% aggregate at 16 vs 4 workers on a residential link that saturates at ~15 MiB/s; the real before/after needs the cluster benchmark where per-stream throughput, not the link, binds. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
151 lines
4.8 KiB
Python
151 lines
4.8 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Unit tests for EpisodeByteCache fetch concurrency (decoder layer stubbed out)."""
|
|
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
|
|
from lerobot.datasets.episode_byte_cache import EpisodeByteCache, RangeFetcher
|
|
|
|
|
|
@dataclass
|
|
class _FakeLookup:
|
|
file_id: int
|
|
mdat_offset: int
|
|
mdat_length: int
|
|
|
|
|
|
@dataclass
|
|
class _FakeFileInfo:
|
|
file_path: str
|
|
file_size: int
|
|
header_length: int
|
|
|
|
|
|
class _FakeByteIndex:
|
|
"""Two cameras, one file per (episode, cam); fetch delay is injectable."""
|
|
|
|
def __init__(self, tmp_path, num_episodes=4, mdat_len=64):
|
|
self.video_keys = ["cam0", "cam1"]
|
|
self._files = {}
|
|
self._lookups = {}
|
|
file_id = 0
|
|
for ep in range(num_episodes):
|
|
for cam in self.video_keys:
|
|
path = tmp_path / f"ep{ep}_{cam}.bin"
|
|
payload = bytes([ep]) * 16 + bytes(range(64)) * ((mdat_len + 63) // 64)
|
|
path.write_bytes(payload)
|
|
self._files[file_id] = _FakeFileInfo(str(path), len(payload), header_length=16)
|
|
self._lookups[(ep, cam)] = _FakeLookup(file_id, mdat_offset=16, mdat_length=mdat_len)
|
|
file_id += 1
|
|
|
|
def lookup(self, ep_idx, cam):
|
|
return self._lookups[(ep_idx, cam)]
|
|
|
|
def file_lookup(self, file_id):
|
|
return self._files[file_id]
|
|
|
|
def custom_frame_mappings(self, ep_idx, cam):
|
|
return None
|
|
|
|
|
|
class _SlowFetchCache(EpisodeByteCache):
|
|
"""Stub decode/validation; add a per-fetch delay to observe fetch parallelism."""
|
|
|
|
fetch_delay_s = 0.15
|
|
|
|
def _fetch_manifest_slice(self, ep_idx, cam):
|
|
time.sleep(self.fetch_delay_s)
|
|
return f"payload-{ep_idx}-{cam}", 32, f"decoder-{ep_idx}-{cam}"
|
|
|
|
def _decoder_from_payload(self, payload, ep_idx, cam):
|
|
return f"decoder-{ep_idx}-{cam}"
|
|
|
|
|
|
def _make_cache(tmp_path, **kwargs):
|
|
index = _FakeByteIndex(tmp_path)
|
|
return _SlowFetchCache(index, max_bytes=10_000_000, data_root=str(tmp_path), **kwargs)
|
|
|
|
|
|
def test_cameras_fetch_in_parallel(tmp_path):
|
|
"""An episode's cameras must not fetch back-to-back on one thread."""
|
|
cache = _make_cache(tmp_path, max_prefetch_workers=8)
|
|
start = time.perf_counter()
|
|
for ep in range(4):
|
|
cache.submit_prefetch(ep)
|
|
for ep in range(4):
|
|
cache.ensure_ready(ep)
|
|
elapsed = time.perf_counter() - start
|
|
# 4 episodes x 2 cams x 0.15s = 1.2s sequential; 8 workers -> one wave ~0.15s.
|
|
assert elapsed < 0.6, f"fetches serialized: {elapsed:.2f}s for 8 fetches on 8 workers"
|
|
assert cache.get_decoder(2, "cam1") == "decoder-2-cam1"
|
|
cache.close()
|
|
|
|
|
|
def test_prefetch_error_propagates(tmp_path):
|
|
cache = _make_cache(tmp_path, max_prefetch_workers=2)
|
|
|
|
def boom(ep_idx, cam):
|
|
raise RuntimeError("fetch failed")
|
|
|
|
cache._fetch_manifest_slice = boom
|
|
cache.submit_prefetch(0)
|
|
with pytest.raises(RuntimeError, match="fetch failed"):
|
|
cache.ensure_ready(0)
|
|
with pytest.raises(RuntimeError, match="fetch failed"):
|
|
cache.get_decoder(0, "cam0")
|
|
cache.close()
|
|
|
|
|
|
def test_payload_cache_hit_does_not_deadlock(tmp_path):
|
|
"""Regression: the hit path used to re-acquire the non-reentrant lock (deadlock)."""
|
|
cache = _make_cache(tmp_path, max_prefetch_workers=2)
|
|
cache._cache[(0, "cam0")] = ("payload-0-cam0", 32)
|
|
cache._bytes_used = 32
|
|
|
|
result = {}
|
|
|
|
def hit():
|
|
result["dec"] = cache._get_or_build_decoder(0, "cam0")
|
|
|
|
thread = threading.Thread(target=hit, daemon=True)
|
|
thread.start()
|
|
thread.join(timeout=5)
|
|
assert not thread.is_alive(), "cache-hit path deadlocked"
|
|
assert result["dec"] == "decoder-0-cam0"
|
|
assert cache.stats.hits == 1
|
|
cache.close()
|
|
|
|
|
|
def test_range_fetcher_cat_file_correctness(tmp_path):
|
|
payload = bytes(range(256)) * 4
|
|
path = tmp_path / "blob.bin"
|
|
path.write_bytes(payload)
|
|
fetcher = RangeFetcher(str(path))
|
|
assert fetcher.fetch(0, 15) == payload[0:16]
|
|
assert fetcher.fetch(100, 355) == payload[100:356]
|
|
assert fetcher.fetch(len(payload) - 4, len(payload) - 1) == payload[-4:]
|
|
assert fetcher.fetch(10, 9) == b""
|
|
|
|
|
|
def test_ensure_ready_unknown_episode_raises(tmp_path):
|
|
cache = _make_cache(tmp_path)
|
|
with pytest.raises(KeyError):
|
|
cache.ensure_ready(99)
|
|
cache.close()
|