mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a130a9db39 | |||
| 4f5e6596be | |||
| afeeeb8982 | |||
| 040c6b3d66 | |||
| 287c823f13 | |||
| acd31c7de2 | |||
| 58ccc01508 | |||
| 240393d238 |
@@ -355,8 +355,6 @@ explicit = true
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "feat/hffs-cache-cdn-range-reads" }
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
@@ -423,7 +421,6 @@ exclude_dirs = [
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
[tool.typos]
|
||||
default.extend-words = { trak = "trak" }
|
||||
default.extend-ignore-re = [
|
||||
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
||||
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
|
||||
|
||||
@@ -1,860 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import resource
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import (
|
||||
EpisodeByteCache,
|
||||
EpisodeVideoManifest,
|
||||
NativeHTTPRangeFetcher,
|
||||
assert_hf_hub_range_cache_branch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
|
||||
FULL_SIDECAR_NAME = "molmoact2-full.npz"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
|
||||
default="both",
|
||||
help=argparse.SUPPRESS,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--range-backend",
|
||||
choices=("fsspec", "native-http"),
|
||||
default="fsspec",
|
||||
help="Range reader used by indexed/full episode-pool fetch tracks.",
|
||||
)
|
||||
parser.add_argument("--num-episodes", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--manifest-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit manifest construction to the first N episodes for local smoke tests.",
|
||||
)
|
||||
parser.add_argument("--pool-size", type=int, default=16)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--native-http-connections",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max HTTP connections for --range-backend native-http. Defaults to --workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-retries",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Retries per native HTTP range request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-timeout",
|
||||
type=float,
|
||||
default=120.0,
|
||||
help="Timeout in seconds for native HTTP requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-decode",
|
||||
action="store_true",
|
||||
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
|
||||
)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
||||
parser.add_argument(
|
||||
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(total, requested)
|
||||
if pool_size > upper:
|
||||
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
||||
return rng.sample(range(upper), pool_size)
|
||||
|
||||
|
||||
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
span = manifest.lookup(ep, camera_key)
|
||||
lo = span.first_pts
|
||||
hi = max(span.last_pts, lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _timestamps_from_meta(
|
||||
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
|
||||
) -> dict[tuple[int, str], list[float]]:
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
row = meta.episodes[ep]
|
||||
for camera_key in meta.video_keys:
|
||||
lo = float(row[f"videos/{camera_key}/from_timestamp"])
|
||||
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
total = 0
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
total += manifest.lookup(ep, camera_key).mdat_length
|
||||
return total
|
||||
|
||||
|
||||
def _decode_all(
|
||||
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
||||
) -> float:
|
||||
start = time.perf_counter()
|
||||
items = list(timestamps.items())
|
||||
if decode_workers <= 1:
|
||||
for (ep, camera_key), ts in items:
|
||||
cache.get_frames(ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
|
||||
start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
cache.submit_prefetch(ep)
|
||||
for ep in episodes:
|
||||
cache.ensure_ready(ep)
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
|
||||
if elapsed_s <= 0:
|
||||
return float("inf")
|
||||
return len(episodes) * frames_per_episode / elapsed_s
|
||||
|
||||
|
||||
def _log(message: str) -> None:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
if seconds < 3600:
|
||||
return f"{seconds / 60:.1f}m"
|
||||
return f"{seconds / 3600:.1f}h"
|
||||
|
||||
|
||||
def _current_rss_mib() -> float | None:
|
||||
status_path = Path("/proc/self/status")
|
||||
if not status_path.exists():
|
||||
return None
|
||||
for line in status_path.read_text().splitlines():
|
||||
if line.startswith("VmRSS:"):
|
||||
return float(line.split()[1]) / 1024
|
||||
return None
|
||||
|
||||
|
||||
def _peak_rss_mib() -> float:
|
||||
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
# Linux reports KiB; macOS reports bytes.
|
||||
if rss > 10**8:
|
||||
return rss / 1024**2
|
||||
return rss / 1024
|
||||
|
||||
|
||||
def _memory_snapshot() -> dict[str, float | None]:
|
||||
return {"rss_mib": _current_rss_mib(), "peak_rss_mib": _peak_rss_mib()}
|
||||
|
||||
|
||||
def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | None]) -> None:
|
||||
start_rss = start["rss_mib"]
|
||||
end_rss = end["rss_mib"]
|
||||
delta = None if start_rss is None or end_rss is None else end_rss - start_rss
|
||||
print()
|
||||
print("| Memory | MiB |")
|
||||
print("|---|---:|")
|
||||
if start_rss is not None:
|
||||
print(f"| rss start | {start_rss:.1f} |")
|
||||
if end_rss is not None:
|
||||
print(f"| rss end | {end_rss:.1f} |")
|
||||
if delta is not None:
|
||||
print(f"| rss delta | {delta:.1f} |")
|
||||
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
|
||||
|
||||
|
||||
def _root_join(data_root: str, relative_path: str) -> str:
|
||||
if data_root.startswith("hf://"):
|
||||
return f"{data_root.rstrip('/')}/{relative_path}"
|
||||
return str(Path(data_root) / relative_path)
|
||||
|
||||
|
||||
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
|
||||
_ = manifest_episode_count
|
||||
local = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
if _valid_sidecar(local):
|
||||
return local
|
||||
if local.exists():
|
||||
print(f"mp4_sidecar_invalid_local: {local}")
|
||||
local.unlink()
|
||||
remote_relative = f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}"
|
||||
remote = _root_join(data_root, remote_relative)
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
fs = fsspec.filesystem(protocol)
|
||||
if not fs.exists(remote):
|
||||
return None
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"downloading_mp4_sidecar: {remote} -> {local}")
|
||||
if data_root.startswith("hf://"):
|
||||
_download_sidecar_native_http(data_root, remote_relative, local)
|
||||
else:
|
||||
fs.get(remote, str(local))
|
||||
return local
|
||||
|
||||
|
||||
def _valid_sidecar(path: Path) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
with np.load(path, allow_pickle=False) as data:
|
||||
return "manifest_json" in data
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
|
||||
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
|
||||
tmp = local.with_suffix(local.suffix + ".tmp")
|
||||
try:
|
||||
size = fetcher.info_size(relative_path)
|
||||
chunk_size = 16 * 1024 * 1024
|
||||
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
|
||||
with tmp.open("wb") as out_file:
|
||||
out_file.truncate(size)
|
||||
|
||||
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
|
||||
offset, length = offset_length
|
||||
return offset, fetcher.read_range(relative_path, offset, length)
|
||||
|
||||
start = time.perf_counter()
|
||||
done = 0
|
||||
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||
futures = [pool.submit(read_chunk, item) for item in ranges]
|
||||
with tmp.open("r+b") as rw_file:
|
||||
for future in futures:
|
||||
offset, data = future.result()
|
||||
rw_file.seek(offset)
|
||||
rw_file.write(data)
|
||||
done += len(data)
|
||||
elapsed = max(time.perf_counter() - start, 1e-9)
|
||||
print(
|
||||
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
|
||||
f"({done / elapsed / 1024**2:.1f} MiB/s)",
|
||||
flush=True,
|
||||
)
|
||||
tmp.replace(local)
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
|
||||
class EpisodeParquetReader:
|
||||
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
|
||||
self.meta = meta
|
||||
self.data_root = data_root
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self._episode_row_groups = self._build_episode_row_groups()
|
||||
self._table_cache: dict[str, pa.Table] = {}
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def read_episode(self, episode_index: int) -> None:
|
||||
relative_path = str(self.meta.get_data_file_path(episode_index))
|
||||
table = self._read_table(relative_path)
|
||||
table.filter(pc.equal(table["episode_index"], episode_index))
|
||||
|
||||
def _read_table(self, relative_path: str) -> pa.Table:
|
||||
with self._cache_lock:
|
||||
table = self._table_cache.get(relative_path)
|
||||
if table is not None:
|
||||
return table
|
||||
with self.fs.open(
|
||||
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
|
||||
) as f:
|
||||
table = pq.ParquetFile(f).read()
|
||||
with self._cache_lock:
|
||||
return self._table_cache.setdefault(relative_path, table)
|
||||
|
||||
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
|
||||
return pool.submit(self.read_episode, episode_index)
|
||||
|
||||
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
|
||||
start = time.perf_counter()
|
||||
if workers <= 1:
|
||||
for ep in episodes:
|
||||
self.read_episode(ep)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
def _build_episode_row_groups(self) -> dict[int, int]:
|
||||
counts: dict[tuple[int, int], int] = {}
|
||||
row_groups = {}
|
||||
for ep_idx in range(int(self.meta.total_episodes)):
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
|
||||
row_groups[ep_idx] = counts.get(key, 0)
|
||||
counts[key] = row_groups[ep_idx] + 1
|
||||
return row_groups
|
||||
|
||||
|
||||
def run_fetch_pool(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
range_backend: str,
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
native_http_connections=args.native_http_connections,
|
||||
native_http_timeout=args.native_http_timeout,
|
||||
native_http_retries=args.native_http_retries,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
elapsed = _fill_cache(cache, episodes)
|
||||
timings = cache.timing_summary()
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
job_count = max(timings["jobs"], 1.0)
|
||||
result = {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / elapsed,
|
||||
"episode_mb": episode_mb,
|
||||
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
|
||||
"jobs": timings["jobs"],
|
||||
"lookup_ms": timings["lookup_s"] * 1000 / job_count,
|
||||
"range_fetch_ms": timings["fetch_s"] * 1000 / job_count,
|
||||
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
|
||||
"store_ms": timings["store_s"] * 1000 / job_count,
|
||||
}
|
||||
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||
return result
|
||||
|
||||
|
||||
def run_parallel(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
|
||||
fetch_s = _fill_cache(cache, episodes)
|
||||
decoder_start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_decoder(ep, camera_key)
|
||||
decoder_s = time.perf_counter() - decoder_start
|
||||
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
return {
|
||||
"fetch_s": fetch_s,
|
||||
"fetch_mbps": byte_count / fetch_s / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / fetch_s,
|
||||
"parquet_s": parquet_s,
|
||||
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
|
||||
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def run_overlapped(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
prefetch_ahead: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=True,
|
||||
) as cache:
|
||||
start = time.perf_counter()
|
||||
video_wait_decode_s = 0.0
|
||||
parquet_wait_s = 0.0
|
||||
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
|
||||
parquet_futures = {
|
||||
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
|
||||
}
|
||||
for ep in episodes[:prefetch_ahead]:
|
||||
cache.submit_prefetch(ep)
|
||||
try:
|
||||
for idx, ep in enumerate(episodes):
|
||||
next_idx = idx + prefetch_ahead
|
||||
if next_idx < len(episodes):
|
||||
next_ep = episodes[next_idx]
|
||||
cache.submit_prefetch(next_ep)
|
||||
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
|
||||
|
||||
parquet_start = time.perf_counter()
|
||||
parquet_futures.pop(ep).result()
|
||||
parquet_wait_s += time.perf_counter() - parquet_start
|
||||
|
||||
video_start = time.perf_counter()
|
||||
cache.ensure_ready(ep)
|
||||
if decode_workers <= 1:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
|
||||
for camera_key in manifest.video_keys
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
video_wait_decode_s += time.perf_counter() - video_start
|
||||
finally:
|
||||
parquet_pool.shutdown(wait=True)
|
||||
elapsed = time.perf_counter() - start
|
||||
return {
|
||||
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
|
||||
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
|
||||
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
|
||||
"wall_s": elapsed,
|
||||
"video_wait_decode_s": video_wait_decode_s,
|
||||
"parquet_wait_s": parquet_wait_s,
|
||||
}
|
||||
|
||||
|
||||
_remote_decoder_local = threading.local()
|
||||
|
||||
|
||||
def _remote_decoder_cache() -> VideoDecoderCache:
|
||||
cache = getattr(_remote_decoder_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = VideoDecoderCache(max_size=None)
|
||||
_remote_decoder_local.cache = cache
|
||||
return cache
|
||||
|
||||
|
||||
def _decode_remote_source(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
timestamps: list[float],
|
||||
):
|
||||
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
|
||||
return decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
timestamps,
|
||||
tolerance_s=1.0 / float(meta.fps),
|
||||
decoder_cache=_remote_decoder_cache(),
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
|
||||
def run_remote_decoder(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
*,
|
||||
frames_per_episode: int,
|
||||
decode_workers: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> dict[str, float]:
|
||||
items = [
|
||||
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
sequential_s = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
if decode_workers <= 1:
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
|
||||
futures = [
|
||||
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
|
||||
for ep, camera_key, ts in items
|
||||
]
|
||||
for future in parquet_futures:
|
||||
future.result()
|
||||
for future in futures:
|
||||
future.result()
|
||||
parallel_s = time.perf_counter() - start
|
||||
|
||||
return {
|
||||
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
|
||||
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
|
||||
range_jobs = fetch_pool.get("range_jobs", 0.0)
|
||||
if range_jobs <= 0:
|
||||
return
|
||||
|
||||
print()
|
||||
print("| Range Read Stage | avg ms/range |")
|
||||
print("|---|---:|")
|
||||
for key, label in (
|
||||
("range_open_s", "fsspec handle open/lookup"),
|
||||
("range_seek_s", "fsspec seek"),
|
||||
("range_read_s", "fsspec read"),
|
||||
("range_resolve_s", "http URL resolve"),
|
||||
("range_header_s", "http response headers"),
|
||||
("range_first_byte_s", "http first body byte"),
|
||||
("range_body_s", "http body drain"),
|
||||
("range_retry_sleep_s", "http retry sleep"),
|
||||
):
|
||||
value = fetch_pool.get(key)
|
||||
if value is not None:
|
||||
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
|
||||
if "range_retry_attempts" in fetch_pool:
|
||||
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
|
||||
if fetch_pool.get("range_failed_requests"):
|
||||
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
|
||||
print(f"| range reads | {range_jobs:.0f} |")
|
||||
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
|
||||
|
||||
|
||||
def run_indexed_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
*,
|
||||
range_backend: str = "fsspec",
|
||||
label: str = "indexed",
|
||||
sidecar_path: str | None = None,
|
||||
) -> None:
|
||||
_log(f"starting_strategy: {label}")
|
||||
memory_start = _memory_snapshot()
|
||||
manifest_start = time.perf_counter()
|
||||
dataset_episode_count = int(meta.total_episodes)
|
||||
manifest_episode_count = args.manifest_episodes or dataset_episode_count
|
||||
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
|
||||
manifest = EpisodeVideoManifest.build(
|
||||
meta,
|
||||
data_root,
|
||||
episode_indices=range(manifest_episode_count),
|
||||
range_backend=range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
sidecar_path=sidecar_path,
|
||||
)
|
||||
manifest_s = time.perf_counter() - manifest_start
|
||||
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
|
||||
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
|
||||
byte_budget = int(args.byte_budget_gb * 1024**3)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
_log(
|
||||
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
|
||||
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
|
||||
)
|
||||
|
||||
_log(f"{label}: filling episode byte cache with {args.workers} workers")
|
||||
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
|
||||
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
|
||||
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||
print(f"strategy: {label}")
|
||||
print(f"range_backend: {range_backend}")
|
||||
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"dataset_episodes: {dataset_episode_count}")
|
||||
print(f"benchmark_episodes: {benchmark_episode_count}")
|
||||
print(f"pool_episodes: {len(episodes)}")
|
||||
print(f"sampled_episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
print()
|
||||
print(
|
||||
"| Track | fetch MB/s | fetch eps/s | wall s | est benchmark | est full dataset | avg MB/camera | notes |"
|
||||
)
|
||||
print("|---|---:|---:|---:|---:|---:|---:|---|")
|
||||
print(
|
||||
f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | "
|
||||
f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | "
|
||||
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |"
|
||||
)
|
||||
print()
|
||||
print("| Camera Job Stage | avg ms/job |")
|
||||
print("|---|---:|")
|
||||
print(f"| manifest lookup | {fetch_pool['lookup_ms']:.3f} |")
|
||||
print(f"| remote byte-range fetch | {fetch_pool['range_fetch_ms']:.3f} |")
|
||||
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
|
||||
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
|
||||
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
|
||||
_print_range_timing_summary(fetch_pool)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
|
||||
if args.include_decode:
|
||||
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log(f"{label}: running parallel video fetch + decode-only")
|
||||
parallel = run_parallel(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
_log(f"{label}: running overlapped end-to-end")
|
||||
overlapped = run_overlapped(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
args.prefetch_ahead,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
print(
|
||||
f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
|
||||
f"{parallel['fetch_s']:.2f} | "
|
||||
f"{_format_duration(benchmark_episode_count / parallel['fetch_episodes_s'])} | "
|
||||
f"{_format_duration(dataset_episode_count / parallel['fetch_episodes_s'])} | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||
f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, "
|
||||
f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |"
|
||||
)
|
||||
print(
|
||||
f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | - | - | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||
f"{overlapped['samples_s']:.1f} samples/s; video+decode "
|
||||
f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |"
|
||||
)
|
||||
|
||||
|
||||
def run_remote_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> None:
|
||||
_log("starting_strategy: remote-decoder")
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log("remote-decoder: running direct source MP4 decoder")
|
||||
result = run_remote_decoder(
|
||||
meta,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
frames_per_episode=args.frames_per_episode,
|
||||
decode_workers=args.decode_workers,
|
||||
parquet_reader=parquet_reader,
|
||||
)
|
||||
print("strategy: remote-decoder")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"episodes: {episodes}")
|
||||
print(f"cameras: {list(meta.video_keys)}")
|
||||
print()
|
||||
print("| Track | samples/s | notes |")
|
||||
print("|---|---:|---|")
|
||||
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
|
||||
print(
|
||||
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
|
||||
f"direct source MP4 decoder, {args.decode_workers} workers |"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.strategy == "full":
|
||||
args.strategy = "both"
|
||||
if args.strategy == "native-http":
|
||||
args.range_backend = "native-http"
|
||||
data_root = args.data_root
|
||||
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
parquet_reader = EpisodeParquetReader(meta, data_root)
|
||||
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
|
||||
|
||||
if sidecar_path is not None:
|
||||
print(f"using_mp4_sidecar: {sidecar_path}")
|
||||
|
||||
if sidecar_path is not None and args.strategy == "both":
|
||||
if args.include_decode:
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
print()
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend=args.range_backend,
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "indexed":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend=args.range_backend,
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "native-http":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-sidecar-native-http",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy == "both":
|
||||
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
|
||||
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||
print(f"mp4_sidecar_missing_remote: {expected_remote}")
|
||||
print(
|
||||
"build_mp4_sidecar: "
|
||||
"uv run --no-sync python scripts/build_mp4_sidecar.py "
|
||||
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||
)
|
||||
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||
print()
|
||||
|
||||
if args.strategy in ("both", "indexed"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed",
|
||||
sidecar_path=None,
|
||||
)
|
||||
if args.strategy == "both":
|
||||
print()
|
||||
if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode):
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
if args.strategy == "both" and args.include_decode:
|
||||
print()
|
||||
if args.strategy in ("both", "native-http"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http",
|
||||
sidecar_path=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--episodes", type=int, default=None)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument(
|
||||
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def push_sidecar(local_path: str, data_root: str) -> list[str]:
|
||||
if not data_root.startswith("hf://"):
|
||||
return []
|
||||
|
||||
local = Path(local_path)
|
||||
fs = fsspec.filesystem("hf")
|
||||
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
|
||||
remote_paths = [f"{remote_dir}/{local.name}"]
|
||||
|
||||
for remote in remote_paths:
|
||||
fs.put(str(local), remote)
|
||||
return remote_paths
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
total = (
|
||||
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
|
||||
)
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
EpisodeVideoManifest.write_file_sidecar(
|
||||
args.output,
|
||||
rel_paths,
|
||||
args.data_root,
|
||||
range_backend=args.range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f"wrote {args.output}")
|
||||
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
|
||||
if args.no_push:
|
||||
print("push_skipped: --no-push")
|
||||
else:
|
||||
pushed = push_sidecar(args.output, args.data_root)
|
||||
for remote in pushed:
|
||||
print(f"pushed {remote}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -54,6 +54,7 @@ from typing import Any
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
@@ -274,12 +275,11 @@ class LanguageColumnsWriter:
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
# Re-emit one row group per episode (a bulk pq.write_table would collapse
|
||||
# them into one). Write to a sibling tmp path and atomically rename so a
|
||||
# crash mid-write can't leave a half-written shard.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
write_table_one_row_group_per_episode(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
|
||||
@@ -73,8 +73,17 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||
use_async_envs: bool = True
|
||||
# Whether to record eval rollouts as a LeRobot dataset on disk.
|
||||
recording: bool = False
|
||||
# If set, push recorded eval datasets to the Hub under this repo id (one repo per task,
|
||||
# suffixed by task and env index). Requires recording=true.
|
||||
recording_repo_id: str | None = None
|
||||
# Whether the pushed recording repositories should be private.
|
||||
recording_private: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.recording_repo_id is not None and not self.recording:
|
||||
raise ValueError("eval.recording_repo_id requires eval.recording=true.")
|
||||
if self.batch_size == 0:
|
||||
self.batch_size = self._auto_batch_size()
|
||||
if self.batch_size > self.n_episodes:
|
||||
|
||||
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_one_row_group_per_episode,
|
||||
to_parquet_with_hf_images,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -551,6 +552,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
concatenate=concatenate_data,
|
||||
one_row_group_per_episode=True,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
@@ -628,6 +630,7 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
concatenate: bool = True,
|
||||
one_row_group_per_episode: bool = False,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -645,6 +648,8 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
|
||||
the episodes-metadata parquet (already one row per episode).
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
@@ -657,6 +662,8 @@ def append_or_create_parquet_file(
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(df, dst_path)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx, (dst_chunk, dst_file)
|
||||
@@ -683,6 +690,8 @@ def append_or_create_parquet_file(
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(final_df, target_path)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -709,7 +710,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
|
||||
@@ -27,6 +27,7 @@ import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -1101,7 +1102,9 @@ def _copy_episodes_metadata_and_stats(
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info.features[key]["info"] = deepcopy(
|
||||
src_dataset.meta.info.features[key].get("info", {})
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
|
||||
@@ -1,890 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import quote, urljoin, urlparse
|
||||
|
||||
import fsspec
|
||||
import httpx
|
||||
import numpy as np
|
||||
from huggingface_hub import HfApi, HfFileSystem, constants
|
||||
from huggingface_hub.utils import hf_raise_for_status
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EpisodeVideoSpan:
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
first_pts: float
|
||||
last_pts: float
|
||||
frame_count: int
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoFileRecord:
|
||||
file_path: str
|
||||
file_size: int
|
||||
mp4: Mp4Index
|
||||
|
||||
|
||||
class ThreadLocalRangeFetcher:
|
||||
"""Range reader that gives each worker thread independent file handles."""
|
||||
|
||||
def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
protocol = "hf" if self.data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self.block_size = block_size
|
||||
self.cache_type = cache_type
|
||||
self._local = threading.local()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_open_s": 0.0,
|
||||
"range_seek_s": 0.0,
|
||||
"range_read_s": 0.0,
|
||||
}
|
||||
|
||||
def _url(self, relative_path: str) -> str:
|
||||
if self.data_root.startswith("hf://"):
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
return str(Path(self.data_root) / relative_path)
|
||||
|
||||
def _handle(self, relative_path: str):
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
handles = {}
|
||||
self._local.handles = handles
|
||||
handle = handles.get(relative_path)
|
||||
if handle is None or getattr(handle, "closed", False):
|
||||
handle = self.fs.open(
|
||||
self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type
|
||||
)
|
||||
handles[relative_path] = handle
|
||||
return handle
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
return int(self.fs.info(self._url(relative_path))["size"])
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
open_start = time.perf_counter()
|
||||
handle = self._handle(relative_path)
|
||||
open_s = time.perf_counter() - open_start
|
||||
seek_start = time.perf_counter()
|
||||
handle.seek(offset)
|
||||
seek_s = time.perf_counter() - seek_start
|
||||
read_start = time.perf_counter()
|
||||
data = handle.read(length)
|
||||
read_s = time.perf_counter() - read_start
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(data)),
|
||||
range_open_s=open_s,
|
||||
range_seek_s=seek_s,
|
||||
range_read_s=read_s,
|
||||
)
|
||||
return data
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] += value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
return
|
||||
for handle in handles.values():
|
||||
with contextlib.suppress(Exception):
|
||||
handle.close()
|
||||
handles.clear()
|
||||
|
||||
|
||||
class NativeHTTPRangeFetcher:
|
||||
"""Direct pooled HTTP range reader for hf:// paths."""
|
||||
|
||||
_GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
|
||||
_GLOBAL_LOCK = threading.Lock()
|
||||
|
||||
_RETRYABLE_EXCEPTIONS = (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.PoolTimeout,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
max_connections: int = 32,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 4,
|
||||
):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
if not self.data_root.startswith("hf://"):
|
||||
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
|
||||
self.max_retries = max_retries
|
||||
self.api = HfApi()
|
||||
self.fs: HfFileSystem | None = None
|
||||
self._bucket_id: str | None = None
|
||||
self._bucket_prefix = ""
|
||||
if self.data_root.startswith("hf://buckets/"):
|
||||
bucket_root = self.data_root.removeprefix("hf://buckets/")
|
||||
parts = bucket_root.split("/", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid bucket root: {self.data_root}")
|
||||
self._bucket_id = f"{parts[0]}/{parts[1]}"
|
||||
self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else ""
|
||||
else:
|
||||
self.fs = HfFileSystem()
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections),
|
||||
follow_redirects=False,
|
||||
)
|
||||
self._resolved_urls: dict[str, str] = {}
|
||||
self._source_urls: dict[str, str] = {}
|
||||
self._sizes: dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_resolve_s": 0.0,
|
||||
"range_header_s": 0.0,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
"range_retry_attempts": 0.0,
|
||||
"range_retry_sleep_s": 0.0,
|
||||
"range_failed_requests": 0.0,
|
||||
}
|
||||
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self.client.request(method, url, **kwargs)
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
time.sleep(min(0.5 * 2**attempt, 5.0))
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _cache_key(self, relative_path: str) -> tuple[str, str]:
|
||||
return self.data_root, relative_path
|
||||
|
||||
def _path(self, relative_path: str) -> str:
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
|
||||
def _bucket_path(self, relative_path: str) -> str:
|
||||
if self._bucket_prefix:
|
||||
return f"{self._bucket_prefix}/{relative_path}"
|
||||
return relative_path
|
||||
|
||||
def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]:
|
||||
headers = self.api._build_hf_headers()
|
||||
if urlparse(request_url).netloc != urlparse(source_url).netloc:
|
||||
headers.pop("authorization", None)
|
||||
headers.pop("Authorization", None)
|
||||
return headers
|
||||
|
||||
def _source_url(self, relative_path: str) -> str:
|
||||
with self._lock:
|
||||
source = self._source_urls.get(relative_path)
|
||||
if source is not None:
|
||||
return source
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
source = self._GLOBAL_SOURCE_URLS.get(key)
|
||||
if source is None:
|
||||
if self._bucket_id is not None:
|
||||
source = (
|
||||
f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/"
|
||||
f"{quote(self._bucket_path(relative_path))}"
|
||||
)
|
||||
else:
|
||||
if self.fs is None:
|
||||
raise RuntimeError("HfFileSystem fallback was not initialized")
|
||||
source = self.fs.url(self._path(relative_path))
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SOURCE_URLS[key] = source
|
||||
with self._lock:
|
||||
self._source_urls[relative_path] = source
|
||||
return source
|
||||
|
||||
def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str:
|
||||
with self._lock:
|
||||
if not refresh and relative_path in self._resolved_urls:
|
||||
return self._resolved_urls[relative_path]
|
||||
key = self._cache_key(relative_path)
|
||||
if not refresh:
|
||||
with self._GLOBAL_LOCK:
|
||||
resolved = self._GLOBAL_RESOLVED_URLS.get(key)
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if resolved is not None:
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if size is not None:
|
||||
self._sizes[relative_path] = size
|
||||
return resolved
|
||||
|
||||
source = self._source_url(relative_path)
|
||||
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
location = response.headers.get("Location")
|
||||
resolved = urljoin(source, location) if location else source
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._sizes[relative_path] = int(response.headers["Content-Length"])
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_RESOLVED_URLS[key] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"])
|
||||
return resolved
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
with self._lock:
|
||||
size = self._sizes.get(relative_path)
|
||||
if size is not None:
|
||||
return size
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if size is not None:
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
return size
|
||||
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
response = self._request(
|
||||
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
|
||||
)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
size = int(response.headers["Content-Length"])
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SIZES[key] = size
|
||||
return size
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
resolve_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
resolve_s = time.perf_counter() - resolve_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
payload, status_code, timings = self._read_range_response(resolved, headers)
|
||||
if status_code == 403:
|
||||
refresh_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path, refresh=True)
|
||||
resolve_s += time.perf_counter() - refresh_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
|
||||
for key, value in retry_timings.items():
|
||||
timings[key] += value
|
||||
if status_code == 403:
|
||||
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(payload)),
|
||||
range_resolve_s=resolve_s,
|
||||
**timings,
|
||||
)
|
||||
return payload
|
||||
|
||||
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
|
||||
last_exc: Exception | None = None
|
||||
retry_attempts = 0.0
|
||||
retry_sleep_s = 0.0
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
payload, status_code, timings = self._read_range_response_once(url, headers)
|
||||
timings["range_retry_attempts"] = retry_attempts
|
||||
timings["range_retry_sleep_s"] = retry_sleep_s
|
||||
return payload, status_code, timings
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
retry_attempts += 1.0
|
||||
sleep_s = min(0.5 * 2**attempt, 5.0)
|
||||
retry_sleep_s += sleep_s
|
||||
time.sleep(sleep_s)
|
||||
self._record_timing(
|
||||
range_failed_requests=1.0,
|
||||
range_retry_attempts=retry_attempts,
|
||||
range_retry_sleep_s=retry_sleep_s,
|
||||
)
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP range request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _read_range_response_once(
|
||||
self, url: str, headers: dict[str, str]
|
||||
) -> tuple[bytes, int, dict[str, float]]:
|
||||
header_start = time.perf_counter()
|
||||
with self.client.stream("GET", url, headers=headers) as response:
|
||||
header_s = time.perf_counter() - header_start
|
||||
if response.status_code == 403:
|
||||
return (
|
||||
b"",
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
},
|
||||
)
|
||||
hf_raise_for_status(response)
|
||||
chunks = []
|
||||
first_byte_s = 0.0
|
||||
first_chunk = True
|
||||
body_start = time.perf_counter()
|
||||
for chunk in response.iter_bytes():
|
||||
if first_chunk:
|
||||
first_byte_s = time.perf_counter() - body_start
|
||||
first_chunk = False
|
||||
chunks.append(chunk)
|
||||
body_s = time.perf_counter() - body_start
|
||||
return (
|
||||
b"".join(chunks),
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": first_byte_s,
|
||||
"range_body_s": body_s,
|
||||
},
|
||||
)
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] += value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
|
||||
def make_range_fetcher(
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
):
|
||||
if range_backend == "fsspec":
|
||||
return ThreadLocalRangeFetcher(data_root)
|
||||
if range_backend == "native-http":
|
||||
max_connections = native_http_connections or max(8, workers)
|
||||
return NativeHTTPRangeFetcher(
|
||||
data_root,
|
||||
max_connections=max_connections,
|
||||
timeout=native_http_timeout,
|
||||
max_retries=native_http_retries,
|
||||
)
|
||||
raise ValueError(f"Unknown range backend: {range_backend}")
|
||||
|
||||
|
||||
class EpisodeVideoManifest:
|
||||
_FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {}
|
||||
_FILE_SIDECAR_CACHE_LOCK = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
video_keys: list[str],
|
||||
files: list[VideoFileRecord],
|
||||
spans: dict[str, np.ndarray],
|
||||
):
|
||||
self.video_keys = list(video_keys)
|
||||
self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)}
|
||||
self.files = files
|
||||
self.spans = spans
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
episode_indices: list[int] | range | None = None,
|
||||
range_backend: str = "fsspec",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
sidecar_path: str | Path | None = None,
|
||||
) -> EpisodeVideoManifest:
|
||||
meta.ensure_readable()
|
||||
video_keys = list(meta.video_keys)
|
||||
if episode_indices is None:
|
||||
episode_indices = range(int(meta.total_episodes))
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys}
|
||||
)
|
||||
path_to_id = {path: idx for idx, path in enumerate(rel_paths)}
|
||||
if sidecar_path is None:
|
||||
files = cls._build_file_records(
|
||||
rel_paths,
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
else:
|
||||
records = cls.load_file_sidecar(sidecar_path)
|
||||
missing = [path for path in rel_paths if path not in records]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}"
|
||||
)
|
||||
files = [records[path] for path in rel_paths]
|
||||
|
||||
total = int(meta.total_episodes)
|
||||
num_cameras = len(video_keys)
|
||||
spans: dict[str, np.ndarray] = {
|
||||
"file_id": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"mdat_offset": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"mdat_length": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"first_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"last_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"frame_count": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_lo": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_hi": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"source_start_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
}
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep = meta.episodes[ep_idx]
|
||||
for cam_idx, key in enumerate(video_keys):
|
||||
rel_path = str(meta.get_video_file_path(ep_idx, key))
|
||||
file_id = path_to_id[rel_path]
|
||||
mp4 = files[file_id].mp4
|
||||
from_ts = float(ep[f"videos/{key}/from_timestamp"])
|
||||
to_ts = float(ep[f"videos/{key}/to_timestamp"])
|
||||
sample_slice = mp4.sample_slice(
|
||||
from_ts,
|
||||
to_ts,
|
||||
keyframe_pad_s=keyframe_pad_s,
|
||||
keyframe_pad_fraction=keyframe_pad_fraction,
|
||||
file_size=files[file_id].file_size,
|
||||
)
|
||||
spans["file_id"][ep_idx, cam_idx] = file_id
|
||||
spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset
|
||||
spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length
|
||||
spans["first_pts"][ep_idx, cam_idx] = from_ts
|
||||
spans["last_pts"][ep_idx, cam_idx] = to_ts
|
||||
spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1
|
||||
spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo
|
||||
spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi
|
||||
spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts
|
||||
|
||||
return cls(video_keys=video_keys, files=files, spans=spans)
|
||||
|
||||
@staticmethod
|
||||
def _build_file_records(
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
header_probe_bytes: int,
|
||||
max_probe_bytes: int,
|
||||
) -> list[VideoFileRecord]:
|
||||
fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
|
||||
|
||||
def build_file(path: str) -> VideoFileRecord:
|
||||
file_size = fetcher.info_size(path)
|
||||
mp4 = fetch_mp4_index(
|
||||
path,
|
||||
fetcher.read_range,
|
||||
file_size=file_size,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
return VideoFileRecord(path, file_size, mp4)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
return list(pool.map(build_file, rel_paths))
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
@classmethod
|
||||
def write_file_sidecar(
|
||||
cls,
|
||||
sidecar_path: str | Path,
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str = "native-http",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> None:
|
||||
records = cls._build_file_records(
|
||||
sorted(set(rel_paths)),
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
cls.save_file_sidecar(sidecar_path, records)
|
||||
|
||||
@staticmethod
|
||||
def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None:
|
||||
sidecar_path = Path(sidecar_path)
|
||||
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"version": 1,
|
||||
"files": [
|
||||
{"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()}
|
||||
for record in records
|
||||
],
|
||||
}
|
||||
arrays = {}
|
||||
for file_idx, record in enumerate(records):
|
||||
arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts
|
||||
arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations
|
||||
arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes
|
||||
arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets
|
||||
arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples
|
||||
np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays)
|
||||
|
||||
@staticmethod
|
||||
def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]:
|
||||
cache_key = str(Path(sidecar_path).expanduser())
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
with np.load(sidecar_path, allow_pickle=False) as data:
|
||||
payload = json.loads(bytes(data["manifest_json"]).decode("utf-8"))
|
||||
records = {}
|
||||
for file_idx, item in enumerate(payload["files"]):
|
||||
arrays = {
|
||||
name: data[f"{file_idx}/{name}"]
|
||||
for name in [
|
||||
"sample_pts",
|
||||
"sample_durations",
|
||||
"sample_sizes",
|
||||
"sample_offsets",
|
||||
"sync_samples",
|
||||
]
|
||||
}
|
||||
mp4 = Mp4Index.from_dict(item["mp4"], arrays)
|
||||
records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4)
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records
|
||||
return records
|
||||
|
||||
def camera_id(self, camera_key: str) -> int:
|
||||
return self._camera_to_id[camera_key]
|
||||
|
||||
def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan:
|
||||
cam = self.camera_id(camera_key)
|
||||
return EpisodeVideoSpan(
|
||||
file_id=int(self.spans["file_id"][episode_index, cam]),
|
||||
mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]),
|
||||
mdat_length=int(self.spans["mdat_length"][episode_index, cam]),
|
||||
first_pts=float(self.spans["first_pts"][episode_index, cam]),
|
||||
last_pts=float(self.spans["last_pts"][episode_index, cam]),
|
||||
frame_count=int(self.spans["frame_count"][episode_index, cam]),
|
||||
sample_lo=int(self.spans["sample_lo"][episode_index, cam]),
|
||||
sample_hi=int(self.spans["sample_hi"][episode_index, cam]),
|
||||
source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]),
|
||||
)
|
||||
|
||||
def file_lookup(self, file_id: int) -> VideoFileRecord:
|
||||
return self.files[file_id]
|
||||
|
||||
def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index:
|
||||
return self.files[self.lookup(episode_index, camera_key).file_id].mp4
|
||||
|
||||
def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice:
|
||||
span = self.lookup(episode_index, camera_key)
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=span.sample_lo,
|
||||
sample_hi=span.sample_hi,
|
||||
byte_offset=span.mdat_offset,
|
||||
byte_length=span.mdat_length,
|
||||
source_start_pts=span.source_start_pts,
|
||||
)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
def __init__(
|
||||
self,
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
byte_budget: int = 80 * 1024**3,
|
||||
workers: int = 8,
|
||||
range_backend: str = "fsspec",
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
open_decoders: bool = True,
|
||||
):
|
||||
self.manifest = manifest
|
||||
self.fetcher = make_range_fetcher(
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
native_http_connections=native_http_connections,
|
||||
native_http_timeout=native_http_timeout,
|
||||
native_http_retries=native_http_retries,
|
||||
)
|
||||
self.byte_budget = byte_budget
|
||||
self.open_decoders = open_decoders
|
||||
self._pool = ThreadPoolExecutor(max_workers=workers)
|
||||
self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict()
|
||||
self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {}
|
||||
self._bytes = 0
|
||||
self._lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"lookup_s": 0.0,
|
||||
"fetch_s": 0.0,
|
||||
"synthesize_s": 0.0,
|
||||
"store_s": 0.0,
|
||||
"jobs": 0.0,
|
||||
}
|
||||
|
||||
def close(self) -> None:
|
||||
self._pool.shutdown(wait=True)
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._futures.clear()
|
||||
self._bytes = 0
|
||||
self.fetcher.close()
|
||||
|
||||
def __enter__(self) -> EpisodeByteCache:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc) -> None:
|
||||
self.close()
|
||||
|
||||
def submit_prefetch(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self._submit(episode_index, camera_key)
|
||||
|
||||
def ensure_ready(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self.get_bytes(episode_index, camera_key)
|
||||
|
||||
def get_bytes(self, episode_index: int, camera_key: str) -> bytes:
|
||||
return self._get_entry(episode_index, camera_key)["bytes"]
|
||||
|
||||
def get_decoder(self, episode_index: int, camera_key: str):
|
||||
entry = self._get_entry(episode_index, camera_key)
|
||||
decoder = entry.get("decoder")
|
||||
if decoder is None:
|
||||
decoder = open_video_decoder(io.BytesIO(entry["bytes"]))
|
||||
entry["decoder"] = decoder
|
||||
return decoder
|
||||
|
||||
def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]):
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
local_ts = [ts - span.source_start_pts for ts in timestamps]
|
||||
decoder = self.get_decoder(episode_index, camera_key)
|
||||
if hasattr(decoder, "get_frames_played_at"):
|
||||
return decoder.get_frames_played_at(local_ts).data
|
||||
metadata = decoder.metadata
|
||||
fps = getattr(metadata, "average_fps", None)
|
||||
if fps is None:
|
||||
duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9)
|
||||
fps = metadata.num_frames / duration
|
||||
return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._lock:
|
||||
summary = dict(self._timing_totals)
|
||||
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
|
||||
if fetcher_summary is not None:
|
||||
summary.update(fetcher_summary())
|
||||
return summary
|
||||
|
||||
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
future: Future[dict[str, Any]] = Future()
|
||||
future.set_result(self._cache[key])
|
||||
return future
|
||||
future = self._futures.get(key)
|
||||
if future is None:
|
||||
future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key)
|
||||
self._futures[key] = future
|
||||
return future
|
||||
|
||||
def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
entry = self._cache.get(key)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return entry
|
||||
future = self._submit(episode_index, camera_key)
|
||||
entry = future.result()
|
||||
store_start = time.perf_counter()
|
||||
with self._lock:
|
||||
self._futures.pop(key, None)
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return existing
|
||||
self._cache[key] = entry
|
||||
self._bytes += len(entry["bytes"])
|
||||
self._evict_locked()
|
||||
timings = entry.pop("_timings", None)
|
||||
if timings is not None:
|
||||
self._timing_totals["lookup_s"] += timings["lookup_s"]
|
||||
self._timing_totals["fetch_s"] += timings["fetch_s"]
|
||||
self._timing_totals["synthesize_s"] += timings["synthesize_s"]
|
||||
self._timing_totals["store_s"] += time.perf_counter() - store_start
|
||||
self._timing_totals["jobs"] += 1
|
||||
return entry
|
||||
|
||||
def _evict_locked(self) -> None:
|
||||
while self._bytes > self.byte_budget and self._cache:
|
||||
_key, entry = self._cache.popitem(last=False)
|
||||
self._bytes -= len(entry["bytes"])
|
||||
|
||||
def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
lookup_start = time.perf_counter()
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
file_record = self.manifest.file_lookup(span.file_id)
|
||||
sample_slice = Mp4SampleSlice(
|
||||
sample_lo=span.sample_lo,
|
||||
sample_hi=span.sample_hi,
|
||||
byte_offset=span.mdat_offset,
|
||||
byte_length=span.mdat_length,
|
||||
source_start_pts=span.source_start_pts,
|
||||
)
|
||||
lookup_s = time.perf_counter() - lookup_start
|
||||
fetch_start = time.perf_counter()
|
||||
payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length)
|
||||
fetch_s = time.perf_counter() - fetch_start
|
||||
if len(payload) != span.mdat_length:
|
||||
raise OSError(
|
||||
f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}"
|
||||
)
|
||||
synthesize_start = time.perf_counter()
|
||||
mp4_bytes = synthesize_mp4(file_record.mp4, sample_slice, payload)
|
||||
synthesize_s = time.perf_counter() - synthesize_start
|
||||
entry: dict[str, Any] = {
|
||||
"bytes": mp4_bytes,
|
||||
"decoder": None,
|
||||
"_timings": {
|
||||
"lookup_s": lookup_s,
|
||||
"fetch_s": fetch_s,
|
||||
"synthesize_s": synthesize_s,
|
||||
},
|
||||
}
|
||||
if self.open_decoders:
|
||||
entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes))
|
||||
return entry
|
||||
|
||||
|
||||
def open_video_decoder(file_like_or_bytesio, frame_mappings=None):
|
||||
if frame_mappings is not None:
|
||||
raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.")
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
return VideoDecoder(file_like_or_bytesio, seek_mode="approximate")
|
||||
|
||||
|
||||
def assert_hf_hub_range_cache_branch() -> None:
|
||||
"""Fail unless huggingface_hub was installed from the required range-cache branch."""
|
||||
|
||||
try:
|
||||
dist = metadata.distribution("huggingface_hub")
|
||||
except metadata.PackageNotFoundError as exc:
|
||||
raise AssertionError("huggingface_hub is not installed") from exc
|
||||
|
||||
candidates = []
|
||||
direct_url = dist.read_text("direct_url.json")
|
||||
if direct_url:
|
||||
candidates.append(direct_url)
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
parsed = json.loads(direct_url)
|
||||
candidates.append(str(parsed.get("url", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", "")))
|
||||
|
||||
text = "\n".join(candidates)
|
||||
if "feat/hffs-cache-cdn-range-reads" not in text:
|
||||
raise AssertionError(
|
||||
"huggingface_hub must be installed from "
|
||||
"git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageTimer:
|
||||
fetch_ms: float = 0.0
|
||||
decode_ms: float = 0.0
|
||||
bytes_read: int = 0
|
||||
misses: int = 0
|
||||
|
||||
def record_fetch(self, start: float, byte_count: int) -> None:
|
||||
self.fetch_ms += (time.perf_counter() - start) * 1000
|
||||
self.bytes_read += byte_count
|
||||
self.misses += 1
|
||||
@@ -20,6 +20,7 @@ import datasets
|
||||
import numpy as np
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
return items_dict
|
||||
|
||||
|
||||
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
|
||||
"""Write ``table`` with one parquet row group per episode (in episode order).
|
||||
|
||||
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
|
||||
mirroring the recording writer. ``table`` must carry a contiguous
|
||||
``episode_index`` column.
|
||||
"""
|
||||
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
|
||||
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
|
||||
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
|
||||
try:
|
||||
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
|
||||
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
Images are embedded into the arrow table first (``ParquetWriter.write_table``
|
||||
does not embed external image files like ``Dataset.to_parquet`` does).
|
||||
``features`` types image columns as ``Image()`` in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
ds = embed_images(ds)
|
||||
table = ds.with_format("arrow")[:]
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
# No episode boundaries to align row groups to — keep a single write.
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
|
||||
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -1,666 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Box:
|
||||
type: bytes
|
||||
start: int
|
||||
header_size: int
|
||||
end: int
|
||||
|
||||
@property
|
||||
def payload_start(self) -> int:
|
||||
return self.start + self.header_size
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4SampleSlice:
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
byte_offset: int
|
||||
byte_length: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4Index:
|
||||
file_path: str
|
||||
file_size: int
|
||||
ftyp: bytes
|
||||
moov_offset: int
|
||||
mdat_offset: int
|
||||
mdat_payload_offset: int
|
||||
mdat_payload_size: int
|
||||
faststart: bool
|
||||
codec: str
|
||||
timescale: int
|
||||
duration: int
|
||||
track_id: int
|
||||
width: int
|
||||
height: int
|
||||
stsd_body: bytes
|
||||
sample_pts: np.ndarray
|
||||
sample_durations: np.ndarray
|
||||
sample_sizes: np.ndarray
|
||||
sample_offsets: np.ndarray
|
||||
sync_samples: np.ndarray
|
||||
|
||||
def sample_slice(
|
||||
self,
|
||||
from_ts: float,
|
||||
to_ts: float,
|
||||
*,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
file_size: int | None = None,
|
||||
) -> Mp4SampleSlice:
|
||||
if to_ts < from_ts:
|
||||
raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}")
|
||||
if len(self.sample_pts) == 0:
|
||||
raise ValueError(f"{self.file_path} contains no indexed samples")
|
||||
|
||||
pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction)
|
||||
lo_ts = max(0.0, from_ts - pad)
|
||||
hi_ts = to_ts + pad
|
||||
lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left"))
|
||||
hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1
|
||||
lo = min(max(lo, 0), len(self.sample_pts) - 1)
|
||||
hi = min(max(hi, lo), len(self.sample_pts) - 1)
|
||||
|
||||
if len(self.sync_samples):
|
||||
prev_sync = self.sync_samples[self.sync_samples <= lo]
|
||||
if len(prev_sync):
|
||||
lo = int(prev_sync[-1])
|
||||
else:
|
||||
lo = int(self.sync_samples[0])
|
||||
if lo > hi:
|
||||
hi = lo
|
||||
|
||||
offsets = self.sample_offsets[lo : hi + 1]
|
||||
sizes = self.sample_sizes[lo : hi + 1]
|
||||
slice_lo = int(offsets.min())
|
||||
slice_hi = int((offsets + sizes).max())
|
||||
if file_size is not None:
|
||||
slice_hi = min(slice_hi, int(file_size))
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=lo,
|
||||
sample_hi=hi,
|
||||
byte_offset=slice_lo,
|
||||
byte_length=slice_hi - slice_lo,
|
||||
source_start_pts=float(self.sample_pts[lo]),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"file_size": self.file_size,
|
||||
"ftyp": self.ftyp.hex(),
|
||||
"moov_offset": self.moov_offset,
|
||||
"mdat_offset": self.mdat_offset,
|
||||
"mdat_payload_offset": self.mdat_payload_offset,
|
||||
"mdat_payload_size": self.mdat_payload_size,
|
||||
"faststart": self.faststart,
|
||||
"codec": self.codec,
|
||||
"timescale": self.timescale,
|
||||
"duration": self.duration,
|
||||
"track_id": self.track_id,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"stsd_body": self.stsd_body.hex(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index:
|
||||
return cls(
|
||||
file_path=data["file_path"],
|
||||
file_size=int(data["file_size"]),
|
||||
ftyp=bytes.fromhex(data["ftyp"]),
|
||||
moov_offset=int(data["moov_offset"]),
|
||||
mdat_offset=int(data["mdat_offset"]),
|
||||
mdat_payload_offset=int(data["mdat_payload_offset"]),
|
||||
mdat_payload_size=int(data["mdat_payload_size"]),
|
||||
faststart=bool(data["faststart"]),
|
||||
codec=data["codec"],
|
||||
timescale=int(data["timescale"]),
|
||||
duration=int(data["duration"]),
|
||||
track_id=int(data["track_id"]),
|
||||
width=int(data["width"]),
|
||||
height=int(data["height"]),
|
||||
stsd_body=bytes.fromhex(data["stsd_body"]),
|
||||
sample_pts=arrays["sample_pts"],
|
||||
sample_durations=arrays["sample_durations"],
|
||||
sample_sizes=arrays["sample_sizes"],
|
||||
sample_offsets=arrays["sample_offsets"],
|
||||
sync_samples=arrays["sync_samples"],
|
||||
)
|
||||
|
||||
|
||||
def fetch_mp4_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
*,
|
||||
file_size: int,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> Mp4Index:
|
||||
probe_size = min(header_probe_bytes, file_size)
|
||||
while True:
|
||||
data = read_range(path, 0, probe_size)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
has_mdat = any(box.type == b"mdat" for box in top)
|
||||
has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top)
|
||||
if has_mdat and has_moov:
|
||||
return parse_mp4_index(path, data, file_size=file_size)
|
||||
if probe_size >= min(max_probe_bytes, file_size):
|
||||
if has_mdat and not has_moov:
|
||||
tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes)
|
||||
if tail_index is not None:
|
||||
return tail_index
|
||||
missing = []
|
||||
if not has_mdat:
|
||||
missing.append("mdat")
|
||||
if not has_moov:
|
||||
missing.append("moov")
|
||||
raise ValueError(
|
||||
f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}"
|
||||
)
|
||||
probe_size = min(probe_size * 2, max_probe_bytes, file_size)
|
||||
|
||||
|
||||
def _fetch_tail_moov_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
prefix: bytes,
|
||||
top_boxes: list[Box],
|
||||
file_size: int,
|
||||
max_probe_bytes: int,
|
||||
) -> Mp4Index | None:
|
||||
mdat_box = _one(top_boxes, b"mdat")
|
||||
if mdat_box is None or mdat_box.end >= file_size:
|
||||
return None
|
||||
tail_offset = mdat_box.end
|
||||
tail_length = min(max_probe_bytes, file_size - tail_offset)
|
||||
tail = read_range(path, tail_offset, tail_length)
|
||||
tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True))
|
||||
moov_box = next(
|
||||
(box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None
|
||||
)
|
||||
if moov_box is None:
|
||||
return None
|
||||
ftyp_box = _one(top_boxes, b"ftyp", required=False)
|
||||
ftyp = (
|
||||
prefix[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
moov_start = moov_box.payload_start - tail_offset
|
||||
moov_end = moov_box.end - tail_offset
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=tail[moov_start:moov_end],
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index:
|
||||
if file_size is None:
|
||||
file_size = len(data)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
ftyp_box = _one(top, b"ftyp", required=False)
|
||||
moov_box = _one(top, b"moov")
|
||||
mdat_box = _one(top, b"mdat")
|
||||
if moov_box.end > len(data):
|
||||
raise ValueError(f"{path}: moov box is truncated")
|
||||
|
||||
moov = data[moov_box.payload_start : moov_box.end]
|
||||
ftyp = (
|
||||
data[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=moov,
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def _parse_mp4_index_from_layout(
|
||||
path: str,
|
||||
*,
|
||||
file_size: int,
|
||||
ftyp: bytes,
|
||||
moov_offset: int,
|
||||
moov: bytes,
|
||||
mdat_box: Box,
|
||||
) -> Mp4Index:
|
||||
mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"]))
|
||||
trak_box, trak_payload = _find_video_trak(moov)
|
||||
_ = trak_box
|
||||
tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"]))
|
||||
mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"]))
|
||||
stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"])
|
||||
|
||||
stsd = _find_child(stbl, b"stsd")
|
||||
stsd_body = stbl[stsd.payload_start : stsd.end]
|
||||
codec = _parse_stsd_codec(stsd_body)
|
||||
stts = _parse_stts(_payload(stbl, b"stts"))
|
||||
sample_sizes = _parse_stsz(_payload(stbl, b"stsz"))
|
||||
stsc = _parse_stsc(_payload(stbl, b"stsc"))
|
||||
chunk_offsets = _parse_chunk_offsets(stbl)
|
||||
sync_samples = _parse_stss(stbl, len(sample_sizes))
|
||||
|
||||
sample_durations = _expand_stts(stts, len(sample_sizes))
|
||||
sample_pts_units = np.empty(len(sample_durations), dtype=np.int64)
|
||||
if len(sample_durations):
|
||||
sample_pts_units[0] = 0
|
||||
if len(sample_durations) > 1:
|
||||
sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64)
|
||||
sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale)
|
||||
sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes)
|
||||
|
||||
return Mp4Index(
|
||||
file_path=path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_offset,
|
||||
mdat_offset=mdat_box.start,
|
||||
mdat_payload_offset=mdat_box.payload_start,
|
||||
mdat_payload_size=mdat_box.end - mdat_box.payload_start
|
||||
if mdat_box.end <= file_size
|
||||
else file_size - mdat_box.payload_start,
|
||||
faststart=moov_offset < mdat_box.start,
|
||||
codec=codec,
|
||||
timescale=mdhd_timescale,
|
||||
duration=mdhd_duration or mvhd_duration,
|
||||
track_id=tkhd["track_id"],
|
||||
width=tkhd["width"],
|
||||
height=tkhd["height"],
|
||||
stsd_body=stsd_body,
|
||||
sample_pts=sample_pts,
|
||||
sample_durations=sample_durations,
|
||||
sample_sizes=sample_sizes,
|
||||
sample_offsets=sample_offsets,
|
||||
sync_samples=sync_samples,
|
||||
)
|
||||
|
||||
|
||||
def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes:
|
||||
lo = sample_slice.sample_lo
|
||||
hi = sample_slice.sample_hi + 1
|
||||
if lo < 0 or hi > len(index.sample_sizes) or lo >= hi:
|
||||
raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}")
|
||||
|
||||
offsets = index.sample_offsets[lo:hi]
|
||||
sizes = index.sample_sizes[lo:hi]
|
||||
rel_offsets = offsets - sample_slice.byte_offset
|
||||
if int(rel_offsets.min()) != 0:
|
||||
raise ValueError("Sample slice must start at the minimum referenced sample offset")
|
||||
if int((rel_offsets + sizes).max()) > len(mdat_payload):
|
||||
raise ValueError("Sample slice does not cover all referenced samples")
|
||||
|
||||
durations = index.sample_durations[lo:hi]
|
||||
sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0)
|
||||
header_size = len(index.ftyp) + len(moov)
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8)
|
||||
return index.ftyp + moov + _box(b"mdat", mdat_payload)
|
||||
|
||||
|
||||
def iter_boxes(
|
||||
data: bytes,
|
||||
start: int,
|
||||
end: int,
|
||||
*,
|
||||
absolute_base: int = 0,
|
||||
allow_truncated: bool = False,
|
||||
) -> Iterable[Box]:
|
||||
pos = start
|
||||
while pos + 8 <= end:
|
||||
size = struct.unpack_from(">I", data, pos)[0]
|
||||
typ = data[pos + 4 : pos + 8]
|
||||
header_size = 8
|
||||
if size == 1:
|
||||
if pos + 16 > end:
|
||||
break
|
||||
size = struct.unpack_from(">Q", data, pos + 8)[0]
|
||||
header_size = 16
|
||||
elif size == 0:
|
||||
size = end - pos
|
||||
if size < header_size:
|
||||
break
|
||||
box_end = pos + size
|
||||
if box_end > end and not allow_truncated:
|
||||
break
|
||||
yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end)
|
||||
pos = box_end
|
||||
|
||||
|
||||
def _find_video_trak(moov: bytes) -> tuple[Box, bytes]:
|
||||
for trak in _children(moov, 0, len(moov)):
|
||||
if trak.type != b"trak":
|
||||
continue
|
||||
payload = moov[trak.payload_start : trak.end]
|
||||
hdlr = _find_descendant(payload, [b"mdia", b"hdlr"])
|
||||
if hdlr[8:12] == b"vide":
|
||||
return trak, payload
|
||||
raise ValueError("No video track found")
|
||||
|
||||
|
||||
def _find_descendant(data: bytes, path: list[bytes]) -> bytes:
|
||||
current = data
|
||||
for typ in path:
|
||||
box = _find_child(current, typ)
|
||||
current = current[box.payload_start : box.end]
|
||||
return current
|
||||
|
||||
|
||||
def _find_child(data: bytes, typ: bytes) -> Box:
|
||||
for box in _children(data, 0, len(data)):
|
||||
if box.type == typ:
|
||||
return box
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
|
||||
|
||||
def _children(data: bytes, start: int, end: int) -> Iterable[Box]:
|
||||
return iter_boxes(data, start, end, absolute_base=0)
|
||||
|
||||
|
||||
def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None:
|
||||
matches = [box for box in boxes if box.type == typ]
|
||||
if not matches and required:
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
return matches[0] if matches else None
|
||||
|
||||
|
||||
def _payload(parent: bytes, typ: bytes) -> bytes:
|
||||
box = _find_child(parent, typ)
|
||||
return parent[box.payload_start : box.end]
|
||||
|
||||
|
||||
def _parse_mvhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_mdhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_tkhd(payload: bytes) -> dict[str, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
track_id = struct.unpack_from(">I", payload, 20)[0]
|
||||
duration = struct.unpack_from(">Q", payload, 28)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 88)
|
||||
else:
|
||||
track_id = struct.unpack_from(">I", payload, 12)[0]
|
||||
duration = struct.unpack_from(">I", payload, 20)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 76)
|
||||
return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16}
|
||||
|
||||
|
||||
def _parse_stsd_codec(stsd_body: bytes) -> str:
|
||||
if len(stsd_body) < 16:
|
||||
return "unknown"
|
||||
return stsd_body[12:16].decode("latin1")
|
||||
|
||||
|
||||
def _parse_stts(payload: bytes) -> list[tuple[int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">II", payload, offset))
|
||||
offset += 8
|
||||
return out
|
||||
|
||||
|
||||
def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray:
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
pos = 0
|
||||
for count, delta in entries:
|
||||
values[pos : pos + count] = delta
|
||||
pos += count
|
||||
if pos != sample_count:
|
||||
raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}")
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsz(payload: bytes) -> np.ndarray:
|
||||
sample_size, sample_count = struct.unpack_from(">II", payload, 4)
|
||||
if sample_size:
|
||||
return np.full(sample_count, sample_size, dtype=np.int64)
|
||||
offset = 12
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
for idx in range(sample_count):
|
||||
values[idx] = struct.unpack_from(">I", payload, offset)[0]
|
||||
offset += 4
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">III", payload, offset))
|
||||
offset += 12
|
||||
return out
|
||||
|
||||
|
||||
def _parse_chunk_offsets(stbl: bytes) -> np.ndarray:
|
||||
with_stco = None
|
||||
with_co64 = None
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stco":
|
||||
with_stco = stbl[box.payload_start : box.end]
|
||||
elif box.type == b"co64":
|
||||
with_co64 = stbl[box.payload_start : box.end]
|
||||
if with_co64 is not None:
|
||||
count = struct.unpack_from(">I", with_co64, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
if with_stco is None:
|
||||
raise ValueError("Missing stco/co64 chunk offsets")
|
||||
count = struct.unpack_from(">I", with_stco, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
|
||||
|
||||
def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray:
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stss":
|
||||
payload = stbl[box.payload_start : box.end]
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)],
|
||||
dtype=np.int64,
|
||||
)
|
||||
return np.arange(sample_count, dtype=np.int64)
|
||||
|
||||
|
||||
def _sample_offsets(
|
||||
stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray
|
||||
) -> np.ndarray:
|
||||
if not stsc:
|
||||
raise ValueError("stsc is empty")
|
||||
offsets = np.empty(len(sample_sizes), dtype=np.int64)
|
||||
sample_idx = 0
|
||||
for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc):
|
||||
next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1
|
||||
for chunk_number in range(first_chunk, next_first):
|
||||
if chunk_number < 1 or chunk_number > len(chunk_offsets):
|
||||
raise ValueError("stsc references a chunk outside stco/co64")
|
||||
chunk_pos = int(chunk_offsets[chunk_number - 1])
|
||||
for _ in range(samples_per_chunk):
|
||||
if sample_idx >= len(sample_sizes):
|
||||
return offsets
|
||||
offsets[sample_idx] = chunk_pos
|
||||
chunk_pos += int(sample_sizes[sample_idx])
|
||||
sample_idx += 1
|
||||
if sample_idx != len(sample_sizes):
|
||||
raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}")
|
||||
return offsets
|
||||
|
||||
|
||||
def _make_moov(
|
||||
index: Mp4Index,
|
||||
durations: np.ndarray,
|
||||
sizes: np.ndarray,
|
||||
rel_offsets: np.ndarray,
|
||||
sync_samples: np.ndarray,
|
||||
*,
|
||||
mdat_data_offset: int,
|
||||
) -> bytes:
|
||||
duration = int(durations.sum())
|
||||
stco_values = [int(mdat_data_offset + value) for value in rel_offsets]
|
||||
if any(value > 0xFFFFFFFF for value in stco_values):
|
||||
offset_box = _co64(stco_values)
|
||||
else:
|
||||
offset_box = _stco(stco_values)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", index.stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offset_box
|
||||
+ (_stss(sync_samples) if len(sync_samples) else b""),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia)
|
||||
return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak)
|
||||
|
||||
|
||||
def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes:
|
||||
return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload)
|
||||
|
||||
|
||||
def _box(typ: bytes, payload: bytes) -> bytes:
|
||||
size = len(payload) + 8
|
||||
if size <= 0xFFFFFFFF:
|
||||
return struct.pack(">I4s", size, typ) + payload
|
||||
return struct.pack(">I4sQ", 1, typ, size + 8) + payload
|
||||
|
||||
|
||||
def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIII", 0, 0, timescale, duration)
|
||||
+ struct.pack(">IHH", 0x00010000, 0x0100, 0)
|
||||
+ b"\0" * 8
|
||||
+ matrix
|
||||
+ b"\0" * 24
|
||||
+ struct.pack(">I", next_track_id)
|
||||
)
|
||||
return _full_box(b"mvhd", 0, 0, payload)
|
||||
|
||||
|
||||
def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIIII", 0, 0, track_id, 0, duration)
|
||||
+ b"\0" * 8
|
||||
+ struct.pack(">hhhh", 0, 0, 0, 0)
|
||||
+ matrix
|
||||
+ struct.pack(">II", width << 16, height << 16)
|
||||
)
|
||||
return _full_box(b"tkhd", 0, 7, payload)
|
||||
|
||||
|
||||
def _mdhd(timescale: int, duration: int) -> bytes:
|
||||
return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0")
|
||||
|
||||
|
||||
def _hdlr() -> bytes:
|
||||
return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0")
|
||||
|
||||
|
||||
def _vmhd() -> bytes:
|
||||
return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0))
|
||||
|
||||
|
||||
def _dinf() -> bytes:
|
||||
url = _full_box(b"url ", 0, 1)
|
||||
dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url)
|
||||
return _box(b"dinf", dref)
|
||||
|
||||
|
||||
def _stts(durations: np.ndarray) -> bytes:
|
||||
runs = []
|
||||
for duration in durations.tolist():
|
||||
if runs and runs[-1][1] == int(duration):
|
||||
runs[-1][0] += 1
|
||||
else:
|
||||
runs.append([1, int(duration)])
|
||||
payload = struct.pack(">I", len(runs)) + b"".join(
|
||||
struct.pack(">II", count, delta) for count, delta in runs
|
||||
)
|
||||
return _full_box(b"stts", 0, 0, payload)
|
||||
|
||||
|
||||
def _stsc_one_sample_per_chunk(sample_count: int) -> bytes:
|
||||
return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1))
|
||||
|
||||
|
||||
def _stsz(sizes: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stsz",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()),
|
||||
)
|
||||
|
||||
|
||||
def _stco(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _co64(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _stss(values: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stss",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()),
|
||||
)
|
||||
@@ -72,8 +72,9 @@ from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs import FeatureType, parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs import (
|
||||
check_env_attributes_and_types,
|
||||
close_envs,
|
||||
@@ -84,7 +85,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -95,6 +96,65 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict) -> dict:
|
||||
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
shape = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
return features
|
||||
|
||||
|
||||
def _build_raw_frame(
|
||||
raw_obs: dict,
|
||||
env_idx: int,
|
||||
action: np.ndarray,
|
||||
reward: float,
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if key.startswith("next."):
|
||||
continue
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
val = raw_obs[key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -105,6 +165,10 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -145,6 +209,33 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
recording_datasets: list[LeRobotDataset] | None = None
|
||||
raw_observation = None
|
||||
task_desc = ""
|
||||
if recording_dir is not None and env_features is not None:
|
||||
features = _env_features_to_dataset_features(env_features)
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
recording_datasets = []
|
||||
for i in range(env.num_envs):
|
||||
multi_env = env.num_envs > 1
|
||||
root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir)
|
||||
base_repo_id = recording_repo_id or "eval_recording"
|
||||
repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id
|
||||
recording_datasets.append(
|
||||
LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=True,
|
||||
)
|
||||
)
|
||||
raw_observation = deepcopy(observation)
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
task_desc = ""
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
@@ -217,6 +308,26 @@ def rollout(
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
if recording_datasets is not None and raw_observation is not None:
|
||||
prev_done = done.copy()
|
||||
for env_idx in range(env.num_envs):
|
||||
if prev_done[env_idx]:
|
||||
continue
|
||||
frame = _build_raw_frame(
|
||||
raw_observation,
|
||||
env_idx,
|
||||
action_numpy[env_idx],
|
||||
reward[env_idx],
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_datasets[env_idx].features,
|
||||
)
|
||||
recording_datasets[env_idx].add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_datasets[env_idx].save_episode()
|
||||
raw_observation = deepcopy(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||
@@ -255,6 +366,12 @@ def rollout(
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret[OBS_STR] = stacked_observations
|
||||
|
||||
if recording_datasets is not None:
|
||||
for ds in recording_datasets:
|
||||
ds.finalize()
|
||||
if recording_repo_id is not None:
|
||||
ds.push_to_hub(private=recording_private)
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
policy.use_original_modules()
|
||||
|
||||
@@ -273,6 +390,10 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -361,6 +482,10 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -563,6 +688,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
|
||||
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
@@ -572,10 +701,15 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=False,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
recording_dir=recording_dir,
|
||||
env_features=cfg.env.features if cfg.eval.recording else None,
|
||||
recording_repo_id=cfg.eval.recording_repo_id,
|
||||
recording_private=cfg.eval.recording_private,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
@@ -618,6 +752,10 @@ def eval_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> TaskMetrics:
|
||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||
|
||||
@@ -635,6 +773,10 @@ def eval_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -661,6 +803,10 @@ def run_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
):
|
||||
"""
|
||||
Run eval_one for a single (task_group, task_id, env).
|
||||
@@ -672,7 +818,13 @@ def run_one(
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
|
||||
task_recording_dir = None
|
||||
task_repo_id = None
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
if recording_repo_id is not None:
|
||||
task_repo_id = f"{recording_repo_id}_{task_group}_{task_id}"
|
||||
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
@@ -685,8 +837,12 @@ def run_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=task_recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=task_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
@@ -702,6 +858,10 @@ def eval_policy_all(
|
||||
n_episodes: int,
|
||||
*,
|
||||
max_episodes_rendered: int = 0,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
@@ -761,6 +921,10 @@ def eval_policy_all(
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
|
||||
@@ -28,6 +28,7 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
@@ -344,6 +345,78 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
assert len(dataset) == 24
|
||||
|
||||
|
||||
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
|
||||
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
|
||||
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
|
||||
from lerobot.datasets.io_utils import write_tasks
|
||||
from lerobot.utils.io_utils import write_json
|
||||
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
|
||||
for ep, length in enumerate(episode_lengths):
|
||||
episode_index += [ep] * length
|
||||
frame_index += list(range(length))
|
||||
timestamp += [round(i / fps, 6) for i in range(length)]
|
||||
task_index += [0] * length
|
||||
subtask_index += [0] * length # legacy column the writer must drop
|
||||
pd.DataFrame(
|
||||
{
|
||||
"episode_index": episode_index,
|
||||
"frame_index": frame_index,
|
||||
"timestamp": timestamp,
|
||||
"task_index": task_index,
|
||||
"subtask_index": subtask_index,
|
||||
}
|
||||
).to_parquet(data_dir / "file-000.parquet", index=False)
|
||||
|
||||
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
|
||||
write_tasks(tasks_df, root)
|
||||
write_json(
|
||||
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
|
||||
root / "meta" / "info.json",
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
|
||||
"""Rewriting a packed shard must keep one row group per episode, not collapse
|
||||
every episode into a single giant row group."""
|
||||
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
|
||||
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
|
||||
shard = root / "data" / "chunk-000" / "file-000.parquet"
|
||||
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
|
||||
|
||||
staging_dir = tmp_path / "stage"
|
||||
for ep in range(len(episode_lengths)):
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
ep,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"subtask for ep {ep}",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
records = list(iter_episodes(root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, root)
|
||||
|
||||
# One row group per episode, with row counts matching the episode lengths.
|
||||
md = pq.ParquetFile(shard).metadata
|
||||
assert md.num_row_groups == len(episode_lengths)
|
||||
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
|
||||
# Language columns are still present after the per-episode rewrite.
|
||||
table = pq.read_table(shard)
|
||||
assert "language_persistent" in table.column_names
|
||||
assert "language_events" in table.column_names
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
|
||||
@@ -32,6 +32,26 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def assert_data_shards_one_row_group_per_episode(root):
|
||||
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
shards = sorted((root / "data").rglob("*.parquet"))
|
||||
assert shards, f"no data shards found under {root}/data"
|
||||
n_episodes = 0
|
||||
for shard in shards:
|
||||
pf = pq.ParquetFile(shard)
|
||||
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
assert pf.metadata.num_row_groups == len(set(episodes)), shard
|
||||
for i in range(pf.metadata.num_row_groups):
|
||||
rg_episodes = set(
|
||||
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
)
|
||||
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
|
||||
n_episodes += len(set(episodes))
|
||||
return n_episodes
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||
assert aggr_ds.num_episodes == expected_episodes, (
|
||||
@@ -566,6 +586,41 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
|
||||
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
|
||||
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
|
||||
|
||||
Covers both the non-image (``df.to_parquet``) and image
|
||||
(``to_parquet_with_hf_images``) write branches, including the merge-into-
|
||||
existing-file branch via a low file-size threshold that forces packing.
|
||||
"""
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_0",
|
||||
total_episodes=3,
|
||||
total_frames=60,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_1",
|
||||
total_episodes=4,
|
||||
total_frames=80,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
|
||||
aggr_root = tmp_path / "rg_aggr"
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
|
||||
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
|
||||
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
||||
)
|
||||
|
||||
|
||||
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
|
||||
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
|
||||
)
|
||||
dataset_copy = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
|
||||
)
|
||||
|
||||
original_shape = dataset.meta.info.features["state"]["shape"]
|
||||
dataset_copy.meta.info.features["state"]["shape"] = (999,)
|
||||
|
||||
assert dataset.meta.info.features["state"]["shape"] == original_shape
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch
|
||||
from lerobot.datasets.mp4 import (
|
||||
_box,
|
||||
_co64,
|
||||
_dinf,
|
||||
_hdlr,
|
||||
_mdhd,
|
||||
_mvhd,
|
||||
_stco,
|
||||
_stsc_one_sample_per_chunk,
|
||||
_stss,
|
||||
_stsz,
|
||||
_stts,
|
||||
_tkhd,
|
||||
_vmhd,
|
||||
parse_mp4_index,
|
||||
synthesize_mp4,
|
||||
)
|
||||
|
||||
|
||||
def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes:
|
||||
ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
sizes = np.array([10, 10, 10], dtype=np.int64)
|
||||
durations = np.array([1000, 1000, 1000], dtype=np.int64)
|
||||
stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8
|
||||
offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offsets
|
||||
+ _stss(np.array([1], dtype=np.int64)),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia)
|
||||
moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak)
|
||||
mdat_payload_start = 10_000
|
||||
free_size = mdat_payload_start - 8 - len(ftyp) - len(moov)
|
||||
assert free_size >= 8
|
||||
free = _box(b"free", b"\0" * (free_size - 8))
|
||||
return ftyp + moov + free + _box(b"mdat", b"x" * 128)
|
||||
|
||||
|
||||
def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
assert sample_slice.byte_offset == 10_000
|
||||
assert sample_slice.byte_length == 60
|
||||
assert sample_slice.sample_lo == 0
|
||||
assert sample_slice.sample_hi == 2
|
||||
|
||||
|
||||
def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length)
|
||||
mini_index = parse_mp4_index("mini.mp4", mini)
|
||||
|
||||
expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset
|
||||
np.testing.assert_array_equal(mini_index.sample_offsets, expected)
|
||||
np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10]))
|
||||
|
||||
|
||||
def test_parser_accepts_co64_chunk_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True))
|
||||
|
||||
np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025]))
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps(
|
||||
{
|
||||
"url": "https://github.com/huggingface/huggingface_hub.git",
|
||||
"vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"},
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_hf_hub_range_cache_branch()
|
||||
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')",
|
||||
@@ -1089,8 +1089,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "5.0.1.dev0"
|
||||
source = { git = "https://github.com/huggingface/datasets.git?branch=main#06fcc085fcdd22fc5cc741954f6187dd879543b6" }
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
@@ -1107,6 +1107,10 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
@@ -1143,7 +1147,7 @@ name = "decord"
|
||||
version = "0.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy", marker = "(platform_machine != 'arm64' and platform_machine != 's390x' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "numpy", marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" },
|
||||
@@ -2046,8 +2050,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.20.0.dev0"
|
||||
source = { git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads#5319b287faa73239bb40df16d69c39e5d6daf0f7" }
|
||||
version = "1.19.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "filelock" },
|
||||
@@ -2060,6 +2064,10 @@ dependencies = [
|
||||
{ name = "typer" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/27/629cfe58c582f92ded066c4a07d1a057ff617118ab7973200f770bd853cb/huggingface_hub-1.19.0.tar.gz", hash = "sha256:fd771622182d40977272a923953ee3b1b13538f9f8a7f5d78398f10af0f1c0bd", size = 824721, upload-time = "2026-06-11T12:33:18.665Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/a5/558da89f66464d8d0229ff497e8b8666977de2d8cf48c28a2862ecf1250f/huggingface_hub-1.19.0-py3-none-any.whl", hash = "sha256:1dc72e1f6b4d6df6b30eb72e57d00514ef453d660f04af2b87f0e67267f31ee0", size = 693398, upload-time = "2026-06-11T12:33:16.695Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hydra-core"
|
||||
@@ -3179,7 +3187,7 @@ requires-dist = [
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?branch=main" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
@@ -3202,7 +3210,7 @@ requires-dist = [
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user