From 7b6f4f2b11b566bddcf5b34a23ead64f04ba9418 Mon Sep 17 00:00:00 2001 From: pepijn Date: Tue, 16 Jun 2026 15:03:17 +0000 Subject: [PATCH] Add in-memory byte index and manifest-driven episode MP4 cache. Build moov-derived byte ranges in RAM or from sidecar parquet, fetch tight mdat slices over the network, and decode via TorchCodec custom_frame_mappings to skip full-file metadata scans. Co-authored-by: Cursor --- pyproject.toml | 11 +- scripts/build_byte_index.py | 51 ++ src/lerobot/datasets/byte_index.py | 228 +++++++++ src/lerobot/datasets/byte_index_builder.py | 281 +++++++++++ src/lerobot/datasets/episode_byte_cache.py | 263 ++++++++++ src/lerobot/datasets/mp4_episode_slice.py | 555 +++++++++++++++++++++ src/lerobot/datasets/streaming_dataset.py | 92 ++++ src/lerobot/datasets/torchcodec_utils.py | 49 ++ src/lerobot/datasets/video_utils.py | 7 +- tests/datasets/test_byte_index.py | 150 ++++++ 10 files changed, 1682 insertions(+), 5 deletions(-) create mode 100644 scripts/build_byte_index.py create mode 100644 src/lerobot/datasets/byte_index.py create mode 100644 src/lerobot/datasets/byte_index_builder.py create mode 100644 src/lerobot/datasets/episode_byte_cache.py create mode 100644 src/lerobot/datasets/mp4_episode_slice.py create mode 100644 src/lerobot/datasets/torchcodec_utils.py create mode 100644 tests/datasets/test_byte_index.py diff --git a/pyproject.toml b/pyproject.toml index 42116722a..67be61677 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -335,9 +335,14 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] # Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle # re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...) -# (#8194) — all merged, not yet in a tagged 5.0 release. Pin to the merge commit until the next -# datasets release ships them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`. -datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" } +# (#8194) — all merged, not yet in a tagged 5.0 release. Track main until the next datasets release ships +# them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`. +datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" } +# Temporary: huggingface_hub main carries the 408-retry fix (not yet released). NOTE: main still closes the +# shared httpx.Client on every ConnectError, which races with concurrent streaming requests +# ("Cannot send a request, as the client has been closed"); we patch that out locally in +# huggingface_hub/utils/_http.py. A fresh `uv sync` re-installs main *without* that local patch. +huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "main" } [tool.setuptools.package-data] lerobot = ["envs/*.json"] diff --git a/scripts/build_byte_index.py b/scripts/build_byte_index.py new file mode 100644 index 000000000..1b2b2463b --- /dev/null +++ b/scripts/build_byte_index.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +"""Build mmap-able byte-index sidecars for LeRobot streaming datasets.""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from lerobot.datasets.byte_index_builder import ( + build_byte_index_tables, + load_existing_file_ids, + write_byte_index, +) +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build LeRobot video byte-index sidecar.") + parser.add_argument("--repo-id", required=True) + parser.add_argument("--revision", default=None) + parser.add_argument("--data-root", required=True, help="fsspec root for videos/ + data/") + parser.add_argument("--output", type=Path, required=True, help="Output meta/byte_index directory") + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--max-episodes", type=int, default=None, help="Limit episodes (debug/smoke)") + parser.add_argument("--no-keyframes", action="store_true") + args = parser.parse_args() + + meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision) + output = args.output + existing = load_existing_file_ids(output) + if existing: + logger.info("resuming: %s files already indexed", len(existing)) + + files_tbl, episodes_tbl, keyframes_tbl = build_byte_index_tables( + meta, + args.data_root, + include_keyframes=not args.no_keyframes, + workers=args.workers, + existing_files=existing, + max_episodes=args.max_episodes, + ) + write_byte_index(output, files_tbl, episodes_tbl, keyframes_tbl, merge_existing=True) + logger.info("wrote byte index to %s", output) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/datasets/byte_index.py b/src/lerobot/datasets/byte_index.py new file mode 100644 index 000000000..b59680888 --- /dev/null +++ b/src/lerobot/datasets/byte_index.py @@ -0,0 +1,228 @@ +"""Runtime in-memory byte index loaded from precomputed sidecar parquet.""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +from .byte_index_builder import BYTE_INDEX_DIR, EPISODES_NAME, FILES_NAME, KEYFRAMES_NAME +from .mp4_episode_slice import episode_custom_frame_mappings_json + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class EpisodeSliceLookup: + global_episode_id: int + file_id: int + mdat_offset: int + mdat_length: int + frame_count: int + first_pts: float + last_pts: float + avg_fps: float + + @property + def fetch_bytes(self) -> int: + return self.mdat_length + + +@dataclass(frozen=True) +class FileLookup: + file_id: int + file_path: str + file_size: int + moov_offset: int + moov_length: int + header_length: int + faststart: bool + avg_fps: float + codec: str + + +class EpisodeByteIndex: + """Columnar byte-index resident in numpy arrays for O(1) episode lookup.""" + + def __init__( + self, + index_dir: str | Path | None, + *, + video_keys: list[str], + num_episodes: int, + mmap: bool = True, + files_table: pa.Table | None = None, + episodes_table: pa.Table | None = None, + mp4_by_rel: dict[str, Any] | None = None, + ): + self.index_dir = Path(index_dir) if index_dir is not None else None + self.video_keys = list(video_keys) + self.num_episodes = num_episodes + self.num_cameras = len(video_keys) + self._cam_to_idx = {cam: i for i, cam in enumerate(self.video_keys)} + self._mp4_by_rel = mp4_by_rel + self._frame_mappings_by_gid: dict[int, bytes] = {} + + t0 = time.perf_counter() + if files_table is not None and episodes_table is not None: + files_tbl, episodes_tbl = files_table, episodes_table + else: + if self.index_dir is None: + raise ValueError("index_dir or in-memory tables required") + files_path = self.index_dir / FILES_NAME + episodes_path = self.index_dir / EPISODES_NAME + if not files_path.exists() or not episodes_path.exists(): + raise FileNotFoundError(f"byte index missing under {self.index_dir}") + files_tbl = pq.read_table(files_path, memory_map=mmap) + episodes_tbl = pq.read_table(episodes_path, memory_map=mmap) + + self._load_tables(files_tbl, episodes_tbl, mmap=mmap) + self.build_time_s = time.perf_counter() - t0 + self.load_time_s = self.build_time_s + + def _load_tables(self, files_tbl: pa.Table, episodes_tbl: pa.Table, *, mmap: bool) -> None: + def col(tbl, name: str): + array = tbl.column(name).combine_chunks() + if pa.types.is_boolean(array.type): + return array.to_numpy(zero_copy_only=False) + return array.to_numpy() + + self.file_id = col(files_tbl, "file_id") + self.file_path = files_tbl.column("file_path").to_pylist() + self.file_size = col(files_tbl, "file_size") + self.moov_offset = col(files_tbl, "moov_offset") + self.moov_length = col(files_tbl, "moov_length") + self.header_length = col(files_tbl, "header_length") + self.faststart = col(files_tbl, "faststart") + self.file_avg_fps = col(files_tbl, "avg_fps") + self.codec = files_tbl.column("codec").to_pylist() + + ep = episodes_tbl + n = len(ep) + gid = col(ep, "global_episode_id") + order = np.argsort(gid) + self._global_episode_id = gid[order] + self._episode_index = col(ep, "episode_index")[order] + self._camera_index = col(ep, "camera_index")[order] + self._file_id = col(ep, "file_id")[order] + self._mdat_offset = col(ep, "mdat_offset")[order] + self._mdat_length = col(ep, "mdat_length")[order] + self._frame_count = col(ep, "frame_count")[order] + self._first_pts = col(ep, "first_pts")[order] + self._last_pts = col(ep, "last_pts")[order] + + expected = self.num_episodes * self.num_cameras + if n != expected: + raise ValueError(f"byte index episodes rows {n} != expected {expected}") + + if self.index_dir is not None: + keyframes_path = self.index_dir / KEYFRAMES_NAME + if keyframes_path.exists(): + kf_tbl = pq.read_table(keyframes_path, memory_map=mmap) + self._keyframes_rows = len(kf_tbl) + else: + self._keyframes_rows = 0 + else: + self._keyframes_rows = 0 + + self.resident_bytes = int( + self._global_episode_id.nbytes + + self._file_id.nbytes + + self._mdat_offset.nbytes + + self._mdat_length.nbytes + + self.file_size.nbytes + ) + + @classmethod + def from_metadata_root(cls, meta_root: Path, *, video_keys: list[str], num_episodes: int) -> EpisodeByteIndex: + return cls(meta_root / BYTE_INDEX_DIR, video_keys=video_keys, num_episodes=num_episodes) + + @classmethod + def from_memory_build( + cls, + meta, + data_root: str, + *, + workers: int = 8, + max_episodes: int | None = None, + include_frame_mappings_cache: bool = True, + ) -> EpisodeByteIndex: + """Build a complete byte index in RAM (no parquet write, no dataset push).""" + from .byte_index_builder import build_byte_index_in_memory + + return build_byte_index_in_memory( + meta, + data_root, + workers=workers, + max_episodes=max_episodes, + include_frame_mappings_cache=include_frame_mappings_cache, + ) + + def lookup(self, episode_index: int, camera_key: str) -> EpisodeSliceLookup: + cam_idx = self._cam_to_idx[camera_key] + gid = episode_index * self.num_cameras + cam_idx + row = int(gid) + if row < 0 or row >= len(self._global_episode_id): + raise IndexError(f"episode_index={episode_index} camera={camera_key} out of range") + file_id = int(self._file_id[row]) + return EpisodeSliceLookup( + global_episode_id=gid, + file_id=file_id, + mdat_offset=int(self._mdat_offset[row]), + mdat_length=int(self._mdat_length[row]), + frame_count=int(self._frame_count[row]), + first_pts=float(self._first_pts[row]), + last_pts=float(self._last_pts[row]), + avg_fps=float(self.file_avg_fps[file_id]), + ) + + def file_lookup(self, file_id: int) -> FileLookup: + return FileLookup( + file_id=file_id, + file_path=self.file_path[file_id], + file_size=int(self.file_size[file_id]), + moov_offset=int(self.moov_offset[file_id]), + moov_length=int(self.moov_length[file_id]), + header_length=int(self.header_length[file_id]), + faststart=bool(self.faststart[file_id]), + avg_fps=float(self.file_avg_fps[file_id]), + codec=self.codec[file_id], + ) + + def header_byte_range(self, file_id: int) -> tuple[int, int]: + length = int(self.header_length[file_id]) + return 0, max(0, length - 1) + + def custom_frame_mappings(self, episode_index: int, camera_key: str) -> bytes | None: + cam_idx = self._cam_to_idx[camera_key] + gid = episode_index * self.num_cameras + cam_idx + cached = self._frame_mappings_by_gid.get(gid) + if cached is not None: + return cached + if self._mp4_by_rel is None: + return None + lookup = self.lookup(episode_index, camera_key) + rel = self.file_path[lookup.file_id] + mp4_index = self._mp4_by_rel.get(rel) + if mp4_index is None: + return None + payload = episode_custom_frame_mappings_json(mp4_index, lookup.first_pts, lookup.last_pts) + self._frame_mappings_by_gid[gid] = payload + return payload + + def stats_dict(self) -> dict[str, float | int]: + return { + "load_time_s": self.load_time_s, + "build_time_s": self.build_time_s, + "resident_bytes": self.resident_bytes, + "frame_mappings_cached": len(self._frame_mappings_by_gid), + "mp4_indices_cached": len(self._mp4_by_rel or {}), + "num_files": len(self.file_path), + "num_episode_rows": len(self._global_episode_id), + } diff --git a/src/lerobot/datasets/byte_index_builder.py b/src/lerobot/datasets/byte_index_builder.py new file mode 100644 index 000000000..de95dd2b9 --- /dev/null +++ b/src/lerobot/datasets/byte_index_builder.py @@ -0,0 +1,281 @@ +"""Build mmap-able byte-index sidecars for LeRobot streaming video fetch.""" + +from __future__ import annotations + +import json +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import fsspec +import pyarrow as pa +import pyarrow.parquet as pq + +from .mp4_episode_slice import ( + HEADER_PROBE_BYTES, + MAX_HEADER_PROBE_BYTES, + average_fps_from_index, + episode_keyframes, + parse_mp4_file_layout, + parse_mp4_index, +) + +logger = logging.getLogger(__name__) + +BYTE_INDEX_DIR = "meta/byte_index" +FILES_NAME = "files.parquet" +EPISODES_NAME = "episodes.parquet" +KEYFRAMES_NAME = "keyframes.parquet" + + +@dataclass +class IndexedFile: + file_id: int + file_path: str + file_size: int + moov_offset: int + moov_length: int + header_length: int + faststart: bool + avg_fps: float + codec: str + + +def fetch_header_bytes(path: str, file_size: int) -> bytes: + fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file") + probe = HEADER_PROBE_BYTES + while True: + with fs.open(path, "rb", block_size=max(probe, 2**20), cache_type="none") as f: + header = f.read(min(probe, file_size)) + try: + parse_mp4_file_layout(header, file_size) + return header + except ValueError as exc: + if probe >= min(MAX_HEADER_PROBE_BYTES, file_size) or "mdat box not found" not in str(exc): + raise + probe = min(probe * 2, MAX_HEADER_PROBE_BYTES, file_size) + + +def index_video_file(path: str, *, rel_path: str | None = None) -> tuple[IndexedFile, Any]: + fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file") + file_size = fs.info(path)["size"] + header = fetch_header_bytes(path, file_size) + layout = parse_mp4_file_layout(header, file_size) + if not layout.faststart: + logger.warning("non-faststart MP4 (moov after mdat): %s", path) + mp4_index = parse_mp4_index(header, file_size) + indexed = IndexedFile( + file_id=-1, + file_path=rel_path or path, + file_size=file_size, + moov_offset=layout.moov_offset, + moov_length=layout.moov_length, + header_length=layout.header_end, + faststart=layout.faststart, + avg_fps=average_fps_from_index(mp4_index), + codec=layout.codec, + ) + return indexed, mp4_index + + +def build_byte_index_tables( + meta, + data_root: str, + *, + file_paths: list[str] | None = None, + include_keyframes: bool = True, + workers: int = 8, + existing_files: dict[str, int] | None = None, + max_episodes: int | None = None, + return_mp4_indices: bool = False, + complete_files_table: bool = False, +) -> tuple[pa.Table, pa.Table, pa.Table | None] | tuple[pa.Table, pa.Table, pa.Table | None, dict[str, Any]]: + """Build files/episodes/(optional keyframes) Arrow tables.""" + video_keys = list(meta.video_keys) + n_cams = len(video_keys) + cam_to_idx = {cam: i for i, cam in enumerate(video_keys)} + num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes) + + rel_paths: set[str] = set() + for ep_idx in range(num_episodes): + for cam in video_keys: + rel_paths.add(str(meta.get_video_file_path(ep_idx, cam))) + path_by_rel = {rel: f"{data_root.rstrip('/')}/{rel}" for rel in sorted(rel_paths)} + if file_paths is None: + file_paths = list(path_by_rel.values()) + rel_by_path = {path_by_rel[rel]: rel for rel in path_by_rel} + + existing_files = existing_files or {} + file_meta_by_rel: dict[str, dict[str, Any]] = {} + mp4_by_rel: dict[str, Any] = {} + next_file_id = max(existing_files.values(), default=-1) + 1 + + to_index = [rel for rel in sorted(rel_paths) if rel not in existing_files] + if to_index: + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel for rel in to_index + } + for fut in as_completed(futures): + rel = futures[fut] + indexed, mp4_index = fut.result() + indexed.file_id = next_file_id + mp4_by_rel[rel] = mp4_index + file_meta_by_rel[rel] = { + "file_id": indexed.file_id, + "file_path": rel, + "file_size": indexed.file_size, + "moov_offset": indexed.moov_offset, + "moov_length": indexed.moov_length, + "header_length": indexed.header_length, + "faststart": indexed.faststart, + "avg_fps": indexed.avg_fps, + "codec": indexed.codec, + } + existing_files[rel] = indexed.file_id + next_file_id += 1 + + missing_rels = { + str(meta.get_video_file_path(ep, cam)) + for ep in range(num_episodes) + for cam in video_keys + } - set(mp4_by_rel.keys()) + if missing_rels: + with ThreadPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel + for rel in missing_rels + if rel not in mp4_by_rel + } + for fut in as_completed(futures): + rel = futures[fut] + _, mp4_index = fut.result() + mp4_by_rel[rel] = mp4_index + + episode_rows: list[dict[str, Any]] = [] + keyframe_rows: list[dict[str, Any]] = [] + for ep_idx in range(num_episodes): + for cam in video_keys: + rel = str(meta.get_video_file_path(ep_idx, cam)) + path = f"{data_root.rstrip('/')}/{rel}" + if rel not in existing_files: + raise KeyError(f"file not indexed: {rel}") + mp4_index = mp4_by_rel[rel] + ep = meta.episodes[ep_idx] + from_ts = float(ep[f"videos/{cam}/from_timestamp"]) + to_ts = float(ep[f"videos/{cam}/to_timestamp"]) + span = mp4_index.episode_byte_span(from_ts, to_ts) + global_episode_id = ep_idx * n_cams + cam_to_idx[cam] + mdat_length = span.slice_hi - span.slice_lo + 1 + episode_rows.append( + { + "global_episode_id": global_episode_id, + "episode_index": ep_idx, + "camera_key": cam, + "camera_index": cam_to_idx[cam], + "file_id": existing_files[rel], + "mdat_offset": span.slice_lo, + "mdat_length": mdat_length, + "frame_count": max(1, round((to_ts - from_ts) * meta.fps)), + "first_pts": from_ts, + "last_pts": to_ts, + } + ) + if include_keyframes: + timescale = mp4_index.timescale + for pts_s, byte_off in episode_keyframes(mp4_index, from_ts, to_ts): + keyframe_rows.append( + { + "global_episode_id": global_episode_id, + "pts": int(round(pts_s * timescale)), + "byte_offset": byte_off, + } + ) + + referenced_rels = { + str(meta.get_video_file_path(ep, cam)) for ep in range(num_episodes) for cam in video_keys + } + if complete_files_table: + files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(referenced_rels)]) + elif to_index: + files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(to_index)]) + else: + files_table = None + episodes_table = pa.Table.from_pylist(episode_rows) + keyframes_table = pa.Table.from_pylist(keyframe_rows) if include_keyframes and keyframe_rows else None + if return_mp4_indices: + return files_table, episodes_table, keyframes_table, mp4_by_rel + return files_table, episodes_table, keyframes_table + + +def build_byte_index_in_memory( + meta, + data_root: str, + *, + workers: int = 8, + max_episodes: int | None = None, + include_frame_mappings_cache: bool = False, +): + """Build a complete byte index resident in RAM (no parquet write, no dataset push).""" + from .byte_index import EpisodeByteIndex + + num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes) + files_tbl, episodes_tbl, _, mp4_by_rel = build_byte_index_tables( + meta, + data_root, + include_keyframes=False, + workers=workers, + max_episodes=max_episodes, + return_mp4_indices=True, + complete_files_table=True, + ) + index = EpisodeByteIndex( + None, + video_keys=list(meta.video_keys), + num_episodes=num_episodes, + files_table=files_tbl, + episodes_table=episodes_tbl, + mp4_by_rel=mp4_by_rel, + ) + if include_frame_mappings_cache: + for ep_idx in range(num_episodes): + for cam in meta.video_keys: + index.custom_frame_mappings(ep_idx, cam) + return index + + +def write_byte_index( + output_dir: Path, + files_table: pa.Table | None, + episodes_table: pa.Table, + keyframes_table: pa.Table | None = None, + *, + merge_existing: bool = True, +) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + files_path = output_dir / FILES_NAME + episodes_path = output_dir / EPISODES_NAME + keyframes_path = output_dir / KEYFRAMES_NAME + + if merge_existing and files_path.exists() and files_table is not None: + prev = pq.read_table(files_path) + files_table = pa.concat_tables([prev, files_table]) + + if files_table is not None: + pq.write_table(files_table, files_path) + + pq.write_table(episodes_table, episodes_path) + if keyframes_table is not None: + if merge_existing and keyframes_path.exists(): + keyframes_table = pa.concat_tables([pq.read_table(keyframes_path), keyframes_table]) + pq.write_table(keyframes_table, keyframes_path) + + +def load_existing_file_ids(index_dir: Path) -> dict[str, int]: + files_path = index_dir / FILES_NAME + if not files_path.exists(): + return {} + table = pq.read_table(files_path, columns=["file_id", "file_path"]) + return {row["file_path"]: int(row["file_id"]) for row in table.to_pylist()} diff --git a/src/lerobot/datasets/episode_byte_cache.py b/src/lerobot/datasets/episode_byte_cache.py new file mode 100644 index 000000000..bf7aefd16 --- /dev/null +++ b/src/lerobot/datasets/episode_byte_cache.py @@ -0,0 +1,263 @@ +"""Node-local LRU byte cache using precomputed byte-index manifest sidecars.""" + +from __future__ import annotations + +import logging +import threading +import time +from collections import OrderedDict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any + +import fsspec + +from .byte_index import EpisodeByteIndex, EpisodeSliceLookup +from .mp4_episode_slice import SparseMp4Reader +from .torchcodec_utils import open_video_decoder + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheStats: + hits: int = 0 + misses: int = 0 + bytes_fetched: int = 0 + full_file_fallbacks: int = 0 + prefetch_submitted: int = 0 + prefetch_waits: int = 0 + mdat_slices: int = 0 + prefix_fetches: int = 0 + fetch_to_buffer_s: float = 0.0 + buffer_to_decoder_s: float = 0.0 + buffer_hit_decoder_s: float = 0.0 + decode_frame_s: float = 0.0 + decode_frames: int = 0 + + def merge(self, other: CacheStats) -> None: + for name in self.__dataclass_fields__: + setattr(self, name, getattr(self, name) + getattr(other, name)) + + def stats_dict(self) -> dict[str, int | float]: + avg_miss = self.bytes_fetched / max(1, self.misses) + return { + "byte_cache_hits": self.hits, + "byte_cache_misses": self.misses, + "byte_cache_bytes_fetched": self.bytes_fetched, + "byte_cache_bytes_per_miss": avg_miss, + "byte_cache_full_file_fallbacks": self.full_file_fallbacks, + "byte_cache_prefetch_submitted": self.prefetch_submitted, + "byte_cache_prefetch_waits": self.prefetch_waits, + "byte_cache_mdat_slices": self.mdat_slices, + "byte_cache_prefix_fetches": self.prefix_fetches, + "fetch_to_buffer_ms_per_miss": 1000 * self.fetch_to_buffer_s / max(1, self.misses), + "buffer_to_decoder_ms_per_miss": 1000 * self.buffer_to_decoder_s / max(1, self.misses), + "buffer_hit_decoder_ms_per_hit": 1000 * self.buffer_hit_decoder_s / max(1, self.hits), + "decode_ms_per_frame": 1000 * self.decode_frame_s / max(1, self.decode_frames), + } + + +@dataclass +class _EpisodeEntry: + decoders: dict[str, Any] = field(default_factory=dict) + ready: threading.Event = field(default_factory=threading.Event) + error: Exception | None = None + + +class RangeFetcher: + """Sequential byte-range GETs via fsspec.""" + + def __init__(self, path: str): + self.path = path + self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file") + + def fetch(self, lo: int, hi: int) -> bytes: + if hi < lo: + return b"" + with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f: + f.seek(lo) + return f.read(hi - lo + 1) + + +class EpisodeByteCache: + """Manifest-driven episode MP4 fetch + in-memory sparse decode.""" + + MAX_BYTES_PER_MISS = 25 * 1024 * 1024 + + def __init__( + self, + byte_index: EpisodeByteIndex, + max_bytes: int, + *, + data_root: str, + max_prefetch_workers: int = 4, + ): + if max_bytes <= 0: + raise ValueError(f"max_bytes must be positive; got {max_bytes}") + self.byte_index = byte_index + self.max_bytes = max_bytes + self.data_root = data_root.rstrip("/") + self._bytes_used = 0 + self._lock = threading.Lock() + self._cache: OrderedDict[tuple[Any, ...], tuple[Any, int]] = OrderedDict() + self._header_cache: dict[int, bytes] = {} + self._fetcher_cache: dict[int, RangeFetcher] = {} + self._episodes: dict[int, _EpisodeEntry] = {} + self._stats = CacheStats() + self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers) + self._futures: dict[int, Future] = {} + + @property + def stats(self) -> CacheStats: + with self._lock: + return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__}) + + def submit_prefetch(self, ep_idx: int) -> None: + with self._lock: + if ep_idx in self._episodes or ep_idx in self._futures: + return + self._stats.prefetch_submitted += 1 + fut = self._executor.submit(self._prefetch_episode, ep_idx) + self._futures[ep_idx] = fut + + def ensure_ready(self, ep_idx: int) -> None: + with self._lock: + fut = self._futures.pop(ep_idx, None) + if fut is not None: + with self._lock: + self._stats.prefetch_waits += 1 + fut.result() + entry = self._episodes.get(ep_idx) + if entry is None: + raise KeyError(f"episode {ep_idx} not prefetched") + if entry.error is not None: + raise entry.error + entry.ready.wait() + + def get_decoder(self, ep_idx: int, video_key: str) -> Any: + entry = self._episodes[ep_idx] + if entry.error is not None: + raise entry.error + entry.ready.wait() + return entry.decoders[video_key] + + def close(self) -> None: + self._executor.shutdown(wait=False, cancel_futures=True) + + def _prefetch_episode(self, ep_idx: int) -> None: + entry = _EpisodeEntry() + self._episodes[ep_idx] = entry + try: + for cam in self.byte_index.video_keys: + entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam) + except Exception as exc: + entry.error = exc + finally: + entry.ready.set() + + def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any: + key = (ep_idx, cam) + with self._lock: + cached = self._cache.get(key) + if cached is not None: + self._cache.move_to_end(key) + self._stats.hits += 1 + payload, _ = cached + t0 = time.perf_counter() + dec = self._decoder_from_payload(payload, ep_idx, cam) + with self._lock: + self._stats.buffer_hit_decoder_s += time.perf_counter() - t0 + return dec + + payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam) + + with self._lock: + self._stats.misses += 1 + if payload_bytes > self.MAX_BYTES_PER_MISS: + logger.warning( + "byte cache miss fetched %.1f MB (>25 MB) for ep=%s cam=%s", + payload_bytes / 1e6, + ep_idx, + cam, + ) + self._evict_until(payload_bytes) + self._cache[key] = (payload, payload_bytes) + self._bytes_used += payload_bytes + return dec + + def _fetch_manifest_slice(self, ep_idx: int, cam: str) -> tuple[SparseMp4Reader, int, Any]: + lookup = self.byte_index.lookup(ep_idx, cam) + file_info = self.byte_index.file_lookup(lookup.file_id) + fetcher = self._get_fetcher(lookup.file_id, file_info.file_path) + t_fetch = time.perf_counter() + header = self._get_header_bytes(lookup.file_id, fetcher, file_info.header_length) + lo = lookup.mdat_offset + hi = lo + lookup.mdat_length - 1 + mdat = fetcher.fetch(lo, hi) + fetch_s = time.perf_counter() - t_fetch + nbytes = len(header) + len(mdat) + with self._lock: + self._stats.bytes_fetched += nbytes + self._stats.mdat_slices += 1 + self._stats.fetch_to_buffer_s += fetch_s + + def lazy_fetch(pos: int, end: int) -> bytes: + data = fetcher.fetch(pos, end) + with self._lock: + self._stats.bytes_fetched += len(data) + return data + + reader = SparseMp4Reader( + file_size=file_info.file_size, + header=header, + mdat_lo=lo, + mdat_bytes=mdat, + lazy_fetch=lazy_fetch, + ) + t_init = time.perf_counter() + dec = self._decoder_from_payload(reader, ep_idx, cam) + self._validate_decoder(dec, lookup) + init_s = time.perf_counter() - t_init + with self._lock: + self._stats.buffer_to_decoder_s += init_s + self._rewind_payload(reader) + return reader, nbytes, dec + + def _get_fetcher(self, file_id: int, rel_path: str) -> RangeFetcher: + if file_id not in self._fetcher_cache: + path = rel_path if rel_path.startswith("hf://") else f"{self.data_root}/{rel_path}" + self._fetcher_cache[file_id] = RangeFetcher(path) + return self._fetcher_cache[file_id] + + def _get_header_bytes(self, file_id: int, fetcher: RangeFetcher, header_length: int) -> bytes: + if file_id in self._header_cache: + return self._header_cache[file_id] + hi = max(0, header_length - 1) + header = fetcher.fetch(0, hi) + with self._lock: + self._header_cache[file_id] = header + self._stats.bytes_fetched += len(header) + return header + + def _decoder_from_payload( + self, payload: SparseMp4Reader, ep_idx: int, cam: str + ) -> Any: + payload.seek(0) + mappings = self.byte_index.custom_frame_mappings(ep_idx, cam) + return open_video_decoder(payload, frame_mappings=mappings) + + def _validate_decoder(self, dec: Any, lookup: EpisodeSliceLookup) -> None: + begin = float(dec.metadata.begin_stream_seconds) + end = float(dec.metadata.end_stream_seconds) + duration = max(0.01, end - begin) + for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3): + dec.get_frames_played_at([ts]).data + + def _rewind_payload(self, payload: SparseMp4Reader) -> None: + payload.seek(0) + + def _evict_until(self, need: int) -> None: + while self._bytes_used + need > self.max_bytes and self._cache: + _, (_, size) = self._cache.popitem(last=False) + self._bytes_used -= size diff --git a/src/lerobot/datasets/mp4_episode_slice.py b/src/lerobot/datasets/mp4_episode_slice.py new file mode 100644 index 000000000..6c0c641e9 --- /dev/null +++ b/src/lerobot/datasets/mp4_episode_slice.py @@ -0,0 +1,555 @@ +"""MP4 moov parsing and tight per-episode mdat byte-range fetching. + +LeRobot v3 concatenates episodes into shared MP4 files (faststart: moov at head). +For streaming we fetch only the file header plus the episode's contiguous mdat span +instead of the ``0..episode_end`` prefix. +""" + +from __future__ import annotations + +import io +import struct +import threading +from dataclasses import dataclass, field +from typing import Callable + +KEYFRAME_PAD_S = 0.1 +HEADER_PROBE_BYTES = 4 * 1024 * 1024 +MAX_HEADER_PROBE_BYTES = 16 * 1024 * 1024 + + +@dataclass +class Mp4FileLayout: + file_size: int + moov_offset: int + moov_length: int + header_end: int + mdat_offset: int + mdat_size: int + faststart: bool + codec: str + + +def parse_mp4_file_layout(header_bytes: bytes, file_size: int) -> Mp4FileLayout: + """Return top-level MP4 layout (moov/mdat positions, faststart flag).""" + boxes = list(_iter_boxes(header_bytes)) + moov_offset = mdat_offset = -1 + moov_length = mdat_size = 0 + for off, size, typ, _ in boxes: + if typ == b"moov" and moov_offset < 0: + moov_offset, moov_length = off, size + if typ == b"mdat" and mdat_offset < 0: + mdat_offset, mdat_size = off, size + if moov_offset < 0: + raise ValueError("moov box not found in header probe") + if mdat_offset < 0: + raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES") + faststart = moov_offset < mdat_offset + header_end = mdat_offset + codec = _parse_video_codec(header_bytes) + return Mp4FileLayout( + file_size=file_size, + moov_offset=moov_offset, + moov_length=moov_length, + header_end=header_end, + mdat_offset=mdat_offset, + mdat_size=mdat_size, + faststart=faststart, + codec=codec, + ) + + +def _parse_video_codec(header_bytes: bytes) -> str: + moov = _find_box_payload(header_bytes, b"moov") + if moov is None: + return "unknown" + trak = _find_video_trak(moov) + if trak is None: + return "unknown" + stsd = _find_box_payload(_find_box_payload(trak, b"stbl") or b"", b"stsd") + if stsd is None or len(stsd) < 12: + return "unknown" + # stsd: version(1)+flags(3)+entry_count(4)+entry_size(4)+codec(4) + if len(stsd) >= 12: + return stsd[8:12].decode("latin1", errors="replace").strip("\x00") + return "unknown" + + +def average_fps_from_index(index: Mp4VideoIndex) -> float: + index.ensure_tables() + if index.num_samples < 2: + return 30.0 + duration = index.sample_pts(index.num_samples - 1) + if duration <= 0: + return 30.0 + return index.num_samples / duration + + +def episode_custom_frame_mappings_json( + index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S +) -> bytes: + """Build TorchCodec ``custom_frame_mappings`` JSON for one episode span.""" + import json + + index.ensure_tables() + lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s)) + hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s) + hi_idx = min(hi_idx, index.num_samples - 1) + lo_idx = _keyframe_back(index.sync_samples, lo_idx) + + sync = set(index.sync_samples) + timescale = index.timescale + # stts deltas for duration per sample (expand stts entries to per-sample delta) + sample_deltas: list[int] = [] + for count, delta in index.stts: + sample_deltas.extend([delta] * count) + while len(sample_deltas) < index.num_samples: + sample_deltas.append(sample_deltas[-1] if sample_deltas else timescale // 30) + + frames = [] + for idx in range(lo_idx, hi_idx + 1): + frames.append( + { + "pts": int(round(index._pts[idx] * timescale)), + "duration": int(sample_deltas[idx]), + "key_frame": int((idx + 1) in sync) if sync else int(idx == lo_idx), + } + ) + return json.dumps({"frames": frames}).encode() + + +def episode_keyframes( + index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S +) -> list[tuple[float, int]]: + """Return (pts_seconds, byte_offset) for sync samples in the episode span.""" + index.ensure_tables() + span = index.episode_byte_span(from_ts, to_ts, keyframe_pad_s) + lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s)) + hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s) + if not index.sync_samples: + return [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))] + out: list[tuple[float, int]] = [] + for sync_one_based in index.sync_samples: + idx = sync_one_based - 1 + if lo_idx <= idx <= hi_idx: + out.append((index.sample_pts(idx), index.sample_offset(idx))) + return out or [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))] + + +@dataclass +class EpisodeByteSpan: + """Absolute file byte ranges to fetch for one episode.""" + + file_size: int + header_end: int + slice_lo: int + slice_hi: int + + @property + def header_bytes(self) -> tuple[int, int]: + return 0, self.header_end - 1 + + @property + def mdat_bytes(self) -> tuple[int, int]: + return self.slice_lo, self.slice_hi + + @property + def total_fetch_bytes(self) -> int: + header = self.header_end + mdat = self.slice_hi - self.slice_lo + 1 + return header + mdat + + +@dataclass +class Mp4VideoIndex: + file_size: int + header_end: int + mdat_offset: int + mdat_size: int + timescale: int + stts: list[tuple[int, int]] + stsz: list[int] + stsc: list[tuple[int, int, int]] + stco: list[int] + sync_samples: list[int] + _pts: list[float] = field(default_factory=list, repr=False) + _offsets: list[int] = field(default_factory=list, repr=False) + + def ensure_tables(self) -> None: + if self._pts: + return + self._pts = _pts_from_stts(self.stts, self.timescale) + self._offsets = _sample_byte_offsets(self.stsc, self.stco, self.stsz) + + @property + def num_samples(self) -> int: + return len(self.stsz) + + def sample_pts(self, index: int) -> float: + self.ensure_tables() + return self._pts[index] + + def sample_offset(self, index: int) -> int: + self.ensure_tables() + index = max(0, min(index, len(self._offsets) - 1)) + return self._offsets[index] + + def sample_end(self, index: int) -> int: + return self.sample_offset(index) + self.stsz[index] + + def episode_byte_span(self, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S) -> EpisodeByteSpan: + self.ensure_tables() + n = self.num_samples + if n == 0: + raise ValueError("MP4 has no video samples") + + pad = max(keyframe_pad_s, 0.05 * max(0.01, to_ts - from_ts)) + lo_ts = max(0.0, from_ts - pad) + hi_ts = to_ts + pad + + lo_idx = _first_sample_at_or_after(self._pts, lo_ts) + hi_idx = _last_sample_at_or_before(self._pts, hi_ts) + hi_idx = min(hi_idx, n - 1) + lo_idx = min(lo_idx, n - 1) + + lo_idx = _keyframe_back(self.sync_samples, lo_idx) + + slice_lo = self.sample_offset(lo_idx) + slice_hi = self.sample_end(min(hi_idx, len(self._offsets) - 1)) + return EpisodeByteSpan( + file_size=self.file_size, + header_end=self.header_end, + slice_lo=slice_lo, + slice_hi=min(slice_hi, self.file_size - 1), + ) + + +class SparseMp4Reader(io.BufferedIOBase): + """Range-backed MP4 reader: header + one mdat span at absolute offsets.""" + + def __init__( + self, + file_size: int, + header: bytes, + mdat_lo: int, + mdat_bytes: bytes, + lazy_fetch: Callable[[int, int], bytes] | None = None, + ): + self._size = file_size + self._header = header + self._mdat_lo = mdat_lo + self._mdat_hi = mdat_lo + len(mdat_bytes) + self._mdat = mdat_bytes + self._lazy_fetch = lazy_fetch + self._pos = 0 + self._lock = threading.Lock() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return True + + def tell(self) -> int: + return self._pos + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if whence == io.SEEK_SET: + self._pos = offset + elif whence == io.SEEK_CUR: + self._pos += offset + elif whence == io.SEEK_END: + self._pos = self._size + offset + else: + raise ValueError(f"invalid whence: {whence}") + self._pos = max(0, min(self._pos, self._size)) + return self._pos + + def read(self, size: int = -1) -> bytes: + if size < 0: + size = self._size - self._pos + if size <= 0: + return b"" + + out = bytearray() + remaining = size + pos = self._pos + while remaining > 0 and pos < self._size: + chunk = self._read_at(pos, remaining) + if not chunk: + break + out.extend(chunk) + pos += len(chunk) + remaining -= len(chunk) + self._pos = pos + return bytes(out) + + def _read_at(self, pos: int, n: int) -> bytes: + header_len = len(self._header) + if pos < header_len: + end = min(pos + n, header_len) + return self._header[pos:end] + + if self._mdat_lo <= pos < self._mdat_hi: + end = min(pos + n, self._mdat_hi) + off = pos - self._mdat_lo + return self._mdat[off : off + (end - pos)] + + if self._lazy_fetch is not None: + with self._lock: + end = min(pos + n, self._size) + return self._lazy_fetch(pos, end - 1) + + return b"\x00" * min(n, self._size - pos) + + +def parse_mp4_index(header_bytes: bytes, file_size: int) -> Mp4VideoIndex: + """Parse moov sample tables from the file header (faststart layout).""" + layout = parse_mp4_file_layout(header_bytes, file_size) + mdat_offset, mdat_size = layout.mdat_offset, layout.mdat_size + moov = _find_box_payload(header_bytes, b"moov") + if moov is None: + raise ValueError("moov box not found in MP4 header probe") + + trak = _find_video_trak(moov) + if trak is None: + raise ValueError("video trak not found in moov") + + mdhd = _find_box_payload(trak, b"mdhd") + if mdhd is None: + raise ValueError("mdhd not found") + timescale = _parse_mdhd_timescale(mdhd) + + stbl = _find_box_payload(trak, b"stbl") + if stbl is None: + raise ValueError("stbl not found") + + stts = _parse_stts(_find_box_payload(stbl, b"stts")) + stsz = _parse_stsz(_find_box_payload(stbl, b"stsz")) + stsc = _parse_stsc(_find_box_payload(stbl, b"stsc")) + stco_payload = _find_box_payload(stbl, b"stco") + co64_payload = _find_box_payload(stbl, b"co64") + if stco_payload is not None: + stco = _parse_stco(stco_payload) + elif co64_payload is not None: + stco = _parse_co64(co64_payload) + else: + raise ValueError("stco/co64 not found") + + stss_payload = _find_box_payload(stbl, b"stss") + sync_samples = _parse_stss(stss_payload) if stss_payload else [] + + return Mp4VideoIndex( + file_size=file_size, + header_end=layout.header_end, + mdat_offset=mdat_offset, + mdat_size=mdat_size, + timescale=timescale, + stts=stts, + stsz=stsz, + stsc=stsc, + stco=stco, + sync_samples=sync_samples, + ) + + +def _box_header(data: bytes, offset: int) -> tuple[int, bytes, int] | None: + if offset + 8 > len(data): + return None + size, typ = struct.unpack_from(">I4s", data, offset) + header = 8 + if size == 1: + if offset + 16 > len(data): + return None + size = struct.unpack_from(">Q", data, offset + 8)[0] + header = 16 + elif size == 0: + size = len(data) - offset + return size, typ, header + + +def _iter_boxes(data: bytes, start: int = 0, end: int | None = None): + end = end if end is not None else len(data) + off = start + while off + 8 <= end: + hdr = _box_header(data, off) + if hdr is None or hdr[0] < hdr[2]: + break + size, typ, header = hdr + yield off, size, typ, data[off + header : off + size] + off += size + + +def _find_box_payload(data: bytes, target: bytes) -> bytes | None: + for _, _, typ, payload in _iter_boxes(data): + if typ == target: + return payload + if typ in (b"moov", b"trak", b"mdia", b"minf", b"stbl"): + found = _find_box_payload(payload, target) + if found is not None: + return found + return None + + +def _find_video_trak(moov: bytes) -> bytes | None: + for _, _, typ, payload in _iter_boxes(moov): + if typ != b"trak": + continue + hdlr = _find_box_payload(payload, b"hdlr") + if hdlr is not None and len(hdlr) >= 12 and hdlr[8:12] == b"vide": + return payload + return None + + +def _find_mdat(header_bytes: bytes, file_size: int) -> tuple[int, int]: + for off, size, typ, _ in _iter_boxes(header_bytes): + if typ == b"mdat": + return off, size + # mdat may start beyond probe; scan from file_size hint unavailable — require probe hit + raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES") + + +def _parse_mdhd_timescale(mdhd: bytes) -> int: + version = mdhd[0] + if version == 0: + return struct.unpack_from(">I", mdhd, 12)[0] + return struct.unpack_from(">I", mdhd, 20)[0] + + +def _parse_stts(stts: bytes | None) -> list[tuple[int, int]]: + if stts is None: + raise ValueError("stts missing") + count = struct.unpack_from(">I", stts, 4)[0] + out = [] + off = 8 + for _ in range(count): + sample_count, delta = struct.unpack_from(">II", stts, off) + out.append((sample_count, delta)) + off += 8 + return out + + +def _parse_stsz(stsz: bytes | None) -> list[int]: + if stsz is None: + raise ValueError("stsz missing") + sample_size, sample_count = struct.unpack_from(">II", stsz, 4) + if sample_size != 0: + return [sample_size] * sample_count + off = 12 + return list(struct.unpack_from(f">{sample_count}I", stsz, off)) + + +def _parse_stsc(stsc: bytes | None) -> list[tuple[int, int, int]]: + if stsc is None: + raise ValueError("stsc missing") + count = struct.unpack_from(">I", stsc, 4)[0] + out = [] + off = 8 + for _ in range(count): + first_chunk, samples_per_chunk, sample_desc = struct.unpack_from(">III", stsc, off) + out.append((first_chunk, samples_per_chunk, sample_desc)) + off += 12 + return out + + +def _parse_stco(stco: bytes) -> list[int]: + count = struct.unpack_from(">I", stco, 4)[0] + return list(struct.unpack_from(f">{count}I", stco, 8)) + + +def _parse_co64(co64: bytes) -> list[int]: + count = struct.unpack_from(">I", co64, 4)[0] + return [struct.unpack_from(">Q", co64, 8 + i * 8)[0] for i in range(count)] + + +def _parse_stss(stss: bytes) -> list[int]: + count = struct.unpack_from(">I", stss, 4)[0] + return list(struct.unpack_from(f">{count}I", stss, 8)) + + +def _pts_from_stts(stts: list[tuple[int, int]], timescale: int) -> list[float]: + pts: list[float] = [] + t = 0 + for count, delta in stts: + for _ in range(count): + pts.append(t / timescale) + t += delta + return pts + + +def _sample_byte_offsets( + stsc: list[tuple[int, int, int]], stco: list[int], stsz: list[int] +) -> list[int]: + if not stsc: + stsc = [(1, len(stsz), 1)] + + offsets: list[int] = [] + chunk_idx = 0 + sample_idx = 0 + sc_idx = 0 + num_chunks = len(stco) + + while chunk_idx < num_chunks and sample_idx < len(stsz): + first_chunk, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)] + if sc_idx + 1 < len(stsc): + next_first = stsc[sc_idx + 1][0] + chunks_in_entry = next_first - first_chunk + else: + chunks_in_entry = num_chunks - chunk_idx + + for _ in range(chunks_in_entry): + if chunk_idx >= num_chunks: + break + offset = stco[chunk_idx] + _, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)] + for _ in range(samples_per_chunk): + if sample_idx >= len(stsz): + break + offsets.append(offset) + offset += stsz[sample_idx] + sample_idx += 1 + chunk_idx += 1 + sc_idx += 1 + + if len(offsets) < len(stsz): + # Pad with last known offset progression for malformed stsc edge cases. + last = offsets[-1] if offsets else 0 + while len(offsets) < len(stsz): + idx = len(offsets) + offsets.append(last) + last += stsz[idx] + + return offsets + + +def _first_sample_at_or_after(pts: list[float], ts: float) -> int: + lo, hi = 0, len(pts) + while lo < hi: + mid = (lo + hi) // 2 + if pts[mid] < ts: + lo = mid + 1 + else: + hi = mid + return min(lo, len(pts) - 1) + + +def _last_sample_at_or_before(pts: list[float], ts: float) -> int: + lo, hi = 0, len(pts) + while lo < hi: + mid = (lo + hi) // 2 + if pts[mid] <= ts: + lo = mid + 1 + else: + hi = mid + return max(0, lo - 1) + + +def _keyframe_back(sync_samples: list[int], sample_idx: int) -> int: + if not sync_samples: + return max(0, sample_idx - 2) + # stss stores 1-based sample numbers + one_based = sample_idx + 1 + prev = [s for s in sync_samples if s <= one_based] + if prev: + return prev[-1] - 1 + return 0 diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index fb6b9eef2..6718fde6f 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -124,6 +124,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): video_decoder_cache_size: int | None = None, data_files_root: str | None = None, validate_row_groups: bool = True, + video_byte_cache_gb: float | None = 80.0, + byte_index_path: str | Path | None = None, + byte_index_build_in_memory: bool | None = None, + byte_index_workers: int = 8, + byte_index_max_episodes: int | None = None, ): """Initialize a StreamingLeRobotDataset. @@ -173,6 +178,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): ``num_shards`` is divisible by ``world_size`` for distributed runs, raising a clear ``ValueError`` otherwise. Set False to skip the checks (e.g. single-process debugging); the divisibility check then downgrades to a warning. + video_byte_cache_gb (float | None, optional): Node-local LRU for episode MP4 mdat slices. + When set (default 80 GB), episodes are prefetched via tight byte ranges before decode. + Set to 0 or None to disable and use remote per-seek decoding. + byte_index_path (str | Path | None, optional): Path to precomputed ``meta/byte_index/`` + sidecar parquet tables. Defaults to ``{meta.root}/meta/byte_index``. + byte_index_build_in_memory (bool | None, optional): When True, build the byte index in RAM + at init (moov-only fetches, no parquet write). When None (default), build in memory only + if the sidecar parquet is missing on disk. + byte_index_workers (int, optional): Parallel moov-index workers for in-memory builds. + byte_index_max_episodes (int | None, optional): Cap episodes indexed (debug/smoke tests). """ super().__init__() self.repo_id = repo_id @@ -210,6 +225,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self.rank, self.world_size = self._resolve_distributed(rank, world_size) self.video_decoder_cache_size = video_decoder_cache_size self.data_files_root = data_files_root.rstrip("/") if data_files_root else None + self.video_byte_cache_gb = video_byte_cache_gb + self.byte_index_path = Path(byte_index_path) if byte_index_path is not None else None + self.byte_index_build_in_memory = byte_index_build_in_memory + self.byte_index_workers = byte_index_workers + self.byte_index_max_episodes = byte_index_max_episodes + self._episode_byte_cache = None + self._byte_index = None + self._data_root = None # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) self.video_decoder_cache = None @@ -228,6 +251,37 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) + if self._use_episode_byte_cache(): + from .byte_index import EpisodeByteIndex + + data_root = self._resolve_data_root() + index_dir = self.byte_index_path or (self.meta.root / "meta" / "byte_index") + sidecar_exists = (index_dir / "files.parquet").exists() and (index_dir / "episodes.parquet").exists() + build_in_memory = ( + self.byte_index_build_in_memory + if self.byte_index_build_in_memory is not None + else not sidecar_exists + ) + if build_in_memory: + logger.info( + "Building byte index in memory from %s (%s episodes, %d workers)", + data_root, + self.byte_index_max_episodes or self.meta.total_episodes, + self.byte_index_workers, + ) + self._byte_index = EpisodeByteIndex.from_memory_build( + self.meta, + data_root, + workers=self.byte_index_workers, + max_episodes=self.byte_index_max_episodes, + ) + else: + self._byte_index = EpisodeByteIndex( + index_dir, + video_keys=self.meta.video_keys, + num_episodes=self.meta.total_episodes, + ) + self.delta_timestamps = None self.delta_indices = None @@ -417,6 +471,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): buffer_size=self.episode_pool_size, max_buffer_input_shards=max_input_shards, ) + if self._use_episode_byte_cache(): + ds = ds.map(self._submit_episode_prefetch, batched=True) # A row-count-changing batched map must drop the input columns explicitly; the exploded # frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks). ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns) @@ -472,6 +528,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return VideoDecoderCache() return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128)) + def _use_episode_byte_cache(self) -> bool: + return ( + self.video_byte_cache_gb not in (None, 0) + and self.data_files_root is not None + ) + + def _make_episode_byte_cache(self): + from .episode_byte_cache import EpisodeByteCache + + if self._byte_index is None: + raise RuntimeError("byte index required for episode byte cache; run build_byte_index.py") + max_bytes = int(float(self.video_byte_cache_gb) * 1e9) + return EpisodeByteCache( + self._byte_index, + max_bytes, + data_root=self._data_root, + ) + + def _submit_episode_prefetch(self, episode_batch: dict[str, list[list]]) -> dict[str, list[list]]: + if self._episode_byte_cache is None: + return episode_batch + for ep_idx in {int(v[0]) for v in episode_batch["episode_index"]}: + self._episode_byte_cache.submit_prefetch(ep_idx) + return episode_batch + def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: # `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch); # DataLoader workers each advance their own copy's counter in lockstep. The in-flight @@ -486,6 +567,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): self._in_flight_epoch = 0 self._pipeline.set_epoch(self._in_flight_epoch) self.video_decoder_cache = self._make_video_decoder_cache() + self._data_root = self._resolve_data_root() + if self._use_episode_byte_cache(): + self._episode_byte_cache = self._make_episode_byte_cache() + else: + self._episode_byte_cache = None iterator = iter(self._pipeline) while True: @@ -623,6 +709,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): """ item = {} + if self._episode_byte_cache is not None: + self._episode_byte_cache.ensure_ready(ep_idx) for video_key, query_ts in query_timestamps.items(): # query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset. from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"] @@ -635,12 +723,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): else: root = self.root video_path = f"{root}/{rel_path}" + episode_decoder = None + if self._episode_byte_cache is not None: + episode_decoder = self._episode_byte_cache.get_decoder(ep_idx, video_key) frames = decode_video_frames_torchcodec( video_path, shifted_query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache, return_uint8=self._return_uint8, + episode_decoder=episode_decoder, ) item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames diff --git a/src/lerobot/datasets/torchcodec_utils.py b/src/lerobot/datasets/torchcodec_utils.py new file mode 100644 index 000000000..3a4a1b0d1 --- /dev/null +++ b/src/lerobot/datasets/torchcodec_utils.py @@ -0,0 +1,49 @@ +"""TorchCodec helpers for sparse MP4 IO with optional custom frame mappings.""" + +from __future__ import annotations + +import json +from typing import Any + +import torch +from torchcodec import FrameBatch, _core as core +from torchcodec.decoders._video_decoder import _get_and_validate_stream_metadata + + +def frame_mappings_tensors(payload: bytes) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + data = json.loads(payload) + frames = data["frames"] + pts = torch.tensor([int(f["pts"]) for f in frames], dtype=torch.int64) + key = torch.tensor([bool(f["key_frame"]) for f in frames], dtype=torch.bool) + dur = torch.tensor([int(f["duration"]) for f in frames], dtype=torch.int64) + return pts, key, dur + + +class VideoDecoderLike: + """Minimal VideoDecoder surface used by episode byte cache.""" + + def __init__(self, decoder: torch.Tensor, *, stream_index: int | None = None): + self._decoder = decoder + ( + self.metadata, + self.stream_index, + self._begin_stream_seconds, + self._end_stream_seconds, + self._num_frames, + ) = _get_and_validate_stream_metadata(decoder=decoder, stream_index=stream_index) + + def get_frames_played_at(self, seconds: list[float]) -> FrameBatch: + return FrameBatch(*core.get_frames_by_pts(self._decoder, timestamps=seconds)) + + +def open_video_decoder(source: Any, *, frame_mappings: bytes | None = None) -> VideoDecoderLike: + """Open a decoder on sparse or full MP4 IO, skipping metadata scan when mappings exist.""" + if frame_mappings is None: + decoder = core.create_from_file_like(source, "approximate") + core.add_video_stream(decoder) + return VideoDecoderLike(decoder) + + mappings = frame_mappings_tensors(frame_mappings) + decoder = core.create_from_file_like(source, "custom_frame_mappings") + core.add_video_stream(decoder, custom_frame_mappings=mappings) + return VideoDecoderLike(decoder) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 30fda72d1..89431ead0 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -326,6 +326,7 @@ def decode_video_frames_torchcodec( log_loaded_timestamps: bool = False, decoder_cache: VideoDecoderCache | None = None, return_uint8: bool = False, + episode_decoder: Any | None = None, ) -> torch.Tensor: """Loads frames associated with the requested timestamps of a video using torchcodec. @@ -347,8 +348,10 @@ def decode_video_frames_torchcodec( if decoder_cache is None: decoder_cache = _default_decoder_cache - # Use cached decoder instead of creating new one each time - decoder = decoder_cache.get_decoder(str(video_path)) + if episode_decoder is not None: + decoder = episode_decoder + else: + decoder = decoder_cache.get_decoder(str(video_path)) loaded_ts = [] loaded_frames = [] diff --git a/tests/datasets/test_byte_index.py b/tests/datasets/test_byte_index.py new file mode 100644 index 000000000..1990156da --- /dev/null +++ b/tests/datasets/test_byte_index.py @@ -0,0 +1,150 @@ +"""Acceptance tests for manifest byte-index sidecars. + +Run on a compute node (not login-node): + + srun --partition=hopper-dev --nodes=1 --ntasks=1 --cpus-per-task=8 --mem=32G --time=00:30:00 \\ + bash -lc 'cd /admin/home/pepijn/lerobot && conda run --no-capture-output -n lerobot \\ + env -u HF_HUB_ENABLE_HF_TRANSFER python -m pytest tests/datasets/test_byte_index.py -m integration -v' +""" + +from __future__ import annotations + +import json +import socket + +import pytest + +pytest.importorskip("torchcodec") + +REPO = "allenai/MolmoAct2-BimanualYAM-Dataset" +REV = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c" +BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket" +MAX_EPISODES = 64 + +COMPUTE_NODE = pytest.mark.skipif( + "login" in socket.gethostname(), + reason="run on compute node via srun (see module docstring), not login-node", +) + + +@pytest.fixture(scope="module") +def byte_index_dir(tmp_path_factory): + from lerobot.datasets.byte_index_builder import build_byte_index_tables, write_byte_index + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata + + out = tmp_path_factory.mktemp("byte_index") + meta = LeRobotDatasetMetadata(REPO, revision=REV) + files, episodes, _ = build_byte_index_tables( + meta, BUCKET, workers=4, max_episodes=MAX_EPISODES, include_keyframes=False + ) + write_byte_index(out, files, episodes, None, merge_existing=False) + return out, meta + + +@pytest.mark.integration +@COMPUTE_NODE +def test_index_load_fast_and_small(byte_index_dir): + from lerobot.datasets.byte_index import EpisodeByteIndex + + out, meta = byte_index_dir + index = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES) + assert index.load_time_s < 1.0 + assert index.resident_bytes < 1_000_000_000 + + +@pytest.mark.integration +@COMPUTE_NODE +def test_tight_fetch_under_25mb(byte_index_dir): + from lerobot.datasets.byte_index import EpisodeByteIndex + from lerobot.datasets.byte_index_builder import build_byte_index_in_memory + from lerobot.datasets.episode_byte_cache import EpisodeByteCache + + _, meta = byte_index_dir + index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES) + cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET) + for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]: + cache.submit_prefetch(ep) + cache.ensure_ready(ep) + stats = cache.stats.stats_dict() + assert stats["byte_cache_bytes_per_miss"] < 25 * 1024 * 1024 + + +@pytest.mark.integration +@COMPUTE_NODE +def test_in_memory_build_matches_parquet(byte_index_dir): + from lerobot.datasets.byte_index import EpisodeByteIndex + from lerobot.datasets.byte_index_builder import build_byte_index_in_memory + + out, meta = byte_index_dir + disk = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES) + mem = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES) + for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]: + for cam in meta.video_keys: + a = disk.lookup(ep, cam) + b = mem.lookup(ep, cam) + assert a.mdat_offset == b.mdat_offset + assert a.mdat_length == b.mdat_length + assert abs(a.first_pts - b.first_pts) < 1e-6 + + +@pytest.mark.integration +@COMPUTE_NODE +def test_custom_frame_mappings_available(byte_index_dir): + from lerobot.datasets.byte_index_builder import build_byte_index_in_memory + + _, meta = byte_index_dir + index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES) + cam = meta.video_keys[0] + ep = MAX_EPISODES // 2 + payload = index.custom_frame_mappings(ep, cam) + assert payload is not None + data = json.loads(payload) + assert len(data["frames"]) > 10 + assert any(f["key_frame"] for f in data["frames"]) + assert all("pts" in f and "duration" in f for f in data["frames"]) + + +@pytest.mark.integration +@COMPUTE_NODE +def test_metadata_skip_decoder_init(byte_index_dir): + from lerobot.datasets.byte_index_builder import build_byte_index_in_memory + from lerobot.datasets.episode_byte_cache import EpisodeByteCache + + _, meta = byte_index_dir + index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES) + cache = EpisodeByteCache(index, max_bytes=8_000_000_000, data_root=BUCKET) + cam = meta.video_keys[0] + ep = 0 + cache.submit_prefetch(ep) + cache.ensure_ready(ep) + dec = cache.get_decoder(ep, cam) + assert dec.metadata.num_frames is not None + assert dec.metadata.num_frames > 0 + begin = float(dec.metadata.begin_stream_seconds) + end = float(dec.metadata.end_stream_seconds) + ts = begin + 0.5 * (end - begin) + frame = dec.get_frames_played_at([ts]).data + assert frame.ndim == 4 + + +@pytest.mark.integration +@COMPUTE_NODE +def test_sparse_decode_produces_frames(byte_index_dir): + from lerobot.datasets.byte_index_builder import build_byte_index_in_memory + from lerobot.datasets.episode_byte_cache import EpisodeByteCache + + _, meta = byte_index_dir + index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES) + cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET) + cam = meta.video_keys[0] + ep = 0 + cache.submit_prefetch(ep) + cache.ensure_ready(ep) + dec = cache.get_decoder(ep, cam) + begin = float(dec.metadata.begin_stream_seconds) + end = float(dec.metadata.end_stream_seconds) + ts = begin + 0.5 * (end - begin) + frame = dec.get_frames_played_at([ts]).data + assert frame.ndim == 4 + assert frame.numel() > 0 + assert float(frame.float().std()) > 1.0