mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5bfb749a9b | |||
| 51c023a7a1 | |||
| 51ea18cb7a | |||
| 04ab43b8d2 | |||
| cdfe192491 | |||
| 3451e53452 | |||
| 30849ce74f | |||
| 7d6907c444 | |||
| d99e1fe89d | |||
| 7fcde61b69 | |||
| bdfe8f8ce9 | |||
| 34d0495d03 | |||
| 834c282631 | |||
| f132885cbc | |||
| d0686be2f5 | |||
| 38327fdc84 | |||
| 9555efc02c | |||
| d576c59afb |
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
|||||||
|
|
||||||
**Compatible teleoperators:**
|
**Compatible teleoperators:**
|
||||||
|
|
||||||
- `openarm_mini` - OpenArm Mini
|
- `bi_openarm_mini` - Bimanual OpenArm Mini
|
||||||
- `so_leader` - SO100 / SO101 leader arm
|
- `so_leader` - SO100 / SO101 leader arm
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
|
||||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
|
|||||||
--robot.right_arm_config.port=can0 \
|
--robot.right_arm_config.port=can0 \
|
||||||
--robot.right_arm_config.side=right \
|
--robot.right_arm_config.side=right \
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||||
--teleop.type=openarm_mini \
|
--teleop.type=bi_openarm_mini \
|
||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
|
|||||||
--robot.right_arm_config.port=can0 \
|
--robot.right_arm_config.port=can0 \
|
||||||
--robot.right_arm_config.side=right \
|
--robot.right_arm_config.side=right \
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||||
--teleop.type=openarm_mini \
|
--teleop.type=bi_openarm_mini \
|
||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ lerobot-rollout \
|
|||||||
--strategy.num_episodes=20 \
|
--strategy.num_episodes=20 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--teleop.type=openarm_mini \
|
--teleop.type=bi_openarm_mini \
|
||||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||||
--dataset.single_task="Fold the T-shirt"
|
--dataset.single_task="Fold the T-shirt"
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -355,6 +355,8 @@ explicit = true
|
|||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
|
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "feat/hffs-cache-cdn-range-reads" }
|
||||||
|
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||||
@@ -421,6 +423,7 @@ exclude_dirs = [
|
|||||||
skips = ["B101", "B311", "B404", "B603", "B615"]
|
skips = ["B101", "B311", "B404", "B603", "B615"]
|
||||||
|
|
||||||
[tool.typos]
|
[tool.typos]
|
||||||
|
default.extend-words = { trak = "trak" }
|
||||||
default.extend-ignore-re = [
|
default.extend-ignore-re = [
|
||||||
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
|
||||||
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
|
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
|
||||||
|
|||||||
@@ -0,0 +1,893 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import resource
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.compute as pc
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
|
from lerobot.datasets.episode_video_streaming import (
|
||||||
|
EpisodeByteCache,
|
||||||
|
EpisodeVideoManifest,
|
||||||
|
NativeHTTPRangeFetcher,
|
||||||
|
assert_hf_hub_range_cache_branch,
|
||||||
|
)
|
||||||
|
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
|
||||||
|
|
||||||
|
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||||
|
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||||
|
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||||
|
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
|
||||||
|
FULL_SIDECAR_NAME = "molmoact2-full.npz"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
|
||||||
|
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||||
|
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||||
|
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||||
|
parser.add_argument(
|
||||||
|
"--strategy",
|
||||||
|
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
|
||||||
|
default="both",
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--range-backend",
|
||||||
|
choices=("fsspec", "native-http"),
|
||||||
|
default="fsspec",
|
||||||
|
help="Range reader used by indexed/full episode-pool fetch tracks.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-episodes", type=int, default=512)
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-episodes",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Limit manifest construction to the first N episodes for local smoke tests.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--pool-size", type=int, default=16)
|
||||||
|
parser.add_argument("--workers", type=int, default=8)
|
||||||
|
parser.add_argument(
|
||||||
|
"--native-http-connections",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Max HTTP connections for --range-backend native-http. Defaults to --workers.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--native-http-retries",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Retries per native HTTP range request.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--native-http-timeout",
|
||||||
|
type=float,
|
||||||
|
default=120.0,
|
||||||
|
help="Timeout in seconds for native HTTP requests.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--progress-interval",
|
||||||
|
type=float,
|
||||||
|
default=10.0,
|
||||||
|
help="Print episode-pool fill progress every N seconds. Set 0 to disable.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-decode",
|
||||||
|
action="store_true",
|
||||||
|
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--decode-workers", type=int, default=1)
|
||||||
|
parser.add_argument("--prefetch-ahead", type=int, default=8)
|
||||||
|
parser.add_argument("--frames-per-episode", type=int, default=16)
|
||||||
|
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--byte-budget-gb", type=float, default=80)
|
||||||
|
parser.add_argument(
|
||||||
|
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
|
||||||
|
)
|
||||||
|
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
|
||||||
|
rng = random.Random(seed)
|
||||||
|
upper = min(total, requested)
|
||||||
|
if pool_size > upper:
|
||||||
|
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
|
||||||
|
return rng.sample(range(upper), pool_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
|
||||||
|
rng = random.Random(seed)
|
||||||
|
out: dict[tuple[int, str], list[float]] = {}
|
||||||
|
for ep in episodes:
|
||||||
|
for camera_key in manifest.video_keys:
|
||||||
|
span = manifest.lookup(ep, camera_key)
|
||||||
|
lo = span.first_pts
|
||||||
|
hi = max(span.last_pts, lo)
|
||||||
|
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _timestamps_from_meta(
|
||||||
|
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
|
||||||
|
) -> dict[tuple[int, str], list[float]]:
|
||||||
|
rng = random.Random(seed)
|
||||||
|
out: dict[tuple[int, str], list[float]] = {}
|
||||||
|
for ep in episodes:
|
||||||
|
row = meta.episodes[ep]
|
||||||
|
for camera_key in meta.video_keys:
|
||||||
|
lo = float(row[f"videos/{camera_key}/from_timestamp"])
|
||||||
|
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
|
||||||
|
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||||
|
total = 0
|
||||||
|
for ep in episodes:
|
||||||
|
for camera_key in manifest.video_keys:
|
||||||
|
total += manifest.lookup(ep, camera_key).mdat_length
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_all(
|
||||||
|
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
|
||||||
|
) -> float:
|
||||||
|
start = time.perf_counter()
|
||||||
|
items = list(timestamps.items())
|
||||||
|
if decode_workers <= 1:
|
||||||
|
for (ep, camera_key), ts in items:
|
||||||
|
cache.get_frames(ep, camera_key, ts)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||||
|
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
return time.perf_counter() - start
|
||||||
|
|
||||||
|
|
||||||
|
def _fill_cache(
|
||||||
|
cache: EpisodeByteCache, episodes: Sequence[int], *, progress_interval: float = 10.0
|
||||||
|
) -> float:
|
||||||
|
start = time.perf_counter()
|
||||||
|
for ep in episodes:
|
||||||
|
cache.submit_prefetch(ep)
|
||||||
|
last_progress = start
|
||||||
|
for idx, ep in enumerate(episodes, start=1):
|
||||||
|
cache.ensure_ready(ep)
|
||||||
|
now = time.perf_counter()
|
||||||
|
if progress_interval > 0 and now - last_progress >= progress_interval:
|
||||||
|
timings = cache.timing_summary()
|
||||||
|
byte_count = timings.get("range_bytes", 0.0)
|
||||||
|
elapsed = max(now - start, 1e-9)
|
||||||
|
jobs = timings.get("jobs", 0.0)
|
||||||
|
total_jobs = len(episodes) * len(cache.manifest.video_keys)
|
||||||
|
_log(
|
||||||
|
"fill_progress: "
|
||||||
|
f"episodes_ready={idx}/{len(episodes)} "
|
||||||
|
f"camera_jobs={jobs:.0f}/{total_jobs} "
|
||||||
|
f"fetched={byte_count / 1024**3:.2f} GiB "
|
||||||
|
f"fetch={byte_count / elapsed / 1024**2:.1f} MiB/s "
|
||||||
|
f"elapsed={_format_duration(elapsed)}"
|
||||||
|
)
|
||||||
|
last_progress = now
|
||||||
|
return time.perf_counter() - start
|
||||||
|
|
||||||
|
|
||||||
|
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
|
||||||
|
if elapsed_s <= 0:
|
||||||
|
return float("inf")
|
||||||
|
return len(episodes) * frames_per_episode / elapsed_s
|
||||||
|
|
||||||
|
|
||||||
|
def _log(message: str) -> None:
|
||||||
|
print(message, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_duration(seconds: float) -> str:
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds:.1f}s"
|
||||||
|
if seconds < 3600:
|
||||||
|
return f"{seconds / 60:.1f}m"
|
||||||
|
return f"{seconds / 3600:.1f}h"
|
||||||
|
|
||||||
|
|
||||||
|
def _current_rss_mib() -> float | None:
|
||||||
|
status_path = Path("/proc/self/status")
|
||||||
|
if not status_path.exists():
|
||||||
|
return None
|
||||||
|
for line in status_path.read_text().splitlines():
|
||||||
|
if line.startswith("VmRSS:"):
|
||||||
|
return float(line.split()[1]) / 1024
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _peak_rss_mib() -> float:
|
||||||
|
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||||
|
# Linux reports KiB; macOS reports bytes.
|
||||||
|
if rss > 10**8:
|
||||||
|
return rss / 1024**2
|
||||||
|
return rss / 1024
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_snapshot() -> dict[str, float | None]:
|
||||||
|
return {"rss_mib": _current_rss_mib(), "peak_rss_mib": _peak_rss_mib()}
|
||||||
|
|
||||||
|
|
||||||
|
def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | None]) -> None:
|
||||||
|
start_rss = start["rss_mib"]
|
||||||
|
end_rss = end["rss_mib"]
|
||||||
|
delta = None if start_rss is None or end_rss is None else end_rss - start_rss
|
||||||
|
print()
|
||||||
|
print("| Memory | MiB |")
|
||||||
|
print("|---|---:|")
|
||||||
|
if start_rss is not None:
|
||||||
|
print(f"| rss start | {start_rss:.1f} |")
|
||||||
|
if end_rss is not None:
|
||||||
|
print(f"| rss end | {end_rss:.1f} |")
|
||||||
|
if delta is not None:
|
||||||
|
print(f"| rss delta | {delta:.1f} |")
|
||||||
|
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
|
||||||
|
|
||||||
|
|
||||||
|
def _root_join(data_root: str, relative_path: str) -> str:
|
||||||
|
if data_root.startswith("hf://"):
|
||||||
|
return f"{data_root.rstrip('/')}/{relative_path}"
|
||||||
|
return str(Path(data_root) / relative_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
|
||||||
|
_ = manifest_episode_count
|
||||||
|
local = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||||
|
if _valid_sidecar(local):
|
||||||
|
return local
|
||||||
|
if local.exists():
|
||||||
|
print(f"mp4_sidecar_invalid_local: {local}")
|
||||||
|
local.unlink()
|
||||||
|
remote_relative = f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}"
|
||||||
|
remote = _root_join(data_root, remote_relative)
|
||||||
|
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||||
|
fs = fsspec.filesystem(protocol)
|
||||||
|
if not fs.exists(remote):
|
||||||
|
return None
|
||||||
|
local.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"downloading_mp4_sidecar: {remote} -> {local}")
|
||||||
|
if data_root.startswith("hf://"):
|
||||||
|
_download_sidecar_native_http(data_root, remote_relative, local)
|
||||||
|
else:
|
||||||
|
fs.get(remote, str(local))
|
||||||
|
return local
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_sidecar(path: Path) -> bool:
|
||||||
|
if not path.exists():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
with np.load(path, allow_pickle=False) as data:
|
||||||
|
return "manifest_json" in data
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
|
||||||
|
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
|
||||||
|
tmp = local.with_suffix(local.suffix + ".tmp")
|
||||||
|
try:
|
||||||
|
size = fetcher.info_size(relative_path)
|
||||||
|
chunk_size = 16 * 1024 * 1024
|
||||||
|
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
|
||||||
|
with tmp.open("wb") as out_file:
|
||||||
|
out_file.truncate(size)
|
||||||
|
|
||||||
|
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
|
||||||
|
offset, length = offset_length
|
||||||
|
return offset, fetcher.read_range(relative_path, offset, length)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
done = 0
|
||||||
|
with ThreadPoolExecutor(max_workers=8) as pool:
|
||||||
|
futures = [pool.submit(read_chunk, item) for item in ranges]
|
||||||
|
with tmp.open("r+b") as rw_file:
|
||||||
|
for future in futures:
|
||||||
|
offset, data = future.result()
|
||||||
|
rw_file.seek(offset)
|
||||||
|
rw_file.write(data)
|
||||||
|
done += len(data)
|
||||||
|
elapsed = max(time.perf_counter() - start, 1e-9)
|
||||||
|
print(
|
||||||
|
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
|
||||||
|
f"({done / elapsed / 1024**2:.1f} MiB/s)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
tmp.replace(local)
|
||||||
|
finally:
|
||||||
|
fetcher.close()
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeParquetReader:
|
||||||
|
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
|
||||||
|
self.meta = meta
|
||||||
|
self.data_root = data_root
|
||||||
|
protocol = "hf" if data_root.startswith("hf://") else "file"
|
||||||
|
self.fs = fsspec.filesystem(protocol)
|
||||||
|
self._episode_row_groups = self._build_episode_row_groups()
|
||||||
|
self._table_cache: dict[str, pa.Table] = {}
|
||||||
|
self._cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
def read_episode(self, episode_index: int) -> None:
|
||||||
|
relative_path = str(self.meta.get_data_file_path(episode_index))
|
||||||
|
table = self._read_table(relative_path)
|
||||||
|
table.filter(pc.equal(table["episode_index"], episode_index))
|
||||||
|
|
||||||
|
def _read_table(self, relative_path: str) -> pa.Table:
|
||||||
|
with self._cache_lock:
|
||||||
|
table = self._table_cache.get(relative_path)
|
||||||
|
if table is not None:
|
||||||
|
return table
|
||||||
|
with self.fs.open(
|
||||||
|
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
|
||||||
|
) as f:
|
||||||
|
table = pq.ParquetFile(f).read()
|
||||||
|
with self._cache_lock:
|
||||||
|
return self._table_cache.setdefault(relative_path, table)
|
||||||
|
|
||||||
|
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
|
||||||
|
return pool.submit(self.read_episode, episode_index)
|
||||||
|
|
||||||
|
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
|
||||||
|
start = time.perf_counter()
|
||||||
|
if workers <= 1:
|
||||||
|
for ep in episodes:
|
||||||
|
self.read_episode(ep)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||||
|
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
return time.perf_counter() - start
|
||||||
|
|
||||||
|
def _build_episode_row_groups(self) -> dict[int, int]:
|
||||||
|
counts: dict[tuple[int, int], int] = {}
|
||||||
|
row_groups = {}
|
||||||
|
for ep_idx in range(int(self.meta.total_episodes)):
|
||||||
|
ep = self.meta.episodes[ep_idx]
|
||||||
|
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
|
||||||
|
row_groups[ep_idx] = counts.get(key, 0)
|
||||||
|
counts[key] = row_groups[ep_idx] + 1
|
||||||
|
return row_groups
|
||||||
|
|
||||||
|
|
||||||
|
def run_fetch_pool(
|
||||||
|
manifest: EpisodeVideoManifest,
|
||||||
|
data_root: str,
|
||||||
|
episodes: Sequence[int],
|
||||||
|
byte_budget: int,
|
||||||
|
workers: int,
|
||||||
|
range_backend: str,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
with EpisodeByteCache(
|
||||||
|
manifest,
|
||||||
|
data_root,
|
||||||
|
byte_budget=byte_budget,
|
||||||
|
workers=workers,
|
||||||
|
range_backend=range_backend,
|
||||||
|
native_http_connections=args.native_http_connections,
|
||||||
|
native_http_timeout=args.native_http_timeout,
|
||||||
|
native_http_retries=args.native_http_retries,
|
||||||
|
open_decoders=False,
|
||||||
|
) as cache:
|
||||||
|
elapsed = _fill_cache(cache, episodes, progress_interval=args.progress_interval)
|
||||||
|
timings = cache.timing_summary()
|
||||||
|
byte_count = _bytes_for(manifest, episodes)
|
||||||
|
episode_mb = byte_count / len(episodes) / 1024**2
|
||||||
|
job_count = max(timings["jobs"], 1.0)
|
||||||
|
result = {
|
||||||
|
"fetch_s": elapsed,
|
||||||
|
"fetch_mbps": byte_count / elapsed / 1024**2,
|
||||||
|
"fetch_episodes_s": len(episodes) / elapsed,
|
||||||
|
"episode_mb": episode_mb,
|
||||||
|
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
|
||||||
|
"jobs": timings["jobs"],
|
||||||
|
"lookup_ms": timings["lookup_s"] * 1000 / job_count,
|
||||||
|
"range_fetch_ms": timings["fetch_s"] * 1000 / job_count,
|
||||||
|
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
|
||||||
|
"store_ms": timings["store_s"] * 1000 / job_count,
|
||||||
|
}
|
||||||
|
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def run_parallel(
|
||||||
|
manifest: EpisodeVideoManifest,
|
||||||
|
data_root: str,
|
||||||
|
episodes: Sequence[int],
|
||||||
|
timestamps: dict[tuple[int, str], list[float]],
|
||||||
|
byte_budget: int,
|
||||||
|
workers: int,
|
||||||
|
decode_workers: int,
|
||||||
|
frames_per_episode: int,
|
||||||
|
parquet_reader: EpisodeParquetReader,
|
||||||
|
range_backend: str,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
with EpisodeByteCache(
|
||||||
|
manifest,
|
||||||
|
data_root,
|
||||||
|
byte_budget=byte_budget,
|
||||||
|
workers=workers,
|
||||||
|
range_backend=range_backend,
|
||||||
|
open_decoders=False,
|
||||||
|
) as cache:
|
||||||
|
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
|
||||||
|
fetch_s = _fill_cache(cache, episodes)
|
||||||
|
decoder_start = time.perf_counter()
|
||||||
|
for ep in episodes:
|
||||||
|
for camera_key in manifest.video_keys:
|
||||||
|
cache.get_decoder(ep, camera_key)
|
||||||
|
decoder_s = time.perf_counter() - decoder_start
|
||||||
|
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
|
||||||
|
byte_count = _bytes_for(manifest, episodes)
|
||||||
|
return {
|
||||||
|
"fetch_s": fetch_s,
|
||||||
|
"fetch_mbps": byte_count / fetch_s / 1024**2,
|
||||||
|
"fetch_episodes_s": len(episodes) / fetch_s,
|
||||||
|
"parquet_s": parquet_s,
|
||||||
|
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
|
||||||
|
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_overlapped(
|
||||||
|
manifest: EpisodeVideoManifest,
|
||||||
|
data_root: str,
|
||||||
|
episodes: Sequence[int],
|
||||||
|
timestamps: dict[tuple[int, str], list[float]],
|
||||||
|
byte_budget: int,
|
||||||
|
workers: int,
|
||||||
|
decode_workers: int,
|
||||||
|
frames_per_episode: int,
|
||||||
|
prefetch_ahead: int,
|
||||||
|
parquet_reader: EpisodeParquetReader,
|
||||||
|
range_backend: str,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
with EpisodeByteCache(
|
||||||
|
manifest,
|
||||||
|
data_root,
|
||||||
|
byte_budget=byte_budget,
|
||||||
|
workers=workers,
|
||||||
|
range_backend=range_backend,
|
||||||
|
open_decoders=True,
|
||||||
|
) as cache:
|
||||||
|
start = time.perf_counter()
|
||||||
|
video_wait_decode_s = 0.0
|
||||||
|
parquet_wait_s = 0.0
|
||||||
|
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
|
||||||
|
parquet_futures = {
|
||||||
|
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
|
||||||
|
}
|
||||||
|
for ep in episodes[:prefetch_ahead]:
|
||||||
|
cache.submit_prefetch(ep)
|
||||||
|
try:
|
||||||
|
for idx, ep in enumerate(episodes):
|
||||||
|
next_idx = idx + prefetch_ahead
|
||||||
|
if next_idx < len(episodes):
|
||||||
|
next_ep = episodes[next_idx]
|
||||||
|
cache.submit_prefetch(next_ep)
|
||||||
|
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
|
||||||
|
|
||||||
|
parquet_start = time.perf_counter()
|
||||||
|
parquet_futures.pop(ep).result()
|
||||||
|
parquet_wait_s += time.perf_counter() - parquet_start
|
||||||
|
|
||||||
|
video_start = time.perf_counter()
|
||||||
|
cache.ensure_ready(ep)
|
||||||
|
if decode_workers <= 1:
|
||||||
|
for camera_key in manifest.video_keys:
|
||||||
|
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
|
||||||
|
for camera_key in manifest.video_keys
|
||||||
|
]
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
video_wait_decode_s += time.perf_counter() - video_start
|
||||||
|
finally:
|
||||||
|
parquet_pool.shutdown(wait=True)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
return {
|
||||||
|
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
|
||||||
|
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
|
||||||
|
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
|
||||||
|
"wall_s": elapsed,
|
||||||
|
"video_wait_decode_s": video_wait_decode_s,
|
||||||
|
"parquet_wait_s": parquet_wait_s,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_remote_decoder_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def _remote_decoder_cache() -> VideoDecoderCache:
|
||||||
|
cache = getattr(_remote_decoder_local, "cache", None)
|
||||||
|
if cache is None:
|
||||||
|
cache = VideoDecoderCache(max_size=None)
|
||||||
|
_remote_decoder_local.cache = cache
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_remote_source(
|
||||||
|
meta: LeRobotDatasetMetadata,
|
||||||
|
data_root: str,
|
||||||
|
episode_index: int,
|
||||||
|
camera_key: str,
|
||||||
|
timestamps: list[float],
|
||||||
|
):
|
||||||
|
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
|
||||||
|
return decode_video_frames_torchcodec(
|
||||||
|
video_path,
|
||||||
|
timestamps,
|
||||||
|
tolerance_s=1.0 / float(meta.fps),
|
||||||
|
decoder_cache=_remote_decoder_cache(),
|
||||||
|
return_uint8=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_remote_decoder(
|
||||||
|
meta: LeRobotDatasetMetadata,
|
||||||
|
data_root: str,
|
||||||
|
episodes: Sequence[int],
|
||||||
|
timestamps: dict[tuple[int, str], list[float]],
|
||||||
|
*,
|
||||||
|
frames_per_episode: int,
|
||||||
|
decode_workers: int,
|
||||||
|
parquet_reader: EpisodeParquetReader,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
items = [
|
||||||
|
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
|
||||||
|
]
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
for ep, camera_key, ts in items:
|
||||||
|
if camera_key == meta.video_keys[0]:
|
||||||
|
parquet_reader.read_episode(ep)
|
||||||
|
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||||
|
sequential_s = time.perf_counter() - start
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
if decode_workers <= 1:
|
||||||
|
for ep, camera_key, ts in items:
|
||||||
|
if camera_key == meta.video_keys[0]:
|
||||||
|
parquet_reader.read_episode(ep)
|
||||||
|
_decode_remote_source(meta, data_root, ep, camera_key, ts)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
|
||||||
|
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
|
||||||
|
futures = [
|
||||||
|
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
|
||||||
|
for ep, camera_key, ts in items
|
||||||
|
]
|
||||||
|
for future in parquet_futures:
|
||||||
|
future.result()
|
||||||
|
for future in futures:
|
||||||
|
future.result()
|
||||||
|
parallel_s = time.perf_counter() - start
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
|
||||||
|
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
|
||||||
|
range_jobs = fetch_pool.get("range_jobs", 0.0)
|
||||||
|
if range_jobs <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("| Range Read Stage | avg ms/range |")
|
||||||
|
print("|---|---:|")
|
||||||
|
for key, label in (
|
||||||
|
("range_open_s", "fsspec handle open/lookup"),
|
||||||
|
("range_seek_s", "fsspec seek"),
|
||||||
|
("range_read_s", "fsspec read"),
|
||||||
|
("range_resolve_s", "http URL resolve"),
|
||||||
|
("range_header_s", "http response headers"),
|
||||||
|
("range_first_byte_s", "http first body byte"),
|
||||||
|
("range_body_s", "http body drain"),
|
||||||
|
("range_retry_sleep_s", "http retry sleep"),
|
||||||
|
):
|
||||||
|
value = fetch_pool.get(key)
|
||||||
|
if value is not None:
|
||||||
|
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
|
||||||
|
if "range_retry_attempts" in fetch_pool:
|
||||||
|
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
|
||||||
|
if fetch_pool.get("range_failed_requests"):
|
||||||
|
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
|
||||||
|
status_counts = {
|
||||||
|
key.removeprefix("range_status_"): value
|
||||||
|
for key, value in fetch_pool.items()
|
||||||
|
if key.startswith("range_status_")
|
||||||
|
}
|
||||||
|
if status_counts:
|
||||||
|
summary = ", ".join(f"{status}={count:.0f}" for status, count in sorted(status_counts.items()))
|
||||||
|
print(f"| http status counts | {summary} |")
|
||||||
|
print(f"| range reads | {range_jobs:.0f} |")
|
||||||
|
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
|
||||||
|
|
||||||
|
|
||||||
|
def run_indexed_strategy(
|
||||||
|
meta: LeRobotDatasetMetadata,
|
||||||
|
data_root: str,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
parquet_reader: EpisodeParquetReader,
|
||||||
|
*,
|
||||||
|
range_backend: str = "fsspec",
|
||||||
|
label: str = "indexed",
|
||||||
|
sidecar_path: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
_log(f"starting_strategy: {label}")
|
||||||
|
memory_start = _memory_snapshot()
|
||||||
|
manifest_start = time.perf_counter()
|
||||||
|
dataset_episode_count = int(meta.total_episodes)
|
||||||
|
manifest_episode_count = args.manifest_episodes or dataset_episode_count
|
||||||
|
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
|
||||||
|
manifest = EpisodeVideoManifest.build(
|
||||||
|
meta,
|
||||||
|
data_root,
|
||||||
|
episode_indices=range(manifest_episode_count),
|
||||||
|
range_backend=range_backend,
|
||||||
|
workers=args.workers,
|
||||||
|
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||||
|
sidecar_path=sidecar_path,
|
||||||
|
)
|
||||||
|
manifest_s = time.perf_counter() - manifest_start
|
||||||
|
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
|
||||||
|
|
||||||
|
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
|
||||||
|
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
|
||||||
|
byte_budget = int(args.byte_budget_gb * 1024**3)
|
||||||
|
byte_count = _bytes_for(manifest, episodes)
|
||||||
|
_log(
|
||||||
|
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
|
||||||
|
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
|
||||||
|
)
|
||||||
|
|
||||||
|
_log(f"{label}: filling episode byte cache with {args.workers} workers")
|
||||||
|
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
|
||||||
|
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
|
||||||
|
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
|
||||||
|
|
||||||
|
print(f"manifest_build_s: {manifest_s:.2f}")
|
||||||
|
print(f"strategy: {label}")
|
||||||
|
print(f"range_backend: {range_backend}")
|
||||||
|
print(f"mp4_sidecar: {sidecar_path or 'none'}")
|
||||||
|
print(f"data_root: {data_root}")
|
||||||
|
print(f"dataset_episodes: {dataset_episode_count}")
|
||||||
|
print(f"benchmark_episodes: {benchmark_episode_count}")
|
||||||
|
print(f"pool_episodes: {len(episodes)}")
|
||||||
|
print(f"sampled_episodes: {episodes}")
|
||||||
|
print(f"cameras: {manifest.video_keys}")
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
"| Track | fetch MB/s | fetch eps/s | wall s | est benchmark | est full dataset | avg MB/camera | notes |"
|
||||||
|
)
|
||||||
|
print("|---|---:|---:|---:|---:|---:|---:|---|")
|
||||||
|
print(
|
||||||
|
f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | "
|
||||||
|
f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | "
|
||||||
|
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
|
||||||
|
f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
print("| Camera Job Stage | avg ms/job |")
|
||||||
|
print("|---|---:|")
|
||||||
|
print(f"| manifest lookup | {fetch_pool['lookup_ms']:.3f} |")
|
||||||
|
print(f"| remote byte-range fetch | {fetch_pool['range_fetch_ms']:.3f} |")
|
||||||
|
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
|
||||||
|
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
|
||||||
|
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
|
||||||
|
_print_range_timing_summary(fetch_pool)
|
||||||
|
_print_memory_summary(memory_start, _memory_snapshot())
|
||||||
|
|
||||||
|
if args.include_decode:
|
||||||
|
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
|
||||||
|
_log(f"{label}: running parallel video fetch + decode-only")
|
||||||
|
parallel = run_parallel(
|
||||||
|
manifest,
|
||||||
|
data_root,
|
||||||
|
episodes,
|
||||||
|
timestamps,
|
||||||
|
byte_budget,
|
||||||
|
args.workers,
|
||||||
|
args.decode_workers,
|
||||||
|
args.frames_per_episode,
|
||||||
|
parquet_reader,
|
||||||
|
range_backend,
|
||||||
|
)
|
||||||
|
_log(f"{label}: running overlapped end-to-end")
|
||||||
|
overlapped = run_overlapped(
|
||||||
|
manifest,
|
||||||
|
data_root,
|
||||||
|
episodes,
|
||||||
|
timestamps,
|
||||||
|
byte_budget,
|
||||||
|
args.workers,
|
||||||
|
args.decode_workers,
|
||||||
|
args.frames_per_episode,
|
||||||
|
args.prefetch_ahead,
|
||||||
|
parquet_reader,
|
||||||
|
range_backend,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
|
||||||
|
f"{parallel['fetch_s']:.2f} | "
|
||||||
|
f"{_format_duration(benchmark_episode_count / parallel['fetch_episodes_s'])} | "
|
||||||
|
f"{_format_duration(dataset_episode_count / parallel['fetch_episodes_s'])} | "
|
||||||
|
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||||
|
f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, "
|
||||||
|
f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | - | - | "
|
||||||
|
f"{fetch_pool['avg_mb_miss']:.1f} | "
|
||||||
|
f"{overlapped['samples_s']:.1f} samples/s; video+decode "
|
||||||
|
f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_remote_strategy(
|
||||||
|
meta: LeRobotDatasetMetadata,
|
||||||
|
data_root: str,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
parquet_reader: EpisodeParquetReader,
|
||||||
|
) -> None:
|
||||||
|
_log("starting_strategy: remote-decoder")
|
||||||
|
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
|
||||||
|
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
|
||||||
|
_log("remote-decoder: running direct source MP4 decoder")
|
||||||
|
result = run_remote_decoder(
|
||||||
|
meta,
|
||||||
|
data_root,
|
||||||
|
episodes,
|
||||||
|
timestamps,
|
||||||
|
frames_per_episode=args.frames_per_episode,
|
||||||
|
decode_workers=args.decode_workers,
|
||||||
|
parquet_reader=parquet_reader,
|
||||||
|
)
|
||||||
|
print("strategy: remote-decoder")
|
||||||
|
print(f"data_root: {data_root}")
|
||||||
|
print(f"episodes: {episodes}")
|
||||||
|
print(f"cameras: {list(meta.video_keys)}")
|
||||||
|
print()
|
||||||
|
print("| Track | samples/s | notes |")
|
||||||
|
print("|---|---:|---|")
|
||||||
|
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
|
||||||
|
print(
|
||||||
|
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
|
||||||
|
f"direct source MP4 decoder, {args.decode_workers} workers |"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
if args.strategy == "full":
|
||||||
|
args.strategy = "both"
|
||||||
|
if args.strategy == "native-http":
|
||||||
|
args.range_backend = "native-http"
|
||||||
|
data_root = args.data_root
|
||||||
|
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||||
|
assert_hf_hub_range_cache_branch()
|
||||||
|
|
||||||
|
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||||
|
meta.ensure_readable()
|
||||||
|
parquet_reader = EpisodeParquetReader(meta, data_root)
|
||||||
|
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
|
||||||
|
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
|
||||||
|
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
|
||||||
|
|
||||||
|
if sidecar_path is not None:
|
||||||
|
print(f"using_mp4_sidecar: {sidecar_path}")
|
||||||
|
|
||||||
|
if sidecar_path is not None and args.strategy == "both":
|
||||||
|
if args.include_decode:
|
||||||
|
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||||
|
print()
|
||||||
|
run_indexed_strategy(
|
||||||
|
meta,
|
||||||
|
data_root,
|
||||||
|
args,
|
||||||
|
parquet_reader,
|
||||||
|
range_backend=args.range_backend,
|
||||||
|
label=f"indexed-sidecar-{args.range_backend}",
|
||||||
|
sidecar_path=str(sidecar_path),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if sidecar_path is not None and args.strategy == "indexed":
|
||||||
|
run_indexed_strategy(
|
||||||
|
meta,
|
||||||
|
data_root,
|
||||||
|
args,
|
||||||
|
parquet_reader,
|
||||||
|
range_backend=args.range_backend,
|
||||||
|
label=f"indexed-sidecar-{args.range_backend}",
|
||||||
|
sidecar_path=str(sidecar_path),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if sidecar_path is not None and args.strategy == "native-http":
|
||||||
|
run_indexed_strategy(
|
||||||
|
meta,
|
||||||
|
data_root,
|
||||||
|
args,
|
||||||
|
parquet_reader,
|
||||||
|
range_backend="native-http",
|
||||||
|
label="indexed-sidecar-native-http",
|
||||||
|
sidecar_path=str(sidecar_path),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if args.strategy == "both":
|
||||||
|
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
|
||||||
|
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
|
||||||
|
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
|
||||||
|
print(f"mp4_sidecar_missing_remote: {expected_remote}")
|
||||||
|
print(
|
||||||
|
"build_mp4_sidecar: "
|
||||||
|
"uv run --no-sync python scripts/build_mp4_sidecar.py "
|
||||||
|
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
|
||||||
|
)
|
||||||
|
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if args.strategy in ("both", "indexed"):
|
||||||
|
run_indexed_strategy(
|
||||||
|
meta,
|
||||||
|
data_root,
|
||||||
|
args,
|
||||||
|
parquet_reader,
|
||||||
|
range_backend="fsspec",
|
||||||
|
label="indexed",
|
||||||
|
sidecar_path=None,
|
||||||
|
)
|
||||||
|
if args.strategy == "both":
|
||||||
|
print()
|
||||||
|
if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode):
|
||||||
|
run_remote_strategy(meta, data_root, args, parquet_reader)
|
||||||
|
if args.strategy == "both" and args.include_decode:
|
||||||
|
print()
|
||||||
|
if args.strategy in ("both", "native-http"):
|
||||||
|
run_indexed_strategy(
|
||||||
|
meta,
|
||||||
|
data_root,
|
||||||
|
args,
|
||||||
|
parquet_reader,
|
||||||
|
range_backend="native-http",
|
||||||
|
label="indexed-native-http",
|
||||||
|
sidecar_path=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
|
||||||
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
|
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
|
||||||
|
|
||||||
|
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
|
||||||
|
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
|
||||||
|
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
|
||||||
|
parser.add_argument("--repo-id", default=DEFAULT_REPO)
|
||||||
|
parser.add_argument("--revision", default=DEFAULT_REVISION)
|
||||||
|
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
|
||||||
|
parser.add_argument("--output", required=True)
|
||||||
|
parser.add_argument("--episodes", type=int, default=None)
|
||||||
|
parser.add_argument("--workers", type=int, default=8)
|
||||||
|
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
|
||||||
|
parser.add_argument("--max-probe-mb", type=int, default=64)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
|
||||||
|
)
|
||||||
|
parser.add_argument("--no-hub-branch-assert", action="store_true")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def push_sidecar(local_path: str, data_root: str) -> list[str]:
|
||||||
|
if not data_root.startswith("hf://"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
local = Path(local_path)
|
||||||
|
fs = fsspec.filesystem("hf")
|
||||||
|
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
|
||||||
|
remote_paths = [f"{remote_dir}/{local.name}"]
|
||||||
|
|
||||||
|
for remote in remote_paths:
|
||||||
|
fs.put(str(local), remote)
|
||||||
|
return remote_paths
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
|
||||||
|
assert_hf_hub_range_cache_branch()
|
||||||
|
|
||||||
|
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
|
||||||
|
meta.ensure_readable()
|
||||||
|
total = (
|
||||||
|
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
|
||||||
|
)
|
||||||
|
rel_paths = sorted(
|
||||||
|
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
EpisodeVideoManifest.write_file_sidecar(
|
||||||
|
args.output,
|
||||||
|
rel_paths,
|
||||||
|
args.data_root,
|
||||||
|
range_backend=args.range_backend,
|
||||||
|
workers=args.workers,
|
||||||
|
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
|
||||||
|
)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
print(f"wrote {args.output}")
|
||||||
|
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
|
||||||
|
if args.no_push:
|
||||||
|
print("push_skipped: --no-push")
|
||||||
|
else:
|
||||||
|
pushed = push_sidecar(args.output, args.data_root)
|
||||||
|
for remote in pushed:
|
||||||
|
print(f"pushed {remote}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,891 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from importlib import metadata
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import quote, urljoin, urlparse
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import HfApi, HfFileSystem, constants
|
||||||
|
from huggingface_hub.utils import hf_raise_for_status
|
||||||
|
|
||||||
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
|
from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EpisodeVideoSpan:
|
||||||
|
file_id: int
|
||||||
|
mdat_offset: int
|
||||||
|
mdat_length: int
|
||||||
|
first_pts: float
|
||||||
|
last_pts: float
|
||||||
|
frame_count: int
|
||||||
|
sample_lo: int
|
||||||
|
sample_hi: int
|
||||||
|
source_start_pts: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VideoFileRecord:
|
||||||
|
file_path: str
|
||||||
|
file_size: int
|
||||||
|
mp4: Mp4Index
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadLocalRangeFetcher:
|
||||||
|
"""Range reader that gives each worker thread independent file handles."""
|
||||||
|
|
||||||
|
def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"):
|
||||||
|
self.data_root = str(data_root).rstrip("/")
|
||||||
|
protocol = "hf" if self.data_root.startswith("hf://") else "file"
|
||||||
|
self.fs = fsspec.filesystem(protocol)
|
||||||
|
self.block_size = block_size
|
||||||
|
self.cache_type = cache_type
|
||||||
|
self._local = threading.local()
|
||||||
|
self._timing_lock = threading.Lock()
|
||||||
|
self._timing_totals = {
|
||||||
|
"range_jobs": 0.0,
|
||||||
|
"range_bytes": 0.0,
|
||||||
|
"range_open_s": 0.0,
|
||||||
|
"range_seek_s": 0.0,
|
||||||
|
"range_read_s": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _url(self, relative_path: str) -> str:
|
||||||
|
if self.data_root.startswith("hf://"):
|
||||||
|
return f"{self.data_root}/{relative_path}"
|
||||||
|
return str(Path(self.data_root) / relative_path)
|
||||||
|
|
||||||
|
def _handle(self, relative_path: str):
|
||||||
|
handles = getattr(self._local, "handles", None)
|
||||||
|
if handles is None:
|
||||||
|
handles = {}
|
||||||
|
self._local.handles = handles
|
||||||
|
handle = handles.get(relative_path)
|
||||||
|
if handle is None or getattr(handle, "closed", False):
|
||||||
|
handle = self.fs.open(
|
||||||
|
self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type
|
||||||
|
)
|
||||||
|
handles[relative_path] = handle
|
||||||
|
return handle
|
||||||
|
|
||||||
|
def info_size(self, relative_path: str) -> int:
|
||||||
|
return int(self.fs.info(self._url(relative_path))["size"])
|
||||||
|
|
||||||
|
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||||
|
open_start = time.perf_counter()
|
||||||
|
handle = self._handle(relative_path)
|
||||||
|
open_s = time.perf_counter() - open_start
|
||||||
|
seek_start = time.perf_counter()
|
||||||
|
handle.seek(offset)
|
||||||
|
seek_s = time.perf_counter() - seek_start
|
||||||
|
read_start = time.perf_counter()
|
||||||
|
data = handle.read(length)
|
||||||
|
read_s = time.perf_counter() - read_start
|
||||||
|
self._record_timing(
|
||||||
|
range_jobs=1.0,
|
||||||
|
range_bytes=float(len(data)),
|
||||||
|
range_open_s=open_s,
|
||||||
|
range_seek_s=seek_s,
|
||||||
|
range_read_s=read_s,
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _record_timing(self, **kwargs: float) -> None:
|
||||||
|
with self._timing_lock:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
self._timing_totals[key] += value
|
||||||
|
|
||||||
|
def timing_summary(self) -> dict[str, float]:
|
||||||
|
with self._timing_lock:
|
||||||
|
return dict(self._timing_totals)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
handles = getattr(self._local, "handles", None)
|
||||||
|
if handles is None:
|
||||||
|
return
|
||||||
|
for handle in handles.values():
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
handle.close()
|
||||||
|
handles.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class NativeHTTPRangeFetcher:
|
||||||
|
"""Direct pooled HTTP range reader for hf:// paths."""
|
||||||
|
|
||||||
|
_GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {}
|
||||||
|
_GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {}
|
||||||
|
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
|
||||||
|
_GLOBAL_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
_RETRYABLE_EXCEPTIONS = (
|
||||||
|
httpx.ConnectError,
|
||||||
|
httpx.ConnectTimeout,
|
||||||
|
httpx.ReadError,
|
||||||
|
httpx.ReadTimeout,
|
||||||
|
httpx.RemoteProtocolError,
|
||||||
|
httpx.PoolTimeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root: str | Path,
|
||||||
|
*,
|
||||||
|
max_connections: int = 32,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
max_retries: int = 4,
|
||||||
|
):
|
||||||
|
self.data_root = str(data_root).rstrip("/")
|
||||||
|
if not self.data_root.startswith("hf://"):
|
||||||
|
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.api = HfApi()
|
||||||
|
self.fs: HfFileSystem | None = None
|
||||||
|
self._bucket_id: str | None = None
|
||||||
|
self._bucket_prefix = ""
|
||||||
|
if self.data_root.startswith("hf://buckets/"):
|
||||||
|
bucket_root = self.data_root.removeprefix("hf://buckets/")
|
||||||
|
parts = bucket_root.split("/", 2)
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError(f"Invalid bucket root: {self.data_root}")
|
||||||
|
self._bucket_id = f"{parts[0]}/{parts[1]}"
|
||||||
|
self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else ""
|
||||||
|
else:
|
||||||
|
self.fs = HfFileSystem()
|
||||||
|
self.client = httpx.Client(
|
||||||
|
timeout=timeout,
|
||||||
|
limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
self._resolved_urls: dict[str, str] = {}
|
||||||
|
self._source_urls: dict[str, str] = {}
|
||||||
|
self._sizes: dict[str, int] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._timing_lock = threading.Lock()
|
||||||
|
self._timing_totals = {
|
||||||
|
"range_jobs": 0.0,
|
||||||
|
"range_bytes": 0.0,
|
||||||
|
"range_resolve_s": 0.0,
|
||||||
|
"range_header_s": 0.0,
|
||||||
|
"range_first_byte_s": 0.0,
|
||||||
|
"range_body_s": 0.0,
|
||||||
|
"range_retry_attempts": 0.0,
|
||||||
|
"range_retry_sleep_s": 0.0,
|
||||||
|
"range_failed_requests": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return self.client.request(method, url, **kwargs)
|
||||||
|
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt >= self.max_retries:
|
||||||
|
break
|
||||||
|
time.sleep(min(0.5 * 2**attempt, 5.0))
|
||||||
|
if last_exc is None:
|
||||||
|
raise RuntimeError("HTTP request failed without an exception")
|
||||||
|
raise last_exc
|
||||||
|
|
||||||
|
def _cache_key(self, relative_path: str) -> tuple[str, str]:
|
||||||
|
return self.data_root, relative_path
|
||||||
|
|
||||||
|
def _path(self, relative_path: str) -> str:
|
||||||
|
return f"{self.data_root}/{relative_path}"
|
||||||
|
|
||||||
|
def _bucket_path(self, relative_path: str) -> str:
|
||||||
|
if self._bucket_prefix:
|
||||||
|
return f"{self._bucket_prefix}/{relative_path}"
|
||||||
|
return relative_path
|
||||||
|
|
||||||
|
def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]:
|
||||||
|
headers = self.api._build_hf_headers()
|
||||||
|
if urlparse(request_url).netloc != urlparse(source_url).netloc:
|
||||||
|
headers.pop("authorization", None)
|
||||||
|
headers.pop("Authorization", None)
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _source_url(self, relative_path: str) -> str:
|
||||||
|
with self._lock:
|
||||||
|
source = self._source_urls.get(relative_path)
|
||||||
|
if source is not None:
|
||||||
|
return source
|
||||||
|
key = self._cache_key(relative_path)
|
||||||
|
with self._GLOBAL_LOCK:
|
||||||
|
source = self._GLOBAL_SOURCE_URLS.get(key)
|
||||||
|
if source is None:
|
||||||
|
if self._bucket_id is not None:
|
||||||
|
source = (
|
||||||
|
f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/"
|
||||||
|
f"{quote(self._bucket_path(relative_path))}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.fs is None:
|
||||||
|
raise RuntimeError("HfFileSystem fallback was not initialized")
|
||||||
|
source = self.fs.url(self._path(relative_path))
|
||||||
|
with self._GLOBAL_LOCK:
|
||||||
|
self._GLOBAL_SOURCE_URLS[key] = source
|
||||||
|
with self._lock:
|
||||||
|
self._source_urls[relative_path] = source
|
||||||
|
return source
|
||||||
|
|
||||||
|
def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str:
|
||||||
|
with self._lock:
|
||||||
|
if not refresh and relative_path in self._resolved_urls:
|
||||||
|
return self._resolved_urls[relative_path]
|
||||||
|
key = self._cache_key(relative_path)
|
||||||
|
if not refresh:
|
||||||
|
with self._GLOBAL_LOCK:
|
||||||
|
resolved = self._GLOBAL_RESOLVED_URLS.get(key)
|
||||||
|
size = self._GLOBAL_SIZES.get(key)
|
||||||
|
if resolved is not None:
|
||||||
|
with self._lock:
|
||||||
|
self._resolved_urls[relative_path] = resolved
|
||||||
|
if size is not None:
|
||||||
|
self._sizes[relative_path] = size
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
source = self._source_url(relative_path)
|
||||||
|
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
|
||||||
|
try:
|
||||||
|
hf_raise_for_status(response)
|
||||||
|
location = response.headers.get("Location")
|
||||||
|
resolved = urljoin(source, location) if location else source
|
||||||
|
with self._lock:
|
||||||
|
self._resolved_urls[relative_path] = resolved
|
||||||
|
if "Content-Length" in response.headers:
|
||||||
|
self._sizes[relative_path] = int(response.headers["Content-Length"])
|
||||||
|
with self._GLOBAL_LOCK:
|
||||||
|
self._GLOBAL_RESOLVED_URLS[key] = resolved
|
||||||
|
if "Content-Length" in response.headers:
|
||||||
|
self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"])
|
||||||
|
return resolved
|
||||||
|
finally:
|
||||||
|
response.close()
|
||||||
|
|
||||||
|
def info_size(self, relative_path: str) -> int:
|
||||||
|
with self._lock:
|
||||||
|
size = self._sizes.get(relative_path)
|
||||||
|
if size is not None:
|
||||||
|
return size
|
||||||
|
key = self._cache_key(relative_path)
|
||||||
|
with self._GLOBAL_LOCK:
|
||||||
|
size = self._GLOBAL_SIZES.get(key)
|
||||||
|
if size is not None:
|
||||||
|
with self._lock:
|
||||||
|
self._sizes[relative_path] = size
|
||||||
|
return size
|
||||||
|
|
||||||
|
resolved = self._resolve_url(relative_path)
|
||||||
|
source = self._source_url(relative_path)
|
||||||
|
response = self._request(
|
||||||
|
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
hf_raise_for_status(response)
|
||||||
|
size = int(response.headers["Content-Length"])
|
||||||
|
with self._lock:
|
||||||
|
self._sizes[relative_path] = size
|
||||||
|
with self._GLOBAL_LOCK:
|
||||||
|
self._GLOBAL_SIZES[key] = size
|
||||||
|
return size
|
||||||
|
finally:
|
||||||
|
response.close()
|
||||||
|
|
||||||
|
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
|
||||||
|
resolve_start = time.perf_counter()
|
||||||
|
resolved = self._resolve_url(relative_path)
|
||||||
|
source = self._source_url(relative_path)
|
||||||
|
resolve_s = time.perf_counter() - resolve_start
|
||||||
|
headers = self._headers_for(resolved, source)
|
||||||
|
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||||
|
payload, status_code, timings = self._read_range_response(resolved, headers)
|
||||||
|
if status_code == 403:
|
||||||
|
refresh_start = time.perf_counter()
|
||||||
|
resolved = self._resolve_url(relative_path, refresh=True)
|
||||||
|
resolve_s += time.perf_counter() - refresh_start
|
||||||
|
headers = self._headers_for(resolved, source)
|
||||||
|
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
|
||||||
|
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
|
||||||
|
for key, value in retry_timings.items():
|
||||||
|
timings[key] += value
|
||||||
|
if status_code == 403:
|
||||||
|
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
|
||||||
|
self._record_timing(
|
||||||
|
range_jobs=1.0,
|
||||||
|
range_bytes=float(len(payload)),
|
||||||
|
range_resolve_s=resolve_s,
|
||||||
|
**{f"range_status_{status_code}": 1.0},
|
||||||
|
**timings,
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
retry_attempts = 0.0
|
||||||
|
retry_sleep_s = 0.0
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
payload, status_code, timings = self._read_range_response_once(url, headers)
|
||||||
|
timings["range_retry_attempts"] = retry_attempts
|
||||||
|
timings["range_retry_sleep_s"] = retry_sleep_s
|
||||||
|
return payload, status_code, timings
|
||||||
|
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt >= self.max_retries:
|
||||||
|
break
|
||||||
|
retry_attempts += 1.0
|
||||||
|
sleep_s = min(0.5 * 2**attempt, 5.0)
|
||||||
|
retry_sleep_s += sleep_s
|
||||||
|
time.sleep(sleep_s)
|
||||||
|
self._record_timing(
|
||||||
|
range_failed_requests=1.0,
|
||||||
|
range_retry_attempts=retry_attempts,
|
||||||
|
range_retry_sleep_s=retry_sleep_s,
|
||||||
|
)
|
||||||
|
if last_exc is None:
|
||||||
|
raise RuntimeError("HTTP range request failed without an exception")
|
||||||
|
raise last_exc
|
||||||
|
|
||||||
|
def _read_range_response_once(
|
||||||
|
self, url: str, headers: dict[str, str]
|
||||||
|
) -> tuple[bytes, int, dict[str, float]]:
|
||||||
|
header_start = time.perf_counter()
|
||||||
|
with self.client.stream("GET", url, headers=headers) as response:
|
||||||
|
header_s = time.perf_counter() - header_start
|
||||||
|
if response.status_code == 403:
|
||||||
|
return (
|
||||||
|
b"",
|
||||||
|
response.status_code,
|
||||||
|
{
|
||||||
|
"range_header_s": header_s,
|
||||||
|
"range_first_byte_s": 0.0,
|
||||||
|
"range_body_s": 0.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
hf_raise_for_status(response)
|
||||||
|
chunks = []
|
||||||
|
first_byte_s = 0.0
|
||||||
|
first_chunk = True
|
||||||
|
body_start = time.perf_counter()
|
||||||
|
for chunk in response.iter_bytes():
|
||||||
|
if first_chunk:
|
||||||
|
first_byte_s = time.perf_counter() - body_start
|
||||||
|
first_chunk = False
|
||||||
|
chunks.append(chunk)
|
||||||
|
body_s = time.perf_counter() - body_start
|
||||||
|
return (
|
||||||
|
b"".join(chunks),
|
||||||
|
response.status_code,
|
||||||
|
{
|
||||||
|
"range_header_s": header_s,
|
||||||
|
"range_first_byte_s": first_byte_s,
|
||||||
|
"range_body_s": body_s,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_timing(self, **kwargs: float) -> None:
|
||||||
|
with self._timing_lock:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
self._timing_totals[key] += value
|
||||||
|
|
||||||
|
def timing_summary(self) -> dict[str, float]:
|
||||||
|
with self._timing_lock:
|
||||||
|
return dict(self._timing_totals)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def make_range_fetcher(
|
||||||
|
data_root: str | Path,
|
||||||
|
*,
|
||||||
|
range_backend: str,
|
||||||
|
workers: int,
|
||||||
|
native_http_connections: int | None = None,
|
||||||
|
native_http_timeout: float = 60.0,
|
||||||
|
native_http_retries: int = 4,
|
||||||
|
):
|
||||||
|
if range_backend == "fsspec":
|
||||||
|
return ThreadLocalRangeFetcher(data_root)
|
||||||
|
if range_backend == "native-http":
|
||||||
|
max_connections = native_http_connections or max(8, workers)
|
||||||
|
return NativeHTTPRangeFetcher(
|
||||||
|
data_root,
|
||||||
|
max_connections=max_connections,
|
||||||
|
timeout=native_http_timeout,
|
||||||
|
max_retries=native_http_retries,
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unknown range backend: {range_backend}")
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeVideoManifest:
|
||||||
|
_FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {}
|
||||||
|
_FILE_SIDECAR_CACHE_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
video_keys: list[str],
|
||||||
|
files: list[VideoFileRecord],
|
||||||
|
spans: dict[str, np.ndarray],
|
||||||
|
):
|
||||||
|
self.video_keys = list(video_keys)
|
||||||
|
self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)}
|
||||||
|
self.files = files
|
||||||
|
self.spans = spans
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(
|
||||||
|
cls,
|
||||||
|
meta: LeRobotDatasetMetadata,
|
||||||
|
data_root: str | Path,
|
||||||
|
*,
|
||||||
|
episode_indices: list[int] | range | None = None,
|
||||||
|
range_backend: str = "fsspec",
|
||||||
|
workers: int = 8,
|
||||||
|
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||||
|
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||||
|
keyframe_pad_s: float = 0.1,
|
||||||
|
keyframe_pad_fraction: float = 0.05,
|
||||||
|
sidecar_path: str | Path | None = None,
|
||||||
|
) -> EpisodeVideoManifest:
|
||||||
|
meta.ensure_readable()
|
||||||
|
video_keys = list(meta.video_keys)
|
||||||
|
if episode_indices is None:
|
||||||
|
episode_indices = range(int(meta.total_episodes))
|
||||||
|
rel_paths = sorted(
|
||||||
|
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys}
|
||||||
|
)
|
||||||
|
path_to_id = {path: idx for idx, path in enumerate(rel_paths)}
|
||||||
|
if sidecar_path is None:
|
||||||
|
files = cls._build_file_records(
|
||||||
|
rel_paths,
|
||||||
|
data_root,
|
||||||
|
range_backend=range_backend,
|
||||||
|
workers=workers,
|
||||||
|
header_probe_bytes=header_probe_bytes,
|
||||||
|
max_probe_bytes=max_probe_bytes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
records = cls.load_file_sidecar(sidecar_path)
|
||||||
|
missing = [path for path in rel_paths if path not in records]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}"
|
||||||
|
)
|
||||||
|
files = [records[path] for path in rel_paths]
|
||||||
|
|
||||||
|
total = int(meta.total_episodes)
|
||||||
|
num_cameras = len(video_keys)
|
||||||
|
spans: dict[str, np.ndarray] = {
|
||||||
|
"file_id": np.zeros((total, num_cameras), dtype=np.int32),
|
||||||
|
"mdat_offset": np.zeros((total, num_cameras), dtype=np.int64),
|
||||||
|
"mdat_length": np.zeros((total, num_cameras), dtype=np.int64),
|
||||||
|
"first_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||||
|
"last_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||||
|
"frame_count": np.zeros((total, num_cameras), dtype=np.int32),
|
||||||
|
"sample_lo": np.zeros((total, num_cameras), dtype=np.int32),
|
||||||
|
"sample_hi": np.zeros((total, num_cameras), dtype=np.int32),
|
||||||
|
"source_start_pts": np.zeros((total, num_cameras), dtype=np.float64),
|
||||||
|
}
|
||||||
|
|
||||||
|
for ep_idx in episode_indices:
|
||||||
|
ep = meta.episodes[ep_idx]
|
||||||
|
for cam_idx, key in enumerate(video_keys):
|
||||||
|
rel_path = str(meta.get_video_file_path(ep_idx, key))
|
||||||
|
file_id = path_to_id[rel_path]
|
||||||
|
mp4 = files[file_id].mp4
|
||||||
|
from_ts = float(ep[f"videos/{key}/from_timestamp"])
|
||||||
|
to_ts = float(ep[f"videos/{key}/to_timestamp"])
|
||||||
|
sample_slice = mp4.sample_slice(
|
||||||
|
from_ts,
|
||||||
|
to_ts,
|
||||||
|
keyframe_pad_s=keyframe_pad_s,
|
||||||
|
keyframe_pad_fraction=keyframe_pad_fraction,
|
||||||
|
file_size=files[file_id].file_size,
|
||||||
|
)
|
||||||
|
spans["file_id"][ep_idx, cam_idx] = file_id
|
||||||
|
spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset
|
||||||
|
spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length
|
||||||
|
spans["first_pts"][ep_idx, cam_idx] = from_ts
|
||||||
|
spans["last_pts"][ep_idx, cam_idx] = to_ts
|
||||||
|
spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1
|
||||||
|
spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo
|
||||||
|
spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi
|
||||||
|
spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts
|
||||||
|
|
||||||
|
return cls(video_keys=video_keys, files=files, spans=spans)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_file_records(
|
||||||
|
rel_paths: list[str],
|
||||||
|
data_root: str | Path,
|
||||||
|
*,
|
||||||
|
range_backend: str,
|
||||||
|
workers: int,
|
||||||
|
header_probe_bytes: int,
|
||||||
|
max_probe_bytes: int,
|
||||||
|
) -> list[VideoFileRecord]:
|
||||||
|
fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
|
||||||
|
|
||||||
|
def build_file(path: str) -> VideoFileRecord:
|
||||||
|
file_size = fetcher.info_size(path)
|
||||||
|
mp4 = fetch_mp4_index(
|
||||||
|
path,
|
||||||
|
fetcher.read_range,
|
||||||
|
file_size=file_size,
|
||||||
|
header_probe_bytes=header_probe_bytes,
|
||||||
|
max_probe_bytes=max_probe_bytes,
|
||||||
|
)
|
||||||
|
return VideoFileRecord(path, file_size, mp4)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||||
|
return list(pool.map(build_file, rel_paths))
|
||||||
|
finally:
|
||||||
|
fetcher.close()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def write_file_sidecar(
|
||||||
|
cls,
|
||||||
|
sidecar_path: str | Path,
|
||||||
|
rel_paths: list[str],
|
||||||
|
data_root: str | Path,
|
||||||
|
*,
|
||||||
|
range_backend: str = "native-http",
|
||||||
|
workers: int = 8,
|
||||||
|
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||||
|
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||||
|
) -> None:
|
||||||
|
records = cls._build_file_records(
|
||||||
|
sorted(set(rel_paths)),
|
||||||
|
data_root,
|
||||||
|
range_backend=range_backend,
|
||||||
|
workers=workers,
|
||||||
|
header_probe_bytes=header_probe_bytes,
|
||||||
|
max_probe_bytes=max_probe_bytes,
|
||||||
|
)
|
||||||
|
cls.save_file_sidecar(sidecar_path, records)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None:
|
||||||
|
sidecar_path = Path(sidecar_path)
|
||||||
|
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
payload = {
|
||||||
|
"version": 1,
|
||||||
|
"files": [
|
||||||
|
{"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()}
|
||||||
|
for record in records
|
||||||
|
],
|
||||||
|
}
|
||||||
|
arrays = {}
|
||||||
|
for file_idx, record in enumerate(records):
|
||||||
|
arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts
|
||||||
|
arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations
|
||||||
|
arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes
|
||||||
|
arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets
|
||||||
|
arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples
|
||||||
|
np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]:
|
||||||
|
cache_key = str(Path(sidecar_path).expanduser())
|
||||||
|
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||||
|
cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
with np.load(sidecar_path, allow_pickle=False) as data:
|
||||||
|
payload = json.loads(bytes(data["manifest_json"]).decode("utf-8"))
|
||||||
|
records = {}
|
||||||
|
for file_idx, item in enumerate(payload["files"]):
|
||||||
|
arrays = {
|
||||||
|
name: data[f"{file_idx}/{name}"]
|
||||||
|
for name in [
|
||||||
|
"sample_pts",
|
||||||
|
"sample_durations",
|
||||||
|
"sample_sizes",
|
||||||
|
"sample_offsets",
|
||||||
|
"sync_samples",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mp4 = Mp4Index.from_dict(item["mp4"], arrays)
|
||||||
|
records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4)
|
||||||
|
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
|
||||||
|
EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records
|
||||||
|
return records
|
||||||
|
|
||||||
|
def camera_id(self, camera_key: str) -> int:
|
||||||
|
return self._camera_to_id[camera_key]
|
||||||
|
|
||||||
|
def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan:
|
||||||
|
cam = self.camera_id(camera_key)
|
||||||
|
return EpisodeVideoSpan(
|
||||||
|
file_id=int(self.spans["file_id"][episode_index, cam]),
|
||||||
|
mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]),
|
||||||
|
mdat_length=int(self.spans["mdat_length"][episode_index, cam]),
|
||||||
|
first_pts=float(self.spans["first_pts"][episode_index, cam]),
|
||||||
|
last_pts=float(self.spans["last_pts"][episode_index, cam]),
|
||||||
|
frame_count=int(self.spans["frame_count"][episode_index, cam]),
|
||||||
|
sample_lo=int(self.spans["sample_lo"][episode_index, cam]),
|
||||||
|
sample_hi=int(self.spans["sample_hi"][episode_index, cam]),
|
||||||
|
source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def file_lookup(self, file_id: int) -> VideoFileRecord:
|
||||||
|
return self.files[file_id]
|
||||||
|
|
||||||
|
def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index:
|
||||||
|
return self.files[self.lookup(episode_index, camera_key).file_id].mp4
|
||||||
|
|
||||||
|
def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice:
|
||||||
|
span = self.lookup(episode_index, camera_key)
|
||||||
|
return Mp4SampleSlice(
|
||||||
|
sample_lo=span.sample_lo,
|
||||||
|
sample_hi=span.sample_hi,
|
||||||
|
byte_offset=span.mdat_offset,
|
||||||
|
byte_length=span.mdat_length,
|
||||||
|
source_start_pts=span.source_start_pts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeByteCache:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
manifest: EpisodeVideoManifest,
|
||||||
|
data_root: str | Path,
|
||||||
|
*,
|
||||||
|
byte_budget: int = 80 * 1024**3,
|
||||||
|
workers: int = 8,
|
||||||
|
range_backend: str = "fsspec",
|
||||||
|
native_http_connections: int | None = None,
|
||||||
|
native_http_timeout: float = 60.0,
|
||||||
|
native_http_retries: int = 4,
|
||||||
|
open_decoders: bool = True,
|
||||||
|
):
|
||||||
|
self.manifest = manifest
|
||||||
|
self.fetcher = make_range_fetcher(
|
||||||
|
data_root,
|
||||||
|
range_backend=range_backend,
|
||||||
|
workers=workers,
|
||||||
|
native_http_connections=native_http_connections,
|
||||||
|
native_http_timeout=native_http_timeout,
|
||||||
|
native_http_retries=native_http_retries,
|
||||||
|
)
|
||||||
|
self.byte_budget = byte_budget
|
||||||
|
self.open_decoders = open_decoders
|
||||||
|
self._pool = ThreadPoolExecutor(max_workers=workers)
|
||||||
|
self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict()
|
||||||
|
self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {}
|
||||||
|
self._bytes = 0
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._timing_totals = {
|
||||||
|
"lookup_s": 0.0,
|
||||||
|
"fetch_s": 0.0,
|
||||||
|
"synthesize_s": 0.0,
|
||||||
|
"store_s": 0.0,
|
||||||
|
"jobs": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._pool.shutdown(wait=True)
|
||||||
|
with self._lock:
|
||||||
|
self._cache.clear()
|
||||||
|
self._futures.clear()
|
||||||
|
self._bytes = 0
|
||||||
|
self.fetcher.close()
|
||||||
|
|
||||||
|
def __enter__(self) -> EpisodeByteCache:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *_exc) -> None:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def submit_prefetch(self, episode_index: int) -> None:
|
||||||
|
for camera_key in self.manifest.video_keys:
|
||||||
|
self._submit(episode_index, camera_key)
|
||||||
|
|
||||||
|
def ensure_ready(self, episode_index: int) -> None:
|
||||||
|
for camera_key in self.manifest.video_keys:
|
||||||
|
self.get_bytes(episode_index, camera_key)
|
||||||
|
|
||||||
|
def get_bytes(self, episode_index: int, camera_key: str) -> bytes:
|
||||||
|
return self._get_entry(episode_index, camera_key)["bytes"]
|
||||||
|
|
||||||
|
def get_decoder(self, episode_index: int, camera_key: str):
|
||||||
|
entry = self._get_entry(episode_index, camera_key)
|
||||||
|
decoder = entry.get("decoder")
|
||||||
|
if decoder is None:
|
||||||
|
decoder = open_video_decoder(io.BytesIO(entry["bytes"]))
|
||||||
|
entry["decoder"] = decoder
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]):
|
||||||
|
span = self.manifest.lookup(episode_index, camera_key)
|
||||||
|
local_ts = [ts - span.source_start_pts for ts in timestamps]
|
||||||
|
decoder = self.get_decoder(episode_index, camera_key)
|
||||||
|
if hasattr(decoder, "get_frames_played_at"):
|
||||||
|
return decoder.get_frames_played_at(local_ts).data
|
||||||
|
metadata = decoder.metadata
|
||||||
|
fps = getattr(metadata, "average_fps", None)
|
||||||
|
if fps is None:
|
||||||
|
duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9)
|
||||||
|
fps = metadata.num_frames / duration
|
||||||
|
return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data
|
||||||
|
|
||||||
|
def timing_summary(self) -> dict[str, float]:
|
||||||
|
with self._lock:
|
||||||
|
summary = dict(self._timing_totals)
|
||||||
|
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
|
||||||
|
if fetcher_summary is not None:
|
||||||
|
summary.update(fetcher_summary())
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
|
||||||
|
key = (episode_index, camera_key)
|
||||||
|
with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
future: Future[dict[str, Any]] = Future()
|
||||||
|
future.set_result(self._cache[key])
|
||||||
|
return future
|
||||||
|
future = self._futures.get(key)
|
||||||
|
if future is None:
|
||||||
|
future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key)
|
||||||
|
self._futures[key] = future
|
||||||
|
return future
|
||||||
|
|
||||||
|
def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||||
|
key = (episode_index, camera_key)
|
||||||
|
with self._lock:
|
||||||
|
entry = self._cache.get(key)
|
||||||
|
if entry is not None:
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
return entry
|
||||||
|
future = self._submit(episode_index, camera_key)
|
||||||
|
entry = future.result()
|
||||||
|
store_start = time.perf_counter()
|
||||||
|
with self._lock:
|
||||||
|
self._futures.pop(key, None)
|
||||||
|
existing = self._cache.get(key)
|
||||||
|
if existing is not None:
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
return existing
|
||||||
|
self._cache[key] = entry
|
||||||
|
self._bytes += len(entry["bytes"])
|
||||||
|
self._evict_locked()
|
||||||
|
timings = entry.pop("_timings", None)
|
||||||
|
if timings is not None:
|
||||||
|
self._timing_totals["lookup_s"] += timings["lookup_s"]
|
||||||
|
self._timing_totals["fetch_s"] += timings["fetch_s"]
|
||||||
|
self._timing_totals["synthesize_s"] += timings["synthesize_s"]
|
||||||
|
self._timing_totals["store_s"] += time.perf_counter() - store_start
|
||||||
|
self._timing_totals["jobs"] += 1
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def _evict_locked(self) -> None:
|
||||||
|
while self._bytes > self.byte_budget and self._cache:
|
||||||
|
_key, entry = self._cache.popitem(last=False)
|
||||||
|
self._bytes -= len(entry["bytes"])
|
||||||
|
|
||||||
|
def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]:
|
||||||
|
lookup_start = time.perf_counter()
|
||||||
|
span = self.manifest.lookup(episode_index, camera_key)
|
||||||
|
file_record = self.manifest.file_lookup(span.file_id)
|
||||||
|
sample_slice = Mp4SampleSlice(
|
||||||
|
sample_lo=span.sample_lo,
|
||||||
|
sample_hi=span.sample_hi,
|
||||||
|
byte_offset=span.mdat_offset,
|
||||||
|
byte_length=span.mdat_length,
|
||||||
|
source_start_pts=span.source_start_pts,
|
||||||
|
)
|
||||||
|
lookup_s = time.perf_counter() - lookup_start
|
||||||
|
fetch_start = time.perf_counter()
|
||||||
|
payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length)
|
||||||
|
fetch_s = time.perf_counter() - fetch_start
|
||||||
|
if len(payload) != span.mdat_length:
|
||||||
|
raise OSError(
|
||||||
|
f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}"
|
||||||
|
)
|
||||||
|
synthesize_start = time.perf_counter()
|
||||||
|
mp4_bytes = synthesize_mp4(file_record.mp4, sample_slice, payload)
|
||||||
|
synthesize_s = time.perf_counter() - synthesize_start
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"bytes": mp4_bytes,
|
||||||
|
"decoder": None,
|
||||||
|
"_timings": {
|
||||||
|
"lookup_s": lookup_s,
|
||||||
|
"fetch_s": fetch_s,
|
||||||
|
"synthesize_s": synthesize_s,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if self.open_decoders:
|
||||||
|
entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes))
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def open_video_decoder(file_like_or_bytesio, frame_mappings=None):
|
||||||
|
if frame_mappings is not None:
|
||||||
|
raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.")
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
return VideoDecoder(file_like_or_bytesio, seek_mode="approximate")
|
||||||
|
|
||||||
|
|
||||||
|
def assert_hf_hub_range_cache_branch() -> None:
|
||||||
|
"""Fail unless huggingface_hub was installed from the required range-cache branch."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
dist = metadata.distribution("huggingface_hub")
|
||||||
|
except metadata.PackageNotFoundError as exc:
|
||||||
|
raise AssertionError("huggingface_hub is not installed") from exc
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
direct_url = dist.read_text("direct_url.json")
|
||||||
|
if direct_url:
|
||||||
|
candidates.append(direct_url)
|
||||||
|
with contextlib.suppress(json.JSONDecodeError):
|
||||||
|
parsed = json.loads(direct_url)
|
||||||
|
candidates.append(str(parsed.get("url", "")))
|
||||||
|
candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", "")))
|
||||||
|
candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", "")))
|
||||||
|
|
||||||
|
text = "\n".join(candidates)
|
||||||
|
if "feat/hffs-cache-cdn-range-reads" not in text:
|
||||||
|
raise AssertionError(
|
||||||
|
"huggingface_hub must be installed from "
|
||||||
|
"git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StageTimer:
|
||||||
|
fetch_ms: float = 0.0
|
||||||
|
decode_ms: float = 0.0
|
||||||
|
bytes_read: int = 0
|
||||||
|
misses: int = 0
|
||||||
|
|
||||||
|
def record_fetch(self, start: float, byte_count: int) -> None:
|
||||||
|
self.fetch_ms += (time.perf_counter() - start) * 1000
|
||||||
|
self.bytes_read += byte_count
|
||||||
|
self.misses += 1
|
||||||
@@ -0,0 +1,666 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import struct
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Box:
|
||||||
|
type: bytes
|
||||||
|
start: int
|
||||||
|
header_size: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def payload_start(self) -> int:
|
||||||
|
return self.start + self.header_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Mp4SampleSlice:
|
||||||
|
sample_lo: int
|
||||||
|
sample_hi: int
|
||||||
|
byte_offset: int
|
||||||
|
byte_length: int
|
||||||
|
source_start_pts: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Mp4Index:
|
||||||
|
file_path: str
|
||||||
|
file_size: int
|
||||||
|
ftyp: bytes
|
||||||
|
moov_offset: int
|
||||||
|
mdat_offset: int
|
||||||
|
mdat_payload_offset: int
|
||||||
|
mdat_payload_size: int
|
||||||
|
faststart: bool
|
||||||
|
codec: str
|
||||||
|
timescale: int
|
||||||
|
duration: int
|
||||||
|
track_id: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
stsd_body: bytes
|
||||||
|
sample_pts: np.ndarray
|
||||||
|
sample_durations: np.ndarray
|
||||||
|
sample_sizes: np.ndarray
|
||||||
|
sample_offsets: np.ndarray
|
||||||
|
sync_samples: np.ndarray
|
||||||
|
|
||||||
|
def sample_slice(
|
||||||
|
self,
|
||||||
|
from_ts: float,
|
||||||
|
to_ts: float,
|
||||||
|
*,
|
||||||
|
keyframe_pad_s: float = 0.1,
|
||||||
|
keyframe_pad_fraction: float = 0.05,
|
||||||
|
file_size: int | None = None,
|
||||||
|
) -> Mp4SampleSlice:
|
||||||
|
if to_ts < from_ts:
|
||||||
|
raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}")
|
||||||
|
if len(self.sample_pts) == 0:
|
||||||
|
raise ValueError(f"{self.file_path} contains no indexed samples")
|
||||||
|
|
||||||
|
pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction)
|
||||||
|
lo_ts = max(0.0, from_ts - pad)
|
||||||
|
hi_ts = to_ts + pad
|
||||||
|
lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left"))
|
||||||
|
hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1
|
||||||
|
lo = min(max(lo, 0), len(self.sample_pts) - 1)
|
||||||
|
hi = min(max(hi, lo), len(self.sample_pts) - 1)
|
||||||
|
|
||||||
|
if len(self.sync_samples):
|
||||||
|
prev_sync = self.sync_samples[self.sync_samples <= lo]
|
||||||
|
if len(prev_sync):
|
||||||
|
lo = int(prev_sync[-1])
|
||||||
|
else:
|
||||||
|
lo = int(self.sync_samples[0])
|
||||||
|
if lo > hi:
|
||||||
|
hi = lo
|
||||||
|
|
||||||
|
offsets = self.sample_offsets[lo : hi + 1]
|
||||||
|
sizes = self.sample_sizes[lo : hi + 1]
|
||||||
|
slice_lo = int(offsets.min())
|
||||||
|
slice_hi = int((offsets + sizes).max())
|
||||||
|
if file_size is not None:
|
||||||
|
slice_hi = min(slice_hi, int(file_size))
|
||||||
|
return Mp4SampleSlice(
|
||||||
|
sample_lo=lo,
|
||||||
|
sample_hi=hi,
|
||||||
|
byte_offset=slice_lo,
|
||||||
|
byte_length=slice_hi - slice_lo,
|
||||||
|
source_start_pts=float(self.sample_pts[lo]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"file_path": self.file_path,
|
||||||
|
"file_size": self.file_size,
|
||||||
|
"ftyp": self.ftyp.hex(),
|
||||||
|
"moov_offset": self.moov_offset,
|
||||||
|
"mdat_offset": self.mdat_offset,
|
||||||
|
"mdat_payload_offset": self.mdat_payload_offset,
|
||||||
|
"mdat_payload_size": self.mdat_payload_size,
|
||||||
|
"faststart": self.faststart,
|
||||||
|
"codec": self.codec,
|
||||||
|
"timescale": self.timescale,
|
||||||
|
"duration": self.duration,
|
||||||
|
"track_id": self.track_id,
|
||||||
|
"width": self.width,
|
||||||
|
"height": self.height,
|
||||||
|
"stsd_body": self.stsd_body.hex(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index:
|
||||||
|
return cls(
|
||||||
|
file_path=data["file_path"],
|
||||||
|
file_size=int(data["file_size"]),
|
||||||
|
ftyp=bytes.fromhex(data["ftyp"]),
|
||||||
|
moov_offset=int(data["moov_offset"]),
|
||||||
|
mdat_offset=int(data["mdat_offset"]),
|
||||||
|
mdat_payload_offset=int(data["mdat_payload_offset"]),
|
||||||
|
mdat_payload_size=int(data["mdat_payload_size"]),
|
||||||
|
faststart=bool(data["faststart"]),
|
||||||
|
codec=data["codec"],
|
||||||
|
timescale=int(data["timescale"]),
|
||||||
|
duration=int(data["duration"]),
|
||||||
|
track_id=int(data["track_id"]),
|
||||||
|
width=int(data["width"]),
|
||||||
|
height=int(data["height"]),
|
||||||
|
stsd_body=bytes.fromhex(data["stsd_body"]),
|
||||||
|
sample_pts=arrays["sample_pts"],
|
||||||
|
sample_durations=arrays["sample_durations"],
|
||||||
|
sample_sizes=arrays["sample_sizes"],
|
||||||
|
sample_offsets=arrays["sample_offsets"],
|
||||||
|
sync_samples=arrays["sync_samples"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_mp4_index(
|
||||||
|
path: str,
|
||||||
|
read_range: Callable[[str, int, int], bytes],
|
||||||
|
*,
|
||||||
|
file_size: int,
|
||||||
|
header_probe_bytes: int = 4 * 1024 * 1024,
|
||||||
|
max_probe_bytes: int = 64 * 1024 * 1024,
|
||||||
|
) -> Mp4Index:
|
||||||
|
probe_size = min(header_probe_bytes, file_size)
|
||||||
|
while True:
|
||||||
|
data = read_range(path, 0, probe_size)
|
||||||
|
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||||
|
has_mdat = any(box.type == b"mdat" for box in top)
|
||||||
|
has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top)
|
||||||
|
if has_mdat and has_moov:
|
||||||
|
return parse_mp4_index(path, data, file_size=file_size)
|
||||||
|
if probe_size >= min(max_probe_bytes, file_size):
|
||||||
|
if has_mdat and not has_moov:
|
||||||
|
tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes)
|
||||||
|
if tail_index is not None:
|
||||||
|
return tail_index
|
||||||
|
missing = []
|
||||||
|
if not has_mdat:
|
||||||
|
missing.append("mdat")
|
||||||
|
if not has_moov:
|
||||||
|
missing.append("moov")
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}"
|
||||||
|
)
|
||||||
|
probe_size = min(probe_size * 2, max_probe_bytes, file_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_tail_moov_index(
|
||||||
|
path: str,
|
||||||
|
read_range: Callable[[str, int, int], bytes],
|
||||||
|
prefix: bytes,
|
||||||
|
top_boxes: list[Box],
|
||||||
|
file_size: int,
|
||||||
|
max_probe_bytes: int,
|
||||||
|
) -> Mp4Index | None:
|
||||||
|
mdat_box = _one(top_boxes, b"mdat")
|
||||||
|
if mdat_box is None or mdat_box.end >= file_size:
|
||||||
|
return None
|
||||||
|
tail_offset = mdat_box.end
|
||||||
|
tail_length = min(max_probe_bytes, file_size - tail_offset)
|
||||||
|
tail = read_range(path, tail_offset, tail_length)
|
||||||
|
tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True))
|
||||||
|
moov_box = next(
|
||||||
|
(box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None
|
||||||
|
)
|
||||||
|
if moov_box is None:
|
||||||
|
return None
|
||||||
|
ftyp_box = _one(top_boxes, b"ftyp", required=False)
|
||||||
|
ftyp = (
|
||||||
|
prefix[ftyp_box.start : ftyp_box.end]
|
||||||
|
if ftyp_box is not None
|
||||||
|
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||||
|
)
|
||||||
|
moov_start = moov_box.payload_start - tail_offset
|
||||||
|
moov_end = moov_box.end - tail_offset
|
||||||
|
return _parse_mp4_index_from_layout(
|
||||||
|
path,
|
||||||
|
file_size=file_size,
|
||||||
|
ftyp=ftyp,
|
||||||
|
moov_offset=moov_box.start,
|
||||||
|
moov=tail[moov_start:moov_end],
|
||||||
|
mdat_box=mdat_box,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index:
|
||||||
|
if file_size is None:
|
||||||
|
file_size = len(data)
|
||||||
|
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
|
||||||
|
ftyp_box = _one(top, b"ftyp", required=False)
|
||||||
|
moov_box = _one(top, b"moov")
|
||||||
|
mdat_box = _one(top, b"mdat")
|
||||||
|
if moov_box.end > len(data):
|
||||||
|
raise ValueError(f"{path}: moov box is truncated")
|
||||||
|
|
||||||
|
moov = data[moov_box.payload_start : moov_box.end]
|
||||||
|
ftyp = (
|
||||||
|
data[ftyp_box.start : ftyp_box.end]
|
||||||
|
if ftyp_box is not None
|
||||||
|
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||||
|
)
|
||||||
|
return _parse_mp4_index_from_layout(
|
||||||
|
path,
|
||||||
|
file_size=file_size,
|
||||||
|
ftyp=ftyp,
|
||||||
|
moov_offset=moov_box.start,
|
||||||
|
moov=moov,
|
||||||
|
mdat_box=mdat_box,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_mp4_index_from_layout(
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
file_size: int,
|
||||||
|
ftyp: bytes,
|
||||||
|
moov_offset: int,
|
||||||
|
moov: bytes,
|
||||||
|
mdat_box: Box,
|
||||||
|
) -> Mp4Index:
|
||||||
|
mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"]))
|
||||||
|
trak_box, trak_payload = _find_video_trak(moov)
|
||||||
|
_ = trak_box
|
||||||
|
tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"]))
|
||||||
|
mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"]))
|
||||||
|
stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"])
|
||||||
|
|
||||||
|
stsd = _find_child(stbl, b"stsd")
|
||||||
|
stsd_body = stbl[stsd.payload_start : stsd.end]
|
||||||
|
codec = _parse_stsd_codec(stsd_body)
|
||||||
|
stts = _parse_stts(_payload(stbl, b"stts"))
|
||||||
|
sample_sizes = _parse_stsz(_payload(stbl, b"stsz"))
|
||||||
|
stsc = _parse_stsc(_payload(stbl, b"stsc"))
|
||||||
|
chunk_offsets = _parse_chunk_offsets(stbl)
|
||||||
|
sync_samples = _parse_stss(stbl, len(sample_sizes))
|
||||||
|
|
||||||
|
sample_durations = _expand_stts(stts, len(sample_sizes))
|
||||||
|
sample_pts_units = np.empty(len(sample_durations), dtype=np.int64)
|
||||||
|
if len(sample_durations):
|
||||||
|
sample_pts_units[0] = 0
|
||||||
|
if len(sample_durations) > 1:
|
||||||
|
sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64)
|
||||||
|
sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale)
|
||||||
|
sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes)
|
||||||
|
|
||||||
|
return Mp4Index(
|
||||||
|
file_path=path,
|
||||||
|
file_size=file_size,
|
||||||
|
ftyp=ftyp,
|
||||||
|
moov_offset=moov_offset,
|
||||||
|
mdat_offset=mdat_box.start,
|
||||||
|
mdat_payload_offset=mdat_box.payload_start,
|
||||||
|
mdat_payload_size=mdat_box.end - mdat_box.payload_start
|
||||||
|
if mdat_box.end <= file_size
|
||||||
|
else file_size - mdat_box.payload_start,
|
||||||
|
faststart=moov_offset < mdat_box.start,
|
||||||
|
codec=codec,
|
||||||
|
timescale=mdhd_timescale,
|
||||||
|
duration=mdhd_duration or mvhd_duration,
|
||||||
|
track_id=tkhd["track_id"],
|
||||||
|
width=tkhd["width"],
|
||||||
|
height=tkhd["height"],
|
||||||
|
stsd_body=stsd_body,
|
||||||
|
sample_pts=sample_pts,
|
||||||
|
sample_durations=sample_durations,
|
||||||
|
sample_sizes=sample_sizes,
|
||||||
|
sample_offsets=sample_offsets,
|
||||||
|
sync_samples=sync_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes:
|
||||||
|
lo = sample_slice.sample_lo
|
||||||
|
hi = sample_slice.sample_hi + 1
|
||||||
|
if lo < 0 or hi > len(index.sample_sizes) or lo >= hi:
|
||||||
|
raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}")
|
||||||
|
|
||||||
|
offsets = index.sample_offsets[lo:hi]
|
||||||
|
sizes = index.sample_sizes[lo:hi]
|
||||||
|
rel_offsets = offsets - sample_slice.byte_offset
|
||||||
|
if int(rel_offsets.min()) != 0:
|
||||||
|
raise ValueError("Sample slice must start at the minimum referenced sample offset")
|
||||||
|
if int((rel_offsets + sizes).max()) > len(mdat_payload):
|
||||||
|
raise ValueError("Sample slice does not cover all referenced samples")
|
||||||
|
|
||||||
|
durations = index.sample_durations[lo:hi]
|
||||||
|
sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1
|
||||||
|
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0)
|
||||||
|
header_size = len(index.ftyp) + len(moov)
|
||||||
|
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8)
|
||||||
|
return index.ftyp + moov + _box(b"mdat", mdat_payload)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_boxes(
|
||||||
|
data: bytes,
|
||||||
|
start: int,
|
||||||
|
end: int,
|
||||||
|
*,
|
||||||
|
absolute_base: int = 0,
|
||||||
|
allow_truncated: bool = False,
|
||||||
|
) -> Iterable[Box]:
|
||||||
|
pos = start
|
||||||
|
while pos + 8 <= end:
|
||||||
|
size = struct.unpack_from(">I", data, pos)[0]
|
||||||
|
typ = data[pos + 4 : pos + 8]
|
||||||
|
header_size = 8
|
||||||
|
if size == 1:
|
||||||
|
if pos + 16 > end:
|
||||||
|
break
|
||||||
|
size = struct.unpack_from(">Q", data, pos + 8)[0]
|
||||||
|
header_size = 16
|
||||||
|
elif size == 0:
|
||||||
|
size = end - pos
|
||||||
|
if size < header_size:
|
||||||
|
break
|
||||||
|
box_end = pos + size
|
||||||
|
if box_end > end and not allow_truncated:
|
||||||
|
break
|
||||||
|
yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end)
|
||||||
|
pos = box_end
|
||||||
|
|
||||||
|
|
||||||
|
def _find_video_trak(moov: bytes) -> tuple[Box, bytes]:
|
||||||
|
for trak in _children(moov, 0, len(moov)):
|
||||||
|
if trak.type != b"trak":
|
||||||
|
continue
|
||||||
|
payload = moov[trak.payload_start : trak.end]
|
||||||
|
hdlr = _find_descendant(payload, [b"mdia", b"hdlr"])
|
||||||
|
if hdlr[8:12] == b"vide":
|
||||||
|
return trak, payload
|
||||||
|
raise ValueError("No video track found")
|
||||||
|
|
||||||
|
|
||||||
|
def _find_descendant(data: bytes, path: list[bytes]) -> bytes:
|
||||||
|
current = data
|
||||||
|
for typ in path:
|
||||||
|
box = _find_child(current, typ)
|
||||||
|
current = current[box.payload_start : box.end]
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def _find_child(data: bytes, typ: bytes) -> Box:
|
||||||
|
for box in _children(data, 0, len(data)):
|
||||||
|
if box.type == typ:
|
||||||
|
return box
|
||||||
|
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||||
|
|
||||||
|
|
||||||
|
def _children(data: bytes, start: int, end: int) -> Iterable[Box]:
|
||||||
|
return iter_boxes(data, start, end, absolute_base=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None:
|
||||||
|
matches = [box for box in boxes if box.type == typ]
|
||||||
|
if not matches and required:
|
||||||
|
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
|
||||||
|
return matches[0] if matches else None
|
||||||
|
|
||||||
|
|
||||||
|
def _payload(parent: bytes, typ: bytes) -> bytes:
|
||||||
|
box = _find_child(parent, typ)
|
||||||
|
return parent[box.payload_start : box.end]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_mvhd(payload: bytes) -> tuple[int, int]:
|
||||||
|
version = payload[0]
|
||||||
|
if version == 1:
|
||||||
|
return struct.unpack_from(">IQ", payload, 20)
|
||||||
|
return struct.unpack_from(">II", payload, 12)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_mdhd(payload: bytes) -> tuple[int, int]:
|
||||||
|
version = payload[0]
|
||||||
|
if version == 1:
|
||||||
|
return struct.unpack_from(">IQ", payload, 20)
|
||||||
|
return struct.unpack_from(">II", payload, 12)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tkhd(payload: bytes) -> dict[str, int]:
|
||||||
|
version = payload[0]
|
||||||
|
if version == 1:
|
||||||
|
track_id = struct.unpack_from(">I", payload, 20)[0]
|
||||||
|
duration = struct.unpack_from(">Q", payload, 28)[0]
|
||||||
|
width, height = struct.unpack_from(">II", payload, 88)
|
||||||
|
else:
|
||||||
|
track_id = struct.unpack_from(">I", payload, 12)[0]
|
||||||
|
duration = struct.unpack_from(">I", payload, 20)[0]
|
||||||
|
width, height = struct.unpack_from(">II", payload, 76)
|
||||||
|
return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_stsd_codec(stsd_body: bytes) -> str:
|
||||||
|
if len(stsd_body) < 16:
|
||||||
|
return "unknown"
|
||||||
|
return stsd_body[12:16].decode("latin1")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_stts(payload: bytes) -> list[tuple[int, int]]:
|
||||||
|
count = struct.unpack_from(">I", payload, 4)[0]
|
||||||
|
out = []
|
||||||
|
offset = 8
|
||||||
|
for _ in range(count):
|
||||||
|
out.append(struct.unpack_from(">II", payload, offset))
|
||||||
|
offset += 8
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray:
|
||||||
|
values = np.empty(sample_count, dtype=np.int64)
|
||||||
|
pos = 0
|
||||||
|
for count, delta in entries:
|
||||||
|
values[pos : pos + count] = delta
|
||||||
|
pos += count
|
||||||
|
if pos != sample_count:
|
||||||
|
raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_stsz(payload: bytes) -> np.ndarray:
|
||||||
|
sample_size, sample_count = struct.unpack_from(">II", payload, 4)
|
||||||
|
if sample_size:
|
||||||
|
return np.full(sample_count, sample_size, dtype=np.int64)
|
||||||
|
offset = 12
|
||||||
|
values = np.empty(sample_count, dtype=np.int64)
|
||||||
|
for idx in range(sample_count):
|
||||||
|
values[idx] = struct.unpack_from(">I", payload, offset)[0]
|
||||||
|
offset += 4
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]:
|
||||||
|
count = struct.unpack_from(">I", payload, 4)[0]
|
||||||
|
out = []
|
||||||
|
offset = 8
|
||||||
|
for _ in range(count):
|
||||||
|
out.append(struct.unpack_from(">III", payload, offset))
|
||||||
|
offset += 12
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_chunk_offsets(stbl: bytes) -> np.ndarray:
|
||||||
|
with_stco = None
|
||||||
|
with_co64 = None
|
||||||
|
for box in _children(stbl, 0, len(stbl)):
|
||||||
|
if box.type == b"stco":
|
||||||
|
with_stco = stbl[box.payload_start : box.end]
|
||||||
|
elif box.type == b"co64":
|
||||||
|
with_co64 = stbl[box.payload_start : box.end]
|
||||||
|
if with_co64 is not None:
|
||||||
|
count = struct.unpack_from(">I", with_co64, 4)[0]
|
||||||
|
return np.array(
|
||||||
|
[struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64
|
||||||
|
)
|
||||||
|
if with_stco is None:
|
||||||
|
raise ValueError("Missing stco/co64 chunk offsets")
|
||||||
|
count = struct.unpack_from(">I", with_stco, 4)[0]
|
||||||
|
return np.array(
|
||||||
|
[struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray:
|
||||||
|
for box in _children(stbl, 0, len(stbl)):
|
||||||
|
if box.type == b"stss":
|
||||||
|
payload = stbl[box.payload_start : box.end]
|
||||||
|
count = struct.unpack_from(">I", payload, 4)[0]
|
||||||
|
return np.array(
|
||||||
|
[struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)],
|
||||||
|
dtype=np.int64,
|
||||||
|
)
|
||||||
|
return np.arange(sample_count, dtype=np.int64)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_offsets(
|
||||||
|
stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray
|
||||||
|
) -> np.ndarray:
|
||||||
|
if not stsc:
|
||||||
|
raise ValueError("stsc is empty")
|
||||||
|
offsets = np.empty(len(sample_sizes), dtype=np.int64)
|
||||||
|
sample_idx = 0
|
||||||
|
for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc):
|
||||||
|
next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1
|
||||||
|
for chunk_number in range(first_chunk, next_first):
|
||||||
|
if chunk_number < 1 or chunk_number > len(chunk_offsets):
|
||||||
|
raise ValueError("stsc references a chunk outside stco/co64")
|
||||||
|
chunk_pos = int(chunk_offsets[chunk_number - 1])
|
||||||
|
for _ in range(samples_per_chunk):
|
||||||
|
if sample_idx >= len(sample_sizes):
|
||||||
|
return offsets
|
||||||
|
offsets[sample_idx] = chunk_pos
|
||||||
|
chunk_pos += int(sample_sizes[sample_idx])
|
||||||
|
sample_idx += 1
|
||||||
|
if sample_idx != len(sample_sizes):
|
||||||
|
raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}")
|
||||||
|
return offsets
|
||||||
|
|
||||||
|
|
||||||
|
def _make_moov(
|
||||||
|
index: Mp4Index,
|
||||||
|
durations: np.ndarray,
|
||||||
|
sizes: np.ndarray,
|
||||||
|
rel_offsets: np.ndarray,
|
||||||
|
sync_samples: np.ndarray,
|
||||||
|
*,
|
||||||
|
mdat_data_offset: int,
|
||||||
|
) -> bytes:
|
||||||
|
duration = int(durations.sum())
|
||||||
|
stco_values = [int(mdat_data_offset + value) for value in rel_offsets]
|
||||||
|
if any(value > 0xFFFFFFFF for value in stco_values):
|
||||||
|
offset_box = _co64(stco_values)
|
||||||
|
else:
|
||||||
|
offset_box = _stco(stco_values)
|
||||||
|
stbl = _box(
|
||||||
|
b"stbl",
|
||||||
|
_box(b"stsd", index.stsd_body)
|
||||||
|
+ _stts(durations)
|
||||||
|
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||||
|
+ _stsz(sizes)
|
||||||
|
+ offset_box
|
||||||
|
+ (_stss(sync_samples) if len(sync_samples) else b""),
|
||||||
|
)
|
||||||
|
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||||
|
mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf)
|
||||||
|
trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia)
|
||||||
|
return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak)
|
||||||
|
|
||||||
|
|
||||||
|
def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes:
|
||||||
|
return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _box(typ: bytes, payload: bytes) -> bytes:
|
||||||
|
size = len(payload) + 8
|
||||||
|
if size <= 0xFFFFFFFF:
|
||||||
|
return struct.pack(">I4s", size, typ) + payload
|
||||||
|
return struct.pack(">I4sQ", 1, typ, size + 8) + payload
|
||||||
|
|
||||||
|
|
||||||
|
def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes:
|
||||||
|
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||||
|
payload = (
|
||||||
|
struct.pack(">IIII", 0, 0, timescale, duration)
|
||||||
|
+ struct.pack(">IHH", 0x00010000, 0x0100, 0)
|
||||||
|
+ b"\0" * 8
|
||||||
|
+ matrix
|
||||||
|
+ b"\0" * 24
|
||||||
|
+ struct.pack(">I", next_track_id)
|
||||||
|
)
|
||||||
|
return _full_box(b"mvhd", 0, 0, payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes:
|
||||||
|
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
|
||||||
|
payload = (
|
||||||
|
struct.pack(">IIIII", 0, 0, track_id, 0, duration)
|
||||||
|
+ b"\0" * 8
|
||||||
|
+ struct.pack(">hhhh", 0, 0, 0, 0)
|
||||||
|
+ matrix
|
||||||
|
+ struct.pack(">II", width << 16, height << 16)
|
||||||
|
)
|
||||||
|
return _full_box(b"tkhd", 0, 7, payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _mdhd(timescale: int, duration: int) -> bytes:
|
||||||
|
return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0")
|
||||||
|
|
||||||
|
|
||||||
|
def _hdlr() -> bytes:
|
||||||
|
return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0")
|
||||||
|
|
||||||
|
|
||||||
|
def _vmhd() -> bytes:
|
||||||
|
return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def _dinf() -> bytes:
|
||||||
|
url = _full_box(b"url ", 0, 1)
|
||||||
|
dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url)
|
||||||
|
return _box(b"dinf", dref)
|
||||||
|
|
||||||
|
|
||||||
|
def _stts(durations: np.ndarray) -> bytes:
|
||||||
|
runs = []
|
||||||
|
for duration in durations.tolist():
|
||||||
|
if runs and runs[-1][1] == int(duration):
|
||||||
|
runs[-1][0] += 1
|
||||||
|
else:
|
||||||
|
runs.append([1, int(duration)])
|
||||||
|
payload = struct.pack(">I", len(runs)) + b"".join(
|
||||||
|
struct.pack(">II", count, delta) for count, delta in runs
|
||||||
|
)
|
||||||
|
return _full_box(b"stts", 0, 0, payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _stsc_one_sample_per_chunk(sample_count: int) -> bytes:
|
||||||
|
return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _stsz(sizes: np.ndarray) -> bytes:
|
||||||
|
return _full_box(
|
||||||
|
b"stsz",
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _stco(values: list[int]) -> bytes:
|
||||||
|
return _full_box(
|
||||||
|
b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _co64(values: list[int]) -> bytes:
|
||||||
|
return _full_box(
|
||||||
|
b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _stss(values: np.ndarray) -> bytes:
|
||||||
|
return _full_box(
|
||||||
|
b"stss",
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()),
|
||||||
|
)
|
||||||
@@ -70,19 +70,21 @@ def aggregate_pipeline_dataset_features(
|
|||||||
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
initial_features: dict[PipelineFeatureType, dict[str, Any]],
|
||||||
*,
|
*,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
|
exclude_images: bool = False,
|
||||||
patterns: Sequence[str] | None = None,
|
patterns: Sequence[str] | None = None,
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
Aggregates and filters pipeline features to create a dataset-ready features dictionary.
|
||||||
|
|
||||||
This function transforms initial features using the pipeline, categorizes them as action or observations
|
This function transforms initial features using the pipeline, categorizes them as action or observations
|
||||||
(image or state), filters them based on `use_videos` and `patterns`, and finally
|
(image or state), filters them based on `exclude_images` and `patterns`, and finally
|
||||||
formats them for use with a Hugging Face LeRobot Dataset.
|
formats them for use with a Hugging Face LeRobot Dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipeline: The DataProcessorPipeline to apply.
|
pipeline: The DataProcessorPipeline to apply.
|
||||||
initial_features: A dictionary of raw feature specs for actions and observations.
|
initial_features: A dictionary of raw feature specs for actions and observations.
|
||||||
use_videos: If False, image features are excluded.
|
use_videos: Controls the storage dtype for image features. If True, images are stored as "video"; if False, they are stored as "image".
|
||||||
|
exclude_images: If True, image features are dropped entirely from the output.
|
||||||
patterns: A sequence of regex patterns to filter action and state features.
|
patterns: A sequence of regex patterns to filter action and state features.
|
||||||
Image features are not affected by this filter.
|
Image features are not affected by this filter.
|
||||||
|
|
||||||
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. Apply filtering rules.
|
# 2. Apply filtering rules.
|
||||||
if is_image and not use_videos:
|
if is_image and exclude_images:
|
||||||
continue
|
continue
|
||||||
if not is_image and not should_keep(key, compiled_patterns):
|
if not is_image and not should_keep(key, compiled_patterns):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.utils.bimanual import BimanualMixin
|
||||||
|
from lerobot.utils.decorators import check_if_not_connected
|
||||||
|
|
||||||
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
@@ -27,7 +28,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiOpenArmFollower(Robot):
|
class BiOpenArmFollower(BimanualMixin, Robot):
|
||||||
"""
|
"""
|
||||||
Bimanual OpenArm Follower Arms
|
Bimanual OpenArm Follower Arms
|
||||||
"""
|
"""
|
||||||
@@ -39,15 +40,17 @@ class BiOpenArmFollower(Robot):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
|
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||||
# will only open the cameras assigned to it. Per-arm cameras are used
|
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||||
# as fallback when top-level cameras are empty.
|
self._top_level_cam_keys = set(config.cameras)
|
||||||
if config.cameras:
|
_collisions = self._top_level_cam_keys & set(
|
||||||
left_cameras = config.cameras
|
config.left_arm_config.cameras
|
||||||
right_cameras = {}
|
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||||
else:
|
if _collisions:
|
||||||
left_cameras = config.left_arm_config.cameras
|
raise ValueError(
|
||||||
right_cameras = config.right_arm_config.cameras
|
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||||
|
)
|
||||||
|
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||||
|
|
||||||
left_arm_config = OpenArmFollowerConfig(
|
left_arm_config = OpenArmFollowerConfig(
|
||||||
id=f"{config.id}_left" if config.id else None,
|
id=f"{config.id}_left" if config.id else None,
|
||||||
@@ -56,7 +59,7 @@ class BiOpenArmFollower(Robot):
|
|||||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||||
max_relative_target=config.left_arm_config.max_relative_target,
|
max_relative_target=config.left_arm_config.max_relative_target,
|
||||||
cameras=left_cameras,
|
cameras=left_arm_cameras,
|
||||||
side=config.left_arm_config.side,
|
side=config.left_arm_config.side,
|
||||||
can_interface=config.left_arm_config.can_interface,
|
can_interface=config.left_arm_config.can_interface,
|
||||||
use_can_fd=config.left_arm_config.use_can_fd,
|
use_can_fd=config.left_arm_config.use_can_fd,
|
||||||
@@ -75,7 +78,7 @@ class BiOpenArmFollower(Robot):
|
|||||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||||
max_relative_target=config.right_arm_config.max_relative_target,
|
max_relative_target=config.right_arm_config.max_relative_target,
|
||||||
cameras=right_cameras,
|
cameras=config.right_arm_config.cameras,
|
||||||
side=config.right_arm_config.side,
|
side=config.right_arm_config.side,
|
||||||
can_interface=config.right_arm_config.can_interface,
|
can_interface=config.right_arm_config.can_interface,
|
||||||
use_can_fd=config.right_arm_config.use_can_fd,
|
use_can_fd=config.right_arm_config.use_can_fd,
|
||||||
@@ -95,22 +98,19 @@ class BiOpenArmFollower(Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _motors_ft(self) -> dict[str, type]:
|
def _motors_ft(self) -> dict[str, type]:
|
||||||
left_arm_motors_ft = self.left_arm._motors_ft
|
|
||||||
right_arm_motors_ft = self.right_arm._motors_ft
|
|
||||||
|
|
||||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
|
||||||
# and the dataset feature names recorded during data collection.
|
|
||||||
return {
|
return {
|
||||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
|
||||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
|
out: dict[str, tuple] = {}
|
||||||
# "right_wrist"), so we merge them directly — unlike motors which need the
|
for k, v in self.left_arm._cameras_ft.items():
|
||||||
# left_/right_ prefix to disambiguate identical per-arm joint names.
|
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||||
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
|
for k, v in self.right_arm._cameras_ft.items():
|
||||||
|
out[f"right_{k}"] = v
|
||||||
|
return out
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
@@ -120,27 +120,6 @@ class BiOpenArmFollower(Robot):
|
|||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return self._motors_ft
|
return self._motors_ft
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
|
||||||
|
|
||||||
@check_if_already_connected
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
self.left_arm.connect(calibrate)
|
|
||||||
self.right_arm.connect(calibrate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
self.left_arm.calibrate()
|
|
||||||
self.right_arm.calibrate()
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
self.left_arm.configure()
|
|
||||||
self.right_arm.configure()
|
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||||
@@ -148,21 +127,15 @@ class BiOpenArmFollower(Robot):
|
|||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_observation(self) -> RobotObservation:
|
def get_observation(self) -> RobotObservation:
|
||||||
obs_dict = {}
|
obs_dict: RobotObservation = {}
|
||||||
|
|
||||||
# Camera keys that should NOT get the arm prefix (they already have unique names)
|
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||||
left_cam_keys = set(self.left_arm.cameras.keys())
|
for key, value in self.left_arm.get_observation().items():
|
||||||
right_cam_keys = set(self.right_arm.cameras.keys())
|
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||||
|
|
||||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
# Add "right_" prefix
|
||||||
# and the dataset feature names recorded during data collection.
|
for key, value in self.right_arm.get_observation().items():
|
||||||
right_obs = self.right_arm.get_observation()
|
obs_dict[f"right_{key}"] = value
|
||||||
for key, value in right_obs.items():
|
|
||||||
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
|
|
||||||
|
|
||||||
left_obs = self.left_arm.get_observation()
|
|
||||||
for key, value in left_obs.items():
|
|
||||||
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@@ -189,9 +162,4 @@ class BiOpenArmFollower(Robot):
|
|||||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||||
|
|
||||||
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
|
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def disconnect(self):
|
|
||||||
self.left_arm.disconnect()
|
|
||||||
self.right_arm.disconnect()
|
|
||||||
|
|||||||
@@ -32,5 +32,7 @@ class BiOpenArmFollowerConfig(RobotConfig):
|
|||||||
left_arm_config: OpenArmFollowerConfigBase
|
left_arm_config: OpenArmFollowerConfigBase
|
||||||
right_arm_config: OpenArmFollowerConfigBase
|
right_arm_config: OpenArmFollowerConfigBase
|
||||||
|
|
||||||
# Top-level cameras shared across both arms.
|
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||||
|
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||||
|
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.utils.bimanual import BimanualMixin
|
||||||
|
from lerobot.utils.decorators import check_if_not_connected
|
||||||
|
|
||||||
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
|
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
@@ -27,7 +28,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiRebotB601Follower(Robot):
|
class BiRebotB601Follower(BimanualMixin, Robot):
|
||||||
"""Bimanual Seeed Studio reBot B601-DM follower.
|
"""Bimanual Seeed Studio reBot B601-DM follower.
|
||||||
|
|
||||||
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
|
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
|
||||||
@@ -41,6 +42,18 @@ class BiRebotB601Follower(Robot):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||||
|
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||||
|
self._top_level_cam_keys = set(config.cameras)
|
||||||
|
_collisions = self._top_level_cam_keys & set(
|
||||||
|
config.left_arm_config.cameras
|
||||||
|
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||||
|
if _collisions:
|
||||||
|
raise ValueError(
|
||||||
|
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||||
|
)
|
||||||
|
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||||
|
|
||||||
left_arm_config = RebotB601FollowerRobotConfig(
|
left_arm_config = RebotB601FollowerRobotConfig(
|
||||||
id=f"{config.id}_left" if config.id else None,
|
id=f"{config.id}_left" if config.id else None,
|
||||||
calibration_dir=config.calibration_dir,
|
calibration_dir=config.calibration_dir,
|
||||||
@@ -49,7 +62,7 @@ class BiRebotB601Follower(Robot):
|
|||||||
dm_serial_baud=config.left_arm_config.dm_serial_baud,
|
dm_serial_baud=config.left_arm_config.dm_serial_baud,
|
||||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||||
max_relative_target=config.left_arm_config.max_relative_target,
|
max_relative_target=config.left_arm_config.max_relative_target,
|
||||||
cameras=config.left_arm_config.cameras,
|
cameras=left_arm_cameras,
|
||||||
motor_can_ids=config.left_arm_config.motor_can_ids,
|
motor_can_ids=config.left_arm_config.motor_can_ids,
|
||||||
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
||||||
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
||||||
@@ -86,10 +99,12 @@ class BiRebotB601Follower(Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
return {
|
out: dict[str, tuple] = {}
|
||||||
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
|
for k, v in self.left_arm._cameras_ft.items():
|
||||||
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
|
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||||
}
|
for k, v in self.right_arm._cameras_ft.items():
|
||||||
|
out[f"right_{k}"] = v
|
||||||
|
return out
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
@@ -99,32 +114,13 @@ class BiRebotB601Follower(Robot):
|
|||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return self._motors_ft
|
return self._motors_ft
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
|
||||||
|
|
||||||
@check_if_already_connected
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
self.left_arm.connect(calibrate)
|
|
||||||
self.right_arm.connect(calibrate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
self.left_arm.calibrate()
|
|
||||||
self.right_arm.calibrate()
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
self.left_arm.configure()
|
|
||||||
self.right_arm.configure()
|
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_observation(self) -> RobotObservation:
|
def get_observation(self) -> RobotObservation:
|
||||||
obs_dict = {}
|
obs_dict: RobotObservation = {}
|
||||||
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
|
for k, v in self.left_arm.get_observation().items():
|
||||||
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
|
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||||
|
for k, v in self.right_arm.get_observation().items():
|
||||||
|
obs_dict[f"right_{k}"] = v
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
@@ -143,8 +139,3 @@ class BiRebotB601Follower(Robot):
|
|||||||
**{f"left_{k}": v for k, v in sent_action_left.items()},
|
**{f"left_{k}": v for k, v in sent_action_left.items()},
|
||||||
**{f"right_{k}": v for k, v in sent_action_right.items()},
|
**{f"right_{k}": v for k, v in sent_action_right.items()},
|
||||||
}
|
}
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
self.left_arm.disconnect()
|
|
||||||
self.right_arm.disconnect()
|
|
||||||
|
|||||||
@@ -14,7 +14,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.cameras import CameraConfig
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
from ..rebot_b601_follower import RebotB601FollowerConfig
|
from ..rebot_b601_follower import RebotB601FollowerConfig
|
||||||
@@ -27,3 +29,8 @@ class BiRebotB601FollowerConfig(RobotConfig):
|
|||||||
|
|
||||||
left_arm_config: RebotB601FollowerConfig
|
left_arm_config: RebotB601FollowerConfig
|
||||||
right_arm_config: RebotB601FollowerConfig
|
right_arm_config: RebotB601FollowerConfig
|
||||||
|
|
||||||
|
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||||
|
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||||
|
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||||
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.utils.bimanual import BimanualMixin
|
||||||
|
from lerobot.utils.decorators import check_if_not_connected
|
||||||
|
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
from ..so_follower import SOFollower, SOFollowerRobotConfig
|
from ..so_follower import SOFollower, SOFollowerRobotConfig
|
||||||
@@ -27,7 +28,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiSOFollower(Robot):
|
class BiSOFollower(BimanualMixin, Robot):
|
||||||
"""
|
"""
|
||||||
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||||
"""
|
"""
|
||||||
@@ -39,6 +40,18 @@ class BiSOFollower(Robot):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||||
|
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||||
|
self._top_level_cam_keys = set(config.cameras)
|
||||||
|
_collisions = self._top_level_cam_keys & set(
|
||||||
|
config.left_arm_config.cameras
|
||||||
|
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||||
|
if _collisions:
|
||||||
|
raise ValueError(
|
||||||
|
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||||
|
)
|
||||||
|
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||||
|
|
||||||
left_arm_config = SOFollowerRobotConfig(
|
left_arm_config = SOFollowerRobotConfig(
|
||||||
id=f"{config.id}_left" if config.id else None,
|
id=f"{config.id}_left" if config.id else None,
|
||||||
calibration_dir=config.calibration_dir,
|
calibration_dir=config.calibration_dir,
|
||||||
@@ -46,7 +59,7 @@ class BiSOFollower(Robot):
|
|||||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||||
max_relative_target=config.left_arm_config.max_relative_target,
|
max_relative_target=config.left_arm_config.max_relative_target,
|
||||||
use_degrees=config.left_arm_config.use_degrees,
|
use_degrees=config.left_arm_config.use_degrees,
|
||||||
cameras=config.left_arm_config.cameras,
|
cameras=left_arm_cameras,
|
||||||
)
|
)
|
||||||
|
|
||||||
right_arm_config = SOFollowerRobotConfig(
|
right_arm_config = SOFollowerRobotConfig(
|
||||||
@@ -77,13 +90,12 @@ class BiSOFollower(Robot):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _cameras_ft(self) -> dict[str, tuple]:
|
def _cameras_ft(self) -> dict[str, tuple]:
|
||||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
out: dict[str, tuple] = {}
|
||||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
for k, v in self.left_arm._cameras_ft.items():
|
||||||
|
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||||
return {
|
for k, v in self.right_arm._cameras_ft.items():
|
||||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
out[f"right_{k}"] = v
|
||||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
return out
|
||||||
}
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def observation_features(self) -> dict[str, type | tuple]:
|
def observation_features(self) -> dict[str, type | tuple]:
|
||||||
@@ -93,42 +105,21 @@ class BiSOFollower(Robot):
|
|||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return self._motors_ft
|
return self._motors_ft
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
|
||||||
|
|
||||||
@check_if_already_connected
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
self.left_arm.connect(calibrate)
|
|
||||||
self.right_arm.connect(calibrate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
self.left_arm.calibrate()
|
|
||||||
self.right_arm.calibrate()
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
self.left_arm.configure()
|
|
||||||
self.right_arm.configure()
|
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
self.left_arm.setup_motors()
|
self.left_arm.setup_motors()
|
||||||
self.right_arm.setup_motors()
|
self.right_arm.setup_motors()
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_observation(self) -> RobotObservation:
|
def get_observation(self) -> RobotObservation:
|
||||||
obs_dict = {}
|
obs_dict: RobotObservation = {}
|
||||||
|
|
||||||
# Add "left_" prefix
|
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||||
left_obs = self.left_arm.get_observation()
|
for key, value in self.left_arm.get_observation().items():
|
||||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||||
|
|
||||||
# Add "right_" prefix
|
# Add "right_" prefix
|
||||||
right_obs = self.right_arm.get_observation()
|
for key, value in self.right_arm.get_observation().items():
|
||||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
obs_dict[f"right_{key}"] = value
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
@@ -151,8 +142,3 @@ class BiSOFollower(Robot):
|
|||||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||||
|
|
||||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def disconnect(self):
|
|
||||||
self.left_arm.disconnect()
|
|
||||||
self.right_arm.disconnect()
|
|
||||||
|
|||||||
@@ -14,7 +14,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.cameras import CameraConfig
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
from ..so_follower import SOFollowerConfig
|
from ..so_follower import SOFollowerConfig
|
||||||
@@ -27,3 +29,8 @@ class BiSOFollowerConfig(RobotConfig):
|
|||||||
|
|
||||||
left_arm_config: SOFollowerConfig
|
left_arm_config: SOFollowerConfig
|
||||||
right_arm_config: SOFollowerConfig
|
right_arm_config: SOFollowerConfig
|
||||||
|
|
||||||
|
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||||
|
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||||
|
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||||
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
|
bi_openarm_mini,
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
homunculus,
|
homunculus,
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
from lerobot.teleoperators import ( # noqa: F401
|
from lerobot.teleoperators import ( # noqa: F401
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
|
bi_openarm_mini,
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
gamepad,
|
gamepad,
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
|
bi_openarm_mini,
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
homunculus,
|
homunculus,
|
||||||
|
|||||||
@@ -174,6 +174,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
|
bi_openarm_mini,
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
homunculus,
|
homunculus,
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
from lerobot.teleoperators import ( # noqa: F401
|
from lerobot.teleoperators import ( # noqa: F401
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
|
bi_openarm_mini,
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
koch_leader,
|
koch_leader,
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
bi_openarm_leader,
|
bi_openarm_leader,
|
||||||
|
bi_openarm_mini,
|
||||||
bi_rebot_102_leader,
|
bi_rebot_102_leader,
|
||||||
bi_so_leader,
|
bi_so_leader,
|
||||||
gamepad,
|
gamepad,
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction
|
from lerobot.types import RobotAction
|
||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.utils.bimanual import BimanualMixin
|
||||||
|
from lerobot.utils.decorators import check_if_not_connected
|
||||||
|
|
||||||
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
@@ -27,7 +28,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiOpenArmLeader(Teleoperator):
|
class BiOpenArmLeader(BimanualMixin, Teleoperator):
|
||||||
"""
|
"""
|
||||||
Bimanual OpenArm Leader Arms
|
Bimanual OpenArm Leader Arms
|
||||||
"""
|
"""
|
||||||
@@ -86,27 +87,6 @@ class BiOpenArmLeader(Teleoperator):
|
|||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
|
||||||
|
|
||||||
@check_if_already_connected
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
self.left_arm.connect(calibrate)
|
|
||||||
self.right_arm.connect(calibrate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
self.left_arm.calibrate()
|
|
||||||
self.right_arm.calibrate()
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
self.left_arm.configure()
|
|
||||||
self.right_arm.configure()
|
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||||
@@ -129,8 +109,3 @@ class BiOpenArmLeader(Teleoperator):
|
|||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
# TODO: Implement force feedback
|
# TODO: Implement force feedback
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
self.left_arm.disconnect()
|
|
||||||
self.right_arm.disconnect()
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
|
|||||||
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
||||||
@dataclass
|
@dataclass
|
||||||
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
||||||
"""Configuration class for Bi OpenArm Follower robots."""
|
"""Configuration class for Bi OpenArm Leader teleoperators."""
|
||||||
|
|
||||||
left_arm_config: OpenArmLeaderConfigBase
|
left_arm_config: OpenArmLeaderConfigBase
|
||||||
right_arm_config: OpenArmLeaderConfigBase
|
right_arm_config: OpenArmLeaderConfigBase
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .bi_openarm_mini import BiOpenArmMini
|
||||||
|
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
||||||
|
|
||||||
|
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
from lerobot.types import RobotAction
|
||||||
|
from lerobot.utils.bimanual import BimanualMixin
|
||||||
|
from lerobot.utils.decorators import check_if_not_connected
|
||||||
|
|
||||||
|
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
|
||||||
|
from ..teleoperator import Teleoperator
|
||||||
|
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BiOpenArmMini(BimanualMixin, Teleoperator):
|
||||||
|
"""Bimanual OpenArm Mini teleoperator.
|
||||||
|
|
||||||
|
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
|
||||||
|
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
|
||||||
|
bimanual leader can teleoperate a bimanual OpenArm follower.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = BiOpenArmMiniConfig
|
||||||
|
name = "bi_openarm_mini"
|
||||||
|
|
||||||
|
def __init__(self, config: BiOpenArmMiniConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# `side` is forced to match left/right regardless of what the user passed
|
||||||
|
# on the per-arm base config — the bimanual wrapper owns the side semantics.
|
||||||
|
left_arm_config = OpenArmMiniConfig(
|
||||||
|
id=f"{config.id}_left" if config.id else None,
|
||||||
|
calibration_dir=config.calibration_dir,
|
||||||
|
port=config.left_arm_config.port,
|
||||||
|
side="left",
|
||||||
|
use_degrees=config.left_arm_config.use_degrees,
|
||||||
|
)
|
||||||
|
|
||||||
|
right_arm_config = OpenArmMiniConfig(
|
||||||
|
id=f"{config.id}_right" if config.id else None,
|
||||||
|
calibration_dir=config.calibration_dir,
|
||||||
|
port=config.right_arm_config.port,
|
||||||
|
side="right",
|
||||||
|
use_degrees=config.right_arm_config.use_degrees,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.left_arm = OpenArmMini(left_arm_config)
|
||||||
|
self.right_arm = OpenArmMini(right_arm_config)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def action_features(self) -> dict[str, type]:
|
||||||
|
return {
|
||||||
|
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
|
||||||
|
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def feedback_features(self) -> dict[str, type]:
|
||||||
|
return {
|
||||||
|
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
|
||||||
|
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
|
||||||
|
}
|
||||||
|
|
||||||
|
def setup_motors(self) -> None:
|
||||||
|
self.left_arm.setup_motors()
|
||||||
|
self.right_arm.setup_motors()
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def get_action(self) -> RobotAction:
|
||||||
|
action: RobotAction = {}
|
||||||
|
for k, v in self.left_arm.get_action().items():
|
||||||
|
action[f"left_{k}"] = v
|
||||||
|
for k, v in self.right_arm.get_action().items():
|
||||||
|
action[f"right_{k}"] = v
|
||||||
|
return action
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
|
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
|
||||||
|
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
|
||||||
|
if left_fb:
|
||||||
|
self.left_arm.send_feedback(left_fb)
|
||||||
|
if right_fb:
|
||||||
|
self.right_arm.send_feedback(right_fb)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..config import TeleoperatorConfig
|
||||||
|
from ..openarm_mini import OpenArmMiniConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
|
||||||
|
@dataclass
|
||||||
|
class BiOpenArmMiniConfig(TeleoperatorConfig):
|
||||||
|
"""Configuration class for Bi OpenArm Mini teleoperators."""
|
||||||
|
|
||||||
|
left_arm_config: OpenArmMiniConfigBase
|
||||||
|
right_arm_config: OpenArmMiniConfigBase
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .bi_rebot_102_leader import BiRebotArm102Leader
|
from .bi_rebot_102_leader import BiRebot102Leader
|
||||||
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
||||||
|
|
||||||
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
|
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
|
||||||
|
|||||||
@@ -18,16 +18,17 @@ import logging
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.types import RobotAction
|
from lerobot.types import RobotAction
|
||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.utils.bimanual import BimanualMixin
|
||||||
|
from lerobot.utils.decorators import check_if_not_connected
|
||||||
|
|
||||||
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
|
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
|
||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiRebotArm102Leader(Teleoperator):
|
class BiRebot102Leader(BimanualMixin, Teleoperator):
|
||||||
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
|
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
|
||||||
|
|
||||||
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
|
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
|
||||||
@@ -35,10 +36,10 @@ class BiRebotArm102Leader(Teleoperator):
|
|||||||
leader can teleoperate a bimanual reBot B601 follower.
|
leader can teleoperate a bimanual reBot B601 follower.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = BiRebotArm102LeaderConfig
|
config_class = BiRebot102LeaderConfig
|
||||||
name = "bi_rebot_102_leader"
|
name = "bi_rebot_102_leader"
|
||||||
|
|
||||||
def __init__(self, config: BiRebotArm102LeaderConfig):
|
def __init__(self, config: BiRebot102LeaderConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -76,27 +77,6 @@ class BiRebotArm102Leader(Teleoperator):
|
|||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
|
||||||
|
|
||||||
@check_if_already_connected
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
self.left_arm.connect(calibrate)
|
|
||||||
self.right_arm.connect(calibrate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
self.left_arm.calibrate()
|
|
||||||
self.right_arm.calibrate()
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
self.left_arm.configure()
|
|
||||||
self.right_arm.configure()
|
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_action(self) -> RobotAction:
|
def get_action(self) -> RobotAction:
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
@@ -106,8 +86,3 @@ class BiRebotArm102Leader(Teleoperator):
|
|||||||
|
|
||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
|
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
self.left_arm.disconnect()
|
|
||||||
self.right_arm.disconnect()
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
|
|||||||
|
|
||||||
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
|
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
|
||||||
@dataclass
|
@dataclass
|
||||||
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
|
class BiRebot102LeaderConfig(TeleoperatorConfig):
|
||||||
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
|
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
|
||||||
|
|
||||||
left_arm_config: RebotArm102LeaderConfig
|
left_arm_config: RebotArm102LeaderConfig
|
||||||
|
|||||||
@@ -17,7 +17,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
from lerobot.types import RobotAction
|
||||||
|
from lerobot.utils.bimanual import BimanualMixin
|
||||||
|
from lerobot.utils.decorators import check_if_not_connected
|
||||||
|
|
||||||
from ..so_leader import SOLeader, SOLeaderTeleopConfig
|
from ..so_leader import SOLeader, SOLeaderTeleopConfig
|
||||||
from ..teleoperator import Teleoperator
|
from ..teleoperator import Teleoperator
|
||||||
@@ -26,7 +28,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BiSOLeader(Teleoperator):
|
class BiSOLeader(BimanualMixin, Teleoperator):
|
||||||
"""
|
"""
|
||||||
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||||
"""
|
"""
|
||||||
@@ -67,33 +69,12 @@ class BiSOLeader(Teleoperator):
|
|||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
|
||||||
|
|
||||||
@check_if_already_connected
|
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
|
||||||
self.left_arm.connect(calibrate)
|
|
||||||
self.right_arm.connect(calibrate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_calibrated(self) -> bool:
|
|
||||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
|
||||||
self.left_arm.calibrate()
|
|
||||||
self.right_arm.calibrate()
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
|
||||||
self.left_arm.configure()
|
|
||||||
self.right_arm.configure()
|
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
self.left_arm.setup_motors()
|
self.left_arm.setup_motors()
|
||||||
self.right_arm.setup_motors()
|
self.right_arm.setup_motors()
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_action(self) -> dict[str, float]:
|
def get_action(self) -> RobotAction:
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
|
|
||||||
# Add "left_" prefix
|
# Add "left_" prefix
|
||||||
@@ -109,8 +90,3 @@ class BiSOLeader(Teleoperator):
|
|||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
# TODO: Implement force feedback
|
# TODO: Implement force feedback
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@check_if_not_connected
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
self.left_arm.disconnect()
|
|
||||||
self.right_arm.disconnect()
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .config_openarm_mini import OpenArmMiniConfig
|
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
|
||||||
from .openarm_mini import OpenArmMini
|
from .openarm_mini import OpenArmMini
|
||||||
|
|
||||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
|
||||||
|
|||||||
@@ -19,12 +19,21 @@ from dataclasses import dataclass
|
|||||||
from ..config import TeleoperatorConfig
|
from ..config import TeleoperatorConfig
|
||||||
|
|
||||||
|
|
||||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
class OpenArmMiniConfigBase:
|
||||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
|
||||||
|
|
||||||
port_right: str = "/dev/ttyUSB0"
|
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
|
||||||
port_left: str = "/dev/ttyUSB1"
|
port: str
|
||||||
|
|
||||||
|
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
|
||||||
|
# during readout. If `None`, no flipping is applied.
|
||||||
|
side: str | None = None
|
||||||
|
|
||||||
use_degrees: bool = True
|
use_degrees: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||||
|
@dataclass
|
||||||
|
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Motors whose direction is inverted during readout
|
# Per-side motor direction flips applied during readout.
|
||||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
|
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
|
||||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
|
||||||
|
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
|
||||||
|
}
|
||||||
|
|
||||||
# Leader joint 6 maps to follower joint 7 and vice versa
|
# Leader joint 6 ↔ follower joint 7 (symmetric — its own inverse).
|
||||||
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
|
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
|
||||||
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
|
|
||||||
|
|
||||||
GRIPPER_TELEOP_TO_DEGREES = -0.65
|
GRIPPER_TELEOP_TO_DEGREES = -0.65
|
||||||
|
|
||||||
|
|
||||||
class OpenArmMini(Teleoperator):
|
class OpenArmMini(Teleoperator):
|
||||||
"""
|
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
|
||||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
|
||||||
|
|
||||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = OpenArmMiniConfig
|
config_class = OpenArmMiniConfig
|
||||||
@@ -56,9 +56,12 @@ class OpenArmMini(Teleoperator):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
|
||||||
|
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
|
||||||
|
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
|
||||||
|
|
||||||
norm_mode_body = MotorNormMode.DEGREES
|
norm_mode_body = MotorNormMode.DEGREES
|
||||||
|
motors = {
|
||||||
motors_right = {
|
|
||||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||||
@@ -69,46 +72,15 @@ class OpenArmMini(Teleoperator):
|
|||||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
}
|
}
|
||||||
|
|
||||||
motors_left = {
|
self.bus = FeetechMotorsBus(
|
||||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
port=self.config.port,
|
||||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
motors=motors,
|
||||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
calibration=self.calibration,
|
||||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
|
||||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
|
||||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
|
||||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
|
||||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
|
||||||
}
|
|
||||||
|
|
||||||
cal_right = {
|
|
||||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
|
||||||
}
|
|
||||||
cal_left = {
|
|
||||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
|
||||||
}
|
|
||||||
|
|
||||||
self.bus_right = FeetechMotorsBus(
|
|
||||||
port=self.config.port_right,
|
|
||||||
motors=motors_right,
|
|
||||||
calibration=cal_right,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.bus_left = FeetechMotorsBus(
|
|
||||||
port=self.config.port_left,
|
|
||||||
motors=motors_left,
|
|
||||||
calibration=cal_left,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||||
# and the dataset feature names recorded during data collection.
|
|
||||||
features: dict[str, type] = {}
|
|
||||||
for motor in self.bus_right.motors:
|
|
||||||
features[f"right_{motor}.pos"] = float
|
|
||||||
for motor in self.bus_left.motors:
|
|
||||||
features[f"left_{motor}.pos"] = float
|
|
||||||
return features
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
@@ -116,14 +88,12 @@ class OpenArmMini(Teleoperator):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
return self.bus.is_connected
|
||||||
|
|
||||||
@check_if_already_connected
|
@check_if_already_connected
|
||||||
def connect(self, calibrate: bool = True) -> None:
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
logger.info(f"Connecting arm on {self.config.port}...")
|
||||||
self.bus_right.connect()
|
self.bus.connect()
|
||||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
|
||||||
self.bus_left.connect()
|
|
||||||
|
|
||||||
if calibrate:
|
if calibrate:
|
||||||
self.calibrate()
|
self.calibrate()
|
||||||
@@ -133,14 +103,14 @@ class OpenArmMini(Teleoperator):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
return self.bus.is_calibrated
|
||||||
|
|
||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
"""
|
"""
|
||||||
Run calibration procedure for OpenArm Mini.
|
Run calibration procedure for a single OpenArm Mini arm.
|
||||||
|
|
||||||
1. Disable torque
|
1. Disable torque
|
||||||
2. Ask user to position arms in hanging position with grippers closed
|
2. Ask user to position arm in hanging position with gripper closed
|
||||||
3. Set this as zero position via half-turn homing
|
3. Set this as zero position via half-turn homing
|
||||||
4. Interactive gripper calibration (open/close positions)
|
4. Interactive gripper calibration (open/close positions)
|
||||||
5. Save calibration
|
5. Save calibration
|
||||||
@@ -152,70 +122,51 @@ class OpenArmMini(Teleoperator):
|
|||||||
)
|
)
|
||||||
if user_input.strip().lower() != "c":
|
if user_input.strip().lower() != "c":
|
||||||
logger.info(f"Using existing calibration for {self.id}")
|
logger.info(f"Using existing calibration for {self.id}")
|
||||||
cal_right = {
|
self.bus.write_calibration(self.calibration)
|
||||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
|
||||||
}
|
|
||||||
cal_left = {
|
|
||||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
|
||||||
}
|
|
||||||
self.bus_right.write_calibration(cal_right)
|
|
||||||
self.bus_left.write_calibration(cal_left)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"\nRunning calibration for {self}")
|
logger.info(f"\nRunning calibration for {self}")
|
||||||
|
|
||||||
self._calibrate_arm("right", self.bus_right)
|
self.bus.disable_torque()
|
||||||
self._calibrate_arm("left", self.bus_left)
|
|
||||||
|
|
||||||
self._save_calibration()
|
logger.info("Setting Phase to 12 for all motors...")
|
||||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
for motor in self.bus.motors:
|
||||||
|
self.bus.write("Phase", motor, 12)
|
||||||
|
|
||||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
for motor in self.bus.motors:
|
||||||
"""Calibrate a single arm with Feetech motors."""
|
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
|
||||||
|
|
||||||
bus.disable_torque()
|
|
||||||
|
|
||||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
|
||||||
for motor in bus.motors:
|
|
||||||
bus.write("Phase", motor, 12)
|
|
||||||
|
|
||||||
for motor in bus.motors:
|
|
||||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
|
||||||
|
|
||||||
input(
|
input(
|
||||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
"\nCalibration: Zero Position\n"
|
||||||
"Position the arm in the following configuration:\n"
|
"Position the arm in the following configuration:\n"
|
||||||
" - Arm hanging straight down\n"
|
" - Arm hanging straight down\n"
|
||||||
" - Gripper closed\n"
|
" - Gripper closed\n"
|
||||||
"Press ENTER when ready..."
|
"Press ENTER when ready..."
|
||||||
)
|
)
|
||||||
|
|
||||||
homing_offsets = bus.set_half_turn_homings()
|
homing_offsets = self.bus.set_half_turn_homings()
|
||||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
logger.info("Arm zero position set.")
|
||||||
|
|
||||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
print("\nSetting motor ranges\n")
|
||||||
|
|
||||||
if self.calibration is None:
|
if self.calibration is None:
|
||||||
self.calibration = {}
|
self.calibration = {}
|
||||||
|
|
||||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
|
||||||
max_res = motor_resolution - 1
|
max_res = motor_resolution - 1
|
||||||
|
|
||||||
for motor_name, motor in bus.motors.items():
|
for motor_name, motor in self.bus.motors.items():
|
||||||
prefixed_name = f"{arm_name}_{motor_name}"
|
|
||||||
|
|
||||||
if motor_name == "gripper":
|
if motor_name == "gripper":
|
||||||
input(
|
input(
|
||||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
"\nGripper Calibration\n"
|
||||||
f"Step 1: CLOSE the gripper fully\n"
|
"Step 1: CLOSE the gripper fully\n"
|
||||||
f"Press ENTER when gripper is closed..."
|
"Press ENTER when gripper is closed..."
|
||||||
)
|
)
|
||||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
||||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||||
|
|
||||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
||||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||||
|
|
||||||
if closed_pos < open_pos:
|
if closed_pos < open_pos:
|
||||||
@@ -228,16 +179,16 @@ class OpenArmMini(Teleoperator):
|
|||||||
drive_mode = 1
|
drive_mode = 1
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
f" {motor_name}: range set to [{range_min}, {range_max}] "
|
||||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
range_min = 0
|
range_min = 0
|
||||||
range_max = max_res
|
range_max = max_res
|
||||||
drive_mode = 0
|
drive_mode = 0
|
||||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
|
||||||
|
|
||||||
self.calibration[prefixed_name] = MotorCalibration(
|
self.calibration[motor_name] = MotorCalibration(
|
||||||
id=motor.id,
|
id=motor.id,
|
||||||
drive_mode=drive_mode,
|
drive_mode=drive_mode,
|
||||||
homing_offset=homing_offsets[motor_name],
|
homing_offset=homing_offsets[motor_name],
|
||||||
@@ -245,108 +196,68 @@ class OpenArmMini(Teleoperator):
|
|||||||
range_max=range_max,
|
range_max=range_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
cal_for_bus = {
|
self.bus.write_calibration(self.calibration)
|
||||||
k.replace(f"{arm_name}_", ""): v
|
self._save_calibration()
|
||||||
for k, v in self.calibration.items()
|
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||||
if k.startswith(f"{arm_name}_")
|
|
||||||
}
|
|
||||||
bus.write_calibration(cal_for_bus)
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
self.bus_right.disable_torque()
|
self.bus.disable_torque()
|
||||||
self.bus_right.configure_motors()
|
self.bus.configure_motors()
|
||||||
for motor in self.bus_right.motors:
|
for motor in self.bus.motors:
|
||||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
self.bus_left.disable_torque()
|
|
||||||
self.bus_left.configure_motors()
|
|
||||||
for motor in self.bus_left.motors:
|
|
||||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
|
||||||
|
|
||||||
def setup_motors(self) -> None:
|
def setup_motors(self) -> None:
|
||||||
print("\nSetting up RIGHT arm motors...")
|
for motor in reversed(self.bus.motors):
|
||||||
for motor in reversed(self.bus_right.motors):
|
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
self.bus.setup_motor(motor)
|
||||||
self.bus_right.setup_motor(motor)
|
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
|
||||||
|
|
||||||
print("\nSetting up LEFT arm motors...")
|
|
||||||
for motor in reversed(self.bus_left.motors):
|
|
||||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
|
||||||
self.bus_left.setup_motor(motor)
|
|
||||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def get_action(self) -> RobotAction:
|
def get_action(self) -> RobotAction:
|
||||||
"""Get current action from both arms (read positions from all motors)."""
|
"""Get current action (read positions from all motors)."""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
right_positions = self.bus_right.sync_read("Present_Position")
|
positions = self.bus.sync_read("Present_Position")
|
||||||
left_positions = self.bus_left.sync_read("Present_Position")
|
|
||||||
|
|
||||||
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
|
||||||
# and the dataset feature names recorded during data collection.
|
|
||||||
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
|
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
|
||||||
|
# Per-side direction flip is applied based on the configured `side`.
|
||||||
action: dict[str, Any] = {}
|
action: dict[str, Any] = {}
|
||||||
for motor, val in right_positions.items():
|
for motor, val in positions.items():
|
||||||
target = JOINT_REMAP.get(motor, motor)
|
target = JOINT_REMAP.get(motor, motor)
|
||||||
if motor == "gripper":
|
if motor == "gripper":
|
||||||
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
|
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
|
||||||
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||||
else:
|
else:
|
||||||
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
|
||||||
for motor, val in left_positions.items():
|
|
||||||
target = JOINT_REMAP.get(motor, motor)
|
|
||||||
if motor == "gripper":
|
|
||||||
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
|
||||||
else:
|
|
||||||
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
|
||||||
|
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def enable_torque(self) -> None:
|
def enable_torque(self) -> None:
|
||||||
"""Enable torque on both arms for position control."""
|
self.bus.enable_torque()
|
||||||
self.bus_right.enable_torque()
|
|
||||||
self.bus_left.enable_torque()
|
|
||||||
|
|
||||||
def disable_torque(self) -> None:
|
def disable_torque(self) -> None:
|
||||||
"""Disable torque on both arms for free movement."""
|
self.bus.disable_torque()
|
||||||
self.bus_right.disable_torque()
|
|
||||||
self.bus_left.disable_torque()
|
|
||||||
|
|
||||||
def write_goal_positions(self, positions: dict[str, float]) -> None:
|
def write_goal_positions(self, positions: dict[str, float]) -> None:
|
||||||
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
|
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
|
||||||
right_goals: dict[str, float] = {}
|
goals: dict[str, float] = {}
|
||||||
left_goals: dict[str, float] = {}
|
|
||||||
|
|
||||||
for key, val in positions.items():
|
for key, val in positions.items():
|
||||||
if not key.endswith(".pos"):
|
if not key.endswith(".pos"):
|
||||||
continue
|
continue
|
||||||
motor_name = key.removesuffix(".pos")
|
base = key.removesuffix(".pos")
|
||||||
if motor_name.startswith("right_"):
|
# JOINT_REMAP is symmetric (its own inverse).
|
||||||
base = motor_name.removeprefix("right_")
|
target = JOINT_REMAP.get(base, base)
|
||||||
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
|
if base == "gripper":
|
||||||
target = JOINT_REMAP_REVERSE.get(base, base)
|
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
||||||
if base == "gripper":
|
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||||
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
else:
|
||||||
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
||||||
else:
|
goals[target] = -val if target in self._motors_to_flip else val
|
||||||
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
|
||||||
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
|
|
||||||
elif motor_name.startswith("left_"):
|
|
||||||
base = motor_name.removeprefix("left_")
|
|
||||||
target = JOINT_REMAP_REVERSE.get(base, base)
|
|
||||||
if base == "gripper":
|
|
||||||
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
|
||||||
else:
|
|
||||||
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
|
|
||||||
|
|
||||||
if right_goals:
|
if goals:
|
||||||
self.bus_right.sync_write("Goal_Position", right_goals)
|
self.bus.sync_write("Goal_Position", goals)
|
||||||
if left_goals:
|
|
||||||
self.bus_left.sync_write("Goal_Position", left_goals)
|
|
||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
@@ -354,6 +265,5 @@ class OpenArmMini(Teleoperator):
|
|||||||
|
|
||||||
@check_if_not_connected
|
@check_if_not_connected
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
self.bus_right.disconnect()
|
self.bus.disconnect()
|
||||||
self.bus_left.disconnect()
|
|
||||||
logger.info(f"{self} disconnected.")
|
logger.info(f"{self} disconnected.")
|
||||||
|
|||||||
@@ -99,14 +99,18 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
|||||||
from .openarm_mini import OpenArmMini
|
from .openarm_mini import OpenArmMini
|
||||||
|
|
||||||
return OpenArmMini(config)
|
return OpenArmMini(config)
|
||||||
|
elif config.type == "bi_openarm_mini":
|
||||||
|
from .bi_openarm_mini import BiOpenArmMini
|
||||||
|
|
||||||
|
return BiOpenArmMini(config)
|
||||||
elif config.type == "rebot_102_leader":
|
elif config.type == "rebot_102_leader":
|
||||||
from .rebot_102_leader import RebotArm102Leader
|
from .rebot_102_leader import RebotArm102Leader
|
||||||
|
|
||||||
return RebotArm102Leader(config)
|
return RebotArm102Leader(config)
|
||||||
elif config.type == "bi_rebot_102_leader":
|
elif config.type == "bi_rebot_102_leader":
|
||||||
from .bi_rebot_102_leader import BiRebotArm102Leader
|
from .bi_rebot_102_leader import BiRebot102Leader
|
||||||
|
|
||||||
return BiRebotArm102Leader(config)
|
return BiRebot102Leader(config)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return cast("Teleoperator", make_device_from_device_class(config))
|
return cast("Teleoperator", make_device_from_device_class(config))
|
||||||
|
|||||||
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
|
|
||||||
|
|
||||||
|
class BimanualMixin:
|
||||||
|
"""Lifecycle delegation for bimanual robots and teleoperators.
|
||||||
|
|
||||||
|
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
|
||||||
|
their own ``__init__``. They retain ownership of feature dicts and the
|
||||||
|
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
|
||||||
|
``send_feedback``), which vary per-embodiment.
|
||||||
|
|
||||||
|
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
|
||||||
|
take precedence in the MRO::
|
||||||
|
|
||||||
|
class BiFooFollower(BimanualMixin, Robot): ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
left_arm: Any
|
||||||
|
right_arm: Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
self.left_arm.connect(calibrate)
|
||||||
|
self.right_arm.connect(calibrate)
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
self.left_arm.calibrate()
|
||||||
|
self.right_arm.calibrate()
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.left_arm.configure()
|
||||||
|
self.right_arm.configure()
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
self.left_arm.disconnect()
|
||||||
|
self.right_arm.disconnect()
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch
|
||||||
|
from lerobot.datasets.mp4 import (
|
||||||
|
_box,
|
||||||
|
_co64,
|
||||||
|
_dinf,
|
||||||
|
_hdlr,
|
||||||
|
_mdhd,
|
||||||
|
_mvhd,
|
||||||
|
_stco,
|
||||||
|
_stsc_one_sample_per_chunk,
|
||||||
|
_stss,
|
||||||
|
_stsz,
|
||||||
|
_stts,
|
||||||
|
_tkhd,
|
||||||
|
_vmhd,
|
||||||
|
parse_mp4_index,
|
||||||
|
synthesize_mp4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes:
|
||||||
|
ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
|
||||||
|
sizes = np.array([10, 10, 10], dtype=np.int64)
|
||||||
|
durations = np.array([1000, 1000, 1000], dtype=np.int64)
|
||||||
|
stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8
|
||||||
|
offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets)
|
||||||
|
stbl = _box(
|
||||||
|
b"stbl",
|
||||||
|
_box(b"stsd", stsd_body)
|
||||||
|
+ _stts(durations)
|
||||||
|
+ _stsc_one_sample_per_chunk(len(sizes))
|
||||||
|
+ _stsz(sizes)
|
||||||
|
+ offsets
|
||||||
|
+ _stss(np.array([1], dtype=np.int64)),
|
||||||
|
)
|
||||||
|
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
|
||||||
|
mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf)
|
||||||
|
trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia)
|
||||||
|
moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak)
|
||||||
|
mdat_payload_start = 10_000
|
||||||
|
free_size = mdat_payload_start - 8 - len(ftyp) - len(moov)
|
||||||
|
assert free_size >= 8
|
||||||
|
free = _box(b"free", b"\0" * (free_size - 8))
|
||||||
|
return ftyp + moov + free + _box(b"mdat", b"x" * 128)
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks():
|
||||||
|
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||||
|
|
||||||
|
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||||
|
|
||||||
|
assert sample_slice.byte_offset == 10_000
|
||||||
|
assert sample_slice.byte_length == 60
|
||||||
|
assert sample_slice.sample_lo == 0
|
||||||
|
assert sample_slice.sample_hi == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets():
|
||||||
|
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
|
||||||
|
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
|
||||||
|
|
||||||
|
mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length)
|
||||||
|
mini_index = parse_mp4_index("mini.mp4", mini)
|
||||||
|
|
||||||
|
expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset
|
||||||
|
np.testing.assert_array_equal(mini_index.sample_offsets, expected)
|
||||||
|
np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser_accepts_co64_chunk_offsets():
|
||||||
|
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True))
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch):
|
||||||
|
class FakeDist:
|
||||||
|
def read_text(self, name):
|
||||||
|
assert name == "direct_url.json"
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"url": "https://github.com/huggingface/huggingface_hub.git",
|
||||||
|
"vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_hf_hub_range_cache_branch()
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch):
|
||||||
|
class FakeDist:
|
||||||
|
def read_text(self, name):
|
||||||
|
assert name == "direct_url.json"
|
||||||
|
return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"})
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
assert_hf_hub_range_cache_branch()
|
||||||
@@ -2370,14 +2370,32 @@ def test_aggregate_images_when_use_videos_false():
|
|||||||
out = aggregate_pipeline_dataset_features(
|
out = aggregate_pipeline_dataset_features(
|
||||||
pipeline=rp,
|
pipeline=rp,
|
||||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||||
use_videos=False, # expect "image" dtype
|
use_videos=False, # images kept, stored as "image" dtype
|
||||||
patterns=None,
|
patterns=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
key = f"{OBS_IMAGES}.back"
|
key = f"{OBS_IMAGES}.back"
|
||||||
key_front = f"{OBS_IMAGES}.front"
|
key_front = f"{OBS_IMAGES}.front"
|
||||||
assert key not in out
|
assert key in out
|
||||||
assert key_front not in out
|
assert key_front in out
|
||||||
|
assert out[key]["dtype"] == "image"
|
||||||
|
assert out[key_front]["dtype"] == "image"
|
||||||
|
assert out[key]["shape"] == initial["back"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_images_excluded():
|
||||||
|
rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)])
|
||||||
|
initial = {"back": (480, 640, 3)}
|
||||||
|
|
||||||
|
out = aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=rp,
|
||||||
|
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||||
|
exclude_images=True,
|
||||||
|
patterns=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert f"{OBS_IMAGES}.back" not in out
|
||||||
|
assert f"{OBS_IMAGES}.front" not in out
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_images_when_use_videos_true():
|
def test_aggregate_images_when_use_videos_true():
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
|
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
|
||||||
from lerobot.teleoperators.rebot_102_leader import (
|
from lerobot.teleoperators.rebot_102_leader import (
|
||||||
RebotArm102Leader,
|
RebotArm102Leader,
|
||||||
RebotArm102LeaderConfig,
|
RebotArm102LeaderConfig,
|
||||||
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
|
|||||||
|
|
||||||
def test_bimanual_prefixes_features():
|
def test_bimanual_prefixes_features():
|
||||||
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
|
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
|
||||||
cfg = BiRebotArm102LeaderConfig(
|
cfg = BiRebot102LeaderConfig(
|
||||||
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
|
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
|
||||||
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
|
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
|
||||||
)
|
)
|
||||||
teleop = BiRebotArm102Leader(cfg)
|
teleop = BiRebot102Leader(cfg)
|
||||||
assert any(k.startswith("left_") for k in teleop.action_features)
|
assert any(k.startswith("left_") for k in teleop.action_features)
|
||||||
assert any(k.startswith("right_") for k in teleop.action_features)
|
assert any(k.startswith("right_") for k in teleop.action_features)
|
||||||
assert "left_gripper.pos" in teleop.action_features
|
assert "left_gripper.pos" in teleop.action_features
|
||||||
|
|||||||
Reference in New Issue
Block a user