Add in-memory byte index and manifest-driven episode MP4 cache.

Build moov-derived byte ranges in RAM or from sidecar parquet, fetch tight mdat slices over the network, and decode via TorchCodec custom_frame_mappings to skip full-file metadata scans.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-06-16 15:03:17 +00:00
parent 4940281120
commit 7b6f4f2b11
10 changed files with 1682 additions and 5 deletions
+8 -3
View File
@@ -335,9 +335,14 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
# (#8194) — all merged, not yet in a tagged 5.0 release. Pin to the merge commit until the next
# datasets release ships them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
# (#8194) — all merged, not yet in a tagged 5.0 release. Track main until the next datasets release ships
# them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
# Temporary: huggingface_hub main carries the 408-retry fix (not yet released). NOTE: main still closes the
# shared httpx.Client on every ConnectError, which races with concurrent streaming requests
# ("Cannot send a request, as the client has been closed"); we patch that out locally in
# huggingface_hub/utils/_http.py. A fresh `uv sync` re-installs main *without* that local patch.
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "main" }
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
+51
View File
@@ -0,0 +1,51 @@
#!/usr/bin/env python
"""Build mmap-able byte-index sidecars for LeRobot streaming datasets."""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from lerobot.datasets.byte_index_builder import (
build_byte_index_tables,
load_existing_file_ids,
write_byte_index,
)
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main() -> None:
parser = argparse.ArgumentParser(description="Build LeRobot video byte-index sidecar.")
parser.add_argument("--repo-id", required=True)
parser.add_argument("--revision", default=None)
parser.add_argument("--data-root", required=True, help="fsspec root for videos/ + data/")
parser.add_argument("--output", type=Path, required=True, help="Output meta/byte_index directory")
parser.add_argument("--workers", type=int, default=8)
parser.add_argument("--max-episodes", type=int, default=None, help="Limit episodes (debug/smoke)")
parser.add_argument("--no-keyframes", action="store_true")
args = parser.parse_args()
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
output = args.output
existing = load_existing_file_ids(output)
if existing:
logger.info("resuming: %s files already indexed", len(existing))
files_tbl, episodes_tbl, keyframes_tbl = build_byte_index_tables(
meta,
args.data_root,
include_keyframes=not args.no_keyframes,
workers=args.workers,
existing_files=existing,
max_episodes=args.max_episodes,
)
write_byte_index(output, files_tbl, episodes_tbl, keyframes_tbl, merge_existing=True)
logger.info("wrote byte index to %s", output)
if __name__ == "__main__":
main()
+228
View File
@@ -0,0 +1,228 @@
"""Runtime in-memory byte index loaded from precomputed sidecar parquet."""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from .byte_index_builder import BYTE_INDEX_DIR, EPISODES_NAME, FILES_NAME, KEYFRAMES_NAME
from .mp4_episode_slice import episode_custom_frame_mappings_json
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class EpisodeSliceLookup:
global_episode_id: int
file_id: int
mdat_offset: int
mdat_length: int
frame_count: int
first_pts: float
last_pts: float
avg_fps: float
@property
def fetch_bytes(self) -> int:
return self.mdat_length
@dataclass(frozen=True)
class FileLookup:
file_id: int
file_path: str
file_size: int
moov_offset: int
moov_length: int
header_length: int
faststart: bool
avg_fps: float
codec: str
class EpisodeByteIndex:
"""Columnar byte-index resident in numpy arrays for O(1) episode lookup."""
def __init__(
self,
index_dir: str | Path | None,
*,
video_keys: list[str],
num_episodes: int,
mmap: bool = True,
files_table: pa.Table | None = None,
episodes_table: pa.Table | None = None,
mp4_by_rel: dict[str, Any] | None = None,
):
self.index_dir = Path(index_dir) if index_dir is not None else None
self.video_keys = list(video_keys)
self.num_episodes = num_episodes
self.num_cameras = len(video_keys)
self._cam_to_idx = {cam: i for i, cam in enumerate(self.video_keys)}
self._mp4_by_rel = mp4_by_rel
self._frame_mappings_by_gid: dict[int, bytes] = {}
t0 = time.perf_counter()
if files_table is not None and episodes_table is not None:
files_tbl, episodes_tbl = files_table, episodes_table
else:
if self.index_dir is None:
raise ValueError("index_dir or in-memory tables required")
files_path = self.index_dir / FILES_NAME
episodes_path = self.index_dir / EPISODES_NAME
if not files_path.exists() or not episodes_path.exists():
raise FileNotFoundError(f"byte index missing under {self.index_dir}")
files_tbl = pq.read_table(files_path, memory_map=mmap)
episodes_tbl = pq.read_table(episodes_path, memory_map=mmap)
self._load_tables(files_tbl, episodes_tbl, mmap=mmap)
self.build_time_s = time.perf_counter() - t0
self.load_time_s = self.build_time_s
def _load_tables(self, files_tbl: pa.Table, episodes_tbl: pa.Table, *, mmap: bool) -> None:
def col(tbl, name: str):
array = tbl.column(name).combine_chunks()
if pa.types.is_boolean(array.type):
return array.to_numpy(zero_copy_only=False)
return array.to_numpy()
self.file_id = col(files_tbl, "file_id")
self.file_path = files_tbl.column("file_path").to_pylist()
self.file_size = col(files_tbl, "file_size")
self.moov_offset = col(files_tbl, "moov_offset")
self.moov_length = col(files_tbl, "moov_length")
self.header_length = col(files_tbl, "header_length")
self.faststart = col(files_tbl, "faststart")
self.file_avg_fps = col(files_tbl, "avg_fps")
self.codec = files_tbl.column("codec").to_pylist()
ep = episodes_tbl
n = len(ep)
gid = col(ep, "global_episode_id")
order = np.argsort(gid)
self._global_episode_id = gid[order]
self._episode_index = col(ep, "episode_index")[order]
self._camera_index = col(ep, "camera_index")[order]
self._file_id = col(ep, "file_id")[order]
self._mdat_offset = col(ep, "mdat_offset")[order]
self._mdat_length = col(ep, "mdat_length")[order]
self._frame_count = col(ep, "frame_count")[order]
self._first_pts = col(ep, "first_pts")[order]
self._last_pts = col(ep, "last_pts")[order]
expected = self.num_episodes * self.num_cameras
if n != expected:
raise ValueError(f"byte index episodes rows {n} != expected {expected}")
if self.index_dir is not None:
keyframes_path = self.index_dir / KEYFRAMES_NAME
if keyframes_path.exists():
kf_tbl = pq.read_table(keyframes_path, memory_map=mmap)
self._keyframes_rows = len(kf_tbl)
else:
self._keyframes_rows = 0
else:
self._keyframes_rows = 0
self.resident_bytes = int(
self._global_episode_id.nbytes
+ self._file_id.nbytes
+ self._mdat_offset.nbytes
+ self._mdat_length.nbytes
+ self.file_size.nbytes
)
@classmethod
def from_metadata_root(cls, meta_root: Path, *, video_keys: list[str], num_episodes: int) -> EpisodeByteIndex:
return cls(meta_root / BYTE_INDEX_DIR, video_keys=video_keys, num_episodes=num_episodes)
@classmethod
def from_memory_build(
cls,
meta,
data_root: str,
*,
workers: int = 8,
max_episodes: int | None = None,
include_frame_mappings_cache: bool = True,
) -> EpisodeByteIndex:
"""Build a complete byte index in RAM (no parquet write, no dataset push)."""
from .byte_index_builder import build_byte_index_in_memory
return build_byte_index_in_memory(
meta,
data_root,
workers=workers,
max_episodes=max_episodes,
include_frame_mappings_cache=include_frame_mappings_cache,
)
def lookup(self, episode_index: int, camera_key: str) -> EpisodeSliceLookup:
cam_idx = self._cam_to_idx[camera_key]
gid = episode_index * self.num_cameras + cam_idx
row = int(gid)
if row < 0 or row >= len(self._global_episode_id):
raise IndexError(f"episode_index={episode_index} camera={camera_key} out of range")
file_id = int(self._file_id[row])
return EpisodeSliceLookup(
global_episode_id=gid,
file_id=file_id,
mdat_offset=int(self._mdat_offset[row]),
mdat_length=int(self._mdat_length[row]),
frame_count=int(self._frame_count[row]),
first_pts=float(self._first_pts[row]),
last_pts=float(self._last_pts[row]),
avg_fps=float(self.file_avg_fps[file_id]),
)
def file_lookup(self, file_id: int) -> FileLookup:
return FileLookup(
file_id=file_id,
file_path=self.file_path[file_id],
file_size=int(self.file_size[file_id]),
moov_offset=int(self.moov_offset[file_id]),
moov_length=int(self.moov_length[file_id]),
header_length=int(self.header_length[file_id]),
faststart=bool(self.faststart[file_id]),
avg_fps=float(self.file_avg_fps[file_id]),
codec=self.codec[file_id],
)
def header_byte_range(self, file_id: int) -> tuple[int, int]:
length = int(self.header_length[file_id])
return 0, max(0, length - 1)
def custom_frame_mappings(self, episode_index: int, camera_key: str) -> bytes | None:
cam_idx = self._cam_to_idx[camera_key]
gid = episode_index * self.num_cameras + cam_idx
cached = self._frame_mappings_by_gid.get(gid)
if cached is not None:
return cached
if self._mp4_by_rel is None:
return None
lookup = self.lookup(episode_index, camera_key)
rel = self.file_path[lookup.file_id]
mp4_index = self._mp4_by_rel.get(rel)
if mp4_index is None:
return None
payload = episode_custom_frame_mappings_json(mp4_index, lookup.first_pts, lookup.last_pts)
self._frame_mappings_by_gid[gid] = payload
return payload
def stats_dict(self) -> dict[str, float | int]:
return {
"load_time_s": self.load_time_s,
"build_time_s": self.build_time_s,
"resident_bytes": self.resident_bytes,
"frame_mappings_cached": len(self._frame_mappings_by_gid),
"mp4_indices_cached": len(self._mp4_by_rel or {}),
"num_files": len(self.file_path),
"num_episode_rows": len(self._global_episode_id),
}
+281
View File
@@ -0,0 +1,281 @@
"""Build mmap-able byte-index sidecars for LeRobot streaming video fetch."""
from __future__ import annotations
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import fsspec
import pyarrow as pa
import pyarrow.parquet as pq
from .mp4_episode_slice import (
HEADER_PROBE_BYTES,
MAX_HEADER_PROBE_BYTES,
average_fps_from_index,
episode_keyframes,
parse_mp4_file_layout,
parse_mp4_index,
)
logger = logging.getLogger(__name__)
BYTE_INDEX_DIR = "meta/byte_index"
FILES_NAME = "files.parquet"
EPISODES_NAME = "episodes.parquet"
KEYFRAMES_NAME = "keyframes.parquet"
@dataclass
class IndexedFile:
file_id: int
file_path: str
file_size: int
moov_offset: int
moov_length: int
header_length: int
faststart: bool
avg_fps: float
codec: str
def fetch_header_bytes(path: str, file_size: int) -> bytes:
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
probe = HEADER_PROBE_BYTES
while True:
with fs.open(path, "rb", block_size=max(probe, 2**20), cache_type="none") as f:
header = f.read(min(probe, file_size))
try:
parse_mp4_file_layout(header, file_size)
return header
except ValueError as exc:
if probe >= min(MAX_HEADER_PROBE_BYTES, file_size) or "mdat box not found" not in str(exc):
raise
probe = min(probe * 2, MAX_HEADER_PROBE_BYTES, file_size)
def index_video_file(path: str, *, rel_path: str | None = None) -> tuple[IndexedFile, Any]:
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
file_size = fs.info(path)["size"]
header = fetch_header_bytes(path, file_size)
layout = parse_mp4_file_layout(header, file_size)
if not layout.faststart:
logger.warning("non-faststart MP4 (moov after mdat): %s", path)
mp4_index = parse_mp4_index(header, file_size)
indexed = IndexedFile(
file_id=-1,
file_path=rel_path or path,
file_size=file_size,
moov_offset=layout.moov_offset,
moov_length=layout.moov_length,
header_length=layout.header_end,
faststart=layout.faststart,
avg_fps=average_fps_from_index(mp4_index),
codec=layout.codec,
)
return indexed, mp4_index
def build_byte_index_tables(
meta,
data_root: str,
*,
file_paths: list[str] | None = None,
include_keyframes: bool = True,
workers: int = 8,
existing_files: dict[str, int] | None = None,
max_episodes: int | None = None,
return_mp4_indices: bool = False,
complete_files_table: bool = False,
) -> tuple[pa.Table, pa.Table, pa.Table | None] | tuple[pa.Table, pa.Table, pa.Table | None, dict[str, Any]]:
"""Build files/episodes/(optional keyframes) Arrow tables."""
video_keys = list(meta.video_keys)
n_cams = len(video_keys)
cam_to_idx = {cam: i for i, cam in enumerate(video_keys)}
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
rel_paths: set[str] = set()
for ep_idx in range(num_episodes):
for cam in video_keys:
rel_paths.add(str(meta.get_video_file_path(ep_idx, cam)))
path_by_rel = {rel: f"{data_root.rstrip('/')}/{rel}" for rel in sorted(rel_paths)}
if file_paths is None:
file_paths = list(path_by_rel.values())
rel_by_path = {path_by_rel[rel]: rel for rel in path_by_rel}
existing_files = existing_files or {}
file_meta_by_rel: dict[str, dict[str, Any]] = {}
mp4_by_rel: dict[str, Any] = {}
next_file_id = max(existing_files.values(), default=-1) + 1
to_index = [rel for rel in sorted(rel_paths) if rel not in existing_files]
if to_index:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel for rel in to_index
}
for fut in as_completed(futures):
rel = futures[fut]
indexed, mp4_index = fut.result()
indexed.file_id = next_file_id
mp4_by_rel[rel] = mp4_index
file_meta_by_rel[rel] = {
"file_id": indexed.file_id,
"file_path": rel,
"file_size": indexed.file_size,
"moov_offset": indexed.moov_offset,
"moov_length": indexed.moov_length,
"header_length": indexed.header_length,
"faststart": indexed.faststart,
"avg_fps": indexed.avg_fps,
"codec": indexed.codec,
}
existing_files[rel] = indexed.file_id
next_file_id += 1
missing_rels = {
str(meta.get_video_file_path(ep, cam))
for ep in range(num_episodes)
for cam in video_keys
} - set(mp4_by_rel.keys())
if missing_rels:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel
for rel in missing_rels
if rel not in mp4_by_rel
}
for fut in as_completed(futures):
rel = futures[fut]
_, mp4_index = fut.result()
mp4_by_rel[rel] = mp4_index
episode_rows: list[dict[str, Any]] = []
keyframe_rows: list[dict[str, Any]] = []
for ep_idx in range(num_episodes):
for cam in video_keys:
rel = str(meta.get_video_file_path(ep_idx, cam))
path = f"{data_root.rstrip('/')}/{rel}"
if rel not in existing_files:
raise KeyError(f"file not indexed: {rel}")
mp4_index = mp4_by_rel[rel]
ep = meta.episodes[ep_idx]
from_ts = float(ep[f"videos/{cam}/from_timestamp"])
to_ts = float(ep[f"videos/{cam}/to_timestamp"])
span = mp4_index.episode_byte_span(from_ts, to_ts)
global_episode_id = ep_idx * n_cams + cam_to_idx[cam]
mdat_length = span.slice_hi - span.slice_lo + 1
episode_rows.append(
{
"global_episode_id": global_episode_id,
"episode_index": ep_idx,
"camera_key": cam,
"camera_index": cam_to_idx[cam],
"file_id": existing_files[rel],
"mdat_offset": span.slice_lo,
"mdat_length": mdat_length,
"frame_count": max(1, round((to_ts - from_ts) * meta.fps)),
"first_pts": from_ts,
"last_pts": to_ts,
}
)
if include_keyframes:
timescale = mp4_index.timescale
for pts_s, byte_off in episode_keyframes(mp4_index, from_ts, to_ts):
keyframe_rows.append(
{
"global_episode_id": global_episode_id,
"pts": int(round(pts_s * timescale)),
"byte_offset": byte_off,
}
)
referenced_rels = {
str(meta.get_video_file_path(ep, cam)) for ep in range(num_episodes) for cam in video_keys
}
if complete_files_table:
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(referenced_rels)])
elif to_index:
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(to_index)])
else:
files_table = None
episodes_table = pa.Table.from_pylist(episode_rows)
keyframes_table = pa.Table.from_pylist(keyframe_rows) if include_keyframes and keyframe_rows else None
if return_mp4_indices:
return files_table, episodes_table, keyframes_table, mp4_by_rel
return files_table, episodes_table, keyframes_table
def build_byte_index_in_memory(
meta,
data_root: str,
*,
workers: int = 8,
max_episodes: int | None = None,
include_frame_mappings_cache: bool = False,
):
"""Build a complete byte index resident in RAM (no parquet write, no dataset push)."""
from .byte_index import EpisodeByteIndex
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
files_tbl, episodes_tbl, _, mp4_by_rel = build_byte_index_tables(
meta,
data_root,
include_keyframes=False,
workers=workers,
max_episodes=max_episodes,
return_mp4_indices=True,
complete_files_table=True,
)
index = EpisodeByteIndex(
None,
video_keys=list(meta.video_keys),
num_episodes=num_episodes,
files_table=files_tbl,
episodes_table=episodes_tbl,
mp4_by_rel=mp4_by_rel,
)
if include_frame_mappings_cache:
for ep_idx in range(num_episodes):
for cam in meta.video_keys:
index.custom_frame_mappings(ep_idx, cam)
return index
def write_byte_index(
output_dir: Path,
files_table: pa.Table | None,
episodes_table: pa.Table,
keyframes_table: pa.Table | None = None,
*,
merge_existing: bool = True,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
files_path = output_dir / FILES_NAME
episodes_path = output_dir / EPISODES_NAME
keyframes_path = output_dir / KEYFRAMES_NAME
if merge_existing and files_path.exists() and files_table is not None:
prev = pq.read_table(files_path)
files_table = pa.concat_tables([prev, files_table])
if files_table is not None:
pq.write_table(files_table, files_path)
pq.write_table(episodes_table, episodes_path)
if keyframes_table is not None:
if merge_existing and keyframes_path.exists():
keyframes_table = pa.concat_tables([pq.read_table(keyframes_path), keyframes_table])
pq.write_table(keyframes_table, keyframes_path)
def load_existing_file_ids(index_dir: Path) -> dict[str, int]:
files_path = index_dir / FILES_NAME
if not files_path.exists():
return {}
table = pq.read_table(files_path, columns=["file_id", "file_path"])
return {row["file_path"]: int(row["file_id"]) for row in table.to_pylist()}
+263
View File
@@ -0,0 +1,263 @@
"""Node-local LRU byte cache using precomputed byte-index manifest sidecars."""
from __future__ import annotations
import logging
import threading
import time
from collections import OrderedDict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any
import fsspec
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
from .mp4_episode_slice import SparseMp4Reader
from .torchcodec_utils import open_video_decoder
logger = logging.getLogger(__name__)
@dataclass
class CacheStats:
hits: int = 0
misses: int = 0
bytes_fetched: int = 0
full_file_fallbacks: int = 0
prefetch_submitted: int = 0
prefetch_waits: int = 0
mdat_slices: int = 0
prefix_fetches: int = 0
fetch_to_buffer_s: float = 0.0
buffer_to_decoder_s: float = 0.0
buffer_hit_decoder_s: float = 0.0
decode_frame_s: float = 0.0
decode_frames: int = 0
def merge(self, other: CacheStats) -> None:
for name in self.__dataclass_fields__:
setattr(self, name, getattr(self, name) + getattr(other, name))
def stats_dict(self) -> dict[str, int | float]:
avg_miss = self.bytes_fetched / max(1, self.misses)
return {
"byte_cache_hits": self.hits,
"byte_cache_misses": self.misses,
"byte_cache_bytes_fetched": self.bytes_fetched,
"byte_cache_bytes_per_miss": avg_miss,
"byte_cache_full_file_fallbacks": self.full_file_fallbacks,
"byte_cache_prefetch_submitted": self.prefetch_submitted,
"byte_cache_prefetch_waits": self.prefetch_waits,
"byte_cache_mdat_slices": self.mdat_slices,
"byte_cache_prefix_fetches": self.prefix_fetches,
"fetch_to_buffer_ms_per_miss": 1000 * self.fetch_to_buffer_s / max(1, self.misses),
"buffer_to_decoder_ms_per_miss": 1000 * self.buffer_to_decoder_s / max(1, self.misses),
"buffer_hit_decoder_ms_per_hit": 1000 * self.buffer_hit_decoder_s / max(1, self.hits),
"decode_ms_per_frame": 1000 * self.decode_frame_s / max(1, self.decode_frames),
}
@dataclass
class _EpisodeEntry:
decoders: dict[str, Any] = field(default_factory=dict)
ready: threading.Event = field(default_factory=threading.Event)
error: Exception | None = None
class RangeFetcher:
"""Sequential byte-range GETs via fsspec."""
def __init__(self, path: str):
self.path = path
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
def fetch(self, lo: int, hi: int) -> bytes:
if hi < lo:
return b""
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
f.seek(lo)
return f.read(hi - lo + 1)
class EpisodeByteCache:
"""Manifest-driven episode MP4 fetch + in-memory sparse decode."""
MAX_BYTES_PER_MISS = 25 * 1024 * 1024
def __init__(
self,
byte_index: EpisodeByteIndex,
max_bytes: int,
*,
data_root: str,
max_prefetch_workers: int = 4,
):
if max_bytes <= 0:
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
self.byte_index = byte_index
self.max_bytes = max_bytes
self.data_root = data_root.rstrip("/")
self._bytes_used = 0
self._lock = threading.Lock()
self._cache: OrderedDict[tuple[Any, ...], tuple[Any, int]] = OrderedDict()
self._header_cache: dict[int, bytes] = {}
self._fetcher_cache: dict[int, RangeFetcher] = {}
self._episodes: dict[int, _EpisodeEntry] = {}
self._stats = CacheStats()
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
self._futures: dict[int, Future] = {}
@property
def stats(self) -> CacheStats:
with self._lock:
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
def submit_prefetch(self, ep_idx: int) -> None:
with self._lock:
if ep_idx in self._episodes or ep_idx in self._futures:
return
self._stats.prefetch_submitted += 1
fut = self._executor.submit(self._prefetch_episode, ep_idx)
self._futures[ep_idx] = fut
def ensure_ready(self, ep_idx: int) -> None:
with self._lock:
fut = self._futures.pop(ep_idx, None)
if fut is not None:
with self._lock:
self._stats.prefetch_waits += 1
fut.result()
entry = self._episodes.get(ep_idx)
if entry is None:
raise KeyError(f"episode {ep_idx} not prefetched")
if entry.error is not None:
raise entry.error
entry.ready.wait()
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
entry = self._episodes[ep_idx]
if entry.error is not None:
raise entry.error
entry.ready.wait()
return entry.decoders[video_key]
def close(self) -> None:
self._executor.shutdown(wait=False, cancel_futures=True)
def _prefetch_episode(self, ep_idx: int) -> None:
entry = _EpisodeEntry()
self._episodes[ep_idx] = entry
try:
for cam in self.byte_index.video_keys:
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
except Exception as exc:
entry.error = exc
finally:
entry.ready.set()
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
key = (ep_idx, cam)
with self._lock:
cached = self._cache.get(key)
if cached is not None:
self._cache.move_to_end(key)
self._stats.hits += 1
payload, _ = cached
t0 = time.perf_counter()
dec = self._decoder_from_payload(payload, ep_idx, cam)
with self._lock:
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
return dec
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
with self._lock:
self._stats.misses += 1
if payload_bytes > self.MAX_BYTES_PER_MISS:
logger.warning(
"byte cache miss fetched %.1f MB (>25 MB) for ep=%s cam=%s",
payload_bytes / 1e6,
ep_idx,
cam,
)
self._evict_until(payload_bytes)
self._cache[key] = (payload, payload_bytes)
self._bytes_used += payload_bytes
return dec
def _fetch_manifest_slice(self, ep_idx: int, cam: str) -> tuple[SparseMp4Reader, int, Any]:
lookup = self.byte_index.lookup(ep_idx, cam)
file_info = self.byte_index.file_lookup(lookup.file_id)
fetcher = self._get_fetcher(lookup.file_id, file_info.file_path)
t_fetch = time.perf_counter()
header = self._get_header_bytes(lookup.file_id, fetcher, file_info.header_length)
lo = lookup.mdat_offset
hi = lo + lookup.mdat_length - 1
mdat = fetcher.fetch(lo, hi)
fetch_s = time.perf_counter() - t_fetch
nbytes = len(header) + len(mdat)
with self._lock:
self._stats.bytes_fetched += nbytes
self._stats.mdat_slices += 1
self._stats.fetch_to_buffer_s += fetch_s
def lazy_fetch(pos: int, end: int) -> bytes:
data = fetcher.fetch(pos, end)
with self._lock:
self._stats.bytes_fetched += len(data)
return data
reader = SparseMp4Reader(
file_size=file_info.file_size,
header=header,
mdat_lo=lo,
mdat_bytes=mdat,
lazy_fetch=lazy_fetch,
)
t_init = time.perf_counter()
dec = self._decoder_from_payload(reader, ep_idx, cam)
self._validate_decoder(dec, lookup)
init_s = time.perf_counter() - t_init
with self._lock:
self._stats.buffer_to_decoder_s += init_s
self._rewind_payload(reader)
return reader, nbytes, dec
def _get_fetcher(self, file_id: int, rel_path: str) -> RangeFetcher:
if file_id not in self._fetcher_cache:
path = rel_path if rel_path.startswith("hf://") else f"{self.data_root}/{rel_path}"
self._fetcher_cache[file_id] = RangeFetcher(path)
return self._fetcher_cache[file_id]
def _get_header_bytes(self, file_id: int, fetcher: RangeFetcher, header_length: int) -> bytes:
if file_id in self._header_cache:
return self._header_cache[file_id]
hi = max(0, header_length - 1)
header = fetcher.fetch(0, hi)
with self._lock:
self._header_cache[file_id] = header
self._stats.bytes_fetched += len(header)
return header
def _decoder_from_payload(
self, payload: SparseMp4Reader, ep_idx: int, cam: str
) -> Any:
payload.seek(0)
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
return open_video_decoder(payload, frame_mappings=mappings)
def _validate_decoder(self, dec: Any, lookup: EpisodeSliceLookup) -> None:
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
duration = max(0.01, end - begin)
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
dec.get_frames_played_at([ts]).data
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
payload.seek(0)
def _evict_until(self, need: int) -> None:
while self._bytes_used + need > self.max_bytes and self._cache:
_, (_, size) = self._cache.popitem(last=False)
self._bytes_used -= size
+555
View File
@@ -0,0 +1,555 @@
"""MP4 moov parsing and tight per-episode mdat byte-range fetching.
LeRobot v3 concatenates episodes into shared MP4 files (faststart: moov at head).
For streaming we fetch only the file header plus the episode's contiguous mdat span
instead of the ``0..episode_end`` prefix.
"""
from __future__ import annotations
import io
import struct
import threading
from dataclasses import dataclass, field
from typing import Callable
KEYFRAME_PAD_S = 0.1
HEADER_PROBE_BYTES = 4 * 1024 * 1024
MAX_HEADER_PROBE_BYTES = 16 * 1024 * 1024
@dataclass
class Mp4FileLayout:
file_size: int
moov_offset: int
moov_length: int
header_end: int
mdat_offset: int
mdat_size: int
faststart: bool
codec: str
def parse_mp4_file_layout(header_bytes: bytes, file_size: int) -> Mp4FileLayout:
"""Return top-level MP4 layout (moov/mdat positions, faststart flag)."""
boxes = list(_iter_boxes(header_bytes))
moov_offset = mdat_offset = -1
moov_length = mdat_size = 0
for off, size, typ, _ in boxes:
if typ == b"moov" and moov_offset < 0:
moov_offset, moov_length = off, size
if typ == b"mdat" and mdat_offset < 0:
mdat_offset, mdat_size = off, size
if moov_offset < 0:
raise ValueError("moov box not found in header probe")
if mdat_offset < 0:
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
faststart = moov_offset < mdat_offset
header_end = mdat_offset
codec = _parse_video_codec(header_bytes)
return Mp4FileLayout(
file_size=file_size,
moov_offset=moov_offset,
moov_length=moov_length,
header_end=header_end,
mdat_offset=mdat_offset,
mdat_size=mdat_size,
faststart=faststart,
codec=codec,
)
def _parse_video_codec(header_bytes: bytes) -> str:
moov = _find_box_payload(header_bytes, b"moov")
if moov is None:
return "unknown"
trak = _find_video_trak(moov)
if trak is None:
return "unknown"
stsd = _find_box_payload(_find_box_payload(trak, b"stbl") or b"", b"stsd")
if stsd is None or len(stsd) < 12:
return "unknown"
# stsd: version(1)+flags(3)+entry_count(4)+entry_size(4)+codec(4)
if len(stsd) >= 12:
return stsd[8:12].decode("latin1", errors="replace").strip("\x00")
return "unknown"
def average_fps_from_index(index: Mp4VideoIndex) -> float:
index.ensure_tables()
if index.num_samples < 2:
return 30.0
duration = index.sample_pts(index.num_samples - 1)
if duration <= 0:
return 30.0
return index.num_samples / duration
def episode_custom_frame_mappings_json(
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
) -> bytes:
"""Build TorchCodec ``custom_frame_mappings`` JSON for one episode span."""
import json
index.ensure_tables()
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
hi_idx = min(hi_idx, index.num_samples - 1)
lo_idx = _keyframe_back(index.sync_samples, lo_idx)
sync = set(index.sync_samples)
timescale = index.timescale
# stts deltas for duration per sample (expand stts entries to per-sample delta)
sample_deltas: list[int] = []
for count, delta in index.stts:
sample_deltas.extend([delta] * count)
while len(sample_deltas) < index.num_samples:
sample_deltas.append(sample_deltas[-1] if sample_deltas else timescale // 30)
frames = []
for idx in range(lo_idx, hi_idx + 1):
frames.append(
{
"pts": int(round(index._pts[idx] * timescale)),
"duration": int(sample_deltas[idx]),
"key_frame": int((idx + 1) in sync) if sync else int(idx == lo_idx),
}
)
return json.dumps({"frames": frames}).encode()
def episode_keyframes(
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
) -> list[tuple[float, int]]:
"""Return (pts_seconds, byte_offset) for sync samples in the episode span."""
index.ensure_tables()
span = index.episode_byte_span(from_ts, to_ts, keyframe_pad_s)
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
if not index.sync_samples:
return [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
out: list[tuple[float, int]] = []
for sync_one_based in index.sync_samples:
idx = sync_one_based - 1
if lo_idx <= idx <= hi_idx:
out.append((index.sample_pts(idx), index.sample_offset(idx)))
return out or [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
@dataclass
class EpisodeByteSpan:
"""Absolute file byte ranges to fetch for one episode."""
file_size: int
header_end: int
slice_lo: int
slice_hi: int
@property
def header_bytes(self) -> tuple[int, int]:
return 0, self.header_end - 1
@property
def mdat_bytes(self) -> tuple[int, int]:
return self.slice_lo, self.slice_hi
@property
def total_fetch_bytes(self) -> int:
header = self.header_end
mdat = self.slice_hi - self.slice_lo + 1
return header + mdat
@dataclass
class Mp4VideoIndex:
file_size: int
header_end: int
mdat_offset: int
mdat_size: int
timescale: int
stts: list[tuple[int, int]]
stsz: list[int]
stsc: list[tuple[int, int, int]]
stco: list[int]
sync_samples: list[int]
_pts: list[float] = field(default_factory=list, repr=False)
_offsets: list[int] = field(default_factory=list, repr=False)
def ensure_tables(self) -> None:
if self._pts:
return
self._pts = _pts_from_stts(self.stts, self.timescale)
self._offsets = _sample_byte_offsets(self.stsc, self.stco, self.stsz)
@property
def num_samples(self) -> int:
return len(self.stsz)
def sample_pts(self, index: int) -> float:
self.ensure_tables()
return self._pts[index]
def sample_offset(self, index: int) -> int:
self.ensure_tables()
index = max(0, min(index, len(self._offsets) - 1))
return self._offsets[index]
def sample_end(self, index: int) -> int:
return self.sample_offset(index) + self.stsz[index]
def episode_byte_span(self, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S) -> EpisodeByteSpan:
self.ensure_tables()
n = self.num_samples
if n == 0:
raise ValueError("MP4 has no video samples")
pad = max(keyframe_pad_s, 0.05 * max(0.01, to_ts - from_ts))
lo_ts = max(0.0, from_ts - pad)
hi_ts = to_ts + pad
lo_idx = _first_sample_at_or_after(self._pts, lo_ts)
hi_idx = _last_sample_at_or_before(self._pts, hi_ts)
hi_idx = min(hi_idx, n - 1)
lo_idx = min(lo_idx, n - 1)
lo_idx = _keyframe_back(self.sync_samples, lo_idx)
slice_lo = self.sample_offset(lo_idx)
slice_hi = self.sample_end(min(hi_idx, len(self._offsets) - 1))
return EpisodeByteSpan(
file_size=self.file_size,
header_end=self.header_end,
slice_lo=slice_lo,
slice_hi=min(slice_hi, self.file_size - 1),
)
class SparseMp4Reader(io.BufferedIOBase):
"""Range-backed MP4 reader: header + one mdat span at absolute offsets."""
def __init__(
self,
file_size: int,
header: bytes,
mdat_lo: int,
mdat_bytes: bytes,
lazy_fetch: Callable[[int, int], bytes] | None = None,
):
self._size = file_size
self._header = header
self._mdat_lo = mdat_lo
self._mdat_hi = mdat_lo + len(mdat_bytes)
self._mdat = mdat_bytes
self._lazy_fetch = lazy_fetch
self._pos = 0
self._lock = threading.Lock()
def readable(self) -> bool:
return True
def seekable(self) -> bool:
return True
def tell(self) -> int:
return self._pos
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
self._pos = offset
elif whence == io.SEEK_CUR:
self._pos += offset
elif whence == io.SEEK_END:
self._pos = self._size + offset
else:
raise ValueError(f"invalid whence: {whence}")
self._pos = max(0, min(self._pos, self._size))
return self._pos
def read(self, size: int = -1) -> bytes:
if size < 0:
size = self._size - self._pos
if size <= 0:
return b""
out = bytearray()
remaining = size
pos = self._pos
while remaining > 0 and pos < self._size:
chunk = self._read_at(pos, remaining)
if not chunk:
break
out.extend(chunk)
pos += len(chunk)
remaining -= len(chunk)
self._pos = pos
return bytes(out)
def _read_at(self, pos: int, n: int) -> bytes:
header_len = len(self._header)
if pos < header_len:
end = min(pos + n, header_len)
return self._header[pos:end]
if self._mdat_lo <= pos < self._mdat_hi:
end = min(pos + n, self._mdat_hi)
off = pos - self._mdat_lo
return self._mdat[off : off + (end - pos)]
if self._lazy_fetch is not None:
with self._lock:
end = min(pos + n, self._size)
return self._lazy_fetch(pos, end - 1)
return b"\x00" * min(n, self._size - pos)
def parse_mp4_index(header_bytes: bytes, file_size: int) -> Mp4VideoIndex:
"""Parse moov sample tables from the file header (faststart layout)."""
layout = parse_mp4_file_layout(header_bytes, file_size)
mdat_offset, mdat_size = layout.mdat_offset, layout.mdat_size
moov = _find_box_payload(header_bytes, b"moov")
if moov is None:
raise ValueError("moov box not found in MP4 header probe")
trak = _find_video_trak(moov)
if trak is None:
raise ValueError("video trak not found in moov")
mdhd = _find_box_payload(trak, b"mdhd")
if mdhd is None:
raise ValueError("mdhd not found")
timescale = _parse_mdhd_timescale(mdhd)
stbl = _find_box_payload(trak, b"stbl")
if stbl is None:
raise ValueError("stbl not found")
stts = _parse_stts(_find_box_payload(stbl, b"stts"))
stsz = _parse_stsz(_find_box_payload(stbl, b"stsz"))
stsc = _parse_stsc(_find_box_payload(stbl, b"stsc"))
stco_payload = _find_box_payload(stbl, b"stco")
co64_payload = _find_box_payload(stbl, b"co64")
if stco_payload is not None:
stco = _parse_stco(stco_payload)
elif co64_payload is not None:
stco = _parse_co64(co64_payload)
else:
raise ValueError("stco/co64 not found")
stss_payload = _find_box_payload(stbl, b"stss")
sync_samples = _parse_stss(stss_payload) if stss_payload else []
return Mp4VideoIndex(
file_size=file_size,
header_end=layout.header_end,
mdat_offset=mdat_offset,
mdat_size=mdat_size,
timescale=timescale,
stts=stts,
stsz=stsz,
stsc=stsc,
stco=stco,
sync_samples=sync_samples,
)
def _box_header(data: bytes, offset: int) -> tuple[int, bytes, int] | None:
if offset + 8 > len(data):
return None
size, typ = struct.unpack_from(">I4s", data, offset)
header = 8
if size == 1:
if offset + 16 > len(data):
return None
size = struct.unpack_from(">Q", data, offset + 8)[0]
header = 16
elif size == 0:
size = len(data) - offset
return size, typ, header
def _iter_boxes(data: bytes, start: int = 0, end: int | None = None):
end = end if end is not None else len(data)
off = start
while off + 8 <= end:
hdr = _box_header(data, off)
if hdr is None or hdr[0] < hdr[2]:
break
size, typ, header = hdr
yield off, size, typ, data[off + header : off + size]
off += size
def _find_box_payload(data: bytes, target: bytes) -> bytes | None:
for _, _, typ, payload in _iter_boxes(data):
if typ == target:
return payload
if typ in (b"moov", b"trak", b"mdia", b"minf", b"stbl"):
found = _find_box_payload(payload, target)
if found is not None:
return found
return None
def _find_video_trak(moov: bytes) -> bytes | None:
for _, _, typ, payload in _iter_boxes(moov):
if typ != b"trak":
continue
hdlr = _find_box_payload(payload, b"hdlr")
if hdlr is not None and len(hdlr) >= 12 and hdlr[8:12] == b"vide":
return payload
return None
def _find_mdat(header_bytes: bytes, file_size: int) -> tuple[int, int]:
for off, size, typ, _ in _iter_boxes(header_bytes):
if typ == b"mdat":
return off, size
# mdat may start beyond probe; scan from file_size hint unavailable — require probe hit
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
def _parse_mdhd_timescale(mdhd: bytes) -> int:
version = mdhd[0]
if version == 0:
return struct.unpack_from(">I", mdhd, 12)[0]
return struct.unpack_from(">I", mdhd, 20)[0]
def _parse_stts(stts: bytes | None) -> list[tuple[int, int]]:
if stts is None:
raise ValueError("stts missing")
count = struct.unpack_from(">I", stts, 4)[0]
out = []
off = 8
for _ in range(count):
sample_count, delta = struct.unpack_from(">II", stts, off)
out.append((sample_count, delta))
off += 8
return out
def _parse_stsz(stsz: bytes | None) -> list[int]:
if stsz is None:
raise ValueError("stsz missing")
sample_size, sample_count = struct.unpack_from(">II", stsz, 4)
if sample_size != 0:
return [sample_size] * sample_count
off = 12
return list(struct.unpack_from(f">{sample_count}I", stsz, off))
def _parse_stsc(stsc: bytes | None) -> list[tuple[int, int, int]]:
if stsc is None:
raise ValueError("stsc missing")
count = struct.unpack_from(">I", stsc, 4)[0]
out = []
off = 8
for _ in range(count):
first_chunk, samples_per_chunk, sample_desc = struct.unpack_from(">III", stsc, off)
out.append((first_chunk, samples_per_chunk, sample_desc))
off += 12
return out
def _parse_stco(stco: bytes) -> list[int]:
count = struct.unpack_from(">I", stco, 4)[0]
return list(struct.unpack_from(f">{count}I", stco, 8))
def _parse_co64(co64: bytes) -> list[int]:
count = struct.unpack_from(">I", co64, 4)[0]
return [struct.unpack_from(">Q", co64, 8 + i * 8)[0] for i in range(count)]
def _parse_stss(stss: bytes) -> list[int]:
count = struct.unpack_from(">I", stss, 4)[0]
return list(struct.unpack_from(f">{count}I", stss, 8))
def _pts_from_stts(stts: list[tuple[int, int]], timescale: int) -> list[float]:
pts: list[float] = []
t = 0
for count, delta in stts:
for _ in range(count):
pts.append(t / timescale)
t += delta
return pts
def _sample_byte_offsets(
stsc: list[tuple[int, int, int]], stco: list[int], stsz: list[int]
) -> list[int]:
if not stsc:
stsc = [(1, len(stsz), 1)]
offsets: list[int] = []
chunk_idx = 0
sample_idx = 0
sc_idx = 0
num_chunks = len(stco)
while chunk_idx < num_chunks and sample_idx < len(stsz):
first_chunk, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
if sc_idx + 1 < len(stsc):
next_first = stsc[sc_idx + 1][0]
chunks_in_entry = next_first - first_chunk
else:
chunks_in_entry = num_chunks - chunk_idx
for _ in range(chunks_in_entry):
if chunk_idx >= num_chunks:
break
offset = stco[chunk_idx]
_, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
for _ in range(samples_per_chunk):
if sample_idx >= len(stsz):
break
offsets.append(offset)
offset += stsz[sample_idx]
sample_idx += 1
chunk_idx += 1
sc_idx += 1
if len(offsets) < len(stsz):
# Pad with last known offset progression for malformed stsc edge cases.
last = offsets[-1] if offsets else 0
while len(offsets) < len(stsz):
idx = len(offsets)
offsets.append(last)
last += stsz[idx]
return offsets
def _first_sample_at_or_after(pts: list[float], ts: float) -> int:
lo, hi = 0, len(pts)
while lo < hi:
mid = (lo + hi) // 2
if pts[mid] < ts:
lo = mid + 1
else:
hi = mid
return min(lo, len(pts) - 1)
def _last_sample_at_or_before(pts: list[float], ts: float) -> int:
lo, hi = 0, len(pts)
while lo < hi:
mid = (lo + hi) // 2
if pts[mid] <= ts:
lo = mid + 1
else:
hi = mid
return max(0, lo - 1)
def _keyframe_back(sync_samples: list[int], sample_idx: int) -> int:
if not sync_samples:
return max(0, sample_idx - 2)
# stss stores 1-based sample numbers
one_based = sample_idx + 1
prev = [s for s in sync_samples if s <= one_based]
if prev:
return prev[-1] - 1
return 0
+92
View File
@@ -124,6 +124,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
video_decoder_cache_size: int | None = None,
data_files_root: str | None = None,
validate_row_groups: bool = True,
video_byte_cache_gb: float | None = 80.0,
byte_index_path: str | Path | None = None,
byte_index_build_in_memory: bool | None = None,
byte_index_workers: int = 8,
byte_index_max_episodes: int | None = None,
):
"""Initialize a StreamingLeRobotDataset.
@@ -173,6 +178,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
``num_shards`` is divisible by ``world_size`` for distributed runs, raising a clear
``ValueError`` otherwise. Set False to skip the checks (e.g. single-process debugging);
the divisibility check then downgrades to a warning.
video_byte_cache_gb (float | None, optional): Node-local LRU for episode MP4 mdat slices.
When set (default 80 GB), episodes are prefetched via tight byte ranges before decode.
Set to 0 or None to disable and use remote per-seek decoding.
byte_index_path (str | Path | None, optional): Path to precomputed ``meta/byte_index/``
sidecar parquet tables. Defaults to ``{meta.root}/meta/byte_index``.
byte_index_build_in_memory (bool | None, optional): When True, build the byte index in RAM
at init (moov-only fetches, no parquet write). When None (default), build in memory only
if the sidecar parquet is missing on disk.
byte_index_workers (int, optional): Parallel moov-index workers for in-memory builds.
byte_index_max_episodes (int | None, optional): Cap episodes indexed (debug/smoke tests).
"""
super().__init__()
self.repo_id = repo_id
@@ -210,6 +225,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
self.video_decoder_cache_size = video_decoder_cache_size
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
self.video_byte_cache_gb = video_byte_cache_gb
self.byte_index_path = Path(byte_index_path) if byte_index_path is not None else None
self.byte_index_build_in_memory = byte_index_build_in_memory
self.byte_index_workers = byte_index_workers
self.byte_index_max_episodes = byte_index_max_episodes
self._episode_byte_cache = None
self._byte_index = None
self._data_root = None
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
@@ -228,6 +251,37 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
if self._use_episode_byte_cache():
from .byte_index import EpisodeByteIndex
data_root = self._resolve_data_root()
index_dir = self.byte_index_path or (self.meta.root / "meta" / "byte_index")
sidecar_exists = (index_dir / "files.parquet").exists() and (index_dir / "episodes.parquet").exists()
build_in_memory = (
self.byte_index_build_in_memory
if self.byte_index_build_in_memory is not None
else not sidecar_exists
)
if build_in_memory:
logger.info(
"Building byte index in memory from %s (%s episodes, %d workers)",
data_root,
self.byte_index_max_episodes or self.meta.total_episodes,
self.byte_index_workers,
)
self._byte_index = EpisodeByteIndex.from_memory_build(
self.meta,
data_root,
workers=self.byte_index_workers,
max_episodes=self.byte_index_max_episodes,
)
else:
self._byte_index = EpisodeByteIndex(
index_dir,
video_keys=self.meta.video_keys,
num_episodes=self.meta.total_episodes,
)
self.delta_timestamps = None
self.delta_indices = None
@@ -417,6 +471,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
buffer_size=self.episode_pool_size,
max_buffer_input_shards=max_input_shards,
)
if self._use_episode_byte_cache():
ds = ds.map(self._submit_episode_prefetch, batched=True)
# A row-count-changing batched map must drop the input columns explicitly; the exploded
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
@@ -472,6 +528,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return VideoDecoderCache()
return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128))
def _use_episode_byte_cache(self) -> bool:
return (
self.video_byte_cache_gb not in (None, 0)
and self.data_files_root is not None
)
def _make_episode_byte_cache(self):
from .episode_byte_cache import EpisodeByteCache
if self._byte_index is None:
raise RuntimeError("byte index required for episode byte cache; run build_byte_index.py")
max_bytes = int(float(self.video_byte_cache_gb) * 1e9)
return EpisodeByteCache(
self._byte_index,
max_bytes,
data_root=self._data_root,
)
def _submit_episode_prefetch(self, episode_batch: dict[str, list[list]]) -> dict[str, list[list]]:
if self._episode_byte_cache is None:
return episode_batch
for ep_idx in {int(v[0]) for v in episode_batch["episode_index"]}:
self._episode_byte_cache.submit_prefetch(ep_idx)
return episode_batch
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
@@ -486,6 +567,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
self._in_flight_epoch = 0
self._pipeline.set_epoch(self._in_flight_epoch)
self.video_decoder_cache = self._make_video_decoder_cache()
self._data_root = self._resolve_data_root()
if self._use_episode_byte_cache():
self._episode_byte_cache = self._make_episode_byte_cache()
else:
self._episode_byte_cache = None
iterator = iter(self._pipeline)
while True:
@@ -623,6 +709,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""
item = {}
if self._episode_byte_cache is not None:
self._episode_byte_cache.ensure_ready(ep_idx)
for video_key, query_ts in query_timestamps.items():
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
@@ -635,12 +723,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
else:
root = self.root
video_path = f"{root}/{rel_path}"
episode_decoder = None
if self._episode_byte_cache is not None:
episode_decoder = self._episode_byte_cache.get_decoder(ep_idx, video_key)
frames = decode_video_frames_torchcodec(
video_path,
shifted_query_ts,
self.tolerance_s,
decoder_cache=self.video_decoder_cache,
return_uint8=self._return_uint8,
episode_decoder=episode_decoder,
)
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
+49
View File
@@ -0,0 +1,49 @@
"""TorchCodec helpers for sparse MP4 IO with optional custom frame mappings."""
from __future__ import annotations
import json
from typing import Any
import torch
from torchcodec import FrameBatch, _core as core
from torchcodec.decoders._video_decoder import _get_and_validate_stream_metadata
def frame_mappings_tensors(payload: bytes) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
data = json.loads(payload)
frames = data["frames"]
pts = torch.tensor([int(f["pts"]) for f in frames], dtype=torch.int64)
key = torch.tensor([bool(f["key_frame"]) for f in frames], dtype=torch.bool)
dur = torch.tensor([int(f["duration"]) for f in frames], dtype=torch.int64)
return pts, key, dur
class VideoDecoderLike:
"""Minimal VideoDecoder surface used by episode byte cache."""
def __init__(self, decoder: torch.Tensor, *, stream_index: int | None = None):
self._decoder = decoder
(
self.metadata,
self.stream_index,
self._begin_stream_seconds,
self._end_stream_seconds,
self._num_frames,
) = _get_and_validate_stream_metadata(decoder=decoder, stream_index=stream_index)
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
return FrameBatch(*core.get_frames_by_pts(self._decoder, timestamps=seconds))
def open_video_decoder(source: Any, *, frame_mappings: bytes | None = None) -> VideoDecoderLike:
"""Open a decoder on sparse or full MP4 IO, skipping metadata scan when mappings exist."""
if frame_mappings is None:
decoder = core.create_from_file_like(source, "approximate")
core.add_video_stream(decoder)
return VideoDecoderLike(decoder)
mappings = frame_mappings_tensors(frame_mappings)
decoder = core.create_from_file_like(source, "custom_frame_mappings")
core.add_video_stream(decoder, custom_frame_mappings=mappings)
return VideoDecoderLike(decoder)
+5 -2
View File
@@ -326,6 +326,7 @@ def decode_video_frames_torchcodec(
log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None,
return_uint8: bool = False,
episode_decoder: Any | None = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
@@ -347,8 +348,10 @@ def decode_video_frames_torchcodec(
if decoder_cache is None:
decoder_cache = _default_decoder_cache
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
if episode_decoder is not None:
decoder = episode_decoder
else:
decoder = decoder_cache.get_decoder(str(video_path))
loaded_ts = []
loaded_frames = []
+150
View File
@@ -0,0 +1,150 @@
"""Acceptance tests for manifest byte-index sidecars.
Run on a compute node (not login-node):
srun --partition=hopper-dev --nodes=1 --ntasks=1 --cpus-per-task=8 --mem=32G --time=00:30:00 \\
bash -lc 'cd /admin/home/pepijn/lerobot && conda run --no-capture-output -n lerobot \\
env -u HF_HUB_ENABLE_HF_TRANSFER python -m pytest tests/datasets/test_byte_index.py -m integration -v'
"""
from __future__ import annotations
import json
import socket
import pytest
pytest.importorskip("torchcodec")
REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
REV = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
MAX_EPISODES = 64
COMPUTE_NODE = pytest.mark.skipif(
"login" in socket.gethostname(),
reason="run on compute node via srun (see module docstring), not login-node",
)
@pytest.fixture(scope="module")
def byte_index_dir(tmp_path_factory):
from lerobot.datasets.byte_index_builder import build_byte_index_tables, write_byte_index
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
out = tmp_path_factory.mktemp("byte_index")
meta = LeRobotDatasetMetadata(REPO, revision=REV)
files, episodes, _ = build_byte_index_tables(
meta, BUCKET, workers=4, max_episodes=MAX_EPISODES, include_keyframes=False
)
write_byte_index(out, files, episodes, None, merge_existing=False)
return out, meta
@pytest.mark.integration
@COMPUTE_NODE
def test_index_load_fast_and_small(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
out, meta = byte_index_dir
index = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
assert index.load_time_s < 1.0
assert index.resident_bytes < 1_000_000_000
@pytest.mark.integration
@COMPUTE_NODE
def test_tight_fetch_under_25mb(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
stats = cache.stats.stats_dict()
assert stats["byte_cache_bytes_per_miss"] < 25 * 1024 * 1024
@pytest.mark.integration
@COMPUTE_NODE
def test_in_memory_build_matches_parquet(byte_index_dir):
from lerobot.datasets.byte_index import EpisodeByteIndex
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
out, meta = byte_index_dir
disk = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
mem = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
for cam in meta.video_keys:
a = disk.lookup(ep, cam)
b = mem.lookup(ep, cam)
assert a.mdat_offset == b.mdat_offset
assert a.mdat_length == b.mdat_length
assert abs(a.first_pts - b.first_pts) < 1e-6
@pytest.mark.integration
@COMPUTE_NODE
def test_custom_frame_mappings_available(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cam = meta.video_keys[0]
ep = MAX_EPISODES // 2
payload = index.custom_frame_mappings(ep, cam)
assert payload is not None
data = json.loads(payload)
assert len(data["frames"]) > 10
assert any(f["key_frame"] for f in data["frames"])
assert all("pts" in f and "duration" in f for f in data["frames"])
@pytest.mark.integration
@COMPUTE_NODE
def test_metadata_skip_decoder_init(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=8_000_000_000, data_root=BUCKET)
cam = meta.video_keys[0]
ep = 0
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
dec = cache.get_decoder(ep, cam)
assert dec.metadata.num_frames is not None
assert dec.metadata.num_frames > 0
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
ts = begin + 0.5 * (end - begin)
frame = dec.get_frames_played_at([ts]).data
assert frame.ndim == 4
@pytest.mark.integration
@COMPUTE_NODE
def test_sparse_decode_produces_frames(byte_index_dir):
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
_, meta = byte_index_dir
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
cam = meta.video_keys[0]
ep = 0
cache.submit_prefetch(ep)
cache.ensure_ready(ep)
dec = cache.get_decoder(ep, cam)
begin = float(dec.metadata.begin_stream_seconds)
end = float(dec.metadata.end_stream_seconds)
ts = begin + 0.5 * (end - begin)
frame = dec.get_frames_played_at([ts]).data
assert frame.ndim == 4
assert frame.numel() > 0
assert float(frame.float().std()) > 1.0