mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e069557228 | |||
| 58cf6c8710 | |||
| 36470d059e | |||
| 040a1df9d6 | |||
| 87ae050b28 | |||
| 3bec437d83 | |||
| 97f53732bf | |||
| b31837ffeb | |||
| fd822287e4 | |||
| 7e2d7024c4 | |||
| 240393d238 | |||
| 6407a244c0 | |||
| 0511c12b8f | |||
| 0efa3dc874 | |||
| 949f4fcbe9 | |||
| 0d1d5e0a86 | |||
| 84abfe5c60 | |||
| 2201401c99 | |||
| 64773e7b22 |
@@ -167,9 +167,9 @@ jobs:
|
||||
|
||||
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
|
||||
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
|
||||
# immediately runs eval inside the training loop (eval_freq=1, 1 episode).
|
||||
# immediately runs eval inside the training loop (env_eval_freq=1, 1 episode).
|
||||
# Tests the full train→eval-within-training pipeline end-to-end.
|
||||
- name: Run Libero train+eval smoke (1 step, eval_freq=1)
|
||||
- name: Run Libero train+eval smoke (1 step, env_eval_freq=1)
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
docker run --name libero-train-smoke --gpus all \
|
||||
@@ -196,7 +196,7 @@ jobs:
|
||||
--output_dir=/tmp/train-smoke \
|
||||
--steps=1 \
|
||||
--batch_size=1 \
|
||||
--eval_freq=1 \
|
||||
--env_eval_freq=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.use_async_envs=false \
|
||||
|
||||
@@ -58,7 +58,7 @@ test-act-ete-train:
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=4 \
|
||||
--eval_freq=2 \
|
||||
--env_eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_freq=2 \
|
||||
@@ -96,7 +96,7 @@ test-diffusion-ete-train:
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--env_eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
@@ -126,7 +126,7 @@ test-tdmpc-ete-train:
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=2 \
|
||||
--eval_freq=2 \
|
||||
--env_eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_checkpoint=true \
|
||||
@@ -161,7 +161,7 @@ test-smolvla-ete-train:
|
||||
--dataset.episodes="[0]" \
|
||||
--batch_size=2 \
|
||||
--steps=4 \
|
||||
--eval_freq=2 \
|
||||
--env_eval_freq=2 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--save_freq=2 \
|
||||
|
||||
@@ -719,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
||||
"num_workers": 4,
|
||||
"steps": 5000,
|
||||
"log_freq": 10,
|
||||
"eval_freq": 1000,
|
||||
"env_eval_freq": 1000,
|
||||
"save_freq": 1000,
|
||||
"save_checkpoint": true,
|
||||
"seed": 2,
|
||||
|
||||
@@ -143,7 +143,7 @@ lerobot-train \
|
||||
--batch_size=4 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000
|
||||
--env_eval_freq=1000
|
||||
```
|
||||
|
||||
## Reproducing published results
|
||||
|
||||
@@ -173,7 +173,7 @@ lerobot-train \
|
||||
--batch_size=4 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000
|
||||
--env_eval_freq=1000
|
||||
```
|
||||
|
||||
## Relationship to LIBERO
|
||||
|
||||
@@ -120,11 +120,11 @@ lerobot-train \
|
||||
--batch_size=4 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000
|
||||
--env_eval_freq=1000
|
||||
```
|
||||
|
||||
## Practical tips
|
||||
|
||||
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
||||
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
||||
- Adjust `batch_size`, `steps`, and `env_eval_freq` to match your compute budget.
|
||||
|
||||
@@ -103,7 +103,7 @@ accelerate launch \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--env_eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
@@ -142,7 +142,7 @@ accelerate launch \
|
||||
--batch_size=32 \
|
||||
--num_workers=4 \
|
||||
--log_freq=20 \
|
||||
--eval_freq=-1 \
|
||||
--env_eval_freq=-1 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=2000
|
||||
```
|
||||
|
||||
@@ -314,7 +314,7 @@ lerobot-train \
|
||||
--steps=30000 \
|
||||
--save_freq=1000 \
|
||||
--log_freq=100 \
|
||||
--eval_freq=1000 \
|
||||
--env_eval_freq=1000 \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.device=cuda \
|
||||
--policy.horizon=32 \
|
||||
|
||||
@@ -166,7 +166,7 @@ lerobot-train \
|
||||
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--eval_freq=5000 \
|
||||
--env_eval_freq=5000 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=5 \
|
||||
--save_freq=10000
|
||||
|
||||
@@ -165,7 +165,7 @@ lerobot-train \
|
||||
--output_dir=./outputs/smolvla_vlabench_primitive \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--eval_freq=5000 \
|
||||
--env_eval_freq=5000 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--save_freq=10000
|
||||
|
||||
@@ -355,8 +355,6 @@ explicit = true
|
||||
[tool.uv.sources]
|
||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "feat/hffs-cache-cdn-range-reads" }
|
||||
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
@@ -423,7 +421,6 @@ exclude_dirs = [
|
||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||
|
||||
[tool.typos]
|
||||
default.extend-words = { trak = "trak" }
|
||||
default.extend-ignore-re = [
|
||||
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
||||
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
|
||||
|
||||
@@ -1,860 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
import resource
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import (
|
||||
EpisodeByteCache,
|
||||
EpisodeVideoManifest,
|
||||
NativeHTTPRangeFetcher,
|
||||
assert_hf_hub_range_cache_branch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
|
||||
FULL_SIDECAR_NAME = "molmoact2-full.npz"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
|
||||
default="both",
|
||||
help=argparse.SUPPRESS,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--range-backend",
|
||||
choices=("fsspec", "native-http"),
|
||||
default="fsspec",
|
||||
help="Range reader used by indexed/full episode-pool fetch tracks.",
|
||||
)
|
||||
parser.add_argument("--num-episodes", type=int, default=512)
|
||||
parser.add_argument(
|
||||
"--manifest-episodes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit manifest construction to the first N episodes for local smoke tests.",
|
||||
)
|
||||
parser.add_argument("--pool-size", type=int, default=16)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--native-http-connections",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max HTTP connections for --range-backend native-http. Defaults to --workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-retries",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Retries per native HTTP range request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--native-http-timeout",
|
||||
type=float,
|
||||
default=120.0,
|
||||
help="Timeout in seconds for native HTTP requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-decode",
|
||||
action="store_true",
|
||||
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
|
||||
)
|
||||
parser.add_argument("--decode-workers", type=int, default=1)
|
||||
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
||||
parser.add_argument(
|
||||
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
||||
rng = random.Random(seed)
|
||||
upper = min(total, requested)
|
||||
if pool_size > upper:
|
||||
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
||||
return rng.sample(range(upper), pool_size)
|
||||
|
||||
|
||||
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
span = manifest.lookup(ep, camera_key)
|
||||
lo = span.first_pts
|
||||
hi = max(span.last_pts, lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _timestamps_from_meta(
|
||||
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
|
||||
) -> dict[tuple[int, str], list[float]]:
|
||||
rng = random.Random(seed)
|
||||
out: dict[tuple[int, str], list[float]] = {}
|
||||
for ep in episodes:
|
||||
row = meta.episodes[ep]
|
||||
for camera_key in meta.video_keys:
|
||||
lo = float(row[f"videos/{camera_key}/from_timestamp"])
|
||||
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
|
||||
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||
return out
|
||||
|
||||
|
||||
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
total = 0
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
total += manifest.lookup(ep, camera_key).mdat_length
|
||||
return total
|
||||
|
||||
|
||||
def _decode_all(
|
||||
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
||||
) -> float:
|
||||
start = time.perf_counter()
|
||||
items = list(timestamps.items())
|
||||
if decode_workers <= 1:
|
||||
for (ep, camera_key), ts in items:
|
||||
cache.get_frames(ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
|
||||
start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
cache.submit_prefetch(ep)
|
||||
for ep in episodes:
|
||||
cache.ensure_ready(ep)
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
|
||||
if elapsed_s <= 0:
|
||||
return float("inf")
|
||||
return len(episodes) * frames_per_episode / elapsed_s
|
||||
|
||||
|
||||
def _log(message: str) -> None:
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
if seconds < 3600:
|
||||
return f"{seconds / 60:.1f}m"
|
||||
return f"{seconds / 3600:.1f}h"
|
||||
|
||||
|
||||
def _current_rss_mib() -> float | None:
|
||||
status_path = Path("/proc/self/status")
|
||||
if not status_path.exists():
|
||||
return None
|
||||
for line in status_path.read_text().splitlines():
|
||||
if line.startswith("VmRSS:"):
|
||||
return float(line.split()[1]) / 1024
|
||||
return None
|
||||
|
||||
|
||||
def _peak_rss_mib() -> float:
|
||||
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
# Linux reports KiB; macOS reports bytes.
|
||||
if rss > 10**8:
|
||||
return rss / 1024**2
|
||||
return rss / 1024
|
||||
|
||||
|
||||
def _memory_snapshot() -> dict[str, float | None]:
|
||||
return {"rss_mib": _current_rss_mib(), "peak_rss_mib": _peak_rss_mib()}
|
||||
|
||||
|
||||
def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | None]) -> None:
|
||||
start_rss = start["rss_mib"]
|
||||
end_rss = end["rss_mib"]
|
||||
delta = None if start_rss is None or end_rss is None else end_rss - start_rss
|
||||
print()
|
||||
print("| Memory | MiB |")
|
||||
print("|---|---:|")
|
||||
if start_rss is not None:
|
||||
print(f"| rss start | {start_rss:.1f} |")
|
||||
if end_rss is not None:
|
||||
print(f"| rss end | {end_rss:.1f} |")
|
||||
if delta is not None:
|
||||
print(f"| rss delta | {delta:.1f} |")
|
||||
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
|
||||
|
||||
|
||||
def _root_join(data_root: str, relative_path: str) -> str:
|
||||
if data_root.startswith("hf://"):
|
||||
return f"{data_root.rstrip('/')}/{relative_path}"
|
||||
return str(Path(data_root) / relative_path)
|
||||
|
||||
|
||||
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
|
||||
_ = manifest_episode_count
|
||||
local = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
if _valid_sidecar(local):
|
||||
return local
|
||||
if local.exists():
|
||||
print(f"mp4_sidecar_invalid_local: {local}")
|
||||
local.unlink()
|
||||
remote_relative = f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}"
|
||||
remote = _root_join(data_root, remote_relative)
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
fs = fsspec.filesystem(protocol)
|
||||
if not fs.exists(remote):
|
||||
return None
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"downloading_mp4_sidecar: {remote} -> {local}")
|
||||
if data_root.startswith("hf://"):
|
||||
_download_sidecar_native_http(data_root, remote_relative, local)
|
||||
else:
|
||||
fs.get(remote, str(local))
|
||||
return local
|
||||
|
||||
|
||||
def _valid_sidecar(path: Path) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
with np.load(path, allow_pickle=False) as data:
|
||||
return "manifest_json" in data
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
|
||||
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
|
||||
tmp = local.with_suffix(local.suffix + ".tmp")
|
||||
try:
|
||||
size = fetcher.info_size(relative_path)
|
||||
chunk_size = 16 * 1024 * 1024
|
||||
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
|
||||
with tmp.open("wb") as out_file:
|
||||
out_file.truncate(size)
|
||||
|
||||
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
|
||||
offset, length = offset_length
|
||||
return offset, fetcher.read_range(relative_path, offset, length)
|
||||
|
||||
start = time.perf_counter()
|
||||
done = 0
|
||||
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||
futures = [pool.submit(read_chunk, item) for item in ranges]
|
||||
with tmp.open("r+b") as rw_file:
|
||||
for future in futures:
|
||||
offset, data = future.result()
|
||||
rw_file.seek(offset)
|
||||
rw_file.write(data)
|
||||
done += len(data)
|
||||
elapsed = max(time.perf_counter() - start, 1e-9)
|
||||
print(
|
||||
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
|
||||
f"({done / elapsed / 1024**2:.1f} MiB/s)",
|
||||
flush=True,
|
||||
)
|
||||
tmp.replace(local)
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
|
||||
class EpisodeParquetReader:
|
||||
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
|
||||
self.meta = meta
|
||||
self.data_root = data_root
|
||||
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self._episode_row_groups = self._build_episode_row_groups()
|
||||
self._table_cache: dict[str, pa.Table] = {}
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
def read_episode(self, episode_index: int) -> None:
|
||||
relative_path = str(self.meta.get_data_file_path(episode_index))
|
||||
table = self._read_table(relative_path)
|
||||
table.filter(pc.equal(table["episode_index"], episode_index))
|
||||
|
||||
def _read_table(self, relative_path: str) -> pa.Table:
|
||||
with self._cache_lock:
|
||||
table = self._table_cache.get(relative_path)
|
||||
if table is not None:
|
||||
return table
|
||||
with self.fs.open(
|
||||
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
|
||||
) as f:
|
||||
table = pq.ParquetFile(f).read()
|
||||
with self._cache_lock:
|
||||
return self._table_cache.setdefault(relative_path, table)
|
||||
|
||||
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
|
||||
return pool.submit(self.read_episode, episode_index)
|
||||
|
||||
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
|
||||
start = time.perf_counter()
|
||||
if workers <= 1:
|
||||
for ep in episodes:
|
||||
self.read_episode(ep)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
|
||||
for future in futures:
|
||||
future.result()
|
||||
return time.perf_counter() - start
|
||||
|
||||
def _build_episode_row_groups(self) -> dict[int, int]:
|
||||
counts: dict[tuple[int, int], int] = {}
|
||||
row_groups = {}
|
||||
for ep_idx in range(int(self.meta.total_episodes)):
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
|
||||
row_groups[ep_idx] = counts.get(key, 0)
|
||||
counts[key] = row_groups[ep_idx] + 1
|
||||
return row_groups
|
||||
|
||||
|
||||
def run_fetch_pool(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
range_backend: str,
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
native_http_connections=args.native_http_connections,
|
||||
native_http_timeout=args.native_http_timeout,
|
||||
native_http_retries=args.native_http_retries,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
elapsed = _fill_cache(cache, episodes)
|
||||
timings = cache.timing_summary()
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
episode_mb = byte_count / len(episodes) / 1024**2
|
||||
job_count = max(timings["jobs"], 1.0)
|
||||
result = {
|
||||
"fetch_s": elapsed,
|
||||
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / elapsed,
|
||||
"episode_mb": episode_mb,
|
||||
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
|
||||
"jobs": timings["jobs"],
|
||||
"lookup_ms": timings["lookup_s"] * 1000 / job_count,
|
||||
"range_fetch_ms": timings["fetch_s"] * 1000 / job_count,
|
||||
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
|
||||
"store_ms": timings["store_s"] * 1000 / job_count,
|
||||
}
|
||||
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||
return result
|
||||
|
||||
|
||||
def run_parallel(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=False,
|
||||
) as cache:
|
||||
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
|
||||
fetch_s = _fill_cache(cache, episodes)
|
||||
decoder_start = time.perf_counter()
|
||||
for ep in episodes:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_decoder(ep, camera_key)
|
||||
decoder_s = time.perf_counter() - decoder_start
|
||||
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
return {
|
||||
"fetch_s": fetch_s,
|
||||
"fetch_mbps": byte_count / fetch_s / 1024**2,
|
||||
"fetch_episodes_s": len(episodes) / fetch_s,
|
||||
"parquet_s": parquet_s,
|
||||
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
|
||||
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def run_overlapped(
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
byte_budget: int,
|
||||
workers: int,
|
||||
decode_workers: int,
|
||||
frames_per_episode: int,
|
||||
prefetch_ahead: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
range_backend: str,
|
||||
) -> dict[str, float]:
|
||||
with EpisodeByteCache(
|
||||
manifest,
|
||||
data_root,
|
||||
byte_budget=byte_budget,
|
||||
workers=workers,
|
||||
range_backend=range_backend,
|
||||
open_decoders=True,
|
||||
) as cache:
|
||||
start = time.perf_counter()
|
||||
video_wait_decode_s = 0.0
|
||||
parquet_wait_s = 0.0
|
||||
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
|
||||
parquet_futures = {
|
||||
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
|
||||
}
|
||||
for ep in episodes[:prefetch_ahead]:
|
||||
cache.submit_prefetch(ep)
|
||||
try:
|
||||
for idx, ep in enumerate(episodes):
|
||||
next_idx = idx + prefetch_ahead
|
||||
if next_idx < len(episodes):
|
||||
next_ep = episodes[next_idx]
|
||||
cache.submit_prefetch(next_ep)
|
||||
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
|
||||
|
||||
parquet_start = time.perf_counter()
|
||||
parquet_futures.pop(ep).result()
|
||||
parquet_wait_s += time.perf_counter() - parquet_start
|
||||
|
||||
video_start = time.perf_counter()
|
||||
cache.ensure_ready(ep)
|
||||
if decode_workers <= 1:
|
||||
for camera_key in manifest.video_keys:
|
||||
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
|
||||
for camera_key in manifest.video_keys
|
||||
]
|
||||
for future in futures:
|
||||
future.result()
|
||||
video_wait_decode_s += time.perf_counter() - video_start
|
||||
finally:
|
||||
parquet_pool.shutdown(wait=True)
|
||||
elapsed = time.perf_counter() - start
|
||||
return {
|
||||
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
|
||||
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
|
||||
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
|
||||
"wall_s": elapsed,
|
||||
"video_wait_decode_s": video_wait_decode_s,
|
||||
"parquet_wait_s": parquet_wait_s,
|
||||
}
|
||||
|
||||
|
||||
_remote_decoder_local = threading.local()
|
||||
|
||||
|
||||
def _remote_decoder_cache() -> VideoDecoderCache:
|
||||
cache = getattr(_remote_decoder_local, "cache", None)
|
||||
if cache is None:
|
||||
cache = VideoDecoderCache(max_size=None)
|
||||
_remote_decoder_local.cache = cache
|
||||
return cache
|
||||
|
||||
|
||||
def _decode_remote_source(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episode_index: int,
|
||||
camera_key: str,
|
||||
timestamps: list[float],
|
||||
):
|
||||
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
|
||||
return decode_video_frames_torchcodec(
|
||||
video_path,
|
||||
timestamps,
|
||||
tolerance_s=1.0 / float(meta.fps),
|
||||
decoder_cache=_remote_decoder_cache(),
|
||||
return_uint8=True,
|
||||
)
|
||||
|
||||
|
||||
def run_remote_decoder(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
episodes: Sequence[int],
|
||||
timestamps: dict[tuple[int, str], list[float]],
|
||||
*,
|
||||
frames_per_episode: int,
|
||||
decode_workers: int,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> dict[str, float]:
|
||||
items = [
|
||||
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
|
||||
]
|
||||
|
||||
start = time.perf_counter()
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
sequential_s = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
if decode_workers <= 1:
|
||||
for ep, camera_key, ts in items:
|
||||
if camera_key == meta.video_keys[0]:
|
||||
parquet_reader.read_episode(ep)
|
||||
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
|
||||
futures = [
|
||||
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
|
||||
for ep, camera_key, ts in items
|
||||
]
|
||||
for future in parquet_futures:
|
||||
future.result()
|
||||
for future in futures:
|
||||
future.result()
|
||||
parallel_s = time.perf_counter() - start
|
||||
|
||||
return {
|
||||
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
|
||||
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
|
||||
}
|
||||
|
||||
|
||||
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
|
||||
range_jobs = fetch_pool.get("range_jobs", 0.0)
|
||||
if range_jobs <= 0:
|
||||
return
|
||||
|
||||
print()
|
||||
print("| Range Read Stage | avg ms/range |")
|
||||
print("|---|---:|")
|
||||
for key, label in (
|
||||
("range_open_s", "fsspec handle open/lookup"),
|
||||
("range_seek_s", "fsspec seek"),
|
||||
("range_read_s", "fsspec read"),
|
||||
("range_resolve_s", "http URL resolve"),
|
||||
("range_header_s", "http response headers"),
|
||||
("range_first_byte_s", "http first body byte"),
|
||||
("range_body_s", "http body drain"),
|
||||
("range_retry_sleep_s", "http retry sleep"),
|
||||
):
|
||||
value = fetch_pool.get(key)
|
||||
if value is not None:
|
||||
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
|
||||
if "range_retry_attempts" in fetch_pool:
|
||||
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
|
||||
if fetch_pool.get("range_failed_requests"):
|
||||
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
|
||||
print(f"| range reads | {range_jobs:.0f} |")
|
||||
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
|
||||
|
||||
|
||||
def run_indexed_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
*,
|
||||
range_backend: str = "fsspec",
|
||||
label: str = "indexed",
|
||||
sidecar_path: str | None = None,
|
||||
) -> None:
|
||||
_log(f"starting_strategy: {label}")
|
||||
memory_start = _memory_snapshot()
|
||||
manifest_start = time.perf_counter()
|
||||
dataset_episode_count = int(meta.total_episodes)
|
||||
manifest_episode_count = args.manifest_episodes or dataset_episode_count
|
||||
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
|
||||
manifest = EpisodeVideoManifest.build(
|
||||
meta,
|
||||
data_root,
|
||||
episode_indices=range(manifest_episode_count),
|
||||
range_backend=range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
sidecar_path=sidecar_path,
|
||||
)
|
||||
manifest_s = time.perf_counter() - manifest_start
|
||||
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
||||
|
||||
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
|
||||
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
|
||||
byte_budget = int(args.byte_budget_gb * 1024**3)
|
||||
byte_count = _bytes_for(manifest, episodes)
|
||||
_log(
|
||||
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
|
||||
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
|
||||
)
|
||||
|
||||
_log(f"{label}: filling episode byte cache with {args.workers} workers")
|
||||
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
|
||||
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
|
||||
|
||||
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||
print(f"strategy: {label}")
|
||||
print(f"range_backend: {range_backend}")
|
||||
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"dataset_episodes: {dataset_episode_count}")
|
||||
print(f"benchmark_episodes: {benchmark_episode_count}")
|
||||
print(f"pool_episodes: {len(episodes)}")
|
||||
print(f"sampled_episodes: {episodes}")
|
||||
print(f"cameras: {manifest.video_keys}")
|
||||
print()
|
||||
print(
|
||||
"| Track | fetch MB/s | fetch eps/s | wall s | est benchmark | est full dataset | avg MB/camera | notes |"
|
||||
)
|
||||
print("|---|---:|---:|---:|---:|---:|---:|---|")
|
||||
print(
|
||||
f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | "
|
||||
f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | "
|
||||
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |"
|
||||
)
|
||||
print()
|
||||
print("| Camera Job Stage | avg ms/job |")
|
||||
print("|---|---:|")
|
||||
print(f"| manifest lookup | {fetch_pool['lookup_ms']:.3f} |")
|
||||
print(f"| remote byte-range fetch | {fetch_pool['range_fetch_ms']:.3f} |")
|
||||
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
|
||||
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
|
||||
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
|
||||
_print_range_timing_summary(fetch_pool)
|
||||
_print_memory_summary(memory_start, _memory_snapshot())
|
||||
|
||||
if args.include_decode:
|
||||
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log(f"{label}: running parallel video fetch + decode-only")
|
||||
parallel = run_parallel(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
_log(f"{label}: running overlapped end-to-end")
|
||||
overlapped = run_overlapped(
|
||||
manifest,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
byte_budget,
|
||||
args.workers,
|
||||
args.decode_workers,
|
||||
args.frames_per_episode,
|
||||
args.prefetch_ahead,
|
||||
parquet_reader,
|
||||
range_backend,
|
||||
)
|
||||
print(
|
||||
f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
|
||||
f"{parallel['fetch_s']:.2f} | "
|
||||
f"{_format_duration(benchmark_episode_count / parallel['fetch_episodes_s'])} | "
|
||||
f"{_format_duration(dataset_episode_count / parallel['fetch_episodes_s'])} | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||
f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, "
|
||||
f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |"
|
||||
)
|
||||
print(
|
||||
f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | - | - | "
|
||||
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||
f"{overlapped['samples_s']:.1f} samples/s; video+decode "
|
||||
f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |"
|
||||
)
|
||||
|
||||
|
||||
def run_remote_strategy(
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str,
|
||||
args: argparse.Namespace,
|
||||
parquet_reader: EpisodeParquetReader,
|
||||
) -> None:
|
||||
_log("starting_strategy: remote-decoder")
|
||||
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
||||
_log("remote-decoder: running direct source MP4 decoder")
|
||||
result = run_remote_decoder(
|
||||
meta,
|
||||
data_root,
|
||||
episodes,
|
||||
timestamps,
|
||||
frames_per_episode=args.frames_per_episode,
|
||||
decode_workers=args.decode_workers,
|
||||
parquet_reader=parquet_reader,
|
||||
)
|
||||
print("strategy: remote-decoder")
|
||||
print(f"data_root: {data_root}")
|
||||
print(f"episodes: {episodes}")
|
||||
print(f"cameras: {list(meta.video_keys)}")
|
||||
print()
|
||||
print("| Track | samples/s | notes |")
|
||||
print("|---|---:|---|")
|
||||
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
|
||||
print(
|
||||
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
|
||||
f"direct source MP4 decoder, {args.decode_workers} workers |"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.strategy == "full":
|
||||
args.strategy = "both"
|
||||
if args.strategy == "native-http":
|
||||
args.range_backend = "native-http"
|
||||
data_root = args.data_root
|
||||
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
parquet_reader = EpisodeParquetReader(meta, data_root)
|
||||
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
|
||||
|
||||
if sidecar_path is not None:
|
||||
print(f"using_mp4_sidecar: {sidecar_path}")
|
||||
|
||||
if sidecar_path is not None and args.strategy == "both":
|
||||
if args.include_decode:
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
print()
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend=args.range_backend,
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "indexed":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend=args.range_backend,
|
||||
label=f"indexed-sidecar-{args.range_backend}",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if sidecar_path is not None and args.strategy == "native-http":
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-sidecar-native-http",
|
||||
sidecar_path=str(sidecar_path),
|
||||
)
|
||||
return
|
||||
if args.strategy == "both":
|
||||
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
|
||||
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||
print(f"mp4_sidecar_missing_remote: {expected_remote}")
|
||||
print(
|
||||
"build_mp4_sidecar: "
|
||||
"uv run --no-sync python scripts/build_mp4_sidecar.py "
|
||||
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||
)
|
||||
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||
print()
|
||||
|
||||
if args.strategy in ("both", "indexed"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="fsspec",
|
||||
label="indexed",
|
||||
sidecar_path=None,
|
||||
)
|
||||
if args.strategy == "both":
|
||||
print()
|
||||
if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode):
|
||||
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||
if args.strategy == "both" and args.include_decode:
|
||||
print()
|
||||
if args.strategy in ("both", "native-http"):
|
||||
run_indexed_strategy(
|
||||
meta,
|
||||
data_root,
|
||||
args,
|
||||
parquet_reader,
|
||||
range_backend="native-http",
|
||||
label="indexed-native-http",
|
||||
sidecar_path=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
|
||||
|
||||
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
|
||||
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||
parser.add_argument("--output", required=True)
|
||||
parser.add_argument("--episodes", type=int, default=None)
|
||||
parser.add_argument("--workers", type=int, default=8)
|
||||
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
|
||||
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||
parser.add_argument(
|
||||
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
|
||||
)
|
||||
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def push_sidecar(local_path: str, data_root: str) -> list[str]:
|
||||
if not data_root.startswith("hf://"):
|
||||
return []
|
||||
|
||||
local = Path(local_path)
|
||||
fs = fsspec.filesystem("hf")
|
||||
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
|
||||
remote_paths = [f"{remote_dir}/{local.name}"]
|
||||
|
||||
for remote in remote_paths:
|
||||
fs.put(str(local), remote)
|
||||
return remote_paths
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||
meta.ensure_readable()
|
||||
total = (
|
||||
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
|
||||
)
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
EpisodeVideoManifest.write_file_sidecar(
|
||||
args.output,
|
||||
rel_paths,
|
||||
args.data_root,
|
||||
range_backend=args.range_backend,
|
||||
workers=args.workers,
|
||||
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f"wrote {args.output}")
|
||||
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
|
||||
if args.no_push:
|
||||
print("push_skipped: --no-push")
|
||||
else:
|
||||
pushed = push_sidecar(args.output, args.data_root)
|
||||
for remote in pushed:
|
||||
print(f"pushed {remote}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -180,24 +180,32 @@ class WandBLogger:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
batch_data = {}
|
||||
for k, v in d.items():
|
||||
# Skip the custom step key here, it's added to the batch below.
|
||||
if custom_step_key is not None and k == custom_step_key:
|
||||
continue
|
||||
|
||||
if isinstance(v, list):
|
||||
for i, elem in enumerate(v):
|
||||
if isinstance(elem, (int | float)):
|
||||
batch_data[f"{mode}/{k}_{i}"] = elem
|
||||
continue
|
||||
|
||||
if not isinstance(v, (int | float | str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
batch_data[f"{mode}/{k}"] = v
|
||||
|
||||
if batch_data:
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
|
||||
self._wandb.log(batch_data)
|
||||
else:
|
||||
self._wandb.log(data=batch_data, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -39,6 +39,8 @@ class DatasetConfig:
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
streaming: bool = False
|
||||
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
|
||||
eval_split: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
@@ -73,6 +75,8 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||
use_async_envs: bool = True
|
||||
# Whether to record eval rollouts as a LeRobot v3.0 dataset on disk.
|
||||
recording: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size == 0:
|
||||
|
||||
@@ -79,6 +79,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
|
||||
@@ -56,6 +56,8 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -100,8 +100,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
prefetch_factor: int = 4
|
||||
persistent_workers: bool = True
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
# Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
|
||||
env_eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
|
||||
eval_steps: int = 0
|
||||
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
|
||||
max_eval_samples: int = 0
|
||||
tolerance_s: float = 1e-4
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
|
||||
@@ -35,7 +35,7 @@ from .dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
@@ -89,6 +89,7 @@ __all__ = [
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
"make_dataset",
|
||||
"make_train_eval_datasets",
|
||||
"merge_datasets",
|
||||
"modify_features",
|
||||
"modify_tasks",
|
||||
|
||||
@@ -1,890 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import quote, urljoin, urlparse
|
||||
|
||||
import fsspec
|
||||
import httpx
|
||||
import numpy as np
|
||||
from huggingface_hub import HfApi, HfFileSystem, constants
|
||||
from huggingface_hub.utils import hf_raise_for_status
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EpisodeVideoSpan:
|
||||
file_id: int
|
||||
mdat_offset: int
|
||||
mdat_length: int
|
||||
first_pts: float
|
||||
last_pts: float
|
||||
frame_count: int
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoFileRecord:
|
||||
file_path: str
|
||||
file_size: int
|
||||
mp4: Mp4Index
|
||||
|
||||
|
||||
class ThreadLocalRangeFetcher:
|
||||
"""Range reader that gives each worker thread independent file handles."""
|
||||
|
||||
def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
protocol = "hf" if self.data_root.startswith("hf://") else "file"
|
||||
self.fs = fsspec.filesystem(protocol)
|
||||
self.block_size = block_size
|
||||
self.cache_type = cache_type
|
||||
self._local = threading.local()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_open_s": 0.0,
|
||||
"range_seek_s": 0.0,
|
||||
"range_read_s": 0.0,
|
||||
}
|
||||
|
||||
def _url(self, relative_path: str) -> str:
|
||||
if self.data_root.startswith("hf://"):
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
return str(Path(self.data_root) / relative_path)
|
||||
|
||||
def _handle(self, relative_path: str):
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
handles = {}
|
||||
self._local.handles = handles
|
||||
handle = handles.get(relative_path)
|
||||
if handle is None or getattr(handle, "closed", False):
|
||||
handle = self.fs.open(
|
||||
self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type
|
||||
)
|
||||
handles[relative_path] = handle
|
||||
return handle
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
return int(self.fs.info(self._url(relative_path))["size"])
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
open_start = time.perf_counter()
|
||||
handle = self._handle(relative_path)
|
||||
open_s = time.perf_counter() - open_start
|
||||
seek_start = time.perf_counter()
|
||||
handle.seek(offset)
|
||||
seek_s = time.perf_counter() - seek_start
|
||||
read_start = time.perf_counter()
|
||||
data = handle.read(length)
|
||||
read_s = time.perf_counter() - read_start
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(data)),
|
||||
range_open_s=open_s,
|
||||
range_seek_s=seek_s,
|
||||
range_read_s=read_s,
|
||||
)
|
||||
return data
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] += value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
handles = getattr(self._local, "handles", None)
|
||||
if handles is None:
|
||||
return
|
||||
for handle in handles.values():
|
||||
with contextlib.suppress(Exception):
|
||||
handle.close()
|
||||
handles.clear()
|
||||
|
||||
|
||||
class NativeHTTPRangeFetcher:
|
||||
"""Direct pooled HTTP range reader for hf:// paths."""
|
||||
|
||||
_GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {}
|
||||
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
|
||||
_GLOBAL_LOCK = threading.Lock()
|
||||
|
||||
_RETRYABLE_EXCEPTIONS = (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.PoolTimeout,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
max_connections: int = 32,
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 4,
|
||||
):
|
||||
self.data_root = str(data_root).rstrip("/")
|
||||
if not self.data_root.startswith("hf://"):
|
||||
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
|
||||
self.max_retries = max_retries
|
||||
self.api = HfApi()
|
||||
self.fs: HfFileSystem | None = None
|
||||
self._bucket_id: str | None = None
|
||||
self._bucket_prefix = ""
|
||||
if self.data_root.startswith("hf://buckets/"):
|
||||
bucket_root = self.data_root.removeprefix("hf://buckets/")
|
||||
parts = bucket_root.split("/", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid bucket root: {self.data_root}")
|
||||
self._bucket_id = f"{parts[0]}/{parts[1]}"
|
||||
self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else ""
|
||||
else:
|
||||
self.fs = HfFileSystem()
|
||||
self.client = httpx.Client(
|
||||
timeout=timeout,
|
||||
limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections),
|
||||
follow_redirects=False,
|
||||
)
|
||||
self._resolved_urls: dict[str, str] = {}
|
||||
self._source_urls: dict[str, str] = {}
|
||||
self._sizes: dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._timing_lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"range_jobs": 0.0,
|
||||
"range_bytes": 0.0,
|
||||
"range_resolve_s": 0.0,
|
||||
"range_header_s": 0.0,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
"range_retry_attempts": 0.0,
|
||||
"range_retry_sleep_s": 0.0,
|
||||
"range_failed_requests": 0.0,
|
||||
}
|
||||
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self.client.request(method, url, **kwargs)
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
time.sleep(min(0.5 * 2**attempt, 5.0))
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _cache_key(self, relative_path: str) -> tuple[str, str]:
|
||||
return self.data_root, relative_path
|
||||
|
||||
def _path(self, relative_path: str) -> str:
|
||||
return f"{self.data_root}/{relative_path}"
|
||||
|
||||
def _bucket_path(self, relative_path: str) -> str:
|
||||
if self._bucket_prefix:
|
||||
return f"{self._bucket_prefix}/{relative_path}"
|
||||
return relative_path
|
||||
|
||||
def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]:
|
||||
headers = self.api._build_hf_headers()
|
||||
if urlparse(request_url).netloc != urlparse(source_url).netloc:
|
||||
headers.pop("authorization", None)
|
||||
headers.pop("Authorization", None)
|
||||
return headers
|
||||
|
||||
def _source_url(self, relative_path: str) -> str:
|
||||
with self._lock:
|
||||
source = self._source_urls.get(relative_path)
|
||||
if source is not None:
|
||||
return source
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
source = self._GLOBAL_SOURCE_URLS.get(key)
|
||||
if source is None:
|
||||
if self._bucket_id is not None:
|
||||
source = (
|
||||
f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/"
|
||||
f"{quote(self._bucket_path(relative_path))}"
|
||||
)
|
||||
else:
|
||||
if self.fs is None:
|
||||
raise RuntimeError("HfFileSystem fallback was not initialized")
|
||||
source = self.fs.url(self._path(relative_path))
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SOURCE_URLS[key] = source
|
||||
with self._lock:
|
||||
self._source_urls[relative_path] = source
|
||||
return source
|
||||
|
||||
def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str:
|
||||
with self._lock:
|
||||
if not refresh and relative_path in self._resolved_urls:
|
||||
return self._resolved_urls[relative_path]
|
||||
key = self._cache_key(relative_path)
|
||||
if not refresh:
|
||||
with self._GLOBAL_LOCK:
|
||||
resolved = self._GLOBAL_RESOLVED_URLS.get(key)
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if resolved is not None:
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if size is not None:
|
||||
self._sizes[relative_path] = size
|
||||
return resolved
|
||||
|
||||
source = self._source_url(relative_path)
|
||||
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
location = response.headers.get("Location")
|
||||
resolved = urljoin(source, location) if location else source
|
||||
with self._lock:
|
||||
self._resolved_urls[relative_path] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._sizes[relative_path] = int(response.headers["Content-Length"])
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_RESOLVED_URLS[key] = resolved
|
||||
if "Content-Length" in response.headers:
|
||||
self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"])
|
||||
return resolved
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def info_size(self, relative_path: str) -> int:
|
||||
with self._lock:
|
||||
size = self._sizes.get(relative_path)
|
||||
if size is not None:
|
||||
return size
|
||||
key = self._cache_key(relative_path)
|
||||
with self._GLOBAL_LOCK:
|
||||
size = self._GLOBAL_SIZES.get(key)
|
||||
if size is not None:
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
return size
|
||||
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
response = self._request(
|
||||
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
|
||||
)
|
||||
try:
|
||||
hf_raise_for_status(response)
|
||||
size = int(response.headers["Content-Length"])
|
||||
with self._lock:
|
||||
self._sizes[relative_path] = size
|
||||
with self._GLOBAL_LOCK:
|
||||
self._GLOBAL_SIZES[key] = size
|
||||
return size
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||
resolve_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path)
|
||||
source = self._source_url(relative_path)
|
||||
resolve_s = time.perf_counter() - resolve_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
payload, status_code, timings = self._read_range_response(resolved, headers)
|
||||
if status_code == 403:
|
||||
refresh_start = time.perf_counter()
|
||||
resolved = self._resolve_url(relative_path, refresh=True)
|
||||
resolve_s += time.perf_counter() - refresh_start
|
||||
headers = self._headers_for(resolved, source)
|
||||
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
|
||||
for key, value in retry_timings.items():
|
||||
timings[key] += value
|
||||
if status_code == 403:
|
||||
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
|
||||
self._record_timing(
|
||||
range_jobs=1.0,
|
||||
range_bytes=float(len(payload)),
|
||||
range_resolve_s=resolve_s,
|
||||
**timings,
|
||||
)
|
||||
return payload
|
||||
|
||||
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
|
||||
last_exc: Exception | None = None
|
||||
retry_attempts = 0.0
|
||||
retry_sleep_s = 0.0
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
payload, status_code, timings = self._read_range_response_once(url, headers)
|
||||
timings["range_retry_attempts"] = retry_attempts
|
||||
timings["range_retry_sleep_s"] = retry_sleep_s
|
||||
return payload, status_code, timings
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_exc = exc
|
||||
if attempt >= self.max_retries:
|
||||
break
|
||||
retry_attempts += 1.0
|
||||
sleep_s = min(0.5 * 2**attempt, 5.0)
|
||||
retry_sleep_s += sleep_s
|
||||
time.sleep(sleep_s)
|
||||
self._record_timing(
|
||||
range_failed_requests=1.0,
|
||||
range_retry_attempts=retry_attempts,
|
||||
range_retry_sleep_s=retry_sleep_s,
|
||||
)
|
||||
if last_exc is None:
|
||||
raise RuntimeError("HTTP range request failed without an exception")
|
||||
raise last_exc
|
||||
|
||||
def _read_range_response_once(
|
||||
self, url: str, headers: dict[str, str]
|
||||
) -> tuple[bytes, int, dict[str, float]]:
|
||||
header_start = time.perf_counter()
|
||||
with self.client.stream("GET", url, headers=headers) as response:
|
||||
header_s = time.perf_counter() - header_start
|
||||
if response.status_code == 403:
|
||||
return (
|
||||
b"",
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": 0.0,
|
||||
"range_body_s": 0.0,
|
||||
},
|
||||
)
|
||||
hf_raise_for_status(response)
|
||||
chunks = []
|
||||
first_byte_s = 0.0
|
||||
first_chunk = True
|
||||
body_start = time.perf_counter()
|
||||
for chunk in response.iter_bytes():
|
||||
if first_chunk:
|
||||
first_byte_s = time.perf_counter() - body_start
|
||||
first_chunk = False
|
||||
chunks.append(chunk)
|
||||
body_s = time.perf_counter() - body_start
|
||||
return (
|
||||
b"".join(chunks),
|
||||
response.status_code,
|
||||
{
|
||||
"range_header_s": header_s,
|
||||
"range_first_byte_s": first_byte_s,
|
||||
"range_body_s": body_s,
|
||||
},
|
||||
)
|
||||
|
||||
def _record_timing(self, **kwargs: float) -> None:
|
||||
with self._timing_lock:
|
||||
for key, value in kwargs.items():
|
||||
self._timing_totals[key] += value
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._timing_lock:
|
||||
return dict(self._timing_totals)
|
||||
|
||||
def close(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
|
||||
def make_range_fetcher(
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
):
|
||||
if range_backend == "fsspec":
|
||||
return ThreadLocalRangeFetcher(data_root)
|
||||
if range_backend == "native-http":
|
||||
max_connections = native_http_connections or max(8, workers)
|
||||
return NativeHTTPRangeFetcher(
|
||||
data_root,
|
||||
max_connections=max_connections,
|
||||
timeout=native_http_timeout,
|
||||
max_retries=native_http_retries,
|
||||
)
|
||||
raise ValueError(f"Unknown range backend: {range_backend}")
|
||||
|
||||
|
||||
class EpisodeVideoManifest:
|
||||
_FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {}
|
||||
_FILE_SIDECAR_CACHE_LOCK = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
video_keys: list[str],
|
||||
files: list[VideoFileRecord],
|
||||
spans: dict[str, np.ndarray],
|
||||
):
|
||||
self.video_keys = list(video_keys)
|
||||
self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)}
|
||||
self.files = files
|
||||
self.spans = spans
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
episode_indices: list[int] | range | None = None,
|
||||
range_backend: str = "fsspec",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
sidecar_path: str | Path | None = None,
|
||||
) -> EpisodeVideoManifest:
|
||||
meta.ensure_readable()
|
||||
video_keys = list(meta.video_keys)
|
||||
if episode_indices is None:
|
||||
episode_indices = range(int(meta.total_episodes))
|
||||
rel_paths = sorted(
|
||||
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys}
|
||||
)
|
||||
path_to_id = {path: idx for idx, path in enumerate(rel_paths)}
|
||||
if sidecar_path is None:
|
||||
files = cls._build_file_records(
|
||||
rel_paths,
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
else:
|
||||
records = cls.load_file_sidecar(sidecar_path)
|
||||
missing = [path for path in rel_paths if path not in records]
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}"
|
||||
)
|
||||
files = [records[path] for path in rel_paths]
|
||||
|
||||
total = int(meta.total_episodes)
|
||||
num_cameras = len(video_keys)
|
||||
spans: dict[str, np.ndarray] = {
|
||||
"file_id": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"mdat_offset": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"mdat_length": np.zeros((total, num_cameras), dtype=np.int64),
|
||||
"first_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"last_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
"frame_count": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_lo": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"sample_hi": np.zeros((total, num_cameras), dtype=np.int32),
|
||||
"source_start_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||
}
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
ep = meta.episodes[ep_idx]
|
||||
for cam_idx, key in enumerate(video_keys):
|
||||
rel_path = str(meta.get_video_file_path(ep_idx, key))
|
||||
file_id = path_to_id[rel_path]
|
||||
mp4 = files[file_id].mp4
|
||||
from_ts = float(ep[f"videos/{key}/from_timestamp"])
|
||||
to_ts = float(ep[f"videos/{key}/to_timestamp"])
|
||||
sample_slice = mp4.sample_slice(
|
||||
from_ts,
|
||||
to_ts,
|
||||
keyframe_pad_s=keyframe_pad_s,
|
||||
keyframe_pad_fraction=keyframe_pad_fraction,
|
||||
file_size=files[file_id].file_size,
|
||||
)
|
||||
spans["file_id"][ep_idx, cam_idx] = file_id
|
||||
spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset
|
||||
spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length
|
||||
spans["first_pts"][ep_idx, cam_idx] = from_ts
|
||||
spans["last_pts"][ep_idx, cam_idx] = to_ts
|
||||
spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1
|
||||
spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo
|
||||
spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi
|
||||
spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts
|
||||
|
||||
return cls(video_keys=video_keys, files=files, spans=spans)
|
||||
|
||||
@staticmethod
|
||||
def _build_file_records(
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str,
|
||||
workers: int,
|
||||
header_probe_bytes: int,
|
||||
max_probe_bytes: int,
|
||||
) -> list[VideoFileRecord]:
|
||||
fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
|
||||
|
||||
def build_file(path: str) -> VideoFileRecord:
|
||||
file_size = fetcher.info_size(path)
|
||||
mp4 = fetch_mp4_index(
|
||||
path,
|
||||
fetcher.read_range,
|
||||
file_size=file_size,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
return VideoFileRecord(path, file_size, mp4)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
return list(pool.map(build_file, rel_paths))
|
||||
finally:
|
||||
fetcher.close()
|
||||
|
||||
@classmethod
|
||||
def write_file_sidecar(
|
||||
cls,
|
||||
sidecar_path: str | Path,
|
||||
rel_paths: list[str],
|
||||
data_root: str | Path,
|
||||
*,
|
||||
range_backend: str = "native-http",
|
||||
workers: int = 8,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> None:
|
||||
records = cls._build_file_records(
|
||||
sorted(set(rel_paths)),
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
header_probe_bytes=header_probe_bytes,
|
||||
max_probe_bytes=max_probe_bytes,
|
||||
)
|
||||
cls.save_file_sidecar(sidecar_path, records)
|
||||
|
||||
@staticmethod
|
||||
def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None:
|
||||
sidecar_path = Path(sidecar_path)
|
||||
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"version": 1,
|
||||
"files": [
|
||||
{"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()}
|
||||
for record in records
|
||||
],
|
||||
}
|
||||
arrays = {}
|
||||
for file_idx, record in enumerate(records):
|
||||
arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts
|
||||
arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations
|
||||
arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes
|
||||
arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets
|
||||
arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples
|
||||
np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays)
|
||||
|
||||
@staticmethod
|
||||
def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]:
|
||||
cache_key = str(Path(sidecar_path).expanduser())
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
with np.load(sidecar_path, allow_pickle=False) as data:
|
||||
payload = json.loads(bytes(data["manifest_json"]).decode("utf-8"))
|
||||
records = {}
|
||||
for file_idx, item in enumerate(payload["files"]):
|
||||
arrays = {
|
||||
name: data[f"{file_idx}/{name}"]
|
||||
for name in [
|
||||
"sample_pts",
|
||||
"sample_durations",
|
||||
"sample_sizes",
|
||||
"sample_offsets",
|
||||
"sync_samples",
|
||||
]
|
||||
}
|
||||
mp4 = Mp4Index.from_dict(item["mp4"], arrays)
|
||||
records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4)
|
||||
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||
EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records
|
||||
return records
|
||||
|
||||
def camera_id(self, camera_key: str) -> int:
|
||||
return self._camera_to_id[camera_key]
|
||||
|
||||
def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan:
|
||||
cam = self.camera_id(camera_key)
|
||||
return EpisodeVideoSpan(
|
||||
file_id=int(self.spans["file_id"][episode_index, cam]),
|
||||
mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]),
|
||||
mdat_length=int(self.spans["mdat_length"][episode_index, cam]),
|
||||
first_pts=float(self.spans["first_pts"][episode_index, cam]),
|
||||
last_pts=float(self.spans["last_pts"][episode_index, cam]),
|
||||
frame_count=int(self.spans["frame_count"][episode_index, cam]),
|
||||
sample_lo=int(self.spans["sample_lo"][episode_index, cam]),
|
||||
sample_hi=int(self.spans["sample_hi"][episode_index, cam]),
|
||||
source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]),
|
||||
)
|
||||
|
||||
def file_lookup(self, file_id: int) -> VideoFileRecord:
|
||||
return self.files[file_id]
|
||||
|
||||
def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index:
|
||||
return self.files[self.lookup(episode_index, camera_key).file_id].mp4
|
||||
|
||||
def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice:
|
||||
span = self.lookup(episode_index, camera_key)
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=span.sample_lo,
|
||||
sample_hi=span.sample_hi,
|
||||
byte_offset=span.mdat_offset,
|
||||
byte_length=span.mdat_length,
|
||||
source_start_pts=span.source_start_pts,
|
||||
)
|
||||
|
||||
|
||||
class EpisodeByteCache:
|
||||
def __init__(
|
||||
self,
|
||||
manifest: EpisodeVideoManifest,
|
||||
data_root: str | Path,
|
||||
*,
|
||||
byte_budget: int = 80 * 1024**3,
|
||||
workers: int = 8,
|
||||
range_backend: str = "fsspec",
|
||||
native_http_connections: int | None = None,
|
||||
native_http_timeout: float = 60.0,
|
||||
native_http_retries: int = 4,
|
||||
open_decoders: bool = True,
|
||||
):
|
||||
self.manifest = manifest
|
||||
self.fetcher = make_range_fetcher(
|
||||
data_root,
|
||||
range_backend=range_backend,
|
||||
workers=workers,
|
||||
native_http_connections=native_http_connections,
|
||||
native_http_timeout=native_http_timeout,
|
||||
native_http_retries=native_http_retries,
|
||||
)
|
||||
self.byte_budget = byte_budget
|
||||
self.open_decoders = open_decoders
|
||||
self._pool = ThreadPoolExecutor(max_workers=workers)
|
||||
self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict()
|
||||
self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {}
|
||||
self._bytes = 0
|
||||
self._lock = threading.Lock()
|
||||
self._timing_totals = {
|
||||
"lookup_s": 0.0,
|
||||
"fetch_s": 0.0,
|
||||
"synthesize_s": 0.0,
|
||||
"store_s": 0.0,
|
||||
"jobs": 0.0,
|
||||
}
|
||||
|
||||
def close(self) -> None:
|
||||
self._pool.shutdown(wait=True)
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._futures.clear()
|
||||
self._bytes = 0
|
||||
self.fetcher.close()
|
||||
|
||||
def __enter__(self) -> EpisodeByteCache:
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc) -> None:
|
||||
self.close()
|
||||
|
||||
def submit_prefetch(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self._submit(episode_index, camera_key)
|
||||
|
||||
def ensure_ready(self, episode_index: int) -> None:
|
||||
for camera_key in self.manifest.video_keys:
|
||||
self.get_bytes(episode_index, camera_key)
|
||||
|
||||
def get_bytes(self, episode_index: int, camera_key: str) -> bytes:
|
||||
return self._get_entry(episode_index, camera_key)["bytes"]
|
||||
|
||||
def get_decoder(self, episode_index: int, camera_key: str):
|
||||
entry = self._get_entry(episode_index, camera_key)
|
||||
decoder = entry.get("decoder")
|
||||
if decoder is None:
|
||||
decoder = open_video_decoder(io.BytesIO(entry["bytes"]))
|
||||
entry["decoder"] = decoder
|
||||
return decoder
|
||||
|
||||
def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]):
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
local_ts = [ts - span.source_start_pts for ts in timestamps]
|
||||
decoder = self.get_decoder(episode_index, camera_key)
|
||||
if hasattr(decoder, "get_frames_played_at"):
|
||||
return decoder.get_frames_played_at(local_ts).data
|
||||
metadata = decoder.metadata
|
||||
fps = getattr(metadata, "average_fps", None)
|
||||
if fps is None:
|
||||
duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9)
|
||||
fps = metadata.num_frames / duration
|
||||
return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data
|
||||
|
||||
def timing_summary(self) -> dict[str, float]:
|
||||
with self._lock:
|
||||
summary = dict(self._timing_totals)
|
||||
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
|
||||
if fetcher_summary is not None:
|
||||
summary.update(fetcher_summary())
|
||||
return summary
|
||||
|
||||
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
future: Future[dict[str, Any]] = Future()
|
||||
future.set_result(self._cache[key])
|
||||
return future
|
||||
future = self._futures.get(key)
|
||||
if future is None:
|
||||
future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key)
|
||||
self._futures[key] = future
|
||||
return future
|
||||
|
||||
def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
key = (episode_index, camera_key)
|
||||
with self._lock:
|
||||
entry = self._cache.get(key)
|
||||
if entry is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return entry
|
||||
future = self._submit(episode_index, camera_key)
|
||||
entry = future.result()
|
||||
store_start = time.perf_counter()
|
||||
with self._lock:
|
||||
self._futures.pop(key, None)
|
||||
existing = self._cache.get(key)
|
||||
if existing is not None:
|
||||
self._cache.move_to_end(key)
|
||||
return existing
|
||||
self._cache[key] = entry
|
||||
self._bytes += len(entry["bytes"])
|
||||
self._evict_locked()
|
||||
timings = entry.pop("_timings", None)
|
||||
if timings is not None:
|
||||
self._timing_totals["lookup_s"] += timings["lookup_s"]
|
||||
self._timing_totals["fetch_s"] += timings["fetch_s"]
|
||||
self._timing_totals["synthesize_s"] += timings["synthesize_s"]
|
||||
self._timing_totals["store_s"] += time.perf_counter() - store_start
|
||||
self._timing_totals["jobs"] += 1
|
||||
return entry
|
||||
|
||||
def _evict_locked(self) -> None:
|
||||
while self._bytes > self.byte_budget and self._cache:
|
||||
_key, entry = self._cache.popitem(last=False)
|
||||
self._bytes -= len(entry["bytes"])
|
||||
|
||||
def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||
lookup_start = time.perf_counter()
|
||||
span = self.manifest.lookup(episode_index, camera_key)
|
||||
file_record = self.manifest.file_lookup(span.file_id)
|
||||
sample_slice = Mp4SampleSlice(
|
||||
sample_lo=span.sample_lo,
|
||||
sample_hi=span.sample_hi,
|
||||
byte_offset=span.mdat_offset,
|
||||
byte_length=span.mdat_length,
|
||||
source_start_pts=span.source_start_pts,
|
||||
)
|
||||
lookup_s = time.perf_counter() - lookup_start
|
||||
fetch_start = time.perf_counter()
|
||||
payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length)
|
||||
fetch_s = time.perf_counter() - fetch_start
|
||||
if len(payload) != span.mdat_length:
|
||||
raise OSError(
|
||||
f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}"
|
||||
)
|
||||
synthesize_start = time.perf_counter()
|
||||
mp4_bytes = synthesize_mp4(file_record.mp4, sample_slice, payload)
|
||||
synthesize_s = time.perf_counter() - synthesize_start
|
||||
entry: dict[str, Any] = {
|
||||
"bytes": mp4_bytes,
|
||||
"decoder": None,
|
||||
"_timings": {
|
||||
"lookup_s": lookup_s,
|
||||
"fetch_s": fetch_s,
|
||||
"synthesize_s": synthesize_s,
|
||||
},
|
||||
}
|
||||
if self.open_decoders:
|
||||
entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes))
|
||||
return entry
|
||||
|
||||
|
||||
def open_video_decoder(file_like_or_bytesio, frame_mappings=None):
|
||||
if frame_mappings is not None:
|
||||
raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.")
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
return VideoDecoder(file_like_or_bytesio, seek_mode="approximate")
|
||||
|
||||
|
||||
def assert_hf_hub_range_cache_branch() -> None:
|
||||
"""Fail unless huggingface_hub was installed from the required range-cache branch."""
|
||||
|
||||
try:
|
||||
dist = metadata.distribution("huggingface_hub")
|
||||
except metadata.PackageNotFoundError as exc:
|
||||
raise AssertionError("huggingface_hub is not installed") from exc
|
||||
|
||||
candidates = []
|
||||
direct_url = dist.read_text("direct_url.json")
|
||||
if direct_url:
|
||||
candidates.append(direct_url)
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
parsed = json.loads(direct_url)
|
||||
candidates.append(str(parsed.get("url", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", "")))
|
||||
candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", "")))
|
||||
|
||||
text = "\n".join(candidates)
|
||||
if "feat/hffs-cache-cdn-range-reads" not in text:
|
||||
raise AssertionError(
|
||||
"huggingface_hub must be installed from "
|
||||
"git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageTimer:
|
||||
fetch_ms: float = 0.0
|
||||
decode_ms: float = 0.0
|
||||
bytes_read: int = 0
|
||||
misses: int = 0
|
||||
|
||||
def record_fetch(self, start: float, byte_count: int) -> None:
|
||||
self.fetch_ms += (time.perf_counter() - start) * 1000
|
||||
self.bytes_read += byte_count
|
||||
self.misses += 1
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
@@ -130,3 +131,81 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def make_train_eval_datasets(
|
||||
cfg: TrainPipelineConfig,
|
||||
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
|
||||
"""Create train and optional eval datasets by splitting episodes based on eval_split.
|
||||
|
||||
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
|
||||
If eval_split == 0.0, returns (full_dataset, None).
|
||||
"""
|
||||
full_dataset = make_dataset(cfg)
|
||||
|
||||
if cfg.dataset.eval_split == 0.0:
|
||||
return full_dataset, None
|
||||
|
||||
base_episodes = (
|
||||
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
|
||||
)
|
||||
|
||||
episode_tasks = full_dataset.meta.episodes["tasks"]
|
||||
task_to_episodes: dict[str, list[int]] = {}
|
||||
for ep_idx in base_episodes:
|
||||
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
|
||||
task_to_episodes.setdefault(task_key, []).append(ep_idx)
|
||||
|
||||
train_episodes, eval_episodes = [], []
|
||||
for eps in task_to_episodes.values():
|
||||
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
|
||||
train_episodes.extend(eps[: len(eps) - n_eval])
|
||||
eval_episodes.extend(eps[len(eps) - n_eval :])
|
||||
|
||||
if not train_episodes:
|
||||
raise ValueError(
|
||||
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
|
||||
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
|
||||
)
|
||||
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
|
||||
|
||||
train_image_transforms = (
|
||||
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||
)
|
||||
|
||||
train_dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=train_episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=train_image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
|
||||
eval_dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=eval_episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=None,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for ds in (train_dataset, eval_dataset):
|
||||
for key in ds.meta.camera_keys:
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
|
||||
@@ -474,6 +474,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if reader.hf_dataset is None:
|
||||
# One-shot load after finalize()
|
||||
reader.load_and_activate()
|
||||
if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx:
|
||||
idx = reader._absolute_to_relative_idx[idx]
|
||||
return reader.get_item(idx)
|
||||
|
||||
def select_columns(self, column_names: str | list[str]):
|
||||
|
||||
@@ -1,666 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Box:
|
||||
type: bytes
|
||||
start: int
|
||||
header_size: int
|
||||
end: int
|
||||
|
||||
@property
|
||||
def payload_start(self) -> int:
|
||||
return self.start + self.header_size
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4SampleSlice:
|
||||
sample_lo: int
|
||||
sample_hi: int
|
||||
byte_offset: int
|
||||
byte_length: int
|
||||
source_start_pts: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Mp4Index:
|
||||
file_path: str
|
||||
file_size: int
|
||||
ftyp: bytes
|
||||
moov_offset: int
|
||||
mdat_offset: int
|
||||
mdat_payload_offset: int
|
||||
mdat_payload_size: int
|
||||
faststart: bool
|
||||
codec: str
|
||||
timescale: int
|
||||
duration: int
|
||||
track_id: int
|
||||
width: int
|
||||
height: int
|
||||
stsd_body: bytes
|
||||
sample_pts: np.ndarray
|
||||
sample_durations: np.ndarray
|
||||
sample_sizes: np.ndarray
|
||||
sample_offsets: np.ndarray
|
||||
sync_samples: np.ndarray
|
||||
|
||||
def sample_slice(
|
||||
self,
|
||||
from_ts: float,
|
||||
to_ts: float,
|
||||
*,
|
||||
keyframe_pad_s: float = 0.1,
|
||||
keyframe_pad_fraction: float = 0.05,
|
||||
file_size: int | None = None,
|
||||
) -> Mp4SampleSlice:
|
||||
if to_ts < from_ts:
|
||||
raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}")
|
||||
if len(self.sample_pts) == 0:
|
||||
raise ValueError(f"{self.file_path} contains no indexed samples")
|
||||
|
||||
pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction)
|
||||
lo_ts = max(0.0, from_ts - pad)
|
||||
hi_ts = to_ts + pad
|
||||
lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left"))
|
||||
hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1
|
||||
lo = min(max(lo, 0), len(self.sample_pts) - 1)
|
||||
hi = min(max(hi, lo), len(self.sample_pts) - 1)
|
||||
|
||||
if len(self.sync_samples):
|
||||
prev_sync = self.sync_samples[self.sync_samples <= lo]
|
||||
if len(prev_sync):
|
||||
lo = int(prev_sync[-1])
|
||||
else:
|
||||
lo = int(self.sync_samples[0])
|
||||
if lo > hi:
|
||||
hi = lo
|
||||
|
||||
offsets = self.sample_offsets[lo : hi + 1]
|
||||
sizes = self.sample_sizes[lo : hi + 1]
|
||||
slice_lo = int(offsets.min())
|
||||
slice_hi = int((offsets + sizes).max())
|
||||
if file_size is not None:
|
||||
slice_hi = min(slice_hi, int(file_size))
|
||||
return Mp4SampleSlice(
|
||||
sample_lo=lo,
|
||||
sample_hi=hi,
|
||||
byte_offset=slice_lo,
|
||||
byte_length=slice_hi - slice_lo,
|
||||
source_start_pts=float(self.sample_pts[lo]),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file_path": self.file_path,
|
||||
"file_size": self.file_size,
|
||||
"ftyp": self.ftyp.hex(),
|
||||
"moov_offset": self.moov_offset,
|
||||
"mdat_offset": self.mdat_offset,
|
||||
"mdat_payload_offset": self.mdat_payload_offset,
|
||||
"mdat_payload_size": self.mdat_payload_size,
|
||||
"faststart": self.faststart,
|
||||
"codec": self.codec,
|
||||
"timescale": self.timescale,
|
||||
"duration": self.duration,
|
||||
"track_id": self.track_id,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"stsd_body": self.stsd_body.hex(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index:
|
||||
return cls(
|
||||
file_path=data["file_path"],
|
||||
file_size=int(data["file_size"]),
|
||||
ftyp=bytes.fromhex(data["ftyp"]),
|
||||
moov_offset=int(data["moov_offset"]),
|
||||
mdat_offset=int(data["mdat_offset"]),
|
||||
mdat_payload_offset=int(data["mdat_payload_offset"]),
|
||||
mdat_payload_size=int(data["mdat_payload_size"]),
|
||||
faststart=bool(data["faststart"]),
|
||||
codec=data["codec"],
|
||||
timescale=int(data["timescale"]),
|
||||
duration=int(data["duration"]),
|
||||
track_id=int(data["track_id"]),
|
||||
width=int(data["width"]),
|
||||
height=int(data["height"]),
|
||||
stsd_body=bytes.fromhex(data["stsd_body"]),
|
||||
sample_pts=arrays["sample_pts"],
|
||||
sample_durations=arrays["sample_durations"],
|
||||
sample_sizes=arrays["sample_sizes"],
|
||||
sample_offsets=arrays["sample_offsets"],
|
||||
sync_samples=arrays["sync_samples"],
|
||||
)
|
||||
|
||||
|
||||
def fetch_mp4_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
*,
|
||||
file_size: int,
|
||||
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||
) -> Mp4Index:
|
||||
probe_size = min(header_probe_bytes, file_size)
|
||||
while True:
|
||||
data = read_range(path, 0, probe_size)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
has_mdat = any(box.type == b"mdat" for box in top)
|
||||
has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top)
|
||||
if has_mdat and has_moov:
|
||||
return parse_mp4_index(path, data, file_size=file_size)
|
||||
if probe_size >= min(max_probe_bytes, file_size):
|
||||
if has_mdat and not has_moov:
|
||||
tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes)
|
||||
if tail_index is not None:
|
||||
return tail_index
|
||||
missing = []
|
||||
if not has_mdat:
|
||||
missing.append("mdat")
|
||||
if not has_moov:
|
||||
missing.append("moov")
|
||||
raise ValueError(
|
||||
f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}"
|
||||
)
|
||||
probe_size = min(probe_size * 2, max_probe_bytes, file_size)
|
||||
|
||||
|
||||
def _fetch_tail_moov_index(
|
||||
path: str,
|
||||
read_range: Callable[[str, int, int], bytes],
|
||||
prefix: bytes,
|
||||
top_boxes: list[Box],
|
||||
file_size: int,
|
||||
max_probe_bytes: int,
|
||||
) -> Mp4Index | None:
|
||||
mdat_box = _one(top_boxes, b"mdat")
|
||||
if mdat_box is None or mdat_box.end >= file_size:
|
||||
return None
|
||||
tail_offset = mdat_box.end
|
||||
tail_length = min(max_probe_bytes, file_size - tail_offset)
|
||||
tail = read_range(path, tail_offset, tail_length)
|
||||
tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True))
|
||||
moov_box = next(
|
||||
(box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None
|
||||
)
|
||||
if moov_box is None:
|
||||
return None
|
||||
ftyp_box = _one(top_boxes, b"ftyp", required=False)
|
||||
ftyp = (
|
||||
prefix[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
moov_start = moov_box.payload_start - tail_offset
|
||||
moov_end = moov_box.end - tail_offset
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=tail[moov_start:moov_end],
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index:
|
||||
if file_size is None:
|
||||
file_size = len(data)
|
||||
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||
ftyp_box = _one(top, b"ftyp", required=False)
|
||||
moov_box = _one(top, b"moov")
|
||||
mdat_box = _one(top, b"mdat")
|
||||
if moov_box.end > len(data):
|
||||
raise ValueError(f"{path}: moov box is truncated")
|
||||
|
||||
moov = data[moov_box.payload_start : moov_box.end]
|
||||
ftyp = (
|
||||
data[ftyp_box.start : ftyp_box.end]
|
||||
if ftyp_box is not None
|
||||
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
)
|
||||
return _parse_mp4_index_from_layout(
|
||||
path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_box.start,
|
||||
moov=moov,
|
||||
mdat_box=mdat_box,
|
||||
)
|
||||
|
||||
|
||||
def _parse_mp4_index_from_layout(
|
||||
path: str,
|
||||
*,
|
||||
file_size: int,
|
||||
ftyp: bytes,
|
||||
moov_offset: int,
|
||||
moov: bytes,
|
||||
mdat_box: Box,
|
||||
) -> Mp4Index:
|
||||
mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"]))
|
||||
trak_box, trak_payload = _find_video_trak(moov)
|
||||
_ = trak_box
|
||||
tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"]))
|
||||
mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"]))
|
||||
stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"])
|
||||
|
||||
stsd = _find_child(stbl, b"stsd")
|
||||
stsd_body = stbl[stsd.payload_start : stsd.end]
|
||||
codec = _parse_stsd_codec(stsd_body)
|
||||
stts = _parse_stts(_payload(stbl, b"stts"))
|
||||
sample_sizes = _parse_stsz(_payload(stbl, b"stsz"))
|
||||
stsc = _parse_stsc(_payload(stbl, b"stsc"))
|
||||
chunk_offsets = _parse_chunk_offsets(stbl)
|
||||
sync_samples = _parse_stss(stbl, len(sample_sizes))
|
||||
|
||||
sample_durations = _expand_stts(stts, len(sample_sizes))
|
||||
sample_pts_units = np.empty(len(sample_durations), dtype=np.int64)
|
||||
if len(sample_durations):
|
||||
sample_pts_units[0] = 0
|
||||
if len(sample_durations) > 1:
|
||||
sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64)
|
||||
sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale)
|
||||
sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes)
|
||||
|
||||
return Mp4Index(
|
||||
file_path=path,
|
||||
file_size=file_size,
|
||||
ftyp=ftyp,
|
||||
moov_offset=moov_offset,
|
||||
mdat_offset=mdat_box.start,
|
||||
mdat_payload_offset=mdat_box.payload_start,
|
||||
mdat_payload_size=mdat_box.end - mdat_box.payload_start
|
||||
if mdat_box.end <= file_size
|
||||
else file_size - mdat_box.payload_start,
|
||||
faststart=moov_offset < mdat_box.start,
|
||||
codec=codec,
|
||||
timescale=mdhd_timescale,
|
||||
duration=mdhd_duration or mvhd_duration,
|
||||
track_id=tkhd["track_id"],
|
||||
width=tkhd["width"],
|
||||
height=tkhd["height"],
|
||||
stsd_body=stsd_body,
|
||||
sample_pts=sample_pts,
|
||||
sample_durations=sample_durations,
|
||||
sample_sizes=sample_sizes,
|
||||
sample_offsets=sample_offsets,
|
||||
sync_samples=sync_samples,
|
||||
)
|
||||
|
||||
|
||||
def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes:
|
||||
lo = sample_slice.sample_lo
|
||||
hi = sample_slice.sample_hi + 1
|
||||
if lo < 0 or hi > len(index.sample_sizes) or lo >= hi:
|
||||
raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}")
|
||||
|
||||
offsets = index.sample_offsets[lo:hi]
|
||||
sizes = index.sample_sizes[lo:hi]
|
||||
rel_offsets = offsets - sample_slice.byte_offset
|
||||
if int(rel_offsets.min()) != 0:
|
||||
raise ValueError("Sample slice must start at the minimum referenced sample offset")
|
||||
if int((rel_offsets + sizes).max()) > len(mdat_payload):
|
||||
raise ValueError("Sample slice does not cover all referenced samples")
|
||||
|
||||
durations = index.sample_durations[lo:hi]
|
||||
sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0)
|
||||
header_size = len(index.ftyp) + len(moov)
|
||||
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8)
|
||||
return index.ftyp + moov + _box(b"mdat", mdat_payload)
|
||||
|
||||
|
||||
def iter_boxes(
|
||||
data: bytes,
|
||||
start: int,
|
||||
end: int,
|
||||
*,
|
||||
absolute_base: int = 0,
|
||||
allow_truncated: bool = False,
|
||||
) -> Iterable[Box]:
|
||||
pos = start
|
||||
while pos + 8 <= end:
|
||||
size = struct.unpack_from(">I", data, pos)[0]
|
||||
typ = data[pos + 4 : pos + 8]
|
||||
header_size = 8
|
||||
if size == 1:
|
||||
if pos + 16 > end:
|
||||
break
|
||||
size = struct.unpack_from(">Q", data, pos + 8)[0]
|
||||
header_size = 16
|
||||
elif size == 0:
|
||||
size = end - pos
|
||||
if size < header_size:
|
||||
break
|
||||
box_end = pos + size
|
||||
if box_end > end and not allow_truncated:
|
||||
break
|
||||
yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end)
|
||||
pos = box_end
|
||||
|
||||
|
||||
def _find_video_trak(moov: bytes) -> tuple[Box, bytes]:
|
||||
for trak in _children(moov, 0, len(moov)):
|
||||
if trak.type != b"trak":
|
||||
continue
|
||||
payload = moov[trak.payload_start : trak.end]
|
||||
hdlr = _find_descendant(payload, [b"mdia", b"hdlr"])
|
||||
if hdlr[8:12] == b"vide":
|
||||
return trak, payload
|
||||
raise ValueError("No video track found")
|
||||
|
||||
|
||||
def _find_descendant(data: bytes, path: list[bytes]) -> bytes:
|
||||
current = data
|
||||
for typ in path:
|
||||
box = _find_child(current, typ)
|
||||
current = current[box.payload_start : box.end]
|
||||
return current
|
||||
|
||||
|
||||
def _find_child(data: bytes, typ: bytes) -> Box:
|
||||
for box in _children(data, 0, len(data)):
|
||||
if box.type == typ:
|
||||
return box
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
|
||||
|
||||
def _children(data: bytes, start: int, end: int) -> Iterable[Box]:
|
||||
return iter_boxes(data, start, end, absolute_base=0)
|
||||
|
||||
|
||||
def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None:
|
||||
matches = [box for box in boxes if box.type == typ]
|
||||
if not matches and required:
|
||||
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||
return matches[0] if matches else None
|
||||
|
||||
|
||||
def _payload(parent: bytes, typ: bytes) -> bytes:
|
||||
box = _find_child(parent, typ)
|
||||
return parent[box.payload_start : box.end]
|
||||
|
||||
|
||||
def _parse_mvhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_mdhd(payload: bytes) -> tuple[int, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
return struct.unpack_from(">IQ", payload, 20)
|
||||
return struct.unpack_from(">II", payload, 12)
|
||||
|
||||
|
||||
def _parse_tkhd(payload: bytes) -> dict[str, int]:
|
||||
version = payload[0]
|
||||
if version == 1:
|
||||
track_id = struct.unpack_from(">I", payload, 20)[0]
|
||||
duration = struct.unpack_from(">Q", payload, 28)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 88)
|
||||
else:
|
||||
track_id = struct.unpack_from(">I", payload, 12)[0]
|
||||
duration = struct.unpack_from(">I", payload, 20)[0]
|
||||
width, height = struct.unpack_from(">II", payload, 76)
|
||||
return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16}
|
||||
|
||||
|
||||
def _parse_stsd_codec(stsd_body: bytes) -> str:
|
||||
if len(stsd_body) < 16:
|
||||
return "unknown"
|
||||
return stsd_body[12:16].decode("latin1")
|
||||
|
||||
|
||||
def _parse_stts(payload: bytes) -> list[tuple[int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">II", payload, offset))
|
||||
offset += 8
|
||||
return out
|
||||
|
||||
|
||||
def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray:
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
pos = 0
|
||||
for count, delta in entries:
|
||||
values[pos : pos + count] = delta
|
||||
pos += count
|
||||
if pos != sample_count:
|
||||
raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}")
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsz(payload: bytes) -> np.ndarray:
|
||||
sample_size, sample_count = struct.unpack_from(">II", payload, 4)
|
||||
if sample_size:
|
||||
return np.full(sample_count, sample_size, dtype=np.int64)
|
||||
offset = 12
|
||||
values = np.empty(sample_count, dtype=np.int64)
|
||||
for idx in range(sample_count):
|
||||
values[idx] = struct.unpack_from(">I", payload, offset)[0]
|
||||
offset += 4
|
||||
return values
|
||||
|
||||
|
||||
def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]:
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
out = []
|
||||
offset = 8
|
||||
for _ in range(count):
|
||||
out.append(struct.unpack_from(">III", payload, offset))
|
||||
offset += 12
|
||||
return out
|
||||
|
||||
|
||||
def _parse_chunk_offsets(stbl: bytes) -> np.ndarray:
|
||||
with_stco = None
|
||||
with_co64 = None
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stco":
|
||||
with_stco = stbl[box.payload_start : box.end]
|
||||
elif box.type == b"co64":
|
||||
with_co64 = stbl[box.payload_start : box.end]
|
||||
if with_co64 is not None:
|
||||
count = struct.unpack_from(">I", with_co64, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
if with_stco is None:
|
||||
raise ValueError("Missing stco/co64 chunk offsets")
|
||||
count = struct.unpack_from(">I", with_stco, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64
|
||||
)
|
||||
|
||||
|
||||
def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray:
|
||||
for box in _children(stbl, 0, len(stbl)):
|
||||
if box.type == b"stss":
|
||||
payload = stbl[box.payload_start : box.end]
|
||||
count = struct.unpack_from(">I", payload, 4)[0]
|
||||
return np.array(
|
||||
[struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)],
|
||||
dtype=np.int64,
|
||||
)
|
||||
return np.arange(sample_count, dtype=np.int64)
|
||||
|
||||
|
||||
def _sample_offsets(
|
||||
stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray
|
||||
) -> np.ndarray:
|
||||
if not stsc:
|
||||
raise ValueError("stsc is empty")
|
||||
offsets = np.empty(len(sample_sizes), dtype=np.int64)
|
||||
sample_idx = 0
|
||||
for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc):
|
||||
next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1
|
||||
for chunk_number in range(first_chunk, next_first):
|
||||
if chunk_number < 1 or chunk_number > len(chunk_offsets):
|
||||
raise ValueError("stsc references a chunk outside stco/co64")
|
||||
chunk_pos = int(chunk_offsets[chunk_number - 1])
|
||||
for _ in range(samples_per_chunk):
|
||||
if sample_idx >= len(sample_sizes):
|
||||
return offsets
|
||||
offsets[sample_idx] = chunk_pos
|
||||
chunk_pos += int(sample_sizes[sample_idx])
|
||||
sample_idx += 1
|
||||
if sample_idx != len(sample_sizes):
|
||||
raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}")
|
||||
return offsets
|
||||
|
||||
|
||||
def _make_moov(
|
||||
index: Mp4Index,
|
||||
durations: np.ndarray,
|
||||
sizes: np.ndarray,
|
||||
rel_offsets: np.ndarray,
|
||||
sync_samples: np.ndarray,
|
||||
*,
|
||||
mdat_data_offset: int,
|
||||
) -> bytes:
|
||||
duration = int(durations.sum())
|
||||
stco_values = [int(mdat_data_offset + value) for value in rel_offsets]
|
||||
if any(value > 0xFFFFFFFF for value in stco_values):
|
||||
offset_box = _co64(stco_values)
|
||||
else:
|
||||
offset_box = _stco(stco_values)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", index.stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offset_box
|
||||
+ (_stss(sync_samples) if len(sync_samples) else b""),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia)
|
||||
return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak)
|
||||
|
||||
|
||||
def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes:
|
||||
return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload)
|
||||
|
||||
|
||||
def _box(typ: bytes, payload: bytes) -> bytes:
|
||||
size = len(payload) + 8
|
||||
if size <= 0xFFFFFFFF:
|
||||
return struct.pack(">I4s", size, typ) + payload
|
||||
return struct.pack(">I4sQ", 1, typ, size + 8) + payload
|
||||
|
||||
|
||||
def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIII", 0, 0, timescale, duration)
|
||||
+ struct.pack(">IHH", 0x00010000, 0x0100, 0)
|
||||
+ b"\0" * 8
|
||||
+ matrix
|
||||
+ b"\0" * 24
|
||||
+ struct.pack(">I", next_track_id)
|
||||
)
|
||||
return _full_box(b"mvhd", 0, 0, payload)
|
||||
|
||||
|
||||
def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes:
|
||||
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||
payload = (
|
||||
struct.pack(">IIIII", 0, 0, track_id, 0, duration)
|
||||
+ b"\0" * 8
|
||||
+ struct.pack(">hhhh", 0, 0, 0, 0)
|
||||
+ matrix
|
||||
+ struct.pack(">II", width << 16, height << 16)
|
||||
)
|
||||
return _full_box(b"tkhd", 0, 7, payload)
|
||||
|
||||
|
||||
def _mdhd(timescale: int, duration: int) -> bytes:
|
||||
return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0")
|
||||
|
||||
|
||||
def _hdlr() -> bytes:
|
||||
return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0")
|
||||
|
||||
|
||||
def _vmhd() -> bytes:
|
||||
return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0))
|
||||
|
||||
|
||||
def _dinf() -> bytes:
|
||||
url = _full_box(b"url ", 0, 1)
|
||||
dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url)
|
||||
return _box(b"dinf", dref)
|
||||
|
||||
|
||||
def _stts(durations: np.ndarray) -> bytes:
|
||||
runs = []
|
||||
for duration in durations.tolist():
|
||||
if runs and runs[-1][1] == int(duration):
|
||||
runs[-1][0] += 1
|
||||
else:
|
||||
runs.append([1, int(duration)])
|
||||
payload = struct.pack(">I", len(runs)) + b"".join(
|
||||
struct.pack(">II", count, delta) for count, delta in runs
|
||||
)
|
||||
return _full_box(b"stts", 0, 0, payload)
|
||||
|
||||
|
||||
def _stsc_one_sample_per_chunk(sample_count: int) -> bytes:
|
||||
return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1))
|
||||
|
||||
|
||||
def _stsz(sizes: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stsz",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()),
|
||||
)
|
||||
|
||||
|
||||
def _stco(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _co64(values: list[int]) -> bytes:
|
||||
return _full_box(
|
||||
b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values)
|
||||
)
|
||||
|
||||
|
||||
def _stss(values: np.ndarray) -> bytes:
|
||||
return _full_box(
|
||||
b"stss",
|
||||
0,
|
||||
0,
|
||||
struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()),
|
||||
)
|
||||
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
|
||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
exclude_images: bool = False,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||
|
||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||
(image or state), filters them based on `exclude_images` and `patterns`, and finally
|
||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
||||
formats them for use with a Hugging Face LeRobot Dataset.
|
||||
|
||||
Args:
|
||||
pipeline: The DataProcessorPipeline to apply.
|
||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
|
||||
exclude_images: If True, image features are dropped entirely from the output.
|
||||
use_videos: If False, image features are excluded.
|
||||
patterns: A sequence of regex patterns to filter action and state features.
|
||||
Image features are not affected by this filter.
|
||||
|
||||
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
|
||||
)
|
||||
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and exclude_images:
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, compiled_patterns):
|
||||
continue
|
||||
|
||||
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if "camera_obs" in observations:
|
||||
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
||||
|
||||
# Pass through any remaining ndarray/tensor keys not already handled above,
|
||||
# so env plugins can expose extra observation keys via get_env_processors().
|
||||
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
|
||||
for key, value in observations.items():
|
||||
if key in _handled:
|
||||
continue
|
||||
target = f"{OBS_STR}.{key}"
|
||||
if target in return_observations:
|
||||
continue
|
||||
if isinstance(value, np.ndarray):
|
||||
val = torch.from_numpy(value).float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
elif isinstance(value, Tensor):
|
||||
val = value.float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
if self.config.use_vae and log_sigma_x2_hat is not None:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
|
||||
@@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||
"""Predict a chunk of actions given environment observations.
|
||||
|
||||
Supports two modes:
|
||||
- Online (queues populated via select_action): stacks observations from internal queues.
|
||||
- Offline (empty queues, e.g. dataloader batch): uses the batch directly.
|
||||
"""
|
||||
queues_populated = any(len(q) > 0 for q in self._queues.values())
|
||||
if queues_populated:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
else:
|
||||
batch = dict(batch)
|
||||
if self.config.image_features:
|
||||
for key in self.config.image_features:
|
||||
if batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -252,6 +252,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
pretrained_revision: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -309,6 +310,7 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -318,6 +320,7 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
@@ -557,6 +560,7 @@ def make_policy(
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
elif cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||
|
||||
@@ -124,6 +124,7 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
@@ -72,8 +72,9 @@ from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs import FeatureType, parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs import (
|
||||
check_env_attributes_and_types,
|
||||
close_envs,
|
||||
@@ -84,7 +85,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -95,6 +96,81 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
|
||||
|
||||
If raw_obs is provided, visual feature shapes are inferred from the actual observation
|
||||
to avoid mismatches between the env config and the real observation resolution.
|
||||
"""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
elif raw_obs is not None and "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
|
||||
shape = img.shape[1:] # strip batch dim
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
shape = pixels.shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
return features
|
||||
|
||||
|
||||
def _build_raw_frame(
|
||||
raw_obs: dict,
|
||||
env_idx: int,
|
||||
action: np.ndarray,
|
||||
reward: float,
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
raw_key = key
|
||||
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
|
||||
val = raw_obs[raw_key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -105,6 +181,7 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -145,6 +222,14 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
raw_observation = deepcopy(observation) if recording_dataset is not None else None
|
||||
task_desc = ""
|
||||
if recording_dataset is not None:
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
task_desc = ""
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
@@ -217,6 +302,26 @@ def rollout(
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
if recording_dataset is not None and raw_observation is not None:
|
||||
prev_done = done.copy()
|
||||
for env_idx in range(env.num_envs):
|
||||
if prev_done[env_idx]:
|
||||
continue
|
||||
frame = _build_raw_frame(
|
||||
raw_observation,
|
||||
env_idx,
|
||||
action_numpy[env_idx],
|
||||
reward[env_idx],
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_dataset.features,
|
||||
)
|
||||
recording_dataset.add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_dataset.save_episode()
|
||||
raw_observation = deepcopy(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||
@@ -273,6 +378,7 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -361,6 +467,7 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -563,6 +670,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
|
||||
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
@@ -572,10 +683,13 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=False,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
recording_dir=recording_dir,
|
||||
env_features=cfg.env.features if cfg.eval.recording else None,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
@@ -618,6 +732,7 @@ def eval_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dataset: Any | None = None,
|
||||
) -> TaskMetrics:
|
||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||
|
||||
@@ -635,6 +750,7 @@ def eval_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -661,6 +777,8 @@ def run_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Run eval_one for a single (task_group, task_id, env).
|
||||
@@ -672,21 +790,39 @@ def run_one(
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
)
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
recording_dataset = None
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
sample_obs, _ = env.reset()
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
|
||||
recording_dataset = LeRobotDataset.create(
|
||||
repo_id=f"eval_{task_group}_{task_id}",
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=str(task_recording_dir),
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
finally:
|
||||
if recording_dataset is not None:
|
||||
recording_dataset.finalize()
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
@@ -702,6 +838,8 @@ def eval_policy_all(
|
||||
n_episodes: int,
|
||||
*,
|
||||
max_episodes_rendered: int = 0,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
@@ -761,6 +899,8 @@ def eval_policy_all(
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
|
||||
@@ -45,7 +45,8 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state
|
||||
from lerobot.datasets.factory import make_train_eval_datasets
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -244,19 +245,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
dataset, eval_dataset = make_train_eval_datasets(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Other ranks read from the shared copy populated by the main process.
|
||||
if not is_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
dataset, eval_dataset = make_train_eval_datasets(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
|
||||
if cfg.env_eval_freq > 0 and cfg.env is not None and is_main_process:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
@@ -345,6 +346,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
@@ -455,6 +457,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
# Build eval dataloader if a held-out split exists
|
||||
eval_dataloader = None
|
||||
if eval_dataset is not None:
|
||||
eval_ds = eval_dataset
|
||||
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
|
||||
task_indices = eval_dataset.hf_dataset["task_index"]
|
||||
unique_tasks = sorted(set(task_indices))
|
||||
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
|
||||
selected: list[int] = []
|
||||
for t in unique_tasks:
|
||||
frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task]
|
||||
selected.extend(frames)
|
||||
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
|
||||
|
||||
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||
eval_dataloader = torch.utils.data.DataLoader(
|
||||
eval_ds,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
collate_fn=eval_collate_fn,
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
accelerator.wait_for_everyone()
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -534,7 +561,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0
|
||||
is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0
|
||||
|
||||
if is_log_step:
|
||||
# Collective reduce must run on every rank, before the main-process gate below.
|
||||
@@ -557,6 +585,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if is_eval_step:
|
||||
policy.eval()
|
||||
eval_loss_sum = 0.0
|
||||
n_eval_batches = 0
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for eval_batch in eval_dataloader:
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8:
|
||||
eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
eval_batch = preprocessor(eval_batch)
|
||||
loss, _ = policy.forward(eval_batch)
|
||||
eval_loss_sum += loss.item()
|
||||
n_eval_batches += 1
|
||||
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
|
||||
policy.train()
|
||||
|
||||
if is_main_process:
|
||||
logging.info(f"step {step}: eval_loss={eval_loss:.4f}")
|
||||
if wandb_logger:
|
||||
wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval")
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
if is_main_process:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
@@ -579,7 +628,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
if cfg.env and is_env_eval_step:
|
||||
if is_main_process:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
|
||||
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
|
||||
|
||||
This function uses `importlib.metadata` to find packages installed in the environment
|
||||
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
|
||||
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
|
||||
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
|
||||
"""
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
|
||||
prefixes = (
|
||||
"lerobot_robot_",
|
||||
"lerobot_camera_",
|
||||
"lerobot_teleoperator_",
|
||||
"lerobot_policy_",
|
||||
"lerobot_env_",
|
||||
)
|
||||
imported: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch
|
||||
from lerobot.datasets.mp4 import (
|
||||
_box,
|
||||
_co64,
|
||||
_dinf,
|
||||
_hdlr,
|
||||
_mdhd,
|
||||
_mvhd,
|
||||
_stco,
|
||||
_stsc_one_sample_per_chunk,
|
||||
_stss,
|
||||
_stsz,
|
||||
_stts,
|
||||
_tkhd,
|
||||
_vmhd,
|
||||
parse_mp4_index,
|
||||
synthesize_mp4,
|
||||
)
|
||||
|
||||
|
||||
def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes:
|
||||
ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||
sizes = np.array([10, 10, 10], dtype=np.int64)
|
||||
durations = np.array([1000, 1000, 1000], dtype=np.int64)
|
||||
stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8
|
||||
offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets)
|
||||
stbl = _box(
|
||||
b"stbl",
|
||||
_box(b"stsd", stsd_body)
|
||||
+ _stts(durations)
|
||||
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||
+ _stsz(sizes)
|
||||
+ offsets
|
||||
+ _stss(np.array([1], dtype=np.int64)),
|
||||
)
|
||||
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||
mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf)
|
||||
trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia)
|
||||
moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak)
|
||||
mdat_payload_start = 10_000
|
||||
free_size = mdat_payload_start - 8 - len(ftyp) - len(moov)
|
||||
assert free_size >= 8
|
||||
free = _box(b"free", b"\0" * (free_size - 8))
|
||||
return ftyp + moov + free + _box(b"mdat", b"x" * 128)
|
||||
|
||||
|
||||
def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
assert sample_slice.byte_offset == 10_000
|
||||
assert sample_slice.byte_length == 60
|
||||
assert sample_slice.sample_lo == 0
|
||||
assert sample_slice.sample_hi == 2
|
||||
|
||||
|
||||
def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||
|
||||
mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length)
|
||||
mini_index = parse_mp4_index("mini.mp4", mini)
|
||||
|
||||
expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset
|
||||
np.testing.assert_array_equal(mini_index.sample_offsets, expected)
|
||||
np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10]))
|
||||
|
||||
|
||||
def test_parser_accepts_co64_chunk_offsets():
|
||||
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True))
|
||||
|
||||
np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025]))
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps(
|
||||
{
|
||||
"url": "https://github.com/huggingface/huggingface_hub.git",
|
||||
"vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"},
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
assert_hf_hub_range_cache_branch()
|
||||
|
||||
|
||||
def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch):
|
||||
class FakeDist:
|
||||
def read_text(self, name):
|
||||
assert name == "direct_url.json"
|
||||
return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_hf_hub_range_cache_branch()
|
||||
@@ -2370,32 +2370,14 @@ def test_aggregate_images_when_use_videos_false():
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=False, # images kept, stored as "image" dtype
|
||||
use_videos=False, # expect "image" dtype
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = f"{OBS_IMAGES}.back"
|
||||
key_front = f"{OBS_IMAGES}.front"
|
||||
assert key in out
|
||||
assert key_front in out
|
||||
assert out[key]["dtype"] == "image"
|
||||
assert out[key_front]["dtype"] == "image"
|
||||
assert out[key]["shape"] == initial["back"]
|
||||
|
||||
|
||||
def test_aggregate_images_excluded():
|
||||
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
|
||||
initial = {"back": (480, 640, 3)}
|
||||
|
||||
out = aggregate_pipeline_dataset_features(
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
exclude_images=True,
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
assert f"{OBS_IMAGES}.back" not in out
|
||||
assert f"{OBS_IMAGES}.front" not in out
|
||||
assert key not in out
|
||||
assert key_front not in out
|
||||
|
||||
|
||||
def test_aggregate_images_when_use_videos_true():
|
||||
|
||||
@@ -134,7 +134,7 @@ class TestMultiGPUTraining:
|
||||
f"--output_dir={output_dir}",
|
||||
"--batch_size=4",
|
||||
"--steps=10",
|
||||
"--eval_freq=-1",
|
||||
"--env_eval_freq=-1",
|
||||
"--log_freq=5",
|
||||
"--save_freq=10",
|
||||
"--seed=42",
|
||||
@@ -177,7 +177,7 @@ class TestMultiGPUTraining:
|
||||
f"--output_dir={output_dir}",
|
||||
"--batch_size=4",
|
||||
"--steps=20",
|
||||
"--eval_freq=-1",
|
||||
"--env_eval_freq=-1",
|
||||
"--log_freq=5",
|
||||
"--save_freq=10",
|
||||
"--seed=42",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')",
|
||||
@@ -1089,8 +1089,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "5.0.1.dev0"
|
||||
source = { git = "https://github.com/huggingface/datasets.git?branch=main#06fcc085fcdd22fc5cc741954f6187dd879543b6" }
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
{ name = "filelock" },
|
||||
@@ -1107,6 +1107,10 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
@@ -1143,7 +1147,7 @@ name = "decord"
|
||||
version = "0.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy", marker = "(platform_machine != 'arm64' and platform_machine != 's390x' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "numpy", marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" },
|
||||
@@ -2046,8 +2050,8 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.20.0.dev0"
|
||||
source = { git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads#5319b287faa73239bb40df16d69c39e5d6daf0f7" }
|
||||
version = "1.19.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "filelock" },
|
||||
@@ -2060,6 +2064,10 @@ dependencies = [
|
||||
{ name = "typer" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/27/629cfe58c582f92ded066c4a07d1a057ff617118ab7973200f770bd853cb/huggingface_hub-1.19.0.tar.gz", hash = "sha256:fd771622182d40977272a923953ee3b1b13538f9f8a7f5d78398f10af0f1c0bd", size = 824721, upload-time = "2026-06-11T12:33:18.665Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/a5/558da89f66464d8d0229ff497e8b8666977de2d8cf48c28a2862ecf1250f/huggingface_hub-1.19.0-py3-none-any.whl", hash = "sha256:1dc72e1f6b4d6df6b30eb72e57d00514ef453d660f04af2b87f0e67267f31ee0", size = 693398, upload-time = "2026-06-11T12:33:16.695Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hydra-core"
|
||||
@@ -3179,7 +3187,7 @@ requires-dist = [
|
||||
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
|
||||
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
|
||||
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?branch=main" },
|
||||
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
|
||||
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
|
||||
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
|
||||
@@ -3202,7 +3210,7 @@ requires-dist = [
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
|
||||
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user