mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Add in-memory byte index and manifest-driven episode MP4 cache.
Build moov-derived byte ranges in RAM or from sidecar parquet, fetch tight mdat slices over the network, and decode via TorchCodec custom_frame_mappings to skip full-file metadata scans. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
+8
-3
@@ -335,9 +335,14 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
# Temporary: the native streaming pipeline needs batch(by_column=...) to survive shard/shuffle
|
||||
# re-creation (datasets#8259), reshard() per row group (#8193), and shuffle(max_buffer_input_shards=...)
|
||||
# (#8194) — all merged, not yet in a tagged 5.0 release. Pin to the merge commit until the next
|
||||
# datasets release ships them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", rev = "2c45eab1bb975ac3d846f2aa6217b82adec8eba3" }
|
||||
# (#8194) — all merged, not yet in a tagged 5.0 release. Track main until the next datasets release ships
|
||||
# them, then drop this and rely on the `datasets>=5.0.0` floor in `dependencies`.
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
|
||||
# Temporary: huggingface_hub main carries the 408-retry fix (not yet released). NOTE: main still closes the
|
||||
# shared httpx.Client on every ConnectError, which races with concurrent streaming requests
|
||||
# ("Cannot send a request, as the client has been closed"); we patch that out locally in
|
||||
# huggingface_hub/utils/_http.py. A fresh `uv sync` re-installs main *without* that local patch.
|
||||
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "main" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
"""Build mmap-able byte-index sidecars for LeRobot streaming datasets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.datasets.byte_index_builder import (
|
||||
build_byte_index_tables,
|
||||
load_existing_file_ids,
|
||||
write_byte_index,
|
||||
)
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Build LeRobot video byte-index sidecar.")
|
||||
parser.add_argument("--repo-id", required=True)
|
||||
parser.add_argument("--revision", default=None)
|
||||
parser.add_argument("--data-root", required=True, help="fsspec root for videos/ + data/")
|
||||
parser.add_argument("--output", type=Path, required=True, help="Output meta/byte_index directory")
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--max-episodes", type=int, default=None, help="Limit episodes (debug/smoke)")
|
||||
parser.add_argument("--no-keyframes", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
output = args.output
|
||||
existing = load_existing_file_ids(output)
|
||||
if existing:
|
||||
logger.info("resuming: %s files already indexed", len(existing))
|
||||
|
||||
files_tbl, episodes_tbl, keyframes_tbl = build_byte_index_tables(
|
||||
meta,
|
||||
args.data_root,
|
||||
include_keyframes=not args.no_keyframes,
|
||||
workers=args.workers,
|
||||
existing_files=existing,
|
||||
max_episodes=args.max_episodes,
|
||||
)
|
||||
write_byte_index(output, files_tbl, episodes_tbl, keyframes_tbl, merge_existing=True)
|
||||
logger.info("wrote byte index to %s", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Runtime in-memory byte index loaded from precomputed sidecar parquet."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from .byte_index_builder import BYTE_INDEX_DIR, EPISODES_NAME, FILES_NAME, KEYFRAMES_NAME
|
||||
from .mp4_episode_slice import episode_custom_frame_mappings_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EpisodeSliceLookup:
|
||||
global_episode_id: int
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
frame_count: int
|
||||
first_pts: float
|
||||
last_pts: float
|
||||
avg_fps: float
|
||||
|
||||
@property
|
||||
def fetch_bytes(self) -> int:
|
||||
return self.mdat_length
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileLookup:
|
||||
file_id: int
|
||||
file_path: str
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_length: int
|
||||
faststart: bool
|
||||
avg_fps: float
|
||||
codec: str
|
||||
|
||||
|
||||
class EpisodeByteIndex:
|
||||
"""Columnar byte-index resident in numpy arrays for O(1) episode lookup."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_dir: str | Path | None,
|
||||
*,
|
||||
video_keys: list[str],
|
||||
num_episodes: int,
|
||||
mmap: bool = True,
|
||||
files_table: pa.Table | None = None,
|
||||
episodes_table: pa.Table | None = None,
|
||||
mp4_by_rel: dict[str, Any] | None = None,
|
||||
):
|
||||
self.index_dir = Path(index_dir) if index_dir is not None else None
|
||||
self.video_keys = list(video_keys)
|
||||
self.num_episodes = num_episodes
|
||||
self.num_cameras = len(video_keys)
|
||||
self._cam_to_idx = {cam: i for i, cam in enumerate(self.video_keys)}
|
||||
self._mp4_by_rel = mp4_by_rel
|
||||
self._frame_mappings_by_gid: dict[int, bytes] = {}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
if files_table is not None and episodes_table is not None:
|
||||
files_tbl, episodes_tbl = files_table, episodes_table
|
||||
else:
|
||||
if self.index_dir is None:
|
||||
raise ValueError("index_dir or in-memory tables required")
|
||||
files_path = self.index_dir / FILES_NAME
|
||||
episodes_path = self.index_dir / EPISODES_NAME
|
||||
if not files_path.exists() or not episodes_path.exists():
|
||||
raise FileNotFoundError(f"byte index missing under {self.index_dir}")
|
||||
files_tbl = pq.read_table(files_path, memory_map=mmap)
|
||||
episodes_tbl = pq.read_table(episodes_path, memory_map=mmap)
|
||||
|
||||
self._load_tables(files_tbl, episodes_tbl, mmap=mmap)
|
||||
self.build_time_s = time.perf_counter() - t0
|
||||
self.load_time_s = self.build_time_s
|
||||
|
||||
def _load_tables(self, files_tbl: pa.Table, episodes_tbl: pa.Table, *, mmap: bool) -> None:
|
||||
def col(tbl, name: str):
|
||||
array = tbl.column(name).combine_chunks()
|
||||
if pa.types.is_boolean(array.type):
|
||||
return array.to_numpy(zero_copy_only=False)
|
||||
return array.to_numpy()
|
||||
|
||||
self.file_id = col(files_tbl, "file_id")
|
||||
self.file_path = files_tbl.column("file_path").to_pylist()
|
||||
self.file_size = col(files_tbl, "file_size")
|
||||
self.moov_offset = col(files_tbl, "moov_offset")
|
||||
self.moov_length = col(files_tbl, "moov_length")
|
||||
self.header_length = col(files_tbl, "header_length")
|
||||
self.faststart = col(files_tbl, "faststart")
|
||||
self.file_avg_fps = col(files_tbl, "avg_fps")
|
||||
self.codec = files_tbl.column("codec").to_pylist()
|
||||
|
||||
ep = episodes_tbl
|
||||
n = len(ep)
|
||||
gid = col(ep, "global_episode_id")
|
||||
order = np.argsort(gid)
|
||||
self._global_episode_id = gid[order]
|
||||
self._episode_index = col(ep, "episode_index")[order]
|
||||
self._camera_index = col(ep, "camera_index")[order]
|
||||
self._file_id = col(ep, "file_id")[order]
|
||||
self._mdat_offset = col(ep, "mdat_offset")[order]
|
||||
self._mdat_length = col(ep, "mdat_length")[order]
|
||||
self._frame_count = col(ep, "frame_count")[order]
|
||||
self._first_pts = col(ep, "first_pts")[order]
|
||||
self._last_pts = col(ep, "last_pts")[order]
|
||||
|
||||
expected = self.num_episodes * self.num_cameras
|
||||
if n != expected:
|
||||
raise ValueError(f"byte index episodes rows {n} != expected {expected}")
|
||||
|
||||
if self.index_dir is not None:
|
||||
keyframes_path = self.index_dir / KEYFRAMES_NAME
|
||||
if keyframes_path.exists():
|
||||
kf_tbl = pq.read_table(keyframes_path, memory_map=mmap)
|
||||
self._keyframes_rows = len(kf_tbl)
|
||||
else:
|
||||
self._keyframes_rows = 0
|
||||
else:
|
||||
self._keyframes_rows = 0
|
||||
|
||||
self.resident_bytes = int(
|
||||
self._global_episode_id.nbytes
|
||||
+ self._file_id.nbytes
|
||||
+ self._mdat_offset.nbytes
|
||||
+ self._mdat_length.nbytes
|
||||
+ self.file_size.nbytes
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata_root(cls, meta_root: Path, *, video_keys: list[str], num_episodes: int) -> EpisodeByteIndex:
|
||||
return cls(meta_root / BYTE_INDEX_DIR, video_keys=video_keys, num_episodes=num_episodes)
|
||||
|
||||
@classmethod
|
||||
def from_memory_build(
|
||||
cls,
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
workers: int = 8,
|
||||
max_episodes: int | None = None,
|
||||
include_frame_mappings_cache: bool = True,
|
||||
) -> EpisodeByteIndex:
|
||||
"""Build a complete byte index in RAM (no parquet write, no dataset push)."""
|
||||
from .byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
return build_byte_index_in_memory(
|
||||
meta,
|
||||
data_root,
|
||||
workers=workers,
|
||||
max_episodes=max_episodes,
|
||||
include_frame_mappings_cache=include_frame_mappings_cache,
|
||||
)
|
||||
|
||||
def lookup(self, episode_index: int, camera_key: str) -> EpisodeSliceLookup:
|
||||
cam_idx = self._cam_to_idx[camera_key]
|
||||
gid = episode_index * self.num_cameras + cam_idx
|
||||
row = int(gid)
|
||||
if row < 0 or row >= len(self._global_episode_id):
|
||||
raise IndexError(f"episode_index={episode_index} camera={camera_key} out of range")
|
||||
file_id = int(self._file_id[row])
|
||||
return EpisodeSliceLookup(
|
||||
global_episode_id=gid,
|
||||
file_id=file_id,
|
||||
mdat_offset=int(self._mdat_offset[row]),
|
||||
mdat_length=int(self._mdat_length[row]),
|
||||
frame_count=int(self._frame_count[row]),
|
||||
first_pts=float(self._first_pts[row]),
|
||||
last_pts=float(self._last_pts[row]),
|
||||
avg_fps=float(self.file_avg_fps[file_id]),
|
||||
)
|
||||
|
||||
def file_lookup(self, file_id: int) -> FileLookup:
|
||||
return FileLookup(
|
||||
file_id=file_id,
|
||||
file_path=self.file_path[file_id],
|
||||
file_size=int(self.file_size[file_id]),
|
||||
moov_offset=int(self.moov_offset[file_id]),
|
||||
moov_length=int(self.moov_length[file_id]),
|
||||
header_length=int(self.header_length[file_id]),
|
||||
faststart=bool(self.faststart[file_id]),
|
||||
avg_fps=float(self.file_avg_fps[file_id]),
|
||||
codec=self.codec[file_id],
|
||||
)
|
||||
|
||||
def header_byte_range(self, file_id: int) -> tuple[int, int]:
|
||||
length = int(self.header_length[file_id])
|
||||
return 0, max(0, length - 1)
|
||||
|
||||
def custom_frame_mappings(self, episode_index: int, camera_key: str) -> bytes | None:
|
||||
cam_idx = self._cam_to_idx[camera_key]
|
||||
gid = episode_index * self.num_cameras + cam_idx
|
||||
cached = self._frame_mappings_by_gid.get(gid)
|
||||
if cached is not None:
|
||||
return cached
|
||||
if self._mp4_by_rel is None:
|
||||
return None
|
||||
lookup = self.lookup(episode_index, camera_key)
|
||||
rel = self.file_path[lookup.file_id]
|
||||
mp4_index = self._mp4_by_rel.get(rel)
|
||||
if mp4_index is None:
|
||||
return None
|
||||
payload = episode_custom_frame_mappings_json(mp4_index, lookup.first_pts, lookup.last_pts)
|
||||
self._frame_mappings_by_gid[gid] = payload
|
||||
return payload
|
||||
|
||||
def stats_dict(self) -> dict[str, float | int]:
|
||||
return {
|
||||
"load_time_s": self.load_time_s,
|
||||
"build_time_s": self.build_time_s,
|
||||
"resident_bytes": self.resident_bytes,
|
||||
"frame_mappings_cached": len(self._frame_mappings_by_gid),
|
||||
"mp4_indices_cached": len(self._mp4_by_rel or {}),
|
||||
"num_files": len(self.file_path),
|
||||
"num_episode_rows": len(self._global_episode_id),
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
"""Build mmap-able byte-index sidecars for LeRobot streaming video fetch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from .mp4_episode_slice import (
|
||||
HEADER_PROBE_BYTES,
|
||||
MAX_HEADER_PROBE_BYTES,
|
||||
average_fps_from_index,
|
||||
episode_keyframes,
|
||||
parse_mp4_file_layout,
|
||||
parse_mp4_index,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BYTE_INDEX_DIR = "meta/byte_index"
|
||||
FILES_NAME = "files.parquet"
|
||||
EPISODES_NAME = "episodes.parquet"
|
||||
KEYFRAMES_NAME = "keyframes.parquet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexedFile:
|
||||
file_id: int
|
||||
file_path: str
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_length: int
|
||||
faststart: bool
|
||||
avg_fps: float
|
||||
codec: str
|
||||
|
||||
|
||||
def fetch_header_bytes(path: str, file_size: int) -> bytes:
|
||||
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
probe = HEADER_PROBE_BYTES
|
||||
while True:
|
||||
with fs.open(path, "rb", block_size=max(probe, 2**20), cache_type="none") as f:
|
||||
header = f.read(min(probe, file_size))
|
||||
try:
|
||||
parse_mp4_file_layout(header, file_size)
|
||||
return header
|
||||
except ValueError as exc:
|
||||
if probe >= min(MAX_HEADER_PROBE_BYTES, file_size) or "mdat box not found" not in str(exc):
|
||||
raise
|
||||
probe = min(probe * 2, MAX_HEADER_PROBE_BYTES, file_size)
|
||||
|
||||
|
||||
def index_video_file(path: str, *, rel_path: str | None = None) -> tuple[IndexedFile, Any]:
|
||||
fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
file_size = fs.info(path)["size"]
|
||||
header = fetch_header_bytes(path, file_size)
|
||||
layout = parse_mp4_file_layout(header, file_size)
|
||||
if not layout.faststart:
|
||||
logger.warning("non-faststart MP4 (moov after mdat): %s", path)
|
||||
mp4_index = parse_mp4_index(header, file_size)
|
||||
indexed = IndexedFile(
|
||||
file_id=-1,
|
||||
file_path=rel_path or path,
|
||||
file_size=file_size,
|
||||
moov_offset=layout.moov_offset,
|
||||
moov_length=layout.moov_length,
|
||||
header_length=layout.header_end,
|
||||
faststart=layout.faststart,
|
||||
avg_fps=average_fps_from_index(mp4_index),
|
||||
codec=layout.codec,
|
||||
)
|
||||
return indexed, mp4_index
|
||||
|
||||
|
||||
def build_byte_index_tables(
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
file_paths: list[str] | None = None,
|
||||
include_keyframes: bool = True,
|
||||
workers: int = 8,
|
||||
existing_files: dict[str, int] | None = None,
|
||||
max_episodes: int | None = None,
|
||||
return_mp4_indices: bool = False,
|
||||
complete_files_table: bool = False,
|
||||
) -> tuple[pa.Table, pa.Table, pa.Table | None] | tuple[pa.Table, pa.Table, pa.Table | None, dict[str, Any]]:
|
||||
"""Build files/episodes/(optional keyframes) Arrow tables."""
|
||||
video_keys = list(meta.video_keys)
|
||||
n_cams = len(video_keys)
|
||||
cam_to_idx = {cam: i for i, cam in enumerate(video_keys)}
|
||||
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
|
||||
|
||||
rel_paths: set[str] = set()
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in video_keys:
|
||||
rel_paths.add(str(meta.get_video_file_path(ep_idx, cam)))
|
||||
path_by_rel = {rel: f"{data_root.rstrip('/')}/{rel}" for rel in sorted(rel_paths)}
|
||||
if file_paths is None:
|
||||
file_paths = list(path_by_rel.values())
|
||||
rel_by_path = {path_by_rel[rel]: rel for rel in path_by_rel}
|
||||
|
||||
existing_files = existing_files or {}
|
||||
file_meta_by_rel: dict[str, dict[str, Any]] = {}
|
||||
mp4_by_rel: dict[str, Any] = {}
|
||||
next_file_id = max(existing_files.values(), default=-1) + 1
|
||||
|
||||
to_index = [rel for rel in sorted(rel_paths) if rel not in existing_files]
|
||||
if to_index:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel for rel in to_index
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
rel = futures[fut]
|
||||
indexed, mp4_index = fut.result()
|
||||
indexed.file_id = next_file_id
|
||||
mp4_by_rel[rel] = mp4_index
|
||||
file_meta_by_rel[rel] = {
|
||||
"file_id": indexed.file_id,
|
||||
"file_path": rel,
|
||||
"file_size": indexed.file_size,
|
||||
"moov_offset": indexed.moov_offset,
|
||||
"moov_length": indexed.moov_length,
|
||||
"header_length": indexed.header_length,
|
||||
"faststart": indexed.faststart,
|
||||
"avg_fps": indexed.avg_fps,
|
||||
"codec": indexed.codec,
|
||||
}
|
||||
existing_files[rel] = indexed.file_id
|
||||
next_file_id += 1
|
||||
|
||||
missing_rels = {
|
||||
str(meta.get_video_file_path(ep, cam))
|
||||
for ep in range(num_episodes)
|
||||
for cam in video_keys
|
||||
} - set(mp4_by_rel.keys())
|
||||
if missing_rels:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(index_video_file, path_by_rel[rel], rel_path=rel): rel
|
||||
for rel in missing_rels
|
||||
if rel not in mp4_by_rel
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
rel = futures[fut]
|
||||
_, mp4_index = fut.result()
|
||||
mp4_by_rel[rel] = mp4_index
|
||||
|
||||
episode_rows: list[dict[str, Any]] = []
|
||||
keyframe_rows: list[dict[str, Any]] = []
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in video_keys:
|
||||
rel = str(meta.get_video_file_path(ep_idx, cam))
|
||||
path = f"{data_root.rstrip('/')}/{rel}"
|
||||
if rel not in existing_files:
|
||||
raise KeyError(f"file not indexed: {rel}")
|
||||
mp4_index = mp4_by_rel[rel]
|
||||
ep = meta.episodes[ep_idx]
|
||||
from_ts = float(ep[f"videos/{cam}/from_timestamp"])
|
||||
to_ts = float(ep[f"videos/{cam}/to_timestamp"])
|
||||
span = mp4_index.episode_byte_span(from_ts, to_ts)
|
||||
global_episode_id = ep_idx * n_cams + cam_to_idx[cam]
|
||||
mdat_length = span.slice_hi - span.slice_lo + 1
|
||||
episode_rows.append(
|
||||
{
|
||||
"global_episode_id": global_episode_id,
|
||||
"episode_index": ep_idx,
|
||||
"camera_key": cam,
|
||||
"camera_index": cam_to_idx[cam],
|
||||
"file_id": existing_files[rel],
|
||||
"mdat_offset": span.slice_lo,
|
||||
"mdat_length": mdat_length,
|
||||
"frame_count": max(1, round((to_ts - from_ts) * meta.fps)),
|
||||
"first_pts": from_ts,
|
||||
"last_pts": to_ts,
|
||||
}
|
||||
)
|
||||
if include_keyframes:
|
||||
timescale = mp4_index.timescale
|
||||
for pts_s, byte_off in episode_keyframes(mp4_index, from_ts, to_ts):
|
||||
keyframe_rows.append(
|
||||
{
|
||||
"global_episode_id": global_episode_id,
|
||||
"pts": int(round(pts_s * timescale)),
|
||||
"byte_offset": byte_off,
|
||||
}
|
||||
)
|
||||
|
||||
referenced_rels = {
|
||||
str(meta.get_video_file_path(ep, cam)) for ep in range(num_episodes) for cam in video_keys
|
||||
}
|
||||
if complete_files_table:
|
||||
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(referenced_rels)])
|
||||
elif to_index:
|
||||
files_table = pa.Table.from_pylist([file_meta_by_rel[rel] for rel in sorted(to_index)])
|
||||
else:
|
||||
files_table = None
|
||||
episodes_table = pa.Table.from_pylist(episode_rows)
|
||||
keyframes_table = pa.Table.from_pylist(keyframe_rows) if include_keyframes and keyframe_rows else None
|
||||
if return_mp4_indices:
|
||||
return files_table, episodes_table, keyframes_table, mp4_by_rel
|
||||
return files_table, episodes_table, keyframes_table
|
||||
|
||||
|
||||
def build_byte_index_in_memory(
|
||||
meta,
|
||||
data_root: str,
|
||||
*,
|
||||
workers: int = 8,
|
||||
max_episodes: int | None = None,
|
||||
include_frame_mappings_cache: bool = False,
|
||||
):
|
||||
"""Build a complete byte index resident in RAM (no parquet write, no dataset push)."""
|
||||
from .byte_index import EpisodeByteIndex
|
||||
|
||||
num_episodes = meta.total_episodes if max_episodes is None else min(max_episodes, meta.total_episodes)
|
||||
files_tbl, episodes_tbl, _, mp4_by_rel = build_byte_index_tables(
|
||||
meta,
|
||||
data_root,
|
||||
include_keyframes=False,
|
||||
workers=workers,
|
||||
max_episodes=max_episodes,
|
||||
return_mp4_indices=True,
|
||||
complete_files_table=True,
|
||||
)
|
||||
index = EpisodeByteIndex(
|
||||
None,
|
||||
video_keys=list(meta.video_keys),
|
||||
num_episodes=num_episodes,
|
||||
files_table=files_tbl,
|
||||
episodes_table=episodes_tbl,
|
||||
mp4_by_rel=mp4_by_rel,
|
||||
)
|
||||
if include_frame_mappings_cache:
|
||||
for ep_idx in range(num_episodes):
|
||||
for cam in meta.video_keys:
|
||||
index.custom_frame_mappings(ep_idx, cam)
|
||||
return index
|
||||
|
||||
|
||||
def write_byte_index(
|
||||
output_dir: Path,
|
||||
files_table: pa.Table | None,
|
||||
episodes_table: pa.Table,
|
||||
keyframes_table: pa.Table | None = None,
|
||||
*,
|
||||
merge_existing: bool = True,
|
||||
) -> None:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
files_path = output_dir / FILES_NAME
|
||||
episodes_path = output_dir / EPISODES_NAME
|
||||
keyframes_path = output_dir / KEYFRAMES_NAME
|
||||
|
||||
if merge_existing and files_path.exists() and files_table is not None:
|
||||
prev = pq.read_table(files_path)
|
||||
files_table = pa.concat_tables([prev, files_table])
|
||||
|
||||
if files_table is not None:
|
||||
pq.write_table(files_table, files_path)
|
||||
|
||||
pq.write_table(episodes_table, episodes_path)
|
||||
if keyframes_table is not None:
|
||||
if merge_existing and keyframes_path.exists():
|
||||
keyframes_table = pa.concat_tables([pq.read_table(keyframes_path), keyframes_table])
|
||||
pq.write_table(keyframes_table, keyframes_path)
|
||||
|
||||
|
||||
def load_existing_file_ids(index_dir: Path) -> dict[str, int]:
|
||||
files_path = index_dir / FILES_NAME
|
||||
if not files_path.exists():
|
||||
return {}
|
||||
table = pq.read_table(files_path, columns=["file_id", "file_path"])
|
||||
return {row["file_path"]: int(row["file_id"]) for row in table.to_pylist()}
|
||||
@@ -0,0 +1,263 @@
|
||||
"""Node-local LRU byte cache using precomputed byte-index manifest sidecars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
|
||||
from .byte_index import EpisodeByteIndex, EpisodeSliceLookup
|
||||
from .mp4_episode_slice import SparseMp4Reader
|
||||
from .torchcodec_utils import open_video_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
bytes_fetched: int = 0
|
||||
full_file_fallbacks: int = 0
|
||||
prefetch_submitted: int = 0
|
||||
prefetch_waits: int = 0
|
||||
mdat_slices: int = 0
|
||||
prefix_fetches: int = 0
|
||||
fetch_to_buffer_s: float = 0.0
|
||||
buffer_to_decoder_s: float = 0.0
|
||||
buffer_hit_decoder_s: float = 0.0
|
||||
decode_frame_s: float = 0.0
|
||||
decode_frames: int = 0
|
||||
|
||||
def merge(self, other: CacheStats) -> None:
|
||||
for name in self.__dataclass_fields__:
|
||||
setattr(self, name, getattr(self, name) + getattr(other, name))
|
||||
|
||||
def stats_dict(self) -> dict[str, int | float]:
|
||||
avg_miss = self.bytes_fetched / max(1, self.misses)
|
||||
return {
|
||||
"byte_cache_hits": self.hits,
|
||||
"byte_cache_misses": self.misses,
|
||||
"byte_cache_bytes_fetched": self.bytes_fetched,
|
||||
"byte_cache_bytes_per_miss": avg_miss,
|
||||
"byte_cache_full_file_fallbacks": self.full_file_fallbacks,
|
||||
"byte_cache_prefetch_submitted": self.prefetch_submitted,
|
||||
"byte_cache_prefetch_waits": self.prefetch_waits,
|
||||
"byte_cache_mdat_slices": self.mdat_slices,
|
||||
"byte_cache_prefix_fetches": self.prefix_fetches,
|
||||
"fetch_to_buffer_ms_per_miss": 1000 * self.fetch_to_buffer_s / max(1, self.misses),
|
||||
"buffer_to_decoder_ms_per_miss": 1000 * self.buffer_to_decoder_s / max(1, self.misses),
|
||||
"buffer_hit_decoder_ms_per_hit": 1000 * self.buffer_hit_decoder_s / max(1, self.hits),
|
||||
"decode_ms_per_frame": 1000 * self.decode_frame_s / max(1, self.decode_frames),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _EpisodeEntry:
|
||||
decoders: dict[str, Any] = field(default_factory=dict)
|
||||
ready: threading.Event = field(default_factory=threading.Event)
|
||||
error: Exception | None = None
|
||||
|
||||
|
||||
class RangeFetcher:
|
||||
"""Sequential byte-range GETs via fsspec."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self._fs = fsspec.filesystem("hf") if path.startswith("hf://") else fsspec.filesystem("file")
|
||||
|
||||
def fetch(self, lo: int, hi: int) -> bytes:
|
||||
if hi < lo:
|
||||
return b""
|
||||
with self._fs.open(self.path, "rb", block_size=max(2**20, hi - lo + 1), cache_type="none") as f:
|
||||
f.seek(lo)
|
||||
return f.read(hi - lo + 1)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
"""Manifest-driven episode MP4 fetch + in-memory sparse decode."""
|
||||
|
||||
MAX_BYTES_PER_MISS = 25 * 1024 * 1024
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
byte_index: EpisodeByteIndex,
|
||||
max_bytes: int,
|
||||
*,
|
||||
data_root: str,
|
||||
max_prefetch_workers: int = 4,
|
||||
):
|
||||
if max_bytes <= 0:
|
||||
raise ValueError(f"max_bytes must be positive; got {max_bytes}")
|
||||
self.byte_index = byte_index
|
||||
self.max_bytes = max_bytes
|
||||
self.data_root = data_root.rstrip("/")
|
||||
self._bytes_used = 0
|
||||
self._lock = threading.Lock()
|
||||
self._cache: OrderedDict[tuple[Any, ...], tuple[Any, int]] = OrderedDict()
|
||||
self._header_cache: dict[int, bytes] = {}
|
||||
self._fetcher_cache: dict[int, RangeFetcher] = {}
|
||||
self._episodes: dict[int, _EpisodeEntry] = {}
|
||||
self._stats = CacheStats()
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_prefetch_workers)
|
||||
self._futures: dict[int, Future] = {}
|
||||
|
||||
@property
|
||||
def stats(self) -> CacheStats:
|
||||
with self._lock:
|
||||
return CacheStats(**{k: getattr(self._stats, k) for k in CacheStats.__dataclass_fields__})
|
||||
|
||||
def submit_prefetch(self, ep_idx: int) -> None:
|
||||
with self._lock:
|
||||
if ep_idx in self._episodes or ep_idx in self._futures:
|
||||
return
|
||||
self._stats.prefetch_submitted += 1
|
||||
fut = self._executor.submit(self._prefetch_episode, ep_idx)
|
||||
self._futures[ep_idx] = fut
|
||||
|
||||
def ensure_ready(self, ep_idx: int) -> None:
|
||||
with self._lock:
|
||||
fut = self._futures.pop(ep_idx, None)
|
||||
if fut is not None:
|
||||
with self._lock:
|
||||
self._stats.prefetch_waits += 1
|
||||
fut.result()
|
||||
entry = self._episodes.get(ep_idx)
|
||||
if entry is None:
|
||||
raise KeyError(f"episode {ep_idx} not prefetched")
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
|
||||
def get_decoder(self, ep_idx: int, video_key: str) -> Any:
|
||||
entry = self._episodes[ep_idx]
|
||||
if entry.error is not None:
|
||||
raise entry.error
|
||||
entry.ready.wait()
|
||||
return entry.decoders[video_key]
|
||||
|
||||
def close(self) -> None:
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
def _prefetch_episode(self, ep_idx: int) -> None:
|
||||
entry = _EpisodeEntry()
|
||||
self._episodes[ep_idx] = entry
|
||||
try:
|
||||
for cam in self.byte_index.video_keys:
|
||||
entry.decoders[cam] = self._get_or_build_decoder(ep_idx, cam)
|
||||
except Exception as exc:
|
||||
entry.error = exc
|
||||
finally:
|
||||
entry.ready.set()
|
||||
|
||||
def _get_or_build_decoder(self, ep_idx: int, cam: str) -> Any:
|
||||
key = (ep_idx, cam)
|
||||
with self._lock:
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
self._cache.move_to_end(key)
|
||||
self._stats.hits += 1
|
||||
payload, _ = cached
|
||||
t0 = time.perf_counter()
|
||||
dec = self._decoder_from_payload(payload, ep_idx, cam)
|
||||
with self._lock:
|
||||
self._stats.buffer_hit_decoder_s += time.perf_counter() - t0
|
||||
return dec
|
||||
|
||||
payload, payload_bytes, dec = self._fetch_manifest_slice(ep_idx, cam)
|
||||
|
||||
with self._lock:
|
||||
self._stats.misses += 1
|
||||
if payload_bytes > self.MAX_BYTES_PER_MISS:
|
||||
logger.warning(
|
||||
"byte cache miss fetched %.1f MB (>25 MB) for ep=%s cam=%s",
|
||||
payload_bytes / 1e6,
|
||||
ep_idx,
|
||||
cam,
|
||||
)
|
||||
self._evict_until(payload_bytes)
|
||||
self._cache[key] = (payload, payload_bytes)
|
||||
self._bytes_used += payload_bytes
|
||||
return dec
|
||||
|
||||
def _fetch_manifest_slice(self, ep_idx: int, cam: str) -> tuple[SparseMp4Reader, int, Any]:
|
||||
lookup = self.byte_index.lookup(ep_idx, cam)
|
||||
file_info = self.byte_index.file_lookup(lookup.file_id)
|
||||
fetcher = self._get_fetcher(lookup.file_id, file_info.file_path)
|
||||
t_fetch = time.perf_counter()
|
||||
header = self._get_header_bytes(lookup.file_id, fetcher, file_info.header_length)
|
||||
lo = lookup.mdat_offset
|
||||
hi = lo + lookup.mdat_length - 1
|
||||
mdat = fetcher.fetch(lo, hi)
|
||||
fetch_s = time.perf_counter() - t_fetch
|
||||
nbytes = len(header) + len(mdat)
|
||||
with self._lock:
|
||||
self._stats.bytes_fetched += nbytes
|
||||
self._stats.mdat_slices += 1
|
||||
self._stats.fetch_to_buffer_s += fetch_s
|
||||
|
||||
def lazy_fetch(pos: int, end: int) -> bytes:
|
||||
data = fetcher.fetch(pos, end)
|
||||
with self._lock:
|
||||
self._stats.bytes_fetched += len(data)
|
||||
return data
|
||||
|
||||
reader = SparseMp4Reader(
|
||||
file_size=file_info.file_size,
|
||||
header=header,
|
||||
mdat_lo=lo,
|
||||
mdat_bytes=mdat,
|
||||
lazy_fetch=lazy_fetch,
|
||||
)
|
||||
t_init = time.perf_counter()
|
||||
dec = self._decoder_from_payload(reader, ep_idx, cam)
|
||||
self._validate_decoder(dec, lookup)
|
||||
init_s = time.perf_counter() - t_init
|
||||
with self._lock:
|
||||
self._stats.buffer_to_decoder_s += init_s
|
||||
self._rewind_payload(reader)
|
||||
return reader, nbytes, dec
|
||||
|
||||
def _get_fetcher(self, file_id: int, rel_path: str) -> RangeFetcher:
|
||||
if file_id not in self._fetcher_cache:
|
||||
path = rel_path if rel_path.startswith("hf://") else f"{self.data_root}/{rel_path}"
|
||||
self._fetcher_cache[file_id] = RangeFetcher(path)
|
||||
return self._fetcher_cache[file_id]
|
||||
|
||||
def _get_header_bytes(self, file_id: int, fetcher: RangeFetcher, header_length: int) -> bytes:
|
||||
if file_id in self._header_cache:
|
||||
return self._header_cache[file_id]
|
||||
hi = max(0, header_length - 1)
|
||||
header = fetcher.fetch(0, hi)
|
||||
with self._lock:
|
||||
self._header_cache[file_id] = header
|
||||
self._stats.bytes_fetched += len(header)
|
||||
return header
|
||||
|
||||
def _decoder_from_payload(
|
||||
self, payload: SparseMp4Reader, ep_idx: int, cam: str
|
||||
) -> Any:
|
||||
payload.seek(0)
|
||||
mappings = self.byte_index.custom_frame_mappings(ep_idx, cam)
|
||||
return open_video_decoder(payload, frame_mappings=mappings)
|
||||
|
||||
def _validate_decoder(self, dec: Any, lookup: EpisodeSliceLookup) -> None:
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
duration = max(0.01, end - begin)
|
||||
for ts in (begin + 1e-3, begin + 0.5 * duration, end - 1e-3):
|
||||
dec.get_frames_played_at([ts]).data
|
||||
|
||||
def _rewind_payload(self, payload: SparseMp4Reader) -> None:
|
||||
payload.seek(0)
|
||||
|
||||
def _evict_until(self, need: int) -> None:
|
||||
while self._bytes_used + need > self.max_bytes and self._cache:
|
||||
_, (_, size) = self._cache.popitem(last=False)
|
||||
self._bytes_used -= size
|
||||
@@ -0,0 +1,555 @@
|
||||
"""MP4 moov parsing and tight per-episode mdat byte-range fetching.
|
||||
|
||||
LeRobot v3 concatenates episodes into shared MP4 files (faststart: moov at head).
|
||||
For streaming we fetch only the file header plus the episode's contiguous mdat span
|
||||
instead of the ``0..episode_end`` prefix.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import struct
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable
|
||||
|
||||
KEYFRAME_PAD_S = 0.1
|
||||
HEADER_PROBE_BYTES = 4 * 1024 * 1024
|
||||
MAX_HEADER_PROBE_BYTES = 16 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mp4FileLayout:
|
||||
file_size: int
|
||||
moov_offset: int
|
||||
moov_length: int
|
||||
header_end: int
|
||||
mdat_offset: int
|
||||
mdat_size: int
|
||||
faststart: bool
|
||||
codec: str
|
||||
|
||||
|
||||
def parse_mp4_file_layout(header_bytes: bytes, file_size: int) -> Mp4FileLayout:
|
||||
"""Return top-level MP4 layout (moov/mdat positions, faststart flag)."""
|
||||
boxes = list(_iter_boxes(header_bytes))
|
||||
moov_offset = mdat_offset = -1
|
||||
moov_length = mdat_size = 0
|
||||
for off, size, typ, _ in boxes:
|
||||
if typ == b"moov" and moov_offset < 0:
|
||||
moov_offset, moov_length = off, size
|
||||
if typ == b"mdat" and mdat_offset < 0:
|
||||
mdat_offset, mdat_size = off, size
|
||||
if moov_offset < 0:
|
||||
raise ValueError("moov box not found in header probe")
|
||||
if mdat_offset < 0:
|
||||
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
|
||||
faststart = moov_offset < mdat_offset
|
||||
header_end = mdat_offset
|
||||
codec = _parse_video_codec(header_bytes)
|
||||
return Mp4FileLayout(
|
||||
file_size=file_size,
|
||||
moov_offset=moov_offset,
|
||||
moov_length=moov_length,
|
||||
header_end=header_end,
|
||||
mdat_offset=mdat_offset,
|
||||
mdat_size=mdat_size,
|
||||
faststart=faststart,
|
||||
codec=codec,
|
||||
)
|
||||
|
||||
|
||||
def _parse_video_codec(header_bytes: bytes) -> str:
|
||||
moov = _find_box_payload(header_bytes, b"moov")
|
||||
if moov is None:
|
||||
return "unknown"
|
||||
trak = _find_video_trak(moov)
|
||||
if trak is None:
|
||||
return "unknown"
|
||||
stsd = _find_box_payload(_find_box_payload(trak, b"stbl") or b"", b"stsd")
|
||||
if stsd is None or len(stsd) < 12:
|
||||
return "unknown"
|
||||
# stsd: version(1)+flags(3)+entry_count(4)+entry_size(4)+codec(4)
|
||||
if len(stsd) >= 12:
|
||||
return stsd[8:12].decode("latin1", errors="replace").strip("\x00")
|
||||
return "unknown"
|
||||
|
||||
|
||||
def average_fps_from_index(index: Mp4VideoIndex) -> float:
|
||||
index.ensure_tables()
|
||||
if index.num_samples < 2:
|
||||
return 30.0
|
||||
duration = index.sample_pts(index.num_samples - 1)
|
||||
if duration <= 0:
|
||||
return 30.0
|
||||
return index.num_samples / duration
|
||||
|
||||
|
||||
def episode_custom_frame_mappings_json(
|
||||
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
|
||||
) -> bytes:
|
||||
"""Build TorchCodec ``custom_frame_mappings`` JSON for one episode span."""
|
||||
import json
|
||||
|
||||
index.ensure_tables()
|
||||
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
|
||||
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
|
||||
hi_idx = min(hi_idx, index.num_samples - 1)
|
||||
lo_idx = _keyframe_back(index.sync_samples, lo_idx)
|
||||
|
||||
sync = set(index.sync_samples)
|
||||
timescale = index.timescale
|
||||
# stts deltas for duration per sample (expand stts entries to per-sample delta)
|
||||
sample_deltas: list[int] = []
|
||||
for count, delta in index.stts:
|
||||
sample_deltas.extend([delta] * count)
|
||||
while len(sample_deltas) < index.num_samples:
|
||||
sample_deltas.append(sample_deltas[-1] if sample_deltas else timescale // 30)
|
||||
|
||||
frames = []
|
||||
for idx in range(lo_idx, hi_idx + 1):
|
||||
frames.append(
|
||||
{
|
||||
"pts": int(round(index._pts[idx] * timescale)),
|
||||
"duration": int(sample_deltas[idx]),
|
||||
"key_frame": int((idx + 1) in sync) if sync else int(idx == lo_idx),
|
||||
}
|
||||
)
|
||||
return json.dumps({"frames": frames}).encode()
|
||||
|
||||
|
||||
def episode_keyframes(
|
||||
index: Mp4VideoIndex, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S
|
||||
) -> list[tuple[float, int]]:
|
||||
"""Return (pts_seconds, byte_offset) for sync samples in the episode span."""
|
||||
index.ensure_tables()
|
||||
span = index.episode_byte_span(from_ts, to_ts, keyframe_pad_s)
|
||||
lo_idx = _first_sample_at_or_after(index._pts, max(0.0, from_ts - keyframe_pad_s))
|
||||
hi_idx = _last_sample_at_or_before(index._pts, to_ts + keyframe_pad_s)
|
||||
if not index.sync_samples:
|
||||
return [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
|
||||
out: list[tuple[float, int]] = []
|
||||
for sync_one_based in index.sync_samples:
|
||||
idx = sync_one_based - 1
|
||||
if lo_idx <= idx <= hi_idx:
|
||||
out.append((index.sample_pts(idx), index.sample_offset(idx)))
|
||||
return out or [(index.sample_pts(lo_idx), index.sample_offset(lo_idx))]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeByteSpan:
|
||||
"""Absolute file byte ranges to fetch for one episode."""
|
||||
|
||||
file_size: int
|
||||
header_end: int
|
||||
slice_lo: int
|
||||
slice_hi: int
|
||||
|
||||
@property
|
||||
def header_bytes(self) -> tuple[int, int]:
|
||||
return 0, self.header_end - 1
|
||||
|
||||
@property
|
||||
def mdat_bytes(self) -> tuple[int, int]:
|
||||
return self.slice_lo, self.slice_hi
|
||||
|
||||
@property
|
||||
def total_fetch_bytes(self) -> int:
|
||||
header = self.header_end
|
||||
mdat = self.slice_hi - self.slice_lo + 1
|
||||
return header + mdat
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mp4VideoIndex:
|
||||
file_size: int
|
||||
header_end: int
|
||||
mdat_offset: int
|
||||
mdat_size: int
|
||||
timescale: int
|
||||
stts: list[tuple[int, int]]
|
||||
stsz: list[int]
|
||||
stsc: list[tuple[int, int, int]]
|
||||
stco: list[int]
|
||||
sync_samples: list[int]
|
||||
_pts: list[float] = field(default_factory=list, repr=False)
|
||||
_offsets: list[int] = field(default_factory=list, repr=False)
|
||||
|
||||
def ensure_tables(self) -> None:
|
||||
if self._pts:
|
||||
return
|
||||
self._pts = _pts_from_stts(self.stts, self.timescale)
|
||||
self._offsets = _sample_byte_offsets(self.stsc, self.stco, self.stsz)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.stsz)
|
||||
|
||||
def sample_pts(self, index: int) -> float:
|
||||
self.ensure_tables()
|
||||
return self._pts[index]
|
||||
|
||||
def sample_offset(self, index: int) -> int:
|
||||
self.ensure_tables()
|
||||
index = max(0, min(index, len(self._offsets) - 1))
|
||||
return self._offsets[index]
|
||||
|
||||
def sample_end(self, index: int) -> int:
|
||||
return self.sample_offset(index) + self.stsz[index]
|
||||
|
||||
def episode_byte_span(self, from_ts: float, to_ts: float, keyframe_pad_s: float = KEYFRAME_PAD_S) -> EpisodeByteSpan:
|
||||
self.ensure_tables()
|
||||
n = self.num_samples
|
||||
if n == 0:
|
||||
raise ValueError("MP4 has no video samples")
|
||||
|
||||
pad = max(keyframe_pad_s, 0.05 * max(0.01, to_ts - from_ts))
|
||||
lo_ts = max(0.0, from_ts - pad)
|
||||
hi_ts = to_ts + pad
|
||||
|
||||
lo_idx = _first_sample_at_or_after(self._pts, lo_ts)
|
||||
hi_idx = _last_sample_at_or_before(self._pts, hi_ts)
|
||||
hi_idx = min(hi_idx, n - 1)
|
||||
lo_idx = min(lo_idx, n - 1)
|
||||
|
||||
lo_idx = _keyframe_back(self.sync_samples, lo_idx)
|
||||
|
||||
slice_lo = self.sample_offset(lo_idx)
|
||||
slice_hi = self.sample_end(min(hi_idx, len(self._offsets) - 1))
|
||||
return EpisodeByteSpan(
|
||||
file_size=self.file_size,
|
||||
header_end=self.header_end,
|
||||
slice_lo=slice_lo,
|
||||
slice_hi=min(slice_hi, self.file_size - 1),
|
||||
)
|
||||
|
||||
|
||||
class SparseMp4Reader(io.BufferedIOBase):
|
||||
"""Range-backed MP4 reader: header + one mdat span at absolute offsets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_size: int,
|
||||
header: bytes,
|
||||
mdat_lo: int,
|
||||
mdat_bytes: bytes,
|
||||
lazy_fetch: Callable[[int, int], bytes] | None = None,
|
||||
):
|
||||
self._size = file_size
|
||||
self._header = header
|
||||
self._mdat_lo = mdat_lo
|
||||
self._mdat_hi = mdat_lo + len(mdat_bytes)
|
||||
self._mdat = mdat_bytes
|
||||
self._lazy_fetch = lazy_fetch
|
||||
self._pos = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return True
|
||||
|
||||
def tell(self) -> int:
|
||||
return self._pos
|
||||
|
||||
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
|
||||
if whence == io.SEEK_SET:
|
||||
self._pos = offset
|
||||
elif whence == io.SEEK_CUR:
|
||||
self._pos += offset
|
||||
elif whence == io.SEEK_END:
|
||||
self._pos = self._size + offset
|
||||
else:
|
||||
raise ValueError(f"invalid whence: {whence}")
|
||||
self._pos = max(0, min(self._pos, self._size))
|
||||
return self._pos
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if size < 0:
|
||||
size = self._size - self._pos
|
||||
if size <= 0:
|
||||
return b""
|
||||
|
||||
out = bytearray()
|
||||
remaining = size
|
||||
pos = self._pos
|
||||
while remaining > 0 and pos < self._size:
|
||||
chunk = self._read_at(pos, remaining)
|
||||
if not chunk:
|
||||
break
|
||||
out.extend(chunk)
|
||||
pos += len(chunk)
|
||||
remaining -= len(chunk)
|
||||
self._pos = pos
|
||||
return bytes(out)
|
||||
|
||||
def _read_at(self, pos: int, n: int) -> bytes:
|
||||
header_len = len(self._header)
|
||||
if pos < header_len:
|
||||
end = min(pos + n, header_len)
|
||||
return self._header[pos:end]
|
||||
|
||||
if self._mdat_lo <= pos < self._mdat_hi:
|
||||
end = min(pos + n, self._mdat_hi)
|
||||
off = pos - self._mdat_lo
|
||||
return self._mdat[off : off + (end - pos)]
|
||||
|
||||
if self._lazy_fetch is not None:
|
||||
with self._lock:
|
||||
end = min(pos + n, self._size)
|
||||
return self._lazy_fetch(pos, end - 1)
|
||||
|
||||
return b"\x00" * min(n, self._size - pos)
|
||||
|
||||
|
||||
def parse_mp4_index(header_bytes: bytes, file_size: int) -> Mp4VideoIndex:
|
||||
"""Parse moov sample tables from the file header (faststart layout)."""
|
||||
layout = parse_mp4_file_layout(header_bytes, file_size)
|
||||
mdat_offset, mdat_size = layout.mdat_offset, layout.mdat_size
|
||||
moov = _find_box_payload(header_bytes, b"moov")
|
||||
if moov is None:
|
||||
raise ValueError("moov box not found in MP4 header probe")
|
||||
|
||||
trak = _find_video_trak(moov)
|
||||
if trak is None:
|
||||
raise ValueError("video trak not found in moov")
|
||||
|
||||
mdhd = _find_box_payload(trak, b"mdhd")
|
||||
if mdhd is None:
|
||||
raise ValueError("mdhd not found")
|
||||
timescale = _parse_mdhd_timescale(mdhd)
|
||||
|
||||
stbl = _find_box_payload(trak, b"stbl")
|
||||
if stbl is None:
|
||||
raise ValueError("stbl not found")
|
||||
|
||||
stts = _parse_stts(_find_box_payload(stbl, b"stts"))
|
||||
stsz = _parse_stsz(_find_box_payload(stbl, b"stsz"))
|
||||
stsc = _parse_stsc(_find_box_payload(stbl, b"stsc"))
|
||||
stco_payload = _find_box_payload(stbl, b"stco")
|
||||
co64_payload = _find_box_payload(stbl, b"co64")
|
||||
if stco_payload is not None:
|
||||
stco = _parse_stco(stco_payload)
|
||||
elif co64_payload is not None:
|
||||
stco = _parse_co64(co64_payload)
|
||||
else:
|
||||
raise ValueError("stco/co64 not found")
|
||||
|
||||
stss_payload = _find_box_payload(stbl, b"stss")
|
||||
sync_samples = _parse_stss(stss_payload) if stss_payload else []
|
||||
|
||||
return Mp4VideoIndex(
|
||||
file_size=file_size,
|
||||
header_end=layout.header_end,
|
||||
mdat_offset=mdat_offset,
|
||||
mdat_size=mdat_size,
|
||||
timescale=timescale,
|
||||
stts=stts,
|
||||
stsz=stsz,
|
||||
stsc=stsc,
|
||||
stco=stco,
|
||||
sync_samples=sync_samples,
|
||||
)
|
||||
|
||||
|
||||
def _box_header(data: bytes, offset: int) -> tuple[int, bytes, int] | None:
|
||||
if offset + 8 > len(data):
|
||||
return None
|
||||
size, typ = struct.unpack_from(">I4s", data, offset)
|
||||
header = 8
|
||||
if size == 1:
|
||||
if offset + 16 > len(data):
|
||||
return None
|
||||
size = struct.unpack_from(">Q", data, offset + 8)[0]
|
||||
header = 16
|
||||
elif size == 0:
|
||||
size = len(data) - offset
|
||||
return size, typ, header
|
||||
|
||||
|
||||
def _iter_boxes(data: bytes, start: int = 0, end: int | None = None):
|
||||
end = end if end is not None else len(data)
|
||||
off = start
|
||||
while off + 8 <= end:
|
||||
hdr = _box_header(data, off)
|
||||
if hdr is None or hdr[0] < hdr[2]:
|
||||
break
|
||||
size, typ, header = hdr
|
||||
yield off, size, typ, data[off + header : off + size]
|
||||
off += size
|
||||
|
||||
|
||||
def _find_box_payload(data: bytes, target: bytes) -> bytes | None:
|
||||
for _, _, typ, payload in _iter_boxes(data):
|
||||
if typ == target:
|
||||
return payload
|
||||
if typ in (b"moov", b"trak", b"mdia", b"minf", b"stbl"):
|
||||
found = _find_box_payload(payload, target)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
|
||||
def _find_video_trak(moov: bytes) -> bytes | None:
|
||||
for _, _, typ, payload in _iter_boxes(moov):
|
||||
if typ != b"trak":
|
||||
continue
|
||||
hdlr = _find_box_payload(payload, b"hdlr")
|
||||
if hdlr is not None and len(hdlr) >= 12 and hdlr[8:12] == b"vide":
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def _find_mdat(header_bytes: bytes, file_size: int) -> tuple[int, int]:
|
||||
for off, size, typ, _ in _iter_boxes(header_bytes):
|
||||
if typ == b"mdat":
|
||||
return off, size
|
||||
# mdat may start beyond probe; scan from file_size hint unavailable — require probe hit
|
||||
raise ValueError("mdat box not found in header probe; increase HEADER_PROBE_BYTES")
|
||||
|
||||
|
||||
def _parse_mdhd_timescale(mdhd: bytes) -> int:
|
||||
version = mdhd[0]
|
||||
if version == 0:
|
||||
return struct.unpack_from(">I", mdhd, 12)[0]
|
||||
return struct.unpack_from(">I", mdhd, 20)[0]
|
||||
|
||||
|
||||
def _parse_stts(stts: bytes | None) -> list[tuple[int, int]]:
|
||||
if stts is None:
|
||||
raise ValueError("stts missing")
|
||||
count = struct.unpack_from(">I", stts, 4)[0]
|
||||
out = []
|
||||
off = 8
|
||||
for _ in range(count):
|
||||
sample_count, delta = struct.unpack_from(">II", stts, off)
|
||||
out.append((sample_count, delta))
|
||||
off += 8
|
||||
return out
|
||||
|
||||
|
||||
def _parse_stsz(stsz: bytes | None) -> list[int]:
|
||||
if stsz is None:
|
||||
raise ValueError("stsz missing")
|
||||
sample_size, sample_count = struct.unpack_from(">II", stsz, 4)
|
||||
if sample_size != 0:
|
||||
return [sample_size] * sample_count
|
||||
off = 12
|
||||
return list(struct.unpack_from(f">{sample_count}I", stsz, off))
|
||||
|
||||
|
||||
def _parse_stsc(stsc: bytes | None) -> list[tuple[int, int, int]]:
|
||||
if stsc is None:
|
||||
raise ValueError("stsc missing")
|
||||
count = struct.unpack_from(">I", stsc, 4)[0]
|
||||
out = []
|
||||
off = 8
|
||||
for _ in range(count):
|
||||
first_chunk, samples_per_chunk, sample_desc = struct.unpack_from(">III", stsc, off)
|
||||
out.append((first_chunk, samples_per_chunk, sample_desc))
|
||||
off += 12
|
||||
return out
|
||||
|
||||
|
||||
def _parse_stco(stco: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", stco, 4)[0]
|
||||
return list(struct.unpack_from(f">{count}I", stco, 8))
|
||||
|
||||
|
||||
def _parse_co64(co64: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", co64, 4)[0]
|
||||
return [struct.unpack_from(">Q", co64, 8 + i * 8)[0] for i in range(count)]
|
||||
|
||||
|
||||
def _parse_stss(stss: bytes) -> list[int]:
|
||||
count = struct.unpack_from(">I", stss, 4)[0]
|
||||
return list(struct.unpack_from(f">{count}I", stss, 8))
|
||||
|
||||
|
||||
def _pts_from_stts(stts: list[tuple[int, int]], timescale: int) -> list[float]:
|
||||
pts: list[float] = []
|
||||
t = 0
|
||||
for count, delta in stts:
|
||||
for _ in range(count):
|
||||
pts.append(t / timescale)
|
||||
t += delta
|
||||
return pts
|
||||
|
||||
|
||||
def _sample_byte_offsets(
|
||||
stsc: list[tuple[int, int, int]], stco: list[int], stsz: list[int]
|
||||
) -> list[int]:
|
||||
if not stsc:
|
||||
stsc = [(1, len(stsz), 1)]
|
||||
|
||||
offsets: list[int] = []
|
||||
chunk_idx = 0
|
||||
sample_idx = 0
|
||||
sc_idx = 0
|
||||
num_chunks = len(stco)
|
||||
|
||||
while chunk_idx < num_chunks and sample_idx < len(stsz):
|
||||
first_chunk, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
|
||||
if sc_idx + 1 < len(stsc):
|
||||
next_first = stsc[sc_idx + 1][0]
|
||||
chunks_in_entry = next_first - first_chunk
|
||||
else:
|
||||
chunks_in_entry = num_chunks - chunk_idx
|
||||
|
||||
for _ in range(chunks_in_entry):
|
||||
if chunk_idx >= num_chunks:
|
||||
break
|
||||
offset = stco[chunk_idx]
|
||||
_, samples_per_chunk, _ = stsc[min(sc_idx, len(stsc) - 1)]
|
||||
for _ in range(samples_per_chunk):
|
||||
if sample_idx >= len(stsz):
|
||||
break
|
||||
offsets.append(offset)
|
||||
offset += stsz[sample_idx]
|
||||
sample_idx += 1
|
||||
chunk_idx += 1
|
||||
sc_idx += 1
|
||||
|
||||
if len(offsets) < len(stsz):
|
||||
# Pad with last known offset progression for malformed stsc edge cases.
|
||||
last = offsets[-1] if offsets else 0
|
||||
while len(offsets) < len(stsz):
|
||||
idx = len(offsets)
|
||||
offsets.append(last)
|
||||
last += stsz[idx]
|
||||
|
||||
return offsets
|
||||
|
||||
|
||||
def _first_sample_at_or_after(pts: list[float], ts: float) -> int:
|
||||
lo, hi = 0, len(pts)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if pts[mid] < ts:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
return min(lo, len(pts) - 1)
|
||||
|
||||
|
||||
def _last_sample_at_or_before(pts: list[float], ts: float) -> int:
|
||||
lo, hi = 0, len(pts)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if pts[mid] <= ts:
|
||||
lo = mid + 1
|
||||
else:
|
||||
hi = mid
|
||||
return max(0, lo - 1)
|
||||
|
||||
|
||||
def _keyframe_back(sync_samples: list[int], sample_idx: int) -> int:
|
||||
if not sync_samples:
|
||||
return max(0, sample_idx - 2)
|
||||
# stss stores 1-based sample numbers
|
||||
one_based = sample_idx + 1
|
||||
prev = [s for s in sync_samples if s <= one_based]
|
||||
if prev:
|
||||
return prev[-1] - 1
|
||||
return 0
|
||||
@@ -124,6 +124,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
video_decoder_cache_size: int | None = None,
|
||||
data_files_root: str | None = None,
|
||||
validate_row_groups: bool = True,
|
||||
video_byte_cache_gb: float | None = 80.0,
|
||||
byte_index_path: str | Path | None = None,
|
||||
byte_index_build_in_memory: bool | None = None,
|
||||
byte_index_workers: int = 8,
|
||||
byte_index_max_episodes: int | None = None,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
@@ -173,6 +178,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
``num_shards`` is divisible by ``world_size`` for distributed runs, raising a clear
|
||||
``ValueError`` otherwise. Set False to skip the checks (e.g. single-process debugging);
|
||||
the divisibility check then downgrades to a warning.
|
||||
video_byte_cache_gb (float | None, optional): Node-local LRU for episode MP4 mdat slices.
|
||||
When set (default 80 GB), episodes are prefetched via tight byte ranges before decode.
|
||||
Set to 0 or None to disable and use remote per-seek decoding.
|
||||
byte_index_path (str | Path | None, optional): Path to precomputed ``meta/byte_index/``
|
||||
sidecar parquet tables. Defaults to ``{meta.root}/meta/byte_index``.
|
||||
byte_index_build_in_memory (bool | None, optional): When True, build the byte index in RAM
|
||||
at init (moov-only fetches, no parquet write). When None (default), build in memory only
|
||||
if the sidecar parquet is missing on disk.
|
||||
byte_index_workers (int, optional): Parallel moov-index workers for in-memory builds.
|
||||
byte_index_max_episodes (int | None, optional): Cap episodes indexed (debug/smoke tests).
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -210,6 +225,14 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self.rank, self.world_size = self._resolve_distributed(rank, world_size)
|
||||
self.video_decoder_cache_size = video_decoder_cache_size
|
||||
self.data_files_root = data_files_root.rstrip("/") if data_files_root else None
|
||||
self.video_byte_cache_gb = video_byte_cache_gb
|
||||
self.byte_index_path = Path(byte_index_path) if byte_index_path is not None else None
|
||||
self.byte_index_build_in_memory = byte_index_build_in_memory
|
||||
self.byte_index_workers = byte_index_workers
|
||||
self.byte_index_max_episodes = byte_index_max_episodes
|
||||
self._episode_byte_cache = None
|
||||
self._byte_index = None
|
||||
self._data_root = None
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
@@ -228,6 +251,37 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
if self._use_episode_byte_cache():
|
||||
from .byte_index import EpisodeByteIndex
|
||||
|
||||
data_root = self._resolve_data_root()
|
||||
index_dir = self.byte_index_path or (self.meta.root / "meta" / "byte_index")
|
||||
sidecar_exists = (index_dir / "files.parquet").exists() and (index_dir / "episodes.parquet").exists()
|
||||
build_in_memory = (
|
||||
self.byte_index_build_in_memory
|
||||
if self.byte_index_build_in_memory is not None
|
||||
else not sidecar_exists
|
||||
)
|
||||
if build_in_memory:
|
||||
logger.info(
|
||||
"Building byte index in memory from %s (%s episodes, %d workers)",
|
||||
data_root,
|
||||
self.byte_index_max_episodes or self.meta.total_episodes,
|
||||
self.byte_index_workers,
|
||||
)
|
||||
self._byte_index = EpisodeByteIndex.from_memory_build(
|
||||
self.meta,
|
||||
data_root,
|
||||
workers=self.byte_index_workers,
|
||||
max_episodes=self.byte_index_max_episodes,
|
||||
)
|
||||
else:
|
||||
self._byte_index = EpisodeByteIndex(
|
||||
index_dir,
|
||||
video_keys=self.meta.video_keys,
|
||||
num_episodes=self.meta.total_episodes,
|
||||
)
|
||||
|
||||
self.delta_timestamps = None
|
||||
self.delta_indices = None
|
||||
|
||||
@@ -417,6 +471,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
buffer_size=self.episode_pool_size,
|
||||
max_buffer_input_shards=max_input_shards,
|
||||
)
|
||||
if self._use_episode_byte_cache():
|
||||
ds = ds.map(self._submit_episode_prefetch, batched=True)
|
||||
# A row-count-changing batched map must drop the input columns explicitly; the exploded
|
||||
# frames re-emit them (windowed keys replaced by their delta windows + *_is_pad masks).
|
||||
ds = ds.map(self._explode_episodes, batched=True, remove_columns=episode_columns)
|
||||
@@ -472,6 +528,31 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
return VideoDecoderCache()
|
||||
return VideoDecoderCache(max_size=min((self.episode_pool_size + 1) * num_cameras, 128))
|
||||
|
||||
def _use_episode_byte_cache(self) -> bool:
|
||||
return (
|
||||
self.video_byte_cache_gb not in (None, 0)
|
||||
and self.data_files_root is not None
|
||||
)
|
||||
|
||||
def _make_episode_byte_cache(self):
|
||||
from .episode_byte_cache import EpisodeByteCache
|
||||
|
||||
if self._byte_index is None:
|
||||
raise RuntimeError("byte index required for episode byte cache; run build_byte_index.py")
|
||||
max_bytes = int(float(self.video_byte_cache_gb) * 1e9)
|
||||
return EpisodeByteCache(
|
||||
self._byte_index,
|
||||
max_bytes,
|
||||
data_root=self._data_root,
|
||||
)
|
||||
|
||||
def _submit_episode_prefetch(self, episode_batch: dict[str, list[list]]) -> dict[str, list[list]]:
|
||||
if self._episode_byte_cache is None:
|
||||
return episode_batch
|
||||
for ep_idx in {int(v[0]) for v in episode_batch["episode_index"]}:
|
||||
self._episode_byte_cache.submit_prefetch(ep_idx)
|
||||
return episode_batch
|
||||
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
# `datasets` reshuffles (and re-permutes shard order) per epoch from (seed, epoch);
|
||||
# DataLoader workers each advance their own copy's counter in lockstep. The in-flight
|
||||
@@ -486,6 +567,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
self._in_flight_epoch = 0
|
||||
self._pipeline.set_epoch(self._in_flight_epoch)
|
||||
self.video_decoder_cache = self._make_video_decoder_cache()
|
||||
self._data_root = self._resolve_data_root()
|
||||
if self._use_episode_byte_cache():
|
||||
self._episode_byte_cache = self._make_episode_byte_cache()
|
||||
else:
|
||||
self._episode_byte_cache = None
|
||||
|
||||
iterator = iter(self._pipeline)
|
||||
while True:
|
||||
@@ -623,6 +709,8 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""
|
||||
|
||||
item = {}
|
||||
if self._episode_byte_cache is not None:
|
||||
self._episode_byte_cache.ensure_ready(ep_idx)
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
# query_ts is episode-local; shift to the absolute in-file timeline by the episode's offset.
|
||||
from_timestamp = self.meta.episodes[ep_idx][f"videos/{video_key}/from_timestamp"]
|
||||
@@ -635,12 +723,16 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
else:
|
||||
root = self.root
|
||||
video_path = f"{root}/{rel_path}"
|
||||
episode_decoder = None
|
||||
if self._episode_byte_cache is not None:
|
||||
episode_decoder = self._episode_byte_cache.get_decoder(ep_idx, video_key)
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self.tolerance_s,
|
||||
decoder_cache=self.video_decoder_cache,
|
||||
return_uint8=self._return_uint8,
|
||||
episode_decoder=episode_decoder,
|
||||
)
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""TorchCodec helpers for sparse MP4 IO with optional custom frame mappings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torchcodec import FrameBatch, _core as core
|
||||
from torchcodec.decoders._video_decoder import _get_and_validate_stream_metadata
|
||||
|
||||
|
||||
def frame_mappings_tensors(payload: bytes) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
data = json.loads(payload)
|
||||
frames = data["frames"]
|
||||
pts = torch.tensor([int(f["pts"]) for f in frames], dtype=torch.int64)
|
||||
key = torch.tensor([bool(f["key_frame"]) for f in frames], dtype=torch.bool)
|
||||
dur = torch.tensor([int(f["duration"]) for f in frames], dtype=torch.int64)
|
||||
return pts, key, dur
|
||||
|
||||
|
||||
class VideoDecoderLike:
|
||||
"""Minimal VideoDecoder surface used by episode byte cache."""
|
||||
|
||||
def __init__(self, decoder: torch.Tensor, *, stream_index: int | None = None):
|
||||
self._decoder = decoder
|
||||
(
|
||||
self.metadata,
|
||||
self.stream_index,
|
||||
self._begin_stream_seconds,
|
||||
self._end_stream_seconds,
|
||||
self._num_frames,
|
||||
) = _get_and_validate_stream_metadata(decoder=decoder, stream_index=stream_index)
|
||||
|
||||
def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
|
||||
return FrameBatch(*core.get_frames_by_pts(self._decoder, timestamps=seconds))
|
||||
|
||||
|
||||
def open_video_decoder(source: Any, *, frame_mappings: bytes | None = None) -> VideoDecoderLike:
|
||||
"""Open a decoder on sparse or full MP4 IO, skipping metadata scan when mappings exist."""
|
||||
if frame_mappings is None:
|
||||
decoder = core.create_from_file_like(source, "approximate")
|
||||
core.add_video_stream(decoder)
|
||||
return VideoDecoderLike(decoder)
|
||||
|
||||
mappings = frame_mappings_tensors(frame_mappings)
|
||||
decoder = core.create_from_file_like(source, "custom_frame_mappings")
|
||||
core.add_video_stream(decoder, custom_frame_mappings=mappings)
|
||||
return VideoDecoderLike(decoder)
|
||||
@@ -326,6 +326,7 @@ def decode_video_frames_torchcodec(
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
return_uint8: bool = False,
|
||||
episode_decoder: Any | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
@@ -347,8 +348,10 @@ def decode_video_frames_torchcodec(
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
if episode_decoder is not None:
|
||||
decoder = episode_decoder
|
||||
else:
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
|
||||
loaded_ts = []
|
||||
loaded_frames = []
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Acceptance tests for manifest byte-index sidecars.
|
||||
|
||||
Run on a compute node (not login-node):
|
||||
|
||||
srun --partition=hopper-dev --nodes=1 --ntasks=1 --cpus-per-task=8 --mem=32G --time=00:30:00 \\
|
||||
bash -lc 'cd /admin/home/pepijn/lerobot && conda run --no-capture-output -n lerobot \\
|
||||
env -u HF_HUB_ENABLE_HF_TRANSFER python -m pytest tests/datasets/test_byte_index.py -m integration -v'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("torchcodec")
|
||||
|
||||
REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
REV = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
BUCKET = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
MAX_EPISODES = 64
|
||||
|
||||
COMPUTE_NODE = pytest.mark.skipif(
|
||||
"login" in socket.gethostname(),
|
||||
reason="run on compute node via srun (see module docstring), not login-node",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def byte_index_dir(tmp_path_factory):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_tables, write_byte_index
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
out = tmp_path_factory.mktemp("byte_index")
|
||||
meta = LeRobotDatasetMetadata(REPO, revision=REV)
|
||||
files, episodes, _ = build_byte_index_tables(
|
||||
meta, BUCKET, workers=4, max_episodes=MAX_EPISODES, include_keyframes=False
|
||||
)
|
||||
write_byte_index(out, files, episodes, None, merge_existing=False)
|
||||
return out, meta
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_index_load_fast_and_small(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
|
||||
out, meta = byte_index_dir
|
||||
index = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
|
||||
assert index.load_time_s < 1.0
|
||||
assert index.resident_bytes < 1_000_000_000
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_tight_fetch_under_25mb(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
|
||||
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
stats = cache.stats.stats_dict()
|
||||
assert stats["byte_cache_bytes_per_miss"] < 25 * 1024 * 1024
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_in_memory_build_matches_parquet(byte_index_dir):
|
||||
from lerobot.datasets.byte_index import EpisodeByteIndex
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
out, meta = byte_index_dir
|
||||
disk = EpisodeByteIndex(out, video_keys=meta.video_keys, num_episodes=MAX_EPISODES)
|
||||
mem = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
for ep in [0, MAX_EPISODES // 2, MAX_EPISODES - 1]:
|
||||
for cam in meta.video_keys:
|
||||
a = disk.lookup(ep, cam)
|
||||
b = mem.lookup(ep, cam)
|
||||
assert a.mdat_offset == b.mdat_offset
|
||||
assert a.mdat_length == b.mdat_length
|
||||
assert abs(a.first_pts - b.first_pts) < 1e-6
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_custom_frame_mappings_available(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cam = meta.video_keys[0]
|
||||
ep = MAX_EPISODES // 2
|
||||
payload = index.custom_frame_mappings(ep, cam)
|
||||
assert payload is not None
|
||||
data = json.loads(payload)
|
||||
assert len(data["frames"]) > 10
|
||||
assert any(f["key_frame"] for f in data["frames"])
|
||||
assert all("pts" in f and "duration" in f for f in data["frames"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_metadata_skip_decoder_init(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=8_000_000_000, data_root=BUCKET)
|
||||
cam = meta.video_keys[0]
|
||||
ep = 0
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
dec = cache.get_decoder(ep, cam)
|
||||
assert dec.metadata.num_frames is not None
|
||||
assert dec.metadata.num_frames > 0
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
ts = begin + 0.5 * (end - begin)
|
||||
frame = dec.get_frames_played_at([ts]).data
|
||||
assert frame.ndim == 4
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@COMPUTE_NODE
|
||||
def test_sparse_decode_produces_frames(byte_index_dir):
|
||||
from lerobot.datasets.byte_index_builder import build_byte_index_in_memory
|
||||
from lerobot.datasets.episode_byte_cache import EpisodeByteCache
|
||||
|
||||
_, meta = byte_index_dir
|
||||
index = build_byte_index_in_memory(meta, BUCKET, workers=4, max_episodes=MAX_EPISODES)
|
||||
cache = EpisodeByteCache(index, max_bytes=80_000_000_000, data_root=BUCKET)
|
||||
cam = meta.video_keys[0]
|
||||
ep = 0
|
||||
cache.submit_prefetch(ep)
|
||||
cache.ensure_ready(ep)
|
||||
dec = cache.get_decoder(ep, cam)
|
||||
begin = float(dec.metadata.begin_stream_seconds)
|
||||
end = float(dec.metadata.end_stream_seconds)
|
||||
ts = begin + 0.5 * (end - begin)
|
||||
frame = dec.get_frames_played_at([ts]).data
|
||||
assert frame.ndim == 4
|
||||
assert frame.numel() > 0
|
||||
assert float(frame.float().std()) > 1.0
|
||||
Reference in New Issue
Block a user