Compare commits

..

2 Commits

Author SHA1 Message Date
Khalil Meftah 6407a244c0 feat(envs): add generic observation passthrough
- Add generic observation passthrough in preprocess_observation() for
unhandled ndarray/tensor keys, replacing the pattern of adding per-env
hardcoded key handlers. Extra keys are forwarded as observation.<key>
and can be shaped by env-specific ProcessorSteps via get_env_processors().
2026-06-15 14:17:59 +02:00
Khalil Meftah 0511c12b8f feat(envs): add env plugin discovery
- Add 'lerobot_env_' to third-party plugin discovery prefixes, completing
the plugin system for all component types (robots, cameras, teleoperators,
policies, and now environments). External packages named lerobot_env_*
can self-register EnvConfig subclasses on import, enabling --env.type=
resolution without lerobot code changes.
2026-06-15 14:13:12 +02:00
40 changed files with 1372 additions and 4069 deletions
+8 -8
View File
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
**Compatible teleoperators:**
- `bi_openarm_mini` - Bimanual OpenArm Mini
- `openarm_mini` - OpenArm Mini
- `so_leader` - SO100 / SO101 leader arm
> [!IMPORTANT]
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
---
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_dataset \
--dataset.single_task="Fold the T-shirt properly" \
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
--robot.right_arm_config.port=can0 \
--robot.right_arm_config.side=right \
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
--teleop.type=bi_openarm_mini \
--teleop.left_arm_config.port=/dev/ttyACM0 \
--teleop.right_arm_config.port=/dev/ttyACM1 \
--teleop.type=openarm_mini \
--teleop.port_left=/dev/ttyACM0 \
--teleop.port_right=/dev/ttyACM1 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
--dataset.single_task="Fold the T-shirt properly" \
+1 -1
View File
@@ -117,7 +117,7 @@ lerobot-rollout \
--strategy.num_episodes=20 \
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
--robot.type=bi_openarm_follower \
--teleop.type=bi_openarm_mini \
--teleop.type=openarm_mini \
--dataset.repo_id=${HF_USER}/rollout_hil_data \
--dataset.single_task="Fold the T-shirt"
```
-3
View File
@@ -355,8 +355,6 @@ explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "feat/hffs-cache-cdn-range-reads" }
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
[tool.setuptools.package-data]
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
@@ -423,7 +421,6 @@ exclude_dirs = [
skips = ["B101", "B311", "B404", "B603", "B615"]
[tool.typos]
default.extend-words = { trak = "trak" }
default.extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
-860
View File
@@ -1,860 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import argparse
import random
import resource
import tempfile
import threading
import time
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import fsspec
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.episode_video_streaming import (
EpisodeByteCache,
EpisodeVideoManifest,
NativeHTTPRangeFetcher,
assert_hf_hub_range_cache_branch,
)
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
FULL_SIDECAR_NAME = "molmoact2-full.npz"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
parser.add_argument("--repo-id", default=DEFAULT_REPO)
parser.add_argument("--revision", default=DEFAULT_REVISION)
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
parser.add_argument(
"--strategy",
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
default="both",
help=argparse.SUPPRESS,
)
parser.add_argument(
"--range-backend",
choices=("fsspec", "native-http"),
default="fsspec",
help="Range reader used by indexed/full episode-pool fetch tracks.",
)
parser.add_argument("--num-episodes", type=int, default=512)
parser.add_argument(
"--manifest-episodes",
type=int,
default=None,
help="Limit manifest construction to the first N episodes for local smoke tests.",
)
parser.add_argument("--pool-size", type=int, default=16)
parser.add_argument("--workers", type=int, default=8)
parser.add_argument(
"--native-http-connections",
type=int,
default=None,
help="Max HTTP connections for --range-backend native-http. Defaults to --workers.",
)
parser.add_argument(
"--native-http-retries",
type=int,
default=8,
help="Retries per native HTTP range request.",
)
parser.add_argument(
"--native-http-timeout",
type=float,
default=120.0,
help="Timeout in seconds for native HTTP requests.",
)
parser.add_argument(
"--include-decode",
action="store_true",
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
)
parser.add_argument("--decode-workers", type=int, default=1)
parser.add_argument("--prefetch-ahead", type=int, default=8)
parser.add_argument("--frames-per-episode", type=int, default=16)
parser.add_argument("--max-probe-mb", type=int, default=64)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--byte-budget-gb", type=float, default=80)
parser.add_argument(
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
)
parser.add_argument("--no-hub-branch-assert", action="store_true")
return parser.parse_args()
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
rng = random.Random(seed)
upper = min(total, requested)
if pool_size > upper:
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
return rng.sample(range(upper), pool_size)
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
rng = random.Random(seed)
out: dict[tuple[int, str], list[float]] = {}
for ep in episodes:
for camera_key in manifest.video_keys:
span = manifest.lookup(ep, camera_key)
lo = span.first_pts
hi = max(span.last_pts, lo)
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
return out
def _timestamps_from_meta(
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
) -> dict[tuple[int, str], list[float]]:
rng = random.Random(seed)
out: dict[tuple[int, str], list[float]] = {}
for ep in episodes:
row = meta.episodes[ep]
for camera_key in meta.video_keys:
lo = float(row[f"videos/{camera_key}/from_timestamp"])
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
return out
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
total = 0
for ep in episodes:
for camera_key in manifest.video_keys:
total += manifest.lookup(ep, camera_key).mdat_length
return total
def _decode_all(
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
) -> float:
start = time.perf_counter()
items = list(timestamps.items())
if decode_workers <= 1:
for (ep, camera_key), ts in items:
cache.get_frames(ep, camera_key, ts)
else:
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
for future in futures:
future.result()
return time.perf_counter() - start
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
start = time.perf_counter()
for ep in episodes:
cache.submit_prefetch(ep)
for ep in episodes:
cache.ensure_ready(ep)
return time.perf_counter() - start
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
if elapsed_s <= 0:
return float("inf")
return len(episodes) * frames_per_episode / elapsed_s
def _log(message: str) -> None:
print(message, flush=True)
def _format_duration(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.1f}s"
if seconds < 3600:
return f"{seconds / 60:.1f}m"
return f"{seconds / 3600:.1f}h"
def _current_rss_mib() -> float | None:
status_path = Path("/proc/self/status")
if not status_path.exists():
return None
for line in status_path.read_text().splitlines():
if line.startswith("VmRSS:"):
return float(line.split()[1]) / 1024
return None
def _peak_rss_mib() -> float:
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
# Linux reports KiB; macOS reports bytes.
if rss > 10**8:
return rss / 1024**2
return rss / 1024
def _memory_snapshot() -> dict[str, float | None]:
return {"rss_mib": _current_rss_mib(), "peak_rss_mib": _peak_rss_mib()}
def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | None]) -> None:
start_rss = start["rss_mib"]
end_rss = end["rss_mib"]
delta = None if start_rss is None or end_rss is None else end_rss - start_rss
print()
print("| Memory | MiB |")
print("|---|---:|")
if start_rss is not None:
print(f"| rss start | {start_rss:.1f} |")
if end_rss is not None:
print(f"| rss end | {end_rss:.1f} |")
if delta is not None:
print(f"| rss delta | {delta:.1f} |")
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
def _root_join(data_root: str, relative_path: str) -> str:
if data_root.startswith("hf://"):
return f"{data_root.rstrip('/')}/{relative_path}"
return str(Path(data_root) / relative_path)
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
_ = manifest_episode_count
local = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
if _valid_sidecar(local):
return local
if local.exists():
print(f"mp4_sidecar_invalid_local: {local}")
local.unlink()
remote_relative = f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}"
remote = _root_join(data_root, remote_relative)
protocol = "hf" if data_root.startswith("hf://") else "file"
fs = fsspec.filesystem(protocol)
if not fs.exists(remote):
return None
local.parent.mkdir(parents=True, exist_ok=True)
print(f"downloading_mp4_sidecar: {remote} -> {local}")
if data_root.startswith("hf://"):
_download_sidecar_native_http(data_root, remote_relative, local)
else:
fs.get(remote, str(local))
return local
def _valid_sidecar(path: Path) -> bool:
if not path.exists():
return False
try:
with np.load(path, allow_pickle=False) as data:
return "manifest_json" in data
except Exception:
return False
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
tmp = local.with_suffix(local.suffix + ".tmp")
try:
size = fetcher.info_size(relative_path)
chunk_size = 16 * 1024 * 1024
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
with tmp.open("wb") as out_file:
out_file.truncate(size)
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
offset, length = offset_length
return offset, fetcher.read_range(relative_path, offset, length)
start = time.perf_counter()
done = 0
with ThreadPoolExecutor(max_workers=8) as pool:
futures = [pool.submit(read_chunk, item) for item in ranges]
with tmp.open("r+b") as rw_file:
for future in futures:
offset, data = future.result()
rw_file.seek(offset)
rw_file.write(data)
done += len(data)
elapsed = max(time.perf_counter() - start, 1e-9)
print(
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
f"({done / elapsed / 1024**2:.1f} MiB/s)",
flush=True,
)
tmp.replace(local)
finally:
fetcher.close()
class EpisodeParquetReader:
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
self.meta = meta
self.data_root = data_root
protocol = "hf" if data_root.startswith("hf://") else "file"
self.fs = fsspec.filesystem(protocol)
self._episode_row_groups = self._build_episode_row_groups()
self._table_cache: dict[str, pa.Table] = {}
self._cache_lock = threading.Lock()
def read_episode(self, episode_index: int) -> None:
relative_path = str(self.meta.get_data_file_path(episode_index))
table = self._read_table(relative_path)
table.filter(pc.equal(table["episode_index"], episode_index))
def _read_table(self, relative_path: str) -> pa.Table:
with self._cache_lock:
table = self._table_cache.get(relative_path)
if table is not None:
return table
with self.fs.open(
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
) as f:
table = pq.ParquetFile(f).read()
with self._cache_lock:
return self._table_cache.setdefault(relative_path, table)
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
return pool.submit(self.read_episode, episode_index)
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
start = time.perf_counter()
if workers <= 1:
for ep in episodes:
self.read_episode(ep)
else:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
for future in futures:
future.result()
return time.perf_counter() - start
def _build_episode_row_groups(self) -> dict[int, int]:
counts: dict[tuple[int, int], int] = {}
row_groups = {}
for ep_idx in range(int(self.meta.total_episodes)):
ep = self.meta.episodes[ep_idx]
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
row_groups[ep_idx] = counts.get(key, 0)
counts[key] = row_groups[ep_idx] + 1
return row_groups
def run_fetch_pool(
manifest: EpisodeVideoManifest,
data_root: str,
episodes: Sequence[int],
byte_budget: int,
workers: int,
range_backend: str,
args: argparse.Namespace,
) -> dict[str, float]:
with EpisodeByteCache(
manifest,
data_root,
byte_budget=byte_budget,
workers=workers,
range_backend=range_backend,
native_http_connections=args.native_http_connections,
native_http_timeout=args.native_http_timeout,
native_http_retries=args.native_http_retries,
open_decoders=False,
) as cache:
elapsed = _fill_cache(cache, episodes)
timings = cache.timing_summary()
byte_count = _bytes_for(manifest, episodes)
episode_mb = byte_count / len(episodes) / 1024**2
job_count = max(timings["jobs"], 1.0)
result = {
"fetch_s": elapsed,
"fetch_mbps": byte_count / elapsed / 1024**2,
"fetch_episodes_s": len(episodes) / elapsed,
"episode_mb": episode_mb,
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
"jobs": timings["jobs"],
"lookup_ms": timings["lookup_s"] * 1000 / job_count,
"range_fetch_ms": timings["fetch_s"] * 1000 / job_count,
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
"store_ms": timings["store_s"] * 1000 / job_count,
}
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
return result
def run_parallel(
manifest: EpisodeVideoManifest,
data_root: str,
episodes: Sequence[int],
timestamps: dict[tuple[int, str], list[float]],
byte_budget: int,
workers: int,
decode_workers: int,
frames_per_episode: int,
parquet_reader: EpisodeParquetReader,
range_backend: str,
) -> dict[str, float]:
with EpisodeByteCache(
manifest,
data_root,
byte_budget=byte_budget,
workers=workers,
range_backend=range_backend,
open_decoders=False,
) as cache:
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
fetch_s = _fill_cache(cache, episodes)
decoder_start = time.perf_counter()
for ep in episodes:
for camera_key in manifest.video_keys:
cache.get_decoder(ep, camera_key)
decoder_s = time.perf_counter() - decoder_start
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
byte_count = _bytes_for(manifest, episodes)
return {
"fetch_s": fetch_s,
"fetch_mbps": byte_count / fetch_s / 1024**2,
"fetch_episodes_s": len(episodes) / fetch_s,
"parquet_s": parquet_s,
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
}
def run_overlapped(
manifest: EpisodeVideoManifest,
data_root: str,
episodes: Sequence[int],
timestamps: dict[tuple[int, str], list[float]],
byte_budget: int,
workers: int,
decode_workers: int,
frames_per_episode: int,
prefetch_ahead: int,
parquet_reader: EpisodeParquetReader,
range_backend: str,
) -> dict[str, float]:
with EpisodeByteCache(
manifest,
data_root,
byte_budget=byte_budget,
workers=workers,
range_backend=range_backend,
open_decoders=True,
) as cache:
start = time.perf_counter()
video_wait_decode_s = 0.0
parquet_wait_s = 0.0
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
parquet_futures = {
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
}
for ep in episodes[:prefetch_ahead]:
cache.submit_prefetch(ep)
try:
for idx, ep in enumerate(episodes):
next_idx = idx + prefetch_ahead
if next_idx < len(episodes):
next_ep = episodes[next_idx]
cache.submit_prefetch(next_ep)
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
parquet_start = time.perf_counter()
parquet_futures.pop(ep).result()
parquet_wait_s += time.perf_counter() - parquet_start
video_start = time.perf_counter()
cache.ensure_ready(ep)
if decode_workers <= 1:
for camera_key in manifest.video_keys:
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
else:
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
futures = [
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
for camera_key in manifest.video_keys
]
for future in futures:
future.result()
video_wait_decode_s += time.perf_counter() - video_start
finally:
parquet_pool.shutdown(wait=True)
elapsed = time.perf_counter() - start
return {
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
"wall_s": elapsed,
"video_wait_decode_s": video_wait_decode_s,
"parquet_wait_s": parquet_wait_s,
}
_remote_decoder_local = threading.local()
def _remote_decoder_cache() -> VideoDecoderCache:
cache = getattr(_remote_decoder_local, "cache", None)
if cache is None:
cache = VideoDecoderCache(max_size=None)
_remote_decoder_local.cache = cache
return cache
def _decode_remote_source(
meta: LeRobotDatasetMetadata,
data_root: str,
episode_index: int,
camera_key: str,
timestamps: list[float],
):
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
return decode_video_frames_torchcodec(
video_path,
timestamps,
tolerance_s=1.0 / float(meta.fps),
decoder_cache=_remote_decoder_cache(),
return_uint8=True,
)
def run_remote_decoder(
meta: LeRobotDatasetMetadata,
data_root: str,
episodes: Sequence[int],
timestamps: dict[tuple[int, str], list[float]],
*,
frames_per_episode: int,
decode_workers: int,
parquet_reader: EpisodeParquetReader,
) -> dict[str, float]:
items = [
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
]
start = time.perf_counter()
for ep, camera_key, ts in items:
if camera_key == meta.video_keys[0]:
parquet_reader.read_episode(ep)
_decode_remote_source(meta, data_root, ep, camera_key, ts)
sequential_s = time.perf_counter() - start
start = time.perf_counter()
if decode_workers <= 1:
for ep, camera_key, ts in items:
if camera_key == meta.video_keys[0]:
parquet_reader.read_episode(ep)
_decode_remote_source(meta, data_root, ep, camera_key, ts)
else:
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
futures = [
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
for ep, camera_key, ts in items
]
for future in parquet_futures:
future.result()
for future in futures:
future.result()
parallel_s = time.perf_counter() - start
return {
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
}
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
range_jobs = fetch_pool.get("range_jobs", 0.0)
if range_jobs <= 0:
return
print()
print("| Range Read Stage | avg ms/range |")
print("|---|---:|")
for key, label in (
("range_open_s", "fsspec handle open/lookup"),
("range_seek_s", "fsspec seek"),
("range_read_s", "fsspec read"),
("range_resolve_s", "http URL resolve"),
("range_header_s", "http response headers"),
("range_first_byte_s", "http first body byte"),
("range_body_s", "http body drain"),
("range_retry_sleep_s", "http retry sleep"),
):
value = fetch_pool.get(key)
if value is not None:
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
if "range_retry_attempts" in fetch_pool:
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
if fetch_pool.get("range_failed_requests"):
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
print(f"| range reads | {range_jobs:.0f} |")
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
def run_indexed_strategy(
meta: LeRobotDatasetMetadata,
data_root: str,
args: argparse.Namespace,
parquet_reader: EpisodeParquetReader,
*,
range_backend: str = "fsspec",
label: str = "indexed",
sidecar_path: str | None = None,
) -> None:
_log(f"starting_strategy: {label}")
memory_start = _memory_snapshot()
manifest_start = time.perf_counter()
dataset_episode_count = int(meta.total_episodes)
manifest_episode_count = args.manifest_episodes or dataset_episode_count
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
manifest = EpisodeVideoManifest.build(
meta,
data_root,
episode_indices=range(manifest_episode_count),
range_backend=range_backend,
workers=args.workers,
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
sidecar_path=sidecar_path,
)
manifest_s = time.perf_counter() - manifest_start
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
byte_budget = int(args.byte_budget_gb * 1024**3)
byte_count = _bytes_for(manifest, episodes)
_log(
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
)
_log(f"{label}: filling episode byte cache with {args.workers} workers")
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
print(f"manifest_build_s: {manifest_s:.2f}")
print(f"strategy: {label}")
print(f"range_backend: {range_backend}")
print(f"mp4_sidecar: {sidecar_path or 'none'}")
print(f"data_root: {data_root}")
print(f"dataset_episodes: {dataset_episode_count}")
print(f"benchmark_episodes: {benchmark_episode_count}")
print(f"pool_episodes: {len(episodes)}")
print(f"sampled_episodes: {episodes}")
print(f"cameras: {manifest.video_keys}")
print()
print(
"| Track | fetch MB/s | fetch eps/s | wall s | est benchmark | est full dataset | avg MB/camera | notes |"
)
print("|---|---:|---:|---:|---:|---:|---:|---|")
print(
f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | "
f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | "
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |"
)
print()
print("| Camera Job Stage | avg ms/job |")
print("|---|---:|")
print(f"| manifest lookup | {fetch_pool['lookup_ms']:.3f} |")
print(f"| remote byte-range fetch | {fetch_pool['range_fetch_ms']:.3f} |")
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
_print_range_timing_summary(fetch_pool)
_print_memory_summary(memory_start, _memory_snapshot())
if args.include_decode:
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
_log(f"{label}: running parallel video fetch + decode-only")
parallel = run_parallel(
manifest,
data_root,
episodes,
timestamps,
byte_budget,
args.workers,
args.decode_workers,
args.frames_per_episode,
parquet_reader,
range_backend,
)
_log(f"{label}: running overlapped end-to-end")
overlapped = run_overlapped(
manifest,
data_root,
episodes,
timestamps,
byte_budget,
args.workers,
args.decode_workers,
args.frames_per_episode,
args.prefetch_ahead,
parquet_reader,
range_backend,
)
print(
f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
f"{parallel['fetch_s']:.2f} | "
f"{_format_duration(benchmark_episode_count / parallel['fetch_episodes_s'])} | "
f"{_format_duration(dataset_episode_count / parallel['fetch_episodes_s'])} | "
f"{fetch_pool['avg_mb_miss']:.1f} | "
f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, "
f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |"
)
print(
f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | - | - | "
f"{fetch_pool['avg_mb_miss']:.1f} | "
f"{overlapped['samples_s']:.1f} samples/s; video+decode "
f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |"
)
def run_remote_strategy(
meta: LeRobotDatasetMetadata,
data_root: str,
args: argparse.Namespace,
parquet_reader: EpisodeParquetReader,
) -> None:
_log("starting_strategy: remote-decoder")
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
_log("remote-decoder: running direct source MP4 decoder")
result = run_remote_decoder(
meta,
data_root,
episodes,
timestamps,
frames_per_episode=args.frames_per_episode,
decode_workers=args.decode_workers,
parquet_reader=parquet_reader,
)
print("strategy: remote-decoder")
print(f"data_root: {data_root}")
print(f"episodes: {episodes}")
print(f"cameras: {list(meta.video_keys)}")
print()
print("| Track | samples/s | notes |")
print("|---|---:|---|")
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
print(
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
f"direct source MP4 decoder, {args.decode_workers} workers |"
)
def main() -> None:
args = parse_args()
if args.strategy == "full":
args.strategy = "both"
if args.strategy == "native-http":
args.range_backend = "native-http"
data_root = args.data_root
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
assert_hf_hub_range_cache_branch()
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
meta.ensure_readable()
parquet_reader = EpisodeParquetReader(meta, data_root)
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
if sidecar_path is not None:
print(f"using_mp4_sidecar: {sidecar_path}")
if sidecar_path is not None and args.strategy == "both":
if args.include_decode:
run_remote_strategy(meta, data_root, args, parquet_reader)
print()
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend=args.range_backend,
label=f"indexed-sidecar-{args.range_backend}",
sidecar_path=str(sidecar_path),
)
return
if sidecar_path is not None and args.strategy == "indexed":
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend=args.range_backend,
label=f"indexed-sidecar-{args.range_backend}",
sidecar_path=str(sidecar_path),
)
return
if sidecar_path is not None and args.strategy == "native-http":
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend="native-http",
label="indexed-sidecar-native-http",
sidecar_path=str(sidecar_path),
)
return
if args.strategy == "both":
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
print(f"mp4_sidecar_missing_remote: {expected_remote}")
print(
"build_mp4_sidecar: "
"uv run --no-sync python scripts/build_mp4_sidecar.py "
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
)
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
print()
if args.strategy in ("both", "indexed"):
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend="fsspec",
label="indexed",
sidecar_path=None,
)
if args.strategy == "both":
print()
if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode):
run_remote_strategy(meta, data_root, args, parquet_reader)
if args.strategy == "both" and args.include_decode:
print()
if args.strategy in ("both", "native-http"):
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend="native-http",
label="indexed-native-http",
sidecar_path=None,
)
if __name__ == "__main__":
main()
-93
View File
@@ -1,93 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import argparse
import time
from pathlib import Path
import fsspec
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
parser.add_argument("--repo-id", default=DEFAULT_REPO)
parser.add_argument("--revision", default=DEFAULT_REVISION)
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
parser.add_argument("--output", required=True)
parser.add_argument("--episodes", type=int, default=None)
parser.add_argument("--workers", type=int, default=8)
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
parser.add_argument("--max-probe-mb", type=int, default=64)
parser.add_argument(
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
)
parser.add_argument("--no-hub-branch-assert", action="store_true")
return parser.parse_args()
def push_sidecar(local_path: str, data_root: str) -> list[str]:
if not data_root.startswith("hf://"):
return []
local = Path(local_path)
fs = fsspec.filesystem("hf")
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
remote_paths = [f"{remote_dir}/{local.name}"]
for remote in remote_paths:
fs.put(str(local), remote)
return remote_paths
def main() -> None:
args = parse_args()
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
assert_hf_hub_range_cache_branch()
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
meta.ensure_readable()
total = (
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
)
rel_paths = sorted(
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
)
start = time.perf_counter()
EpisodeVideoManifest.write_file_sidecar(
args.output,
rel_paths,
args.data_root,
range_backend=args.range_backend,
workers=args.workers,
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
)
elapsed = time.perf_counter() - start
print(f"wrote {args.output}")
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
if args.no_push:
print("push_skipped: --no-push")
else:
pushed = push_sidecar(args.output, args.data_root)
for remote in pushed:
print(f"pushed {remote}")
if __name__ == "__main__":
main()
@@ -1,890 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import contextlib
import io
import json
import threading
import time
from collections import OrderedDict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from importlib import metadata
from pathlib import Path
from typing import Any
from urllib.parse import quote, urljoin, urlparse
import fsspec
import httpx
import numpy as np
from huggingface_hub import HfApi, HfFileSystem, constants
from huggingface_hub.utils import hf_raise_for_status
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4
@dataclass(frozen=True)
class EpisodeVideoSpan:
file_id: int
mdat_offset: int
mdat_length: int
first_pts: float
last_pts: float
frame_count: int
sample_lo: int
sample_hi: int
source_start_pts: float
@dataclass(frozen=True)
class VideoFileRecord:
file_path: str
file_size: int
mp4: Mp4Index
class ThreadLocalRangeFetcher:
"""Range reader that gives each worker thread independent file handles."""
def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"):
self.data_root = str(data_root).rstrip("/")
protocol = "hf" if self.data_root.startswith("hf://") else "file"
self.fs = fsspec.filesystem(protocol)
self.block_size = block_size
self.cache_type = cache_type
self._local = threading.local()
self._timing_lock = threading.Lock()
self._timing_totals = {
"range_jobs": 0.0,
"range_bytes": 0.0,
"range_open_s": 0.0,
"range_seek_s": 0.0,
"range_read_s": 0.0,
}
def _url(self, relative_path: str) -> str:
if self.data_root.startswith("hf://"):
return f"{self.data_root}/{relative_path}"
return str(Path(self.data_root) / relative_path)
def _handle(self, relative_path: str):
handles = getattr(self._local, "handles", None)
if handles is None:
handles = {}
self._local.handles = handles
handle = handles.get(relative_path)
if handle is None or getattr(handle, "closed", False):
handle = self.fs.open(
self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type
)
handles[relative_path] = handle
return handle
def info_size(self, relative_path: str) -> int:
return int(self.fs.info(self._url(relative_path))["size"])
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
open_start = time.perf_counter()
handle = self._handle(relative_path)
open_s = time.perf_counter() - open_start
seek_start = time.perf_counter()
handle.seek(offset)
seek_s = time.perf_counter() - seek_start
read_start = time.perf_counter()
data = handle.read(length)
read_s = time.perf_counter() - read_start
self._record_timing(
range_jobs=1.0,
range_bytes=float(len(data)),
range_open_s=open_s,
range_seek_s=seek_s,
range_read_s=read_s,
)
return data
def _record_timing(self, **kwargs: float) -> None:
with self._timing_lock:
for key, value in kwargs.items():
self._timing_totals[key] += value
def timing_summary(self) -> dict[str, float]:
with self._timing_lock:
return dict(self._timing_totals)
def close(self) -> None:
handles = getattr(self._local, "handles", None)
if handles is None:
return
for handle in handles.values():
with contextlib.suppress(Exception):
handle.close()
handles.clear()
class NativeHTTPRangeFetcher:
"""Direct pooled HTTP range reader for hf:// paths."""
_GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {}
_GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {}
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
_GLOBAL_LOCK = threading.Lock()
_RETRYABLE_EXCEPTIONS = (
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
httpx.PoolTimeout,
)
def __init__(
self,
data_root: str | Path,
*,
max_connections: int = 32,
timeout: float = 60.0,
max_retries: int = 4,
):
self.data_root = str(data_root).rstrip("/")
if not self.data_root.startswith("hf://"):
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
self.max_retries = max_retries
self.api = HfApi()
self.fs: HfFileSystem | None = None
self._bucket_id: str | None = None
self._bucket_prefix = ""
if self.data_root.startswith("hf://buckets/"):
bucket_root = self.data_root.removeprefix("hf://buckets/")
parts = bucket_root.split("/", 2)
if len(parts) < 2:
raise ValueError(f"Invalid bucket root: {self.data_root}")
self._bucket_id = f"{parts[0]}/{parts[1]}"
self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else ""
else:
self.fs = HfFileSystem()
self.client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections),
follow_redirects=False,
)
self._resolved_urls: dict[str, str] = {}
self._source_urls: dict[str, str] = {}
self._sizes: dict[str, int] = {}
self._lock = threading.Lock()
self._timing_lock = threading.Lock()
self._timing_totals = {
"range_jobs": 0.0,
"range_bytes": 0.0,
"range_resolve_s": 0.0,
"range_header_s": 0.0,
"range_first_byte_s": 0.0,
"range_body_s": 0.0,
"range_retry_attempts": 0.0,
"range_retry_sleep_s": 0.0,
"range_failed_requests": 0.0,
}
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
return self.client.request(method, url, **kwargs)
except self._RETRYABLE_EXCEPTIONS as exc:
last_exc = exc
if attempt >= self.max_retries:
break
time.sleep(min(0.5 * 2**attempt, 5.0))
if last_exc is None:
raise RuntimeError("HTTP request failed without an exception")
raise last_exc
def _cache_key(self, relative_path: str) -> tuple[str, str]:
return self.data_root, relative_path
def _path(self, relative_path: str) -> str:
return f"{self.data_root}/{relative_path}"
def _bucket_path(self, relative_path: str) -> str:
if self._bucket_prefix:
return f"{self._bucket_prefix}/{relative_path}"
return relative_path
def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]:
headers = self.api._build_hf_headers()
if urlparse(request_url).netloc != urlparse(source_url).netloc:
headers.pop("authorization", None)
headers.pop("Authorization", None)
return headers
def _source_url(self, relative_path: str) -> str:
with self._lock:
source = self._source_urls.get(relative_path)
if source is not None:
return source
key = self._cache_key(relative_path)
with self._GLOBAL_LOCK:
source = self._GLOBAL_SOURCE_URLS.get(key)
if source is None:
if self._bucket_id is not None:
source = (
f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/"
f"{quote(self._bucket_path(relative_path))}"
)
else:
if self.fs is None:
raise RuntimeError("HfFileSystem fallback was not initialized")
source = self.fs.url(self._path(relative_path))
with self._GLOBAL_LOCK:
self._GLOBAL_SOURCE_URLS[key] = source
with self._lock:
self._source_urls[relative_path] = source
return source
def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str:
with self._lock:
if not refresh and relative_path in self._resolved_urls:
return self._resolved_urls[relative_path]
key = self._cache_key(relative_path)
if not refresh:
with self._GLOBAL_LOCK:
resolved = self._GLOBAL_RESOLVED_URLS.get(key)
size = self._GLOBAL_SIZES.get(key)
if resolved is not None:
with self._lock:
self._resolved_urls[relative_path] = resolved
if size is not None:
self._sizes[relative_path] = size
return resolved
source = self._source_url(relative_path)
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
try:
hf_raise_for_status(response)
location = response.headers.get("Location")
resolved = urljoin(source, location) if location else source
with self._lock:
self._resolved_urls[relative_path] = resolved
if "Content-Length" in response.headers:
self._sizes[relative_path] = int(response.headers["Content-Length"])
with self._GLOBAL_LOCK:
self._GLOBAL_RESOLVED_URLS[key] = resolved
if "Content-Length" in response.headers:
self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"])
return resolved
finally:
response.close()
def info_size(self, relative_path: str) -> int:
with self._lock:
size = self._sizes.get(relative_path)
if size is not None:
return size
key = self._cache_key(relative_path)
with self._GLOBAL_LOCK:
size = self._GLOBAL_SIZES.get(key)
if size is not None:
with self._lock:
self._sizes[relative_path] = size
return size
resolved = self._resolve_url(relative_path)
source = self._source_url(relative_path)
response = self._request(
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
)
try:
hf_raise_for_status(response)
size = int(response.headers["Content-Length"])
with self._lock:
self._sizes[relative_path] = size
with self._GLOBAL_LOCK:
self._GLOBAL_SIZES[key] = size
return size
finally:
response.close()
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
resolve_start = time.perf_counter()
resolved = self._resolve_url(relative_path)
source = self._source_url(relative_path)
resolve_s = time.perf_counter() - resolve_start
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
payload, status_code, timings = self._read_range_response(resolved, headers)
if status_code == 403:
refresh_start = time.perf_counter()
resolved = self._resolve_url(relative_path, refresh=True)
resolve_s += time.perf_counter() - refresh_start
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
for key, value in retry_timings.items():
timings[key] += value
if status_code == 403:
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
self._record_timing(
range_jobs=1.0,
range_bytes=float(len(payload)),
range_resolve_s=resolve_s,
**timings,
)
return payload
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
last_exc: Exception | None = None
retry_attempts = 0.0
retry_sleep_s = 0.0
for attempt in range(self.max_retries + 1):
try:
payload, status_code, timings = self._read_range_response_once(url, headers)
timings["range_retry_attempts"] = retry_attempts
timings["range_retry_sleep_s"] = retry_sleep_s
return payload, status_code, timings
except self._RETRYABLE_EXCEPTIONS as exc:
last_exc = exc
if attempt >= self.max_retries:
break
retry_attempts += 1.0
sleep_s = min(0.5 * 2**attempt, 5.0)
retry_sleep_s += sleep_s
time.sleep(sleep_s)
self._record_timing(
range_failed_requests=1.0,
range_retry_attempts=retry_attempts,
range_retry_sleep_s=retry_sleep_s,
)
if last_exc is None:
raise RuntimeError("HTTP range request failed without an exception")
raise last_exc
def _read_range_response_once(
self, url: str, headers: dict[str, str]
) -> tuple[bytes, int, dict[str, float]]:
header_start = time.perf_counter()
with self.client.stream("GET", url, headers=headers) as response:
header_s = time.perf_counter() - header_start
if response.status_code == 403:
return (
b"",
response.status_code,
{
"range_header_s": header_s,
"range_first_byte_s": 0.0,
"range_body_s": 0.0,
},
)
hf_raise_for_status(response)
chunks = []
first_byte_s = 0.0
first_chunk = True
body_start = time.perf_counter()
for chunk in response.iter_bytes():
if first_chunk:
first_byte_s = time.perf_counter() - body_start
first_chunk = False
chunks.append(chunk)
body_s = time.perf_counter() - body_start
return (
b"".join(chunks),
response.status_code,
{
"range_header_s": header_s,
"range_first_byte_s": first_byte_s,
"range_body_s": body_s,
},
)
def _record_timing(self, **kwargs: float) -> None:
with self._timing_lock:
for key, value in kwargs.items():
self._timing_totals[key] += value
def timing_summary(self) -> dict[str, float]:
with self._timing_lock:
return dict(self._timing_totals)
def close(self) -> None:
self.client.close()
def make_range_fetcher(
data_root: str | Path,
*,
range_backend: str,
workers: int,
native_http_connections: int | None = None,
native_http_timeout: float = 60.0,
native_http_retries: int = 4,
):
if range_backend == "fsspec":
return ThreadLocalRangeFetcher(data_root)
if range_backend == "native-http":
max_connections = native_http_connections or max(8, workers)
return NativeHTTPRangeFetcher(
data_root,
max_connections=max_connections,
timeout=native_http_timeout,
max_retries=native_http_retries,
)
raise ValueError(f"Unknown range backend: {range_backend}")
class EpisodeVideoManifest:
_FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {}
_FILE_SIDECAR_CACHE_LOCK = threading.Lock()
def __init__(
self,
*,
video_keys: list[str],
files: list[VideoFileRecord],
spans: dict[str, np.ndarray],
):
self.video_keys = list(video_keys)
self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)}
self.files = files
self.spans = spans
@classmethod
def build(
cls,
meta: LeRobotDatasetMetadata,
data_root: str | Path,
*,
episode_indices: list[int] | range | None = None,
range_backend: str = "fsspec",
workers: int = 8,
header_probe_bytes: int = 4 * 1024 * 1024,
max_probe_bytes: int = 64 * 1024 * 1024,
keyframe_pad_s: float = 0.1,
keyframe_pad_fraction: float = 0.05,
sidecar_path: str | Path | None = None,
) -> EpisodeVideoManifest:
meta.ensure_readable()
video_keys = list(meta.video_keys)
if episode_indices is None:
episode_indices = range(int(meta.total_episodes))
rel_paths = sorted(
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys}
)
path_to_id = {path: idx for idx, path in enumerate(rel_paths)}
if sidecar_path is None:
files = cls._build_file_records(
rel_paths,
data_root,
range_backend=range_backend,
workers=workers,
header_probe_bytes=header_probe_bytes,
max_probe_bytes=max_probe_bytes,
)
else:
records = cls.load_file_sidecar(sidecar_path)
missing = [path for path in rel_paths if path not in records]
if missing:
raise ValueError(
f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}"
)
files = [records[path] for path in rel_paths]
total = int(meta.total_episodes)
num_cameras = len(video_keys)
spans: dict[str, np.ndarray] = {
"file_id": np.zeros((total, num_cameras), dtype=np.int32),
"mdat_offset": np.zeros((total, num_cameras), dtype=np.int64),
"mdat_length": np.zeros((total, num_cameras), dtype=np.int64),
"first_pts": np.zeros((total, num_cameras), dtype=np.float64),
"last_pts": np.zeros((total, num_cameras), dtype=np.float64),
"frame_count": np.zeros((total, num_cameras), dtype=np.int32),
"sample_lo": np.zeros((total, num_cameras), dtype=np.int32),
"sample_hi": np.zeros((total, num_cameras), dtype=np.int32),
"source_start_pts": np.zeros((total, num_cameras), dtype=np.float64),
}
for ep_idx in episode_indices:
ep = meta.episodes[ep_idx]
for cam_idx, key in enumerate(video_keys):
rel_path = str(meta.get_video_file_path(ep_idx, key))
file_id = path_to_id[rel_path]
mp4 = files[file_id].mp4
from_ts = float(ep[f"videos/{key}/from_timestamp"])
to_ts = float(ep[f"videos/{key}/to_timestamp"])
sample_slice = mp4.sample_slice(
from_ts,
to_ts,
keyframe_pad_s=keyframe_pad_s,
keyframe_pad_fraction=keyframe_pad_fraction,
file_size=files[file_id].file_size,
)
spans["file_id"][ep_idx, cam_idx] = file_id
spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset
spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length
spans["first_pts"][ep_idx, cam_idx] = from_ts
spans["last_pts"][ep_idx, cam_idx] = to_ts
spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1
spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo
spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi
spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts
return cls(video_keys=video_keys, files=files, spans=spans)
@staticmethod
def _build_file_records(
rel_paths: list[str],
data_root: str | Path,
*,
range_backend: str,
workers: int,
header_probe_bytes: int,
max_probe_bytes: int,
) -> list[VideoFileRecord]:
fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
def build_file(path: str) -> VideoFileRecord:
file_size = fetcher.info_size(path)
mp4 = fetch_mp4_index(
path,
fetcher.read_range,
file_size=file_size,
header_probe_bytes=header_probe_bytes,
max_probe_bytes=max_probe_bytes,
)
return VideoFileRecord(path, file_size, mp4)
try:
with ThreadPoolExecutor(max_workers=workers) as pool:
return list(pool.map(build_file, rel_paths))
finally:
fetcher.close()
@classmethod
def write_file_sidecar(
cls,
sidecar_path: str | Path,
rel_paths: list[str],
data_root: str | Path,
*,
range_backend: str = "native-http",
workers: int = 8,
header_probe_bytes: int = 4 * 1024 * 1024,
max_probe_bytes: int = 64 * 1024 * 1024,
) -> None:
records = cls._build_file_records(
sorted(set(rel_paths)),
data_root,
range_backend=range_backend,
workers=workers,
header_probe_bytes=header_probe_bytes,
max_probe_bytes=max_probe_bytes,
)
cls.save_file_sidecar(sidecar_path, records)
@staticmethod
def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None:
sidecar_path = Path(sidecar_path)
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"version": 1,
"files": [
{"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()}
for record in records
],
}
arrays = {}
for file_idx, record in enumerate(records):
arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts
arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations
arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes
arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets
arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples
np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays)
@staticmethod
def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]:
cache_key = str(Path(sidecar_path).expanduser())
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key)
if cached is not None:
return cached
with np.load(sidecar_path, allow_pickle=False) as data:
payload = json.loads(bytes(data["manifest_json"]).decode("utf-8"))
records = {}
for file_idx, item in enumerate(payload["files"]):
arrays = {
name: data[f"{file_idx}/{name}"]
for name in [
"sample_pts",
"sample_durations",
"sample_sizes",
"sample_offsets",
"sync_samples",
]
}
mp4 = Mp4Index.from_dict(item["mp4"], arrays)
records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4)
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records
return records
def camera_id(self, camera_key: str) -> int:
return self._camera_to_id[camera_key]
def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan:
cam = self.camera_id(camera_key)
return EpisodeVideoSpan(
file_id=int(self.spans["file_id"][episode_index, cam]),
mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]),
mdat_length=int(self.spans["mdat_length"][episode_index, cam]),
first_pts=float(self.spans["first_pts"][episode_index, cam]),
last_pts=float(self.spans["last_pts"][episode_index, cam]),
frame_count=int(self.spans["frame_count"][episode_index, cam]),
sample_lo=int(self.spans["sample_lo"][episode_index, cam]),
sample_hi=int(self.spans["sample_hi"][episode_index, cam]),
source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]),
)
def file_lookup(self, file_id: int) -> VideoFileRecord:
return self.files[file_id]
def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index:
return self.files[self.lookup(episode_index, camera_key).file_id].mp4
def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice:
span = self.lookup(episode_index, camera_key)
return Mp4SampleSlice(
sample_lo=span.sample_lo,
sample_hi=span.sample_hi,
byte_offset=span.mdat_offset,
byte_length=span.mdat_length,
source_start_pts=span.source_start_pts,
)
class EpisodeByteCache:
def __init__(
self,
manifest: EpisodeVideoManifest,
data_root: str | Path,
*,
byte_budget: int = 80 * 1024**3,
workers: int = 8,
range_backend: str = "fsspec",
native_http_connections: int | None = None,
native_http_timeout: float = 60.0,
native_http_retries: int = 4,
open_decoders: bool = True,
):
self.manifest = manifest
self.fetcher = make_range_fetcher(
data_root,
range_backend=range_backend,
workers=workers,
native_http_connections=native_http_connections,
native_http_timeout=native_http_timeout,
native_http_retries=native_http_retries,
)
self.byte_budget = byte_budget
self.open_decoders = open_decoders
self._pool = ThreadPoolExecutor(max_workers=workers)
self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict()
self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {}
self._bytes = 0
self._lock = threading.Lock()
self._timing_totals = {
"lookup_s": 0.0,
"fetch_s": 0.0,
"synthesize_s": 0.0,
"store_s": 0.0,
"jobs": 0.0,
}
def close(self) -> None:
self._pool.shutdown(wait=True)
with self._lock:
self._cache.clear()
self._futures.clear()
self._bytes = 0
self.fetcher.close()
def __enter__(self) -> EpisodeByteCache:
return self
def __exit__(self, *_exc) -> None:
self.close()
def submit_prefetch(self, episode_index: int) -> None:
for camera_key in self.manifest.video_keys:
self._submit(episode_index, camera_key)
def ensure_ready(self, episode_index: int) -> None:
for camera_key in self.manifest.video_keys:
self.get_bytes(episode_index, camera_key)
def get_bytes(self, episode_index: int, camera_key: str) -> bytes:
return self._get_entry(episode_index, camera_key)["bytes"]
def get_decoder(self, episode_index: int, camera_key: str):
entry = self._get_entry(episode_index, camera_key)
decoder = entry.get("decoder")
if decoder is None:
decoder = open_video_decoder(io.BytesIO(entry["bytes"]))
entry["decoder"] = decoder
return decoder
def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]):
span = self.manifest.lookup(episode_index, camera_key)
local_ts = [ts - span.source_start_pts for ts in timestamps]
decoder = self.get_decoder(episode_index, camera_key)
if hasattr(decoder, "get_frames_played_at"):
return decoder.get_frames_played_at(local_ts).data
metadata = decoder.metadata
fps = getattr(metadata, "average_fps", None)
if fps is None:
duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9)
fps = metadata.num_frames / duration
return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data
def timing_summary(self) -> dict[str, float]:
with self._lock:
summary = dict(self._timing_totals)
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
if fetcher_summary is not None:
summary.update(fetcher_summary())
return summary
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
key = (episode_index, camera_key)
with self._lock:
if key in self._cache:
future: Future[dict[str, Any]] = Future()
future.set_result(self._cache[key])
return future
future = self._futures.get(key)
if future is None:
future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key)
self._futures[key] = future
return future
def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]:
key = (episode_index, camera_key)
with self._lock:
entry = self._cache.get(key)
if entry is not None:
self._cache.move_to_end(key)
return entry
future = self._submit(episode_index, camera_key)
entry = future.result()
store_start = time.perf_counter()
with self._lock:
self._futures.pop(key, None)
existing = self._cache.get(key)
if existing is not None:
self._cache.move_to_end(key)
return existing
self._cache[key] = entry
self._bytes += len(entry["bytes"])
self._evict_locked()
timings = entry.pop("_timings", None)
if timings is not None:
self._timing_totals["lookup_s"] += timings["lookup_s"]
self._timing_totals["fetch_s"] += timings["fetch_s"]
self._timing_totals["synthesize_s"] += timings["synthesize_s"]
self._timing_totals["store_s"] += time.perf_counter() - store_start
self._timing_totals["jobs"] += 1
return entry
def _evict_locked(self) -> None:
while self._bytes > self.byte_budget and self._cache:
_key, entry = self._cache.popitem(last=False)
self._bytes -= len(entry["bytes"])
def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]:
lookup_start = time.perf_counter()
span = self.manifest.lookup(episode_index, camera_key)
file_record = self.manifest.file_lookup(span.file_id)
sample_slice = Mp4SampleSlice(
sample_lo=span.sample_lo,
sample_hi=span.sample_hi,
byte_offset=span.mdat_offset,
byte_length=span.mdat_length,
source_start_pts=span.source_start_pts,
)
lookup_s = time.perf_counter() - lookup_start
fetch_start = time.perf_counter()
payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length)
fetch_s = time.perf_counter() - fetch_start
if len(payload) != span.mdat_length:
raise OSError(
f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}"
)
synthesize_start = time.perf_counter()
mp4_bytes = synthesize_mp4(file_record.mp4, sample_slice, payload)
synthesize_s = time.perf_counter() - synthesize_start
entry: dict[str, Any] = {
"bytes": mp4_bytes,
"decoder": None,
"_timings": {
"lookup_s": lookup_s,
"fetch_s": fetch_s,
"synthesize_s": synthesize_s,
},
}
if self.open_decoders:
entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes))
return entry
def open_video_decoder(file_like_or_bytesio, frame_mappings=None):
if frame_mappings is not None:
raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.")
from torchcodec.decoders import VideoDecoder
return VideoDecoder(file_like_or_bytesio, seek_mode="approximate")
def assert_hf_hub_range_cache_branch() -> None:
"""Fail unless huggingface_hub was installed from the required range-cache branch."""
try:
dist = metadata.distribution("huggingface_hub")
except metadata.PackageNotFoundError as exc:
raise AssertionError("huggingface_hub is not installed") from exc
candidates = []
direct_url = dist.read_text("direct_url.json")
if direct_url:
candidates.append(direct_url)
with contextlib.suppress(json.JSONDecodeError):
parsed = json.loads(direct_url)
candidates.append(str(parsed.get("url", "")))
candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", "")))
candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", "")))
text = "\n".join(candidates)
if "feat/hffs-cache-cdn-range-reads" not in text:
raise AssertionError(
"huggingface_hub must be installed from "
"git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads"
)
@dataclass
class StageTimer:
fetch_ms: float = 0.0
decode_ms: float = 0.0
bytes_read: int = 0
misses: int = 0
def record_fetch(self, start: float, byte_count: int) -> None:
self.fetch_ms += (time.perf_counter() - start) * 1000
self.bytes_read += byte_count
self.misses += 1
-666
View File
@@ -1,666 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import struct
from collections.abc import Callable, Iterable
from dataclasses import dataclass
import numpy as np
@dataclass(frozen=True)
class Box:
type: bytes
start: int
header_size: int
end: int
@property
def payload_start(self) -> int:
return self.start + self.header_size
@property
def size(self) -> int:
return self.end - self.start
@dataclass(frozen=True)
class Mp4SampleSlice:
sample_lo: int
sample_hi: int
byte_offset: int
byte_length: int
source_start_pts: float
@dataclass(frozen=True)
class Mp4Index:
file_path: str
file_size: int
ftyp: bytes
moov_offset: int
mdat_offset: int
mdat_payload_offset: int
mdat_payload_size: int
faststart: bool
codec: str
timescale: int
duration: int
track_id: int
width: int
height: int
stsd_body: bytes
sample_pts: np.ndarray
sample_durations: np.ndarray
sample_sizes: np.ndarray
sample_offsets: np.ndarray
sync_samples: np.ndarray
def sample_slice(
self,
from_ts: float,
to_ts: float,
*,
keyframe_pad_s: float = 0.1,
keyframe_pad_fraction: float = 0.05,
file_size: int | None = None,
) -> Mp4SampleSlice:
if to_ts < from_ts:
raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}")
if len(self.sample_pts) == 0:
raise ValueError(f"{self.file_path} contains no indexed samples")
pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction)
lo_ts = max(0.0, from_ts - pad)
hi_ts = to_ts + pad
lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left"))
hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1
lo = min(max(lo, 0), len(self.sample_pts) - 1)
hi = min(max(hi, lo), len(self.sample_pts) - 1)
if len(self.sync_samples):
prev_sync = self.sync_samples[self.sync_samples <= lo]
if len(prev_sync):
lo = int(prev_sync[-1])
else:
lo = int(self.sync_samples[0])
if lo > hi:
hi = lo
offsets = self.sample_offsets[lo : hi + 1]
sizes = self.sample_sizes[lo : hi + 1]
slice_lo = int(offsets.min())
slice_hi = int((offsets + sizes).max())
if file_size is not None:
slice_hi = min(slice_hi, int(file_size))
return Mp4SampleSlice(
sample_lo=lo,
sample_hi=hi,
byte_offset=slice_lo,
byte_length=slice_hi - slice_lo,
source_start_pts=float(self.sample_pts[lo]),
)
def to_dict(self) -> dict:
return {
"file_path": self.file_path,
"file_size": self.file_size,
"ftyp": self.ftyp.hex(),
"moov_offset": self.moov_offset,
"mdat_offset": self.mdat_offset,
"mdat_payload_offset": self.mdat_payload_offset,
"mdat_payload_size": self.mdat_payload_size,
"faststart": self.faststart,
"codec": self.codec,
"timescale": self.timescale,
"duration": self.duration,
"track_id": self.track_id,
"width": self.width,
"height": self.height,
"stsd_body": self.stsd_body.hex(),
}
@classmethod
def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index:
return cls(
file_path=data["file_path"],
file_size=int(data["file_size"]),
ftyp=bytes.fromhex(data["ftyp"]),
moov_offset=int(data["moov_offset"]),
mdat_offset=int(data["mdat_offset"]),
mdat_payload_offset=int(data["mdat_payload_offset"]),
mdat_payload_size=int(data["mdat_payload_size"]),
faststart=bool(data["faststart"]),
codec=data["codec"],
timescale=int(data["timescale"]),
duration=int(data["duration"]),
track_id=int(data["track_id"]),
width=int(data["width"]),
height=int(data["height"]),
stsd_body=bytes.fromhex(data["stsd_body"]),
sample_pts=arrays["sample_pts"],
sample_durations=arrays["sample_durations"],
sample_sizes=arrays["sample_sizes"],
sample_offsets=arrays["sample_offsets"],
sync_samples=arrays["sync_samples"],
)
def fetch_mp4_index(
path: str,
read_range: Callable[[str, int, int], bytes],
*,
file_size: int,
header_probe_bytes: int = 4 * 1024 * 1024,
max_probe_bytes: int = 64 * 1024 * 1024,
) -> Mp4Index:
probe_size = min(header_probe_bytes, file_size)
while True:
data = read_range(path, 0, probe_size)
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
has_mdat = any(box.type == b"mdat" for box in top)
has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top)
if has_mdat and has_moov:
return parse_mp4_index(path, data, file_size=file_size)
if probe_size >= min(max_probe_bytes, file_size):
if has_mdat and not has_moov:
tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes)
if tail_index is not None:
return tail_index
missing = []
if not has_mdat:
missing.append("mdat")
if not has_moov:
missing.append("moov")
raise ValueError(
f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}"
)
probe_size = min(probe_size * 2, max_probe_bytes, file_size)
def _fetch_tail_moov_index(
path: str,
read_range: Callable[[str, int, int], bytes],
prefix: bytes,
top_boxes: list[Box],
file_size: int,
max_probe_bytes: int,
) -> Mp4Index | None:
mdat_box = _one(top_boxes, b"mdat")
if mdat_box is None or mdat_box.end >= file_size:
return None
tail_offset = mdat_box.end
tail_length = min(max_probe_bytes, file_size - tail_offset)
tail = read_range(path, tail_offset, tail_length)
tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True))
moov_box = next(
(box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None
)
if moov_box is None:
return None
ftyp_box = _one(top_boxes, b"ftyp", required=False)
ftyp = (
prefix[ftyp_box.start : ftyp_box.end]
if ftyp_box is not None
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
)
moov_start = moov_box.payload_start - tail_offset
moov_end = moov_box.end - tail_offset
return _parse_mp4_index_from_layout(
path,
file_size=file_size,
ftyp=ftyp,
moov_offset=moov_box.start,
moov=tail[moov_start:moov_end],
mdat_box=mdat_box,
)
def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index:
if file_size is None:
file_size = len(data)
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
ftyp_box = _one(top, b"ftyp", required=False)
moov_box = _one(top, b"moov")
mdat_box = _one(top, b"mdat")
if moov_box.end > len(data):
raise ValueError(f"{path}: moov box is truncated")
moov = data[moov_box.payload_start : moov_box.end]
ftyp = (
data[ftyp_box.start : ftyp_box.end]
if ftyp_box is not None
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
)
return _parse_mp4_index_from_layout(
path,
file_size=file_size,
ftyp=ftyp,
moov_offset=moov_box.start,
moov=moov,
mdat_box=mdat_box,
)
def _parse_mp4_index_from_layout(
path: str,
*,
file_size: int,
ftyp: bytes,
moov_offset: int,
moov: bytes,
mdat_box: Box,
) -> Mp4Index:
mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"]))
trak_box, trak_payload = _find_video_trak(moov)
_ = trak_box
tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"]))
mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"]))
stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"])
stsd = _find_child(stbl, b"stsd")
stsd_body = stbl[stsd.payload_start : stsd.end]
codec = _parse_stsd_codec(stsd_body)
stts = _parse_stts(_payload(stbl, b"stts"))
sample_sizes = _parse_stsz(_payload(stbl, b"stsz"))
stsc = _parse_stsc(_payload(stbl, b"stsc"))
chunk_offsets = _parse_chunk_offsets(stbl)
sync_samples = _parse_stss(stbl, len(sample_sizes))
sample_durations = _expand_stts(stts, len(sample_sizes))
sample_pts_units = np.empty(len(sample_durations), dtype=np.int64)
if len(sample_durations):
sample_pts_units[0] = 0
if len(sample_durations) > 1:
sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64)
sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale)
sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes)
return Mp4Index(
file_path=path,
file_size=file_size,
ftyp=ftyp,
moov_offset=moov_offset,
mdat_offset=mdat_box.start,
mdat_payload_offset=mdat_box.payload_start,
mdat_payload_size=mdat_box.end - mdat_box.payload_start
if mdat_box.end <= file_size
else file_size - mdat_box.payload_start,
faststart=moov_offset < mdat_box.start,
codec=codec,
timescale=mdhd_timescale,
duration=mdhd_duration or mvhd_duration,
track_id=tkhd["track_id"],
width=tkhd["width"],
height=tkhd["height"],
stsd_body=stsd_body,
sample_pts=sample_pts,
sample_durations=sample_durations,
sample_sizes=sample_sizes,
sample_offsets=sample_offsets,
sync_samples=sync_samples,
)
def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes:
lo = sample_slice.sample_lo
hi = sample_slice.sample_hi + 1
if lo < 0 or hi > len(index.sample_sizes) or lo >= hi:
raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}")
offsets = index.sample_offsets[lo:hi]
sizes = index.sample_sizes[lo:hi]
rel_offsets = offsets - sample_slice.byte_offset
if int(rel_offsets.min()) != 0:
raise ValueError("Sample slice must start at the minimum referenced sample offset")
if int((rel_offsets + sizes).max()) > len(mdat_payload):
raise ValueError("Sample slice does not cover all referenced samples")
durations = index.sample_durations[lo:hi]
sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0)
header_size = len(index.ftyp) + len(moov)
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8)
return index.ftyp + moov + _box(b"mdat", mdat_payload)
def iter_boxes(
data: bytes,
start: int,
end: int,
*,
absolute_base: int = 0,
allow_truncated: bool = False,
) -> Iterable[Box]:
pos = start
while pos + 8 <= end:
size = struct.unpack_from(">I", data, pos)[0]
typ = data[pos + 4 : pos + 8]
header_size = 8
if size == 1:
if pos + 16 > end:
break
size = struct.unpack_from(">Q", data, pos + 8)[0]
header_size = 16
elif size == 0:
size = end - pos
if size < header_size:
break
box_end = pos + size
if box_end > end and not allow_truncated:
break
yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end)
pos = box_end
def _find_video_trak(moov: bytes) -> tuple[Box, bytes]:
for trak in _children(moov, 0, len(moov)):
if trak.type != b"trak":
continue
payload = moov[trak.payload_start : trak.end]
hdlr = _find_descendant(payload, [b"mdia", b"hdlr"])
if hdlr[8:12] == b"vide":
return trak, payload
raise ValueError("No video track found")
def _find_descendant(data: bytes, path: list[bytes]) -> bytes:
current = data
for typ in path:
box = _find_child(current, typ)
current = current[box.payload_start : box.end]
return current
def _find_child(data: bytes, typ: bytes) -> Box:
for box in _children(data, 0, len(data)):
if box.type == typ:
return box
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
def _children(data: bytes, start: int, end: int) -> Iterable[Box]:
return iter_boxes(data, start, end, absolute_base=0)
def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None:
matches = [box for box in boxes if box.type == typ]
if not matches and required:
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
return matches[0] if matches else None
def _payload(parent: bytes, typ: bytes) -> bytes:
box = _find_child(parent, typ)
return parent[box.payload_start : box.end]
def _parse_mvhd(payload: bytes) -> tuple[int, int]:
version = payload[0]
if version == 1:
return struct.unpack_from(">IQ", payload, 20)
return struct.unpack_from(">II", payload, 12)
def _parse_mdhd(payload: bytes) -> tuple[int, int]:
version = payload[0]
if version == 1:
return struct.unpack_from(">IQ", payload, 20)
return struct.unpack_from(">II", payload, 12)
def _parse_tkhd(payload: bytes) -> dict[str, int]:
version = payload[0]
if version == 1:
track_id = struct.unpack_from(">I", payload, 20)[0]
duration = struct.unpack_from(">Q", payload, 28)[0]
width, height = struct.unpack_from(">II", payload, 88)
else:
track_id = struct.unpack_from(">I", payload, 12)[0]
duration = struct.unpack_from(">I", payload, 20)[0]
width, height = struct.unpack_from(">II", payload, 76)
return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16}
def _parse_stsd_codec(stsd_body: bytes) -> str:
if len(stsd_body) < 16:
return "unknown"
return stsd_body[12:16].decode("latin1")
def _parse_stts(payload: bytes) -> list[tuple[int, int]]:
count = struct.unpack_from(">I", payload, 4)[0]
out = []
offset = 8
for _ in range(count):
out.append(struct.unpack_from(">II", payload, offset))
offset += 8
return out
def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray:
values = np.empty(sample_count, dtype=np.int64)
pos = 0
for count, delta in entries:
values[pos : pos + count] = delta
pos += count
if pos != sample_count:
raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}")
return values
def _parse_stsz(payload: bytes) -> np.ndarray:
sample_size, sample_count = struct.unpack_from(">II", payload, 4)
if sample_size:
return np.full(sample_count, sample_size, dtype=np.int64)
offset = 12
values = np.empty(sample_count, dtype=np.int64)
for idx in range(sample_count):
values[idx] = struct.unpack_from(">I", payload, offset)[0]
offset += 4
return values
def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]:
count = struct.unpack_from(">I", payload, 4)[0]
out = []
offset = 8
for _ in range(count):
out.append(struct.unpack_from(">III", payload, offset))
offset += 12
return out
def _parse_chunk_offsets(stbl: bytes) -> np.ndarray:
with_stco = None
with_co64 = None
for box in _children(stbl, 0, len(stbl)):
if box.type == b"stco":
with_stco = stbl[box.payload_start : box.end]
elif box.type == b"co64":
with_co64 = stbl[box.payload_start : box.end]
if with_co64 is not None:
count = struct.unpack_from(">I", with_co64, 4)[0]
return np.array(
[struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64
)
if with_stco is None:
raise ValueError("Missing stco/co64 chunk offsets")
count = struct.unpack_from(">I", with_stco, 4)[0]
return np.array(
[struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64
)
def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray:
for box in _children(stbl, 0, len(stbl)):
if box.type == b"stss":
payload = stbl[box.payload_start : box.end]
count = struct.unpack_from(">I", payload, 4)[0]
return np.array(
[struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)],
dtype=np.int64,
)
return np.arange(sample_count, dtype=np.int64)
def _sample_offsets(
stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray
) -> np.ndarray:
if not stsc:
raise ValueError("stsc is empty")
offsets = np.empty(len(sample_sizes), dtype=np.int64)
sample_idx = 0
for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc):
next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1
for chunk_number in range(first_chunk, next_first):
if chunk_number < 1 or chunk_number > len(chunk_offsets):
raise ValueError("stsc references a chunk outside stco/co64")
chunk_pos = int(chunk_offsets[chunk_number - 1])
for _ in range(samples_per_chunk):
if sample_idx >= len(sample_sizes):
return offsets
offsets[sample_idx] = chunk_pos
chunk_pos += int(sample_sizes[sample_idx])
sample_idx += 1
if sample_idx != len(sample_sizes):
raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}")
return offsets
def _make_moov(
index: Mp4Index,
durations: np.ndarray,
sizes: np.ndarray,
rel_offsets: np.ndarray,
sync_samples: np.ndarray,
*,
mdat_data_offset: int,
) -> bytes:
duration = int(durations.sum())
stco_values = [int(mdat_data_offset + value) for value in rel_offsets]
if any(value > 0xFFFFFFFF for value in stco_values):
offset_box = _co64(stco_values)
else:
offset_box = _stco(stco_values)
stbl = _box(
b"stbl",
_box(b"stsd", index.stsd_body)
+ _stts(durations)
+ _stsc_one_sample_per_chunk(len(sizes))
+ _stsz(sizes)
+ offset_box
+ (_stss(sync_samples) if len(sync_samples) else b""),
)
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf)
trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia)
return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak)
def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes:
return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload)
def _box(typ: bytes, payload: bytes) -> bytes:
size = len(payload) + 8
if size <= 0xFFFFFFFF:
return struct.pack(">I4s", size, typ) + payload
return struct.pack(">I4sQ", 1, typ, size + 8) + payload
def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes:
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
payload = (
struct.pack(">IIII", 0, 0, timescale, duration)
+ struct.pack(">IHH", 0x00010000, 0x0100, 0)
+ b"\0" * 8
+ matrix
+ b"\0" * 24
+ struct.pack(">I", next_track_id)
)
return _full_box(b"mvhd", 0, 0, payload)
def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes:
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
payload = (
struct.pack(">IIIII", 0, 0, track_id, 0, duration)
+ b"\0" * 8
+ struct.pack(">hhhh", 0, 0, 0, 0)
+ matrix
+ struct.pack(">II", width << 16, height << 16)
)
return _full_box(b"tkhd", 0, 7, payload)
def _mdhd(timescale: int, duration: int) -> bytes:
return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0")
def _hdlr() -> bytes:
return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0")
def _vmhd() -> bytes:
return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0))
def _dinf() -> bytes:
url = _full_box(b"url ", 0, 1)
dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url)
return _box(b"dinf", dref)
def _stts(durations: np.ndarray) -> bytes:
runs = []
for duration in durations.tolist():
if runs and runs[-1][1] == int(duration):
runs[-1][0] += 1
else:
runs.append([1, int(duration)])
payload = struct.pack(">I", len(runs)) + b"".join(
struct.pack(">II", count, delta) for count, delta in runs
)
return _full_box(b"stts", 0, 0, payload)
def _stsc_one_sample_per_chunk(sample_count: int) -> bytes:
return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1))
def _stsz(sizes: np.ndarray) -> bytes:
return _full_box(
b"stsz",
0,
0,
struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()),
)
def _stco(values: list[int]) -> bytes:
return _full_box(
b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values)
)
def _co64(values: list[int]) -> bytes:
return _full_box(
b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values)
)
def _stss(values: np.ndarray) -> bytes:
return _full_box(
b"stss",
0,
0,
struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()),
)
+3 -5
View File
@@ -70,21 +70,19 @@ def aggregate_pipeline_dataset_features(
initial_features: dict[PipelineFeatureType, dict[str, Any]],
*,
use_videos: bool = True,
exclude_images: bool = False,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
This function transforms initial features using the pipeline, categorizes them as action or observations
(image or state), filters them based on `exclude_images` and `patterns`, and finally
(image or state), filters them based on `use_videos` and `patterns`, and finally
formats them for use with a Hugging Face LeRobot Dataset.
Args:
pipeline: The DataProcessorPipeline to apply.
initial_features: A dictionary of raw feature specs for actions and observations.
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
exclude_images: If True, image features are dropped entirely from the output.
use_videos: If False, image features are excluded.
patterns: A sequence of regex patterns to filter action and state features.
Image features are not affected by this filter.
@@ -122,7 +120,7 @@ def aggregate_pipeline_dataset_features(
)
# 2. Apply filtering rules.
if is_image and exclude_images:
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, compiled_patterns):
continue
+20
View File
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "camera_obs" in observations:
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
# Pass through any remaining ndarray/tensor keys not already handled above,
# so env plugins can expose extra observation keys via get_env_processors().
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
for key, value in observations.items():
if key in _handled:
continue
target = f"{OBS_STR}.{key}"
if target in return_observations:
continue
if isinstance(value, np.ndarray):
val = torch.from_numpy(value).float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
elif isinstance(value, Tensor):
val = value.float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
return return_observations
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
logger = logging.getLogger(__name__)
class BiOpenArmFollower(BimanualMixin, Robot):
class BiOpenArmFollower(Robot):
"""
Bimanual OpenArm Follower Arms
"""
@@ -40,17 +39,15 @@ class BiOpenArmFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
# will only open the cameras assigned to it. Per-arm cameras are used
# as fallback when top-level cameras are empty.
if config.cameras:
left_cameras = config.cameras
right_cameras = {}
else:
left_cameras = config.left_arm_config.cameras
right_cameras = config.right_arm_config.cameras
left_arm_config = OpenArmFollowerConfig(
id=f"{config.id}_left" if config.id else None,
@@ -59,7 +56,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=left_cameras,
side=config.left_arm_config.side,
can_interface=config.left_arm_config.can_interface,
use_can_fd=config.left_arm_config.use_can_fd,
@@ -78,7 +75,7 @@ class BiOpenArmFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
max_relative_target=config.right_arm_config.max_relative_target,
cameras=config.right_arm_config.cameras,
cameras=right_cameras,
side=config.right_arm_config.side,
can_interface=config.right_arm_config.can_interface,
use_can_fd=config.right_arm_config.use_can_fd,
@@ -98,19 +95,22 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@property
def _motors_ft(self) -> dict[str, type]:
left_arm_motors_ft = self.left_arm._motors_ft
right_arm_motors_ft = self.right_arm._motors_ft
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
return {
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
}
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
# "right_wrist"), so we merge them directly — unlike motors which need the
# left_/right_ prefix to disambiguate identical per-arm joint names.
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -120,6 +120,27 @@ class BiOpenArmFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -127,15 +148,21 @@ class BiOpenArmFollower(BimanualMixin, Robot):
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Camera keys that should NOT get the arm prefix (they already have unique names)
left_cam_keys = set(self.left_arm.cameras.keys())
right_cam_keys = set(self.right_arm.cameras.keys())
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
# and the dataset feature names recorded during data collection.
right_obs = self.right_arm.get_observation()
for key, value in right_obs.items():
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
left_obs = self.left_arm.get_observation()
for key, value in left_obs.items():
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
return obs_dict
@@ -162,4 +189,9 @@ class BiOpenArmFollower(BimanualMixin, Robot):
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -32,7 +32,5 @@ class BiOpenArmFollowerConfig(RobotConfig):
left_arm_config: OpenArmFollowerConfigBase
right_arm_config: OpenArmFollowerConfigBase
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
# Top-level cameras shared across both arms.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
from ..robot import Robot
@@ -28,7 +27,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
logger = logging.getLogger(__name__)
class BiRebotB601Follower(BimanualMixin, Robot):
class BiRebotB601Follower(Robot):
"""Bimanual Seeed Studio reBot B601-DM follower.
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
@@ -42,18 +41,6 @@ class BiRebotB601Follower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = RebotB601FollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -62,7 +49,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
dm_serial_baud=config.left_arm_config.dm_serial_baud,
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
@@ -99,12 +86,10 @@ class BiRebotB601Follower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
return {
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -114,13 +99,32 @@ class BiRebotB601Follower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
for k, v in self.left_arm.get_observation().items():
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm.get_observation().items():
obs_dict[f"right_{k}"] = v
obs_dict = {}
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
return obs_dict
@check_if_not_connected
@@ -139,3 +143,8 @@ class BiRebotB601Follower(BimanualMixin, Robot):
**{f"left_{k}": v for k, v in sent_action_left.items()},
**{f"right_{k}": v for k, v in sent_action_right.items()},
}
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..rebot_b601_follower import RebotB601FollowerConfig
@@ -29,8 +27,3 @@ class BiRebotB601FollowerConfig(RobotConfig):
left_arm_config: RebotB601FollowerConfig
right_arm_config: RebotB601FollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
from ..so_follower import SOFollower, SOFollowerRobotConfig
@@ -28,7 +27,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__)
class BiSOFollower(BimanualMixin, Robot):
class BiSOFollower(Robot):
"""
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -40,18 +39,6 @@ class BiSOFollower(BimanualMixin, Robot):
super().__init__(config)
self.config = config
# Top-level cameras are opened by `left_arm` for convenience, but their
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
self._top_level_cam_keys = set(config.cameras)
_collisions = self._top_level_cam_keys & set(
config.left_arm_config.cameras
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
if _collisions:
raise ValueError(
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
)
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
left_arm_config = SOFollowerRobotConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
@@ -59,7 +46,7 @@ class BiSOFollower(BimanualMixin, Robot):
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
max_relative_target=config.left_arm_config.max_relative_target,
use_degrees=config.left_arm_config.use_degrees,
cameras=left_arm_cameras,
cameras=config.left_arm_config.cameras,
)
right_arm_config = SOFollowerRobotConfig(
@@ -90,12 +77,13 @@ class BiSOFollower(BimanualMixin, Robot):
@property
def _cameras_ft(self) -> dict[str, tuple]:
out: dict[str, tuple] = {}
for k, v in self.left_arm._cameras_ft.items():
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
for k, v in self.right_arm._cameras_ft.items():
out[f"right_{k}"] = v
return out
left_arm_cameras_ft = self.left_arm._cameras_ft
right_arm_cameras_ft = self.right_arm._cameras_ft
return {
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
@@ -105,21 +93,42 @@ class BiSOFollower(BimanualMixin, Robot):
def action_features(self) -> dict[str, type]:
return self._motors_ft
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_observation(self) -> RobotObservation:
obs_dict: RobotObservation = {}
obs_dict = {}
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
for key, value in self.left_arm.get_observation().items():
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
# Add "left_" prefix
left_obs = self.left_arm.get_observation()
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
# Add "right_" prefix
for key, value in self.right_arm.get_observation().items():
obs_dict[f"right_{key}"] = value
right_obs = self.right_arm.get_observation()
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
return obs_dict
@@ -142,3 +151,8 @@ class BiSOFollower(BimanualMixin, Robot):
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
@check_if_not_connected
def disconnect(self):
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from dataclasses import dataclass
from ..config import RobotConfig
from ..so_follower import SOFollowerConfig
@@ -29,8 +27,3 @@ class BiSOFollowerConfig(RobotConfig):
left_arm_config: SOFollowerConfig
right_arm_config: SOFollowerConfig
# Top-level cameras not attached to a specific side. Keys are kept as-is in
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
# `{left,right}_arm_config.cameras`) are prefixed.
cameras: dict[str, CameraConfig] = field(default_factory=dict)
-1
View File
@@ -54,7 +54,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -57,7 +57,6 @@ from lerobot.robots import ( # noqa: F401
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
-1
View File
@@ -137,7 +137,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
-1
View File
@@ -174,7 +174,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
homunculus,
@@ -41,7 +41,6 @@ from lerobot.robots import ( # noqa: F401
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
koch_leader,
@@ -89,7 +89,6 @@ from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
bi_openarm_leader,
bi_openarm_mini,
bi_rebot_102_leader,
bi_so_leader,
gamepad,
@@ -18,8 +18,7 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator
@@ -28,7 +27,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
logger = logging.getLogger(__name__)
class BiOpenArmLeader(BimanualMixin, Teleoperator):
class BiOpenArmLeader(Teleoperator):
"""
Bimanual OpenArm Leader Arms
"""
@@ -87,6 +86,27 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
raise NotImplementedError(
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
@@ -109,3 +129,8 @@ class BiOpenArmLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
@dataclass
class BiOpenArmLeaderConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Leader teleoperators."""
"""Configuration class for Bi OpenArm Follower robots."""
left_arm_config: OpenArmLeaderConfigBase
right_arm_config: OpenArmLeaderConfigBase
@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_openarm_mini import BiOpenArmMini
from .config_bi_openarm_mini import BiOpenArmMiniConfig
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
@@ -1,101 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_mini import BiOpenArmMiniConfig
logger = logging.getLogger(__name__)
class BiOpenArmMini(BimanualMixin, Teleoperator):
"""Bimanual OpenArm Mini teleoperator.
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
bimanual leader can teleoperate a bimanual OpenArm follower.
"""
config_class = BiOpenArmMiniConfig
name = "bi_openarm_mini"
def __init__(self, config: BiOpenArmMiniConfig):
super().__init__(config)
self.config = config
# `side` is forced to match left/right regardless of what the user passed
# on the per-arm base config — the bimanual wrapper owns the side semantics.
left_arm_config = OpenArmMiniConfig(
id=f"{config.id}_left" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.left_arm_config.port,
side="left",
use_degrees=config.left_arm_config.use_degrees,
)
right_arm_config = OpenArmMiniConfig(
id=f"{config.id}_right" if config.id else None,
calibration_dir=config.calibration_dir,
port=config.right_arm_config.port,
side="right",
use_degrees=config.right_arm_config.use_degrees,
)
self.left_arm = OpenArmMini(left_arm_config)
self.right_arm = OpenArmMini(right_arm_config)
@cached_property
def action_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
}
@cached_property
def feedback_features(self) -> dict[str, type]:
return {
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
}
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
action: RobotAction = {}
for k, v in self.left_arm.get_action().items():
action[f"left_{k}"] = v
for k, v in self.right_arm.get_action().items():
action[f"right_{k}"] = v
return action
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
if left_fb:
self.left_arm.send_feedback(left_fb)
if right_fb:
self.right_arm.send_feedback(right_fb)
@@ -1,29 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
from ..openarm_mini import OpenArmMiniConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
@dataclass
class BiOpenArmMiniConfig(TeleoperatorConfig):
"""Configuration class for Bi OpenArm Mini teleoperators."""
left_arm_config: OpenArmMiniConfigBase
right_arm_config: OpenArmMiniConfigBase
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .bi_rebot_102_leader import BiRebot102Leader
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
from .bi_rebot_102_leader import BiRebotArm102Leader
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
@@ -18,17 +18,16 @@ import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
from ..teleoperator import Teleoperator
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
logger = logging.getLogger(__name__)
class BiRebot102Leader(BimanualMixin, Teleoperator):
class BiRebotArm102Leader(Teleoperator):
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
@@ -36,10 +35,10 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
leader can teleoperate a bimanual reBot B601 follower.
"""
config_class = BiRebot102LeaderConfig
config_class = BiRebotArm102LeaderConfig
name = "bi_rebot_102_leader"
def __init__(self, config: BiRebot102LeaderConfig):
def __init__(self, config: BiRebotArm102LeaderConfig):
super().__init__(config)
self.config = config
@@ -77,6 +76,27 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def get_action(self) -> RobotAction:
action_dict = {}
@@ -86,3 +106,8 @@ class BiRebot102Leader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
@dataclass
class BiRebot102LeaderConfig(TeleoperatorConfig):
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
left_arm_config: RebotArm102LeaderConfig
@@ -17,9 +17,7 @@
import logging
from functools import cached_property
from lerobot.types import RobotAction
from lerobot.utils.bimanual import BimanualMixin
from lerobot.utils.decorators import check_if_not_connected
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator
@@ -28,7 +26,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
logger = logging.getLogger(__name__)
class BiSOLeader(BimanualMixin, Teleoperator):
class BiSOLeader(Teleoperator):
"""
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
@@ -69,12 +67,33 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
def setup_motors(self) -> None:
self.left_arm.setup_motors()
self.right_arm.setup_motors()
@check_if_not_connected
def get_action(self) -> RobotAction:
def get_action(self) -> dict[str, float]:
action_dict = {}
# Add "left_" prefix
@@ -90,3 +109,8 @@ class BiSOLeader(BimanualMixin, Teleoperator):
def send_feedback(self, feedback: dict[str, float]) -> None:
# TODO: Implement force feedback
raise NotImplementedError
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
from .config_openarm_mini import OpenArmMiniConfig
from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
@@ -19,21 +19,12 @@ from dataclasses import dataclass
from ..config import TeleoperatorConfig
@dataclass
class OpenArmMiniConfigBase:
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
port: str
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
# during readout. If `None`, no flipping is applied.
side: str | None = None
use_degrees: bool = True
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
pass
class OpenArmMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
use_degrees: bool = True
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Per-side motor direction flips applied during readout.
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
}
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
# Leader joint 6 follower joint 7 (symmetric — its own inverse).
# Leader joint 6 maps to follower joint 7 and vice versa
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
GRIPPER_TELEOP_TO_DEGREES = -0.65
class OpenArmMini(Teleoperator):
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
"""
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
"""
config_class = OpenArmMiniConfig
@@ -56,12 +56,9 @@ class OpenArmMini(Teleoperator):
super().__init__(config)
self.config = config
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
norm_mode_body = MotorNormMode.DEGREES
motors = {
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
@@ -72,15 +69,46 @@ class OpenArmMini(Teleoperator):
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
self.bus = FeetechMotorsBus(
port=self.config.port,
motors=motors,
calibration=self.calibration,
motors_left = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
)
@property
def action_features(self) -> dict[str, type]:
return {f"{motor}.pos": float for motor in self.bus.motors}
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def feedback_features(self) -> dict[str, type]:
@@ -88,12 +116,14 @@ class OpenArmMini(Teleoperator):
@property
def is_connected(self) -> bool:
return self.bus.is_connected
return self.bus_right.is_connected and self.bus_left.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting arm on {self.config.port}...")
self.bus.connect()
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate:
self.calibrate()
@@ -103,14 +133,14 @@ class OpenArmMini(Teleoperator):
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for a single OpenArm Mini arm.
Run calibration procedure for OpenArm Mini.
1. Disable torque
2. Ask user to position arm in hanging position with gripper closed
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions)
5. Save calibration
@@ -122,51 +152,70 @@ class OpenArmMini(Teleoperator):
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
self.bus.write_calibration(self.calibration)
cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
self.bus.disable_torque()
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
logger.info("Setting Phase to 12 for all motors...")
for motor in self.bus.motors:
self.bus.write("Phase", motor, 12)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
"""Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input(
"\nCalibration: Zero Position\n"
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
homing_offsets = self.bus.set_half_turn_homings()
logger.info("Arm zero position set.")
homing_offsets = bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
print("\nSetting motor ranges\n")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
if self.calibration is None:
self.calibration = {}
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
max_res = motor_resolution - 1
for motor_name, motor in self.bus.motors.items():
for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper":
input(
"\nGripper Calibration\n"
"Step 1: CLOSE the gripper fully\n"
"Press ENTER when gripper is closed..."
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
f"Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..."
)
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
open_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos:
@@ -179,16 +228,16 @@ class OpenArmMini(Teleoperator):
drive_mode = 1
logger.info(
f" {motor_name}: range set to [{range_min}, {range_max}] "
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})"
)
else:
range_min = 0
range_max = max_res
drive_mode = 0
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[motor_name] = MotorCalibration(
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name],
@@ -196,68 +245,108 @@ class OpenArmMini(Teleoperator):
range_max=range_max,
)
self.bus.write_calibration(self.calibration)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
cal_for_bus = {
k.replace(f"{arm_name}_", ""): v
for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None:
self.bus.disable_torque()
self.bus.configure_motors()
for motor in self.bus.motors:
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_right.disable_torque()
self.bus_right.configure_motors()
for motor in self.bus_right.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None:
for motor in reversed(self.bus.motors):
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
self.bus.setup_motor(motor)
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
print("\nSetting up RIGHT arm motors...")
for motor in reversed(self.bus_right.motors):
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> RobotAction:
"""Get current action (read positions from all motors)."""
"""Get current action from both arms (read positions from all motors)."""
start = time.perf_counter()
positions = self.bus.sync_read("Present_Position")
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
# and the dataset feature names recorded during data collection.
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
# Per-side direction flip is applied based on the configured `side`.
action: dict[str, Any] = {}
for motor, val in positions.items():
for motor, val in right_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
target = JOINT_REMAP.get(motor, motor)
if motor == "gripper":
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
else:
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def enable_torque(self) -> None:
self.bus.enable_torque()
"""Enable torque on both arms for position control."""
self.bus_right.enable_torque()
self.bus_left.enable_torque()
def disable_torque(self) -> None:
self.bus.disable_torque()
"""Disable torque on both arms for free movement."""
self.bus_right.disable_torque()
self.bus_left.disable_torque()
def write_goal_positions(self, positions: dict[str, float]) -> None:
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
goals: dict[str, float] = {}
right_goals: dict[str, float] = {}
left_goals: dict[str, float] = {}
for key, val in positions.items():
if not key.endswith(".pos"):
continue
base = key.removesuffix(".pos")
# JOINT_REMAP is symmetric (its own inverse).
target = JOINT_REMAP.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
goals[target] = -val if target in self._motors_to_flip else val
motor_name = key.removesuffix(".pos")
if motor_name.startswith("right_"):
base = motor_name.removeprefix("right_")
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
# Un-flip using the ORIGINAL motor name (target = leader motor)
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
elif motor_name.startswith("left_"):
base = motor_name.removeprefix("left_")
target = JOINT_REMAP_REVERSE.get(base, base)
if base == "gripper":
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
else:
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
if goals:
self.bus.sync_write("Goal_Position", goals)
if right_goals:
self.bus_right.sync_write("Goal_Position", right_goals)
if left_goals:
self.bus_left.sync_write("Goal_Position", left_goals)
@check_if_not_connected
def send_feedback(self, feedback: dict[str, float]) -> None:
@@ -265,5 +354,6 @@ class OpenArmMini(Teleoperator):
@check_if_not_connected
def disconnect(self) -> None:
self.bus.disconnect()
self.bus_right.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.")
+2 -6
View File
@@ -99,18 +99,14 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
elif config.type == "bi_openarm_mini":
from .bi_openarm_mini import BiOpenArmMini
return BiOpenArmMini(config)
elif config.type == "rebot_102_leader":
from .rebot_102_leader import RebotArm102Leader
return RebotArm102Leader(config)
elif config.type == "bi_rebot_102_leader":
from .bi_rebot_102_leader import BiRebot102Leader
from .bi_rebot_102_leader import BiRebotArm102Leader
return BiRebot102Leader(config)
return BiRebotArm102Leader(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))
-63
View File
@@ -1,63 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
class BimanualMixin:
"""Lifecycle delegation for bimanual robots and teleoperators.
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
their own ``__init__``. They retain ownership of feature dicts and the
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
``send_feedback``), which vary per-embodiment.
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
take precedence in the MRO::
class BiFooFollower(BimanualMixin, Robot): ...
"""
left_arm: Any
right_arm: Any
@property
def is_connected(self) -> bool:
return self.left_arm.is_connected and self.right_arm.is_connected
@property
def is_calibrated(self) -> bool:
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
self.left_arm.connect(calibrate)
self.right_arm.connect(calibrate)
def calibrate(self) -> None:
self.left_arm.calibrate()
self.right_arm.calibrate()
def configure(self) -> None:
self.left_arm.configure()
self.right_arm.configure()
@check_if_not_connected
def disconnect(self) -> None:
self.left_arm.disconnect()
self.right_arm.disconnect()
+8 -2
View File
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
This function uses `importlib.metadata` to find packages installed in the environment
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
"""
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
prefixes = (
"lerobot_robot_",
"lerobot_camera_",
"lerobot_teleoperator_",
"lerobot_policy_",
"lerobot_env_",
)
imported: list[str] = []
failed: list[str] = []
@@ -1,121 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
import json
import struct
import numpy as np
import pytest
from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch
from lerobot.datasets.mp4 import (
_box,
_co64,
_dinf,
_hdlr,
_mdhd,
_mvhd,
_stco,
_stsc_one_sample_per_chunk,
_stss,
_stsz,
_stts,
_tkhd,
_vmhd,
parse_mp4_index,
synthesize_mp4,
)
def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes:
ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
sizes = np.array([10, 10, 10], dtype=np.int64)
durations = np.array([1000, 1000, 1000], dtype=np.int64)
stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8
offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets)
stbl = _box(
b"stbl",
_box(b"stsd", stsd_body)
+ _stts(durations)
+ _stsc_one_sample_per_chunk(len(sizes))
+ _stsz(sizes)
+ offsets
+ _stss(np.array([1], dtype=np.int64)),
)
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf)
trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia)
moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak)
mdat_payload_start = 10_000
free_size = mdat_payload_start - 8 - len(ftyp) - len(moov)
assert free_size >= 8
free = _box(b"free", b"\0" * (free_size - 8))
return ftyp + moov + free + _box(b"mdat", b"x" * 128)
def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks():
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
assert sample_slice.byte_offset == 10_000
assert sample_slice.byte_length == 60
assert sample_slice.sample_lo == 0
assert sample_slice.sample_hi == 2
def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets():
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length)
mini_index = parse_mp4_index("mini.mp4", mini)
expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset
np.testing.assert_array_equal(mini_index.sample_offsets, expected)
np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10]))
def test_parser_accepts_co64_chunk_offsets():
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True))
np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025]))
def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch):
class FakeDist:
def read_text(self, name):
assert name == "direct_url.json"
return json.dumps(
{
"url": "https://github.com/huggingface/huggingface_hub.git",
"vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"},
}
)
monkeypatch.setattr(
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
)
assert_hf_hub_range_cache_branch()
def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch):
class FakeDist:
def read_text(self, name):
assert name == "direct_url.json"
return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"})
monkeypatch.setattr(
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
)
with pytest.raises(AssertionError):
assert_hf_hub_range_cache_branch()
+3 -21
View File
@@ -2370,32 +2370,14 @@ def test_aggregate_images_when_use_videos_false():
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=False, # images kept, stored as "image" dtype
use_videos=False, # expect "image" dtype
patterns=None,
)
key = f"{OBS_IMAGES}.back"
key_front = f"{OBS_IMAGES}.front"
assert key in out
assert key_front in out
assert out[key]["dtype"] == "image"
assert out[key_front]["dtype"] == "image"
assert out[key]["shape"] == initial["back"]
def test_aggregate_images_excluded():
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
initial = {"back": (480, 640, 3)}
out = aggregate_pipeline_dataset_features(
pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
exclude_images=True,
patterns=None,
)
assert f"{OBS_IMAGES}.back" not in out
assert f"{OBS_IMAGES}.front" not in out
assert key not in out
assert key_front not in out
def test_aggregate_images_when_use_videos_true():
+3 -3
View File
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
from lerobot.teleoperators.rebot_102_leader import (
RebotArm102Leader,
RebotArm102LeaderConfig,
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebot102LeaderConfig(
cfg = BiRebotArm102LeaderConfig(
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
)
teleop = BiRebot102Leader(cfg)
teleop = BiRebotArm102Leader(cfg)
assert any(k.startswith("left_") for k in teleop.action_features)
assert any(k.startswith("right_") for k in teleop.action_features)
assert "left_gripper.pos" in teleop.action_features
Generated
+909 -950
View File
File diff suppressed because it is too large Load Diff