Compare commits

..

14 Commits

Author SHA1 Message Date
Pepijn 51c023a7a1 Tune native HTTP range diagnostics 2026-06-17 21:50:05 +02:00
Pepijn 51ea18cb7a Allow native HTTP sidecar range diagnostics 2026-06-17 21:36:57 +02:00
Pepijn 04ab43b8d2 Report range read timing breakdown 2026-06-17 21:20:08 +02:00
Pepijn cdfe192491 Remove random frame benchmark path 2026-06-17 21:14:42 +02:00
Pepijn 3451e53452 Use HfFileSystem for sidecar episode benchmark 2026-06-17 21:01:43 +02:00
Pepijn 30849ce74f Report memory usage in cache benchmarks 2026-06-17 20:54:12 +02:00
Pepijn 7d6907c444 Add random frame range fetch benchmark 2026-06-17 20:48:46 +02:00
Pepijn d99e1fe89d Report episode cache fill stage timings 2026-06-17 20:29:57 +02:00
Pepijn 7fcde61b69 Report full dataset estimate in episode cache benchmark 2026-06-17 20:25:21 +02:00
Pepijn bdfe8f8ce9 Use full MP4 sidecar for episode cache benchmark 2026-06-17 20:22:04 +02:00
Pepijn 34d0495d03 Retry transient native HTTP range failures 2026-06-17 20:19:54 +02:00
Pepijn 834c282631 Make episode cache benchmark fetch-only by default 2026-06-17 20:16:30 +02:00
Pepijn f132885cbc Pin Hub range cache and datasets main sources 2026-06-17 19:46:41 +02:00
Pepijn d0686be2f5 Add episode video streaming byte cache 2026-06-17 19:31:02 +02:00
36 changed files with 2662 additions and 1830 deletions
-174
View File
@@ -1,174 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Load an SMPL motion clip and expose it in SONIC's encoder format.
SONIC's whole-body tracking mode (``encode_mode == 2``) consumes a flat
720-vector ``smpl_joints_10frame_step1`` = 10 consecutive frames x 24 SMPL
joints x 3 (xyz) at 50 Hz.
IMPORTANT - frame convention: the encoder expects each frame's joints with the
body's *root orientation removed* (per-frame canonical), exactly like the live
deploy stream's ``smpl_joints_local`` (see ``process_smpl_joints`` in the GEAR
PICO teleop and ``smpl_joints_multi_future_local`` in training). The reference
``smpl_filtered`` clips instead store **world-frame** joints (heading retained),
so feeding them raw makes the robot move but mis-track / never face-forward.
This loader therefore canonicalizes on load using the clip's per-frame root
orientation (``pose_aa[:, :3]``):
A = Rx(+90deg) * rotvec(pose_aa[:, :3]) # y-up -> z-up root quat
local = base120 * A^-1 * joints # remove root orient
with ``base120 = quat(0.5,0.5,0.5,0.5)`` (SMPL base rotation). This reproduces
the deployed transform (verified: per-frame hip-heading std -> 0).
Clip is read from a numpy ``.npz``. Expected keys:
smpl_joints : (T, 24, 3) float32 -- world-frame joint positions, 50 fps
pose_aa : (T, 72) float32 -- SMPL axis-angle (root = [:, :3])
transl : (T, 3) float32 -- global root translation (optional)
fps : scalar
Example:
python examples/unitree_g1/motion_loader.py \
--motion examples/unitree_g1/motions/walk_forward.npz
"""
import argparse
import numpy as np
WINDOW = 10 # frames per encoder window (smpl_joints_10frame_step1)
N_JOINTS = 24
JOINT_DIM = 3
SMPL_OBS_DIM = WINDOW * N_JOINTS * JOINT_DIM # 720
def canonicalize_smpl_joints(smpl_joints: np.ndarray, root_aa: np.ndarray) -> np.ndarray:
"""Remove per-frame root orientation -> SONIC ``smpl_joints_local`` format.
Args:
smpl_joints: (T, 24, 3) world-frame (z-up) SMPL joint positions.
root_aa: (T, 3) SMPL global-orient axis-angle (y-up convention).
Returns:
(T, 24, 3) per-frame root-orientation-removed joints.
"""
from scipy.spatial.transform import Rotation as R
rx90 = R.from_euler("x", 90, degrees=True) # smpl_root_ytoz_up
base120 = R.from_quat([0.5, 0.5, 0.5, 0.5]) # remove_smpl_base_rot
a = rx90 * R.from_rotvec(root_aa) # z-up root quat (left-mult)
b_inv = base120 * a.inv() # inv(remove_smpl_base_rot(a))
return np.einsum("tij,tkj->tki", b_inv.as_matrix(), smpl_joints).astype(np.float32)
class SmplMotion:
"""A single SMPL clip with SONIC-format windowing."""
def __init__(self, path: str, loop: bool = True, canonicalize: bool = True):
data = np.load(path)
smpl_joints = data["smpl_joints"].astype(np.float32) # (T, 24, 3)
self.pose_aa = data["pose_aa"].astype(np.float32) if "pose_aa" in data.files else None
self.transl = data["transl"].astype(np.float32) if "transl" in data.files else None
self.fps = float(data["fps"]) if "fps" in data.files else 50.0
self.loop = loop
if smpl_joints.ndim != 3 or smpl_joints.shape[1:] != (N_JOINTS, JOINT_DIM):
raise ValueError(
f"Expected smpl_joints (T, {N_JOINTS}, {JOINT_DIM}), got {smpl_joints.shape}"
)
# Reference clips store world-frame joints; the encoder wants per-frame
# root-orientation-removed joints. Canonicalize when we have the root pose.
self.canonicalized = False
if canonicalize and self.pose_aa is not None:
smpl_joints = canonicalize_smpl_joints(smpl_joints, self.pose_aa[:, :3])
self.canonicalized = True
self.smpl_joints = smpl_joints
self.num_frames = self.smpl_joints.shape[0]
self._cursor = 0
def window(self, start: int) -> np.ndarray:
"""Return the 720-vector for the 10-frame window beginning at ``start``.
Frames are laid out oldest->newest, joint-major within a frame:
[f0_j0_xyz, f0_j1_xyz, ..., f9_j23_xyz].
"""
idx = np.arange(start, start + WINDOW)
if self.loop:
idx = np.mod(idx, self.num_frames)
else:
idx = np.clip(idx, 0, self.num_frames - 1)
return self.smpl_joints[idx].reshape(-1).astype(np.float32)
def reset(self):
self._cursor = 0
def step(self) -> np.ndarray:
"""Advance one frame and return the current 720-vector window."""
w = self.window(self._cursor)
self._cursor += 1
if self.loop:
self._cursor %= self.num_frames
return w
@property
def done(self) -> bool:
return (not self.loop) and (self._cursor + WINDOW >= self.num_frames)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--motion", required=True, help="Path to motion .npz")
parser.add_argument("--no-loop", action="store_true")
parser.add_argument("--no-canon", action="store_true",
help="Skip canonicalization (feed raw stored joints)")
args = parser.parse_args()
m = SmplMotion(args.motion, loop=not args.no_loop, canonicalize=not args.no_canon)
duration = m.num_frames / m.fps
print(f"Loaded '{args.motion}'")
print(f" frames={m.num_frames} fps={m.fps:.1f} duration={duration:.1f}s")
print(f" smpl_joints={m.smpl_joints.shape} canonicalized={m.canonicalized} "
f"pose_aa={None if m.pose_aa is None else m.pose_aa.shape} "
f"transl={None if m.transl is None else m.transl.shape}")
# Sanity: after canonicalization the per-frame body heading should be fixed.
j = m.smpl_joints
v = (j[:, 2, :2] - j[:, 1, :2]) # R_hip - L_hip, horizontal
a = np.arctan2(v[:, 1], v[:, 0])
rlen = np.clip(np.hypot(np.cos(a).mean(), np.sin(a).mean()), 1e-9, 1.0)
circ_std = np.degrees(np.sqrt(-2 * np.log(rlen)))
print(f" hip-heading circ-std={circ_std:.1f} deg "
f"(~0 => orientation removed; large => world-frame)")
w0 = m.window(0)
print(f" window(0): shape={w0.shape} (expected {SMPL_OBS_DIM}) "
f"min={w0.min():.3f} max={w0.max():.3f}")
assert w0.shape == (SMPL_OBS_DIM,), "window must be 720-dim for obs[922:1642]"
# Simulate a few control ticks.
print(" stepping 5 ticks:")
for t in range(5):
w = m.step()
print(f" t={t} cursor={m._cursor} window_norm={np.linalg.norm(w):.2f}")
print("OK: motion loads and yields SONIC-format 720-vec windows.")
if __name__ == "__main__":
main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
-88
View File
@@ -1,88 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert a GEAR-SONIC / BONES-SEED ``smpl_filtered`` clip (.pkl) to .npz.
The reference clips are zlib-compressed joblib pickles holding a dict with
``pose_aa`` (T, 72), ``transl`` (T, 3), ``smpl_joints`` (T, 24, 3), ``fps``.
``motion_loader.SmplMotion`` consumes the .npz form so the runtime needs no
joblib dependency. Canonicalization (root-orientation removal) happens at load
time in ``motion_loader``, so this converter just repackages the raw arrays.
Run this in an environment that has ``joblib`` (e.g. the sonic teleop venv):
python examples/unitree_g1/pkl_to_npz.py \
--pkl sample_data/smpl_filtered/walk_forward_amateur_001__A001.pkl \
--out examples/unitree_g1/motions/walk_forward.npz
"""
import argparse
from pathlib import Path
import numpy as np
def load_pkl(path: str) -> dict:
try:
import joblib
return joblib.load(path)
except Exception:
# joblib clips are zlib-compressed pickles; fall back to manual inflate.
import pickle
import zlib
with open(path, "rb") as f:
raw = f.read()
try:
raw = zlib.decompress(raw)
except zlib.error:
pass
return pickle.loads(raw)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--pkl", required=True, help="Input smpl_filtered .pkl")
parser.add_argument("--out", required=True, help="Output .npz path")
args = parser.parse_args()
d = load_pkl(args.pkl)
if not isinstance(d, dict) or "smpl_joints" not in d:
raise ValueError(f"Unexpected pkl structure; keys={list(d) if isinstance(d, dict) else type(d)}")
smpl_joints = np.asarray(d["smpl_joints"], np.float32)
if smpl_joints.ndim != 3 or smpl_joints.shape[1:] != (24, 3):
raise ValueError(f"smpl_joints must be (T,24,3), got {smpl_joints.shape}")
out = {"smpl_joints": smpl_joints, "fps": np.float32(d.get("fps", 50.0))}
if "pose_aa" in d:
out["pose_aa"] = np.asarray(d["pose_aa"], np.float32)
else:
print("[warn] no pose_aa -> loader cannot canonicalize (will feed raw)")
if "transl" in d:
out["transl"] = np.asarray(d["transl"], np.float32)
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(args.out, **out)
dur = smpl_joints.shape[0] / float(out["fps"])
print(f"Wrote {args.out}")
print(f" frames={smpl_joints.shape[0]} fps={float(out['fps']):.1f} duration={dur:.1f}s "
f"keys={sorted(out)}")
if __name__ == "__main__":
main()
-255
View File
@@ -1,255 +0,0 @@
#!/usr/bin/env python
"""
SONIC planner with full mode control.
Keyboard controls:
N / P - next / previous motion set
1-8 - select mode within current set
WASD - movement direction
Q / E - rotate facing left / right
9 / 0 - decrease / increase speed
- / = - decrease / increase height
R - force replan
M - toggle SMPL motion playback <-> locomotion (needs --motion-file)
Space - emergency stop -> IDLE
Esc - quit
Gamepad controls (Unitree wireless controller):
Left stick Y - speed (forward = fast, back = stop)
Left stick X - movement direction (offset from facing)
Right stick X - facing direction (incremental rotation)
Right stick Y - height (up = tall 0.8m, down = low 0.1m)
Buttons - unused (mode selection is keyboard-only)
For teleop integration use --robot.controller=SonicWholeBodyController instead.
"""
import argparse
import faulthandler
import gc
import sys
import time
import numpy as np
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.controllers.sonic_whole_body import SonicRuntime
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
CONTROL_DT,
DEFAULT_ANGLES,
LM,
MOTION_SETS,
RawKeyboard,
compute_kp_kd,
drain_keyboard,
)
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
from motion_loader import SmplMotion
def main():
parser = argparse.ArgumentParser(description="SONIC planner with keyboard + gamepad control")
parser.add_argument("--ip", type=str, default=None,
help="Robot IP for real hardware (e.g. 192.168.123.164). "
"Omit for simulation.")
parser.add_argument("--log-csv", action="store_true",
help="Write /tmp/sonic_pose_log.csv (disabled by default for teleop perf)")
parser.add_argument("--cpu", action="store_true",
help="Force CPU ONNX Runtime (skip CUDA even if onnxruntime-gpu is installed)")
parser.add_argument("--headless", action="store_true",
help="Ignored for sim (stock UnitreeG1 uses hub MuJoCo defaults)")
parser.add_argument("--gamepad", action="store_true",
help="Read Unitree wireless gamepad in sim (default: keyboard-only in sim)")
parser.add_argument("--keyboard-only", action="store_true",
help="Ignore wireless gamepad (terminal keyboard only)")
parser.add_argument("--motion-file", type=str, default=None,
help="Play an SMPL motion clip (.npz) via SONIC whole-body mode "
"(encode_mode=2) instead of locomotion planning.")
parser.add_argument("--no-loop", action="store_true",
help="With --motion-file, play once instead of looping")
args = parser.parse_args()
# Surface native crashes (onnxruntime / mujoco) with a real traceback, and
# avoid losing buffered diagnostics if the process dies mid-loop.
faulthandler.enable()
try:
sys.stdout.reconfigure(line_buffering=True)
except Exception:
pass
print("=" * 60)
print("SONIC planner - full mode control")
print(" N/P cycle sets | 1-8 select mode | WASD move")
print(" Q/E rotate | 9/0 speed | -/= height")
print(" R replan | Space IDLE | Esc quit")
if args.ip:
print(f" Robot IP: {args.ip}")
else:
print(" Mode: simulation")
print("=" * 60 + "\n")
cfg = UnitreeG1Config(controller=None) # full-body SONIC; standalone loop owns publish
if args.ip:
cfg.is_simulation = False
cfg.robot_ip = args.ip
else:
cfg.is_simulation = True
if args.headless:
print("[Note] --headless ignored: sim uses stock UnitreeG1 + hub env")
robot = UnitreeG1(cfg)
robot.connect()
kp, kd = compute_kp_kd()
robot.kp = kp.copy()
robot.kd = kd.copy()
runtime = SonicRuntime(force_cpu=args.cpu)
controller = runtime.controller
ms = runtime.ms
motion = None
if args.motion_file:
motion = SmplMotion(args.motion_file, loop=not args.no_loop)
controller.smpl_motion = motion # lets 'M' key toggle playback
controller.encode_mode = 2 # start in SONIC whole-body SMPL imitation
dur = motion.num_frames / motion.fps
print(f"\n[Motion] SMPL whole-body playback: {args.motion_file}")
print(f" frames={motion.num_frames} fps={motion.fps:.1f} "
f"duration={dur:.1f}s loop={not args.no_loop} encode_mode=2")
print(" Press 'M' to toggle SMPL playback <-> locomotion at runtime.")
runtime.controller.print_input_diagnostics()
print(f"\nStarting: {MOTION_SETS[0][0]} (default mode: {LM(ms.mode).name})")
[print(f" {i+1}: {m.name}") for i, m in enumerate(MOTION_SETS[0][1])]
print("\n[Ready] Click THIS terminal, then W/A/S/D to move. "
"1-6 change mode, 9/0 speed, Esc quit.\n", flush=True)
# Sim hub publishes wireless_remote bytes that can fight terminal WASD.
base_joystick = not args.keyboard_only and (args.gamepad or args.ip is not None)
with RawKeyboard() as kb:
try:
gc.disable()
gc_timer = 0.0
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
time.sleep(1.0)
last_status = time.time() - 2.1
loop_t = enc_t = dec_t = obs_t = act_t = []
slow_n = blend_n = 0
stall_src = ""
did_blend = False
prev_end = time.time()
t_start = time.time()
log_path = "/tmp/sonic_pose_log.csv"
jnames = [m.name for m in G1_29_JointIndex]
log_ctx = open(log_path, "w") if args.log_csv else None
if log_ctx:
log_ctx.write("t,step,cursor,ts,blend,mode," +
",".join(f"q{i}" for i in range(29)) + "," +
",".join(f"ref{i}" for i in range(29)) + "," +
",".join(f"act{i}" for i in range(29)) +
",delta_max,action_norm,token_norm\n")
try:
while not robot._shutdown_event.is_set():
t0 = time.time()
if drain_keyboard(kb, ms, controller):
break
obs = robot.get_observation()
t_obs = time.time()
obs_t.append(1000 * (t_obs - t0))
if not obs:
runtime.tick({}, use_joystick=False)
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
continue
# SMPL playback only while in whole-body mode; 'M' toggles it.
motion_active = motion is not None and controller.encode_mode == 2
if motion_active:
controller.smpl_joints_10frame_step1 = motion.step()
if motion.done:
print("\n[Motion] clip finished")
break
step_before = runtime.step
t_step = time.time()
action = runtime.tick(obs, use_joystick=base_joystick and not motion_active)
step_ms = 1000 * (time.time() - t_step)
do_enc = step_before % 5 == 0
(enc_t if do_enc else dec_t).append(step_ms)
t_act = time.time()
robot.send_action(action)
act_t.append(1000 * (time.time() - t_act))
if log_ctx and runtime.step % 5 == 0:
t_rel = time.time() - t_start
q_r = np.array([obs.get(f"{n}.q", 0) for n in jnames])
a_v = np.array([action.get(f"{n}.q", 0) for n in jnames])
cur, ts = controller.ref_cursor, controller.motion_timesteps
q_ref = controller.motion_joint_positions[min(cur, ts - 1)] if ts > 0 else np.zeros(29)
log_ctx.write(f"{t_rel:.4f},{runtime.step},{cur},{ts},{int(did_blend)},{ms.mode}," +
",".join(f"{v:.6f}" for v in q_r) + "," +
",".join(f"{v:.6f}" for v in q_ref) + "," +
",".join(f"{v:.6f}" for v in a_v) + "," +
f"{np.max(np.abs(a_v - q_r)):.6f},"
f"{np.linalg.norm(a_v):.6f},"
f"{np.linalg.norm(controller.token):.6f}\n")
did_blend = False
now = time.time()
loop_ms = 1000 * (now - t0)
if loop_ms > 50:
stall_src = (f"[STALL] {loop_ms:.0f}ms: "
f"obs={obs_t[-1]:.0f} step={step_ms:.0f} act={act_t[-1]:.0f}")
if loop_ms > CONTROL_DT * 1500:
slow_n += 1
if now - last_status > 2.0:
def _avg(lst):
return sum(lst) / len(lst) if lst else 0
hz = 1000 / _avg(loop_t) if _avg(loop_t) else 0
print(f"\r {ms.status_line()} step={runtime.step} "
f"ref={controller.ref_cursor}/{controller.motion_timesteps} "
f"loop={_avg(loop_t):.1f}ms(max={max(loop_t, default=0):.1f}) hz={hz:.0f} "
f"enc={_avg(enc_t):.1f} dec={_avg(dec_t):.1f} obs={_avg(obs_t):.1f} "
f"slow={slow_n} blends={blend_n}", end="", flush=True)
if stall_src:
print(f"\n {stall_src}")
stall_src = ""
last_status = now
loop_t = enc_t = dec_t = obs_t = act_t = []
slow_n = blend_n = 0
prev_end = time.time()
gc_timer += CONTROL_DT
if gc_timer >= 10.0:
gc.collect()
gc_timer = 0.0
loop_t.append(loop_ms)
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
finally:
if log_ctx:
log_ctx.close()
except KeyboardInterrupt:
pass
finally:
gc.enable()
if args.log_csv:
print(f"\n[Log] Saved to {log_path}")
runtime.shutdown()
print("\nStopping...")
if robot.is_connected:
robot.disconnect()
print("Done.")
if __name__ == "__main__":
main()
+3
View File
@@ -355,6 +355,8 @@ explicit = true
[tool.uv.sources]
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
huggingface-hub = { git = "https://github.com/huggingface/huggingface_hub.git", branch = "feat/hffs-cache-cdn-range-reads" }
datasets = { git = "https://github.com/huggingface/datasets.git", branch = "main" }
[tool.setuptools.package-data]
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
@@ -421,6 +423,7 @@ exclude_dirs = [
skips = ["B101", "B311", "B404", "B603", "B615"]
[tool.typos]
default.extend-words = { trak = "trak" }
default.extend-ignore-re = [
"(?Rm)^.*(#|//)\\s*spellchecker:disable-line$", # spellchecker:disable-line
"(?s)(#|//)\\s*spellchecker:off.*?\\n\\s*(#|//)\\s*spellchecker:on", # spellchecker:<on|off>
+860
View File
@@ -0,0 +1,860 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import argparse
import random
import resource
import tempfile
import threading
import time
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import fsspec
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.episode_video_streaming import (
EpisodeByteCache,
EpisodeVideoManifest,
NativeHTTPRangeFetcher,
assert_hf_hub_range_cache_branch,
)
from lerobot.datasets.video_utils import VideoDecoderCache, decode_video_frames_torchcodec
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
SIDECAR_CACHE_DIR = Path(tempfile.gettempdir()) / "lerobot-sidecars"
FULL_SIDECAR_NAME = "molmoact2-full.npz"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Benchmark episode-level streaming mini-MP4 cache.")
parser.add_argument("--repo-id", default=DEFAULT_REPO)
parser.add_argument("--revision", default=DEFAULT_REVISION)
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
parser.add_argument(
"--strategy",
choices=("both", "full", "indexed", "remote-decoder", "native-http"),
default="both",
help=argparse.SUPPRESS,
)
parser.add_argument(
"--range-backend",
choices=("fsspec", "native-http"),
default="fsspec",
help="Range reader used by indexed/full episode-pool fetch tracks.",
)
parser.add_argument("--num-episodes", type=int, default=512)
parser.add_argument(
"--manifest-episodes",
type=int,
default=None,
help="Limit manifest construction to the first N episodes for local smoke tests.",
)
parser.add_argument("--pool-size", type=int, default=16)
parser.add_argument("--workers", type=int, default=8)
parser.add_argument(
"--native-http-connections",
type=int,
default=None,
help="Max HTTP connections for --range-backend native-http. Defaults to --workers.",
)
parser.add_argument(
"--native-http-retries",
type=int,
default=8,
help="Retries per native HTTP range request.",
)
parser.add_argument(
"--native-http-timeout",
type=float,
default=120.0,
help="Timeout in seconds for native HTTP requests.",
)
parser.add_argument(
"--include-decode",
action="store_true",
help="Also run decoder-opening/frame-decode comparison tracks. Fetch-only is the default.",
)
parser.add_argument("--decode-workers", type=int, default=1)
parser.add_argument("--prefetch-ahead", type=int, default=8)
parser.add_argument("--frames-per-episode", type=int, default=16)
parser.add_argument("--max-probe-mb", type=int, default=64)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--byte-budget-gb", type=float, default=80)
parser.add_argument(
"--in-memory", action="store_true", help="Accepted for compatibility; manifest is always in memory."
)
parser.add_argument("--no-hub-branch-assert", action="store_true")
return parser.parse_args()
def _episode_pool(total: int, requested: int, pool_size: int, seed: int) -> list[int]:
rng = random.Random(seed)
upper = min(total, requested)
if pool_size > upper:
raise ValueError(f"pool-size={pool_size} exceeds available episodes={upper}")
return rng.sample(range(upper), pool_size)
def _timestamps(manifest: EpisodeVideoManifest, episodes: Sequence[int], frames_per_episode: int, seed: int):
rng = random.Random(seed)
out: dict[tuple[int, str], list[float]] = {}
for ep in episodes:
for camera_key in manifest.video_keys:
span = manifest.lookup(ep, camera_key)
lo = span.first_pts
hi = max(span.last_pts, lo)
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
return out
def _timestamps_from_meta(
meta: LeRobotDatasetMetadata, episodes: Sequence[int], frames_per_episode: int, seed: int
) -> dict[tuple[int, str], list[float]]:
rng = random.Random(seed)
out: dict[tuple[int, str], list[float]] = {}
for ep in episodes:
row = meta.episodes[ep]
for camera_key in meta.video_keys:
lo = float(row[f"videos/{camera_key}/from_timestamp"])
hi = max(float(row[f"videos/{camera_key}/to_timestamp"]), lo)
out[(ep, camera_key)] = sorted(rng.uniform(lo, hi) for _ in range(frames_per_episode))
return out
def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
total = 0
for ep in episodes:
for camera_key in manifest.video_keys:
total += manifest.lookup(ep, camera_key).mdat_length
return total
def _decode_all(
cache: EpisodeByteCache, timestamps: dict[tuple[int, str], list[float]], *, decode_workers: int
) -> float:
start = time.perf_counter()
items = list(timestamps.items())
if decode_workers <= 1:
for (ep, camera_key), ts in items:
cache.get_frames(ep, camera_key, ts)
else:
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
futures = [pool.submit(cache.get_frames, ep, camera_key, ts) for (ep, camera_key), ts in items]
for future in futures:
future.result()
return time.perf_counter() - start
def _fill_cache(cache: EpisodeByteCache, episodes: Sequence[int]) -> float:
start = time.perf_counter()
for ep in episodes:
cache.submit_prefetch(ep)
for ep in episodes:
cache.ensure_ready(ep)
return time.perf_counter() - start
def _samples_per_s(elapsed_s: float, episodes: Sequence[int], frames_per_episode: int) -> float:
if elapsed_s <= 0:
return float("inf")
return len(episodes) * frames_per_episode / elapsed_s
def _log(message: str) -> None:
print(message, flush=True)
def _format_duration(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.1f}s"
if seconds < 3600:
return f"{seconds / 60:.1f}m"
return f"{seconds / 3600:.1f}h"
def _current_rss_mib() -> float | None:
status_path = Path("/proc/self/status")
if not status_path.exists():
return None
for line in status_path.read_text().splitlines():
if line.startswith("VmRSS:"):
return float(line.split()[1]) / 1024
return None
def _peak_rss_mib() -> float:
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
# Linux reports KiB; macOS reports bytes.
if rss > 10**8:
return rss / 1024**2
return rss / 1024
def _memory_snapshot() -> dict[str, float | None]:
return {"rss_mib": _current_rss_mib(), "peak_rss_mib": _peak_rss_mib()}
def _print_memory_summary(start: dict[str, float | None], end: dict[str, float | None]) -> None:
start_rss = start["rss_mib"]
end_rss = end["rss_mib"]
delta = None if start_rss is None or end_rss is None else end_rss - start_rss
print()
print("| Memory | MiB |")
print("|---|---:|")
if start_rss is not None:
print(f"| rss start | {start_rss:.1f} |")
if end_rss is not None:
print(f"| rss end | {end_rss:.1f} |")
if delta is not None:
print(f"| rss delta | {delta:.1f} |")
print(f"| peak rss | {end['peak_rss_mib']:.1f} |")
def _root_join(data_root: str, relative_path: str) -> str:
if data_root.startswith("hf://"):
return f"{data_root.rstrip('/')}/{relative_path}"
return str(Path(data_root) / relative_path)
def _find_or_download_sidecar(data_root: str, manifest_episode_count: int) -> Path | None:
_ = manifest_episode_count
local = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
if _valid_sidecar(local):
return local
if local.exists():
print(f"mp4_sidecar_invalid_local: {local}")
local.unlink()
remote_relative = f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}"
remote = _root_join(data_root, remote_relative)
protocol = "hf" if data_root.startswith("hf://") else "file"
fs = fsspec.filesystem(protocol)
if not fs.exists(remote):
return None
local.parent.mkdir(parents=True, exist_ok=True)
print(f"downloading_mp4_sidecar: {remote} -> {local}")
if data_root.startswith("hf://"):
_download_sidecar_native_http(data_root, remote_relative, local)
else:
fs.get(remote, str(local))
return local
def _valid_sidecar(path: Path) -> bool:
if not path.exists():
return False
try:
with np.load(path, allow_pickle=False) as data:
return "manifest_json" in data
except Exception:
return False
def _download_sidecar_native_http(data_root: str, relative_path: str, local: Path) -> None:
fetcher = NativeHTTPRangeFetcher(data_root, max_connections=16)
tmp = local.with_suffix(local.suffix + ".tmp")
try:
size = fetcher.info_size(relative_path)
chunk_size = 16 * 1024 * 1024
ranges = [(offset, min(chunk_size, size - offset)) for offset in range(0, size, chunk_size)]
with tmp.open("wb") as out_file:
out_file.truncate(size)
def read_chunk(offset_length: tuple[int, int]) -> tuple[int, bytes]:
offset, length = offset_length
return offset, fetcher.read_range(relative_path, offset, length)
start = time.perf_counter()
done = 0
with ThreadPoolExecutor(max_workers=8) as pool:
futures = [pool.submit(read_chunk, item) for item in ranges]
with tmp.open("r+b") as rw_file:
for future in futures:
offset, data = future.result()
rw_file.seek(offset)
rw_file.write(data)
done += len(data)
elapsed = max(time.perf_counter() - start, 1e-9)
print(
f"sidecar_download: {done / 1024**2:.1f}/{size / 1024**2:.1f} MiB "
f"({done / elapsed / 1024**2:.1f} MiB/s)",
flush=True,
)
tmp.replace(local)
finally:
fetcher.close()
class EpisodeParquetReader:
def __init__(self, meta: LeRobotDatasetMetadata, data_root: str):
self.meta = meta
self.data_root = data_root
protocol = "hf" if data_root.startswith("hf://") else "file"
self.fs = fsspec.filesystem(protocol)
self._episode_row_groups = self._build_episode_row_groups()
self._table_cache: dict[str, pa.Table] = {}
self._cache_lock = threading.Lock()
def read_episode(self, episode_index: int) -> None:
relative_path = str(self.meta.get_data_file_path(episode_index))
table = self._read_table(relative_path)
table.filter(pc.equal(table["episode_index"], episode_index))
def _read_table(self, relative_path: str) -> pa.Table:
with self._cache_lock:
table = self._table_cache.get(relative_path)
if table is not None:
return table
with self.fs.open(
_root_join(self.data_root, relative_path), "rb", block_size=2**20, cache_type="none"
) as f:
table = pq.ParquetFile(f).read()
with self._cache_lock:
return self._table_cache.setdefault(relative_path, table)
def submit_read_episode(self, pool: ThreadPoolExecutor, episode_index: int):
return pool.submit(self.read_episode, episode_index)
def read_episodes(self, episodes: Sequence[int], *, workers: int) -> float:
start = time.perf_counter()
if workers <= 1:
for ep in episodes:
self.read_episode(ep)
else:
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = [pool.submit(self.read_episode, ep) for ep in episodes]
for future in futures:
future.result()
return time.perf_counter() - start
def _build_episode_row_groups(self) -> dict[int, int]:
counts: dict[tuple[int, int], int] = {}
row_groups = {}
for ep_idx in range(int(self.meta.total_episodes)):
ep = self.meta.episodes[ep_idx]
key = (int(ep["data/chunk_index"]), int(ep["data/file_index"]))
row_groups[ep_idx] = counts.get(key, 0)
counts[key] = row_groups[ep_idx] + 1
return row_groups
def run_fetch_pool(
manifest: EpisodeVideoManifest,
data_root: str,
episodes: Sequence[int],
byte_budget: int,
workers: int,
range_backend: str,
args: argparse.Namespace,
) -> dict[str, float]:
with EpisodeByteCache(
manifest,
data_root,
byte_budget=byte_budget,
workers=workers,
range_backend=range_backend,
native_http_connections=args.native_http_connections,
native_http_timeout=args.native_http_timeout,
native_http_retries=args.native_http_retries,
open_decoders=False,
) as cache:
elapsed = _fill_cache(cache, episodes)
timings = cache.timing_summary()
byte_count = _bytes_for(manifest, episodes)
episode_mb = byte_count / len(episodes) / 1024**2
job_count = max(timings["jobs"], 1.0)
result = {
"fetch_s": elapsed,
"fetch_mbps": byte_count / elapsed / 1024**2,
"fetch_episodes_s": len(episodes) / elapsed,
"episode_mb": episode_mb,
"avg_mb_miss": byte_count / (len(episodes) * len(manifest.video_keys)) / 1024**2,
"jobs": timings["jobs"],
"lookup_ms": timings["lookup_s"] * 1000 / job_count,
"range_fetch_ms": timings["fetch_s"] * 1000 / job_count,
"synthesize_ms": timings["synthesize_s"] * 1000 / job_count,
"store_ms": timings["store_s"] * 1000 / job_count,
}
result.update({key: value for key, value in timings.items() if key.startswith("range_")})
return result
def run_parallel(
manifest: EpisodeVideoManifest,
data_root: str,
episodes: Sequence[int],
timestamps: dict[tuple[int, str], list[float]],
byte_budget: int,
workers: int,
decode_workers: int,
frames_per_episode: int,
parquet_reader: EpisodeParquetReader,
range_backend: str,
) -> dict[str, float]:
with EpisodeByteCache(
manifest,
data_root,
byte_budget=byte_budget,
workers=workers,
range_backend=range_backend,
open_decoders=False,
) as cache:
parquet_s = parquet_reader.read_episodes(episodes, workers=workers)
fetch_s = _fill_cache(cache, episodes)
decoder_start = time.perf_counter()
for ep in episodes:
for camera_key in manifest.video_keys:
cache.get_decoder(ep, camera_key)
decoder_s = time.perf_counter() - decoder_start
decode_s = _decode_all(cache, timestamps, decode_workers=decode_workers)
byte_count = _bytes_for(manifest, episodes)
return {
"fetch_s": fetch_s,
"fetch_mbps": byte_count / fetch_s / 1024**2,
"fetch_episodes_s": len(episodes) / fetch_s,
"parquet_s": parquet_s,
"decoder_ms_miss": decoder_s * 1000 / (len(episodes) * len(manifest.video_keys)),
"decode_samples_s": _samples_per_s(decode_s, episodes, frames_per_episode),
}
def run_overlapped(
manifest: EpisodeVideoManifest,
data_root: str,
episodes: Sequence[int],
timestamps: dict[tuple[int, str], list[float]],
byte_budget: int,
workers: int,
decode_workers: int,
frames_per_episode: int,
prefetch_ahead: int,
parquet_reader: EpisodeParquetReader,
range_backend: str,
) -> dict[str, float]:
with EpisodeByteCache(
manifest,
data_root,
byte_budget=byte_budget,
workers=workers,
range_backend=range_backend,
open_decoders=True,
) as cache:
start = time.perf_counter()
video_wait_decode_s = 0.0
parquet_wait_s = 0.0
parquet_pool = ThreadPoolExecutor(max_workers=max(1, min(workers, len(episodes))))
parquet_futures = {
ep: parquet_reader.submit_read_episode(parquet_pool, ep) for ep in episodes[:prefetch_ahead]
}
for ep in episodes[:prefetch_ahead]:
cache.submit_prefetch(ep)
try:
for idx, ep in enumerate(episodes):
next_idx = idx + prefetch_ahead
if next_idx < len(episodes):
next_ep = episodes[next_idx]
cache.submit_prefetch(next_ep)
parquet_futures[next_ep] = parquet_reader.submit_read_episode(parquet_pool, next_ep)
parquet_start = time.perf_counter()
parquet_futures.pop(ep).result()
parquet_wait_s += time.perf_counter() - parquet_start
video_start = time.perf_counter()
cache.ensure_ready(ep)
if decode_workers <= 1:
for camera_key in manifest.video_keys:
cache.get_frames(ep, camera_key, timestamps[(ep, camera_key)])
else:
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
futures = [
pool.submit(cache.get_frames, ep, camera_key, timestamps[(ep, camera_key)])
for camera_key in manifest.video_keys
]
for future in futures:
future.result()
video_wait_decode_s += time.perf_counter() - video_start
finally:
parquet_pool.shutdown(wait=True)
elapsed = time.perf_counter() - start
return {
"samples_s": _samples_per_s(elapsed, episodes, frames_per_episode),
"video_samples_s": _samples_per_s(video_wait_decode_s, episodes, frames_per_episode),
"parquet_samples_s": _samples_per_s(parquet_wait_s, episodes, frames_per_episode),
"wall_s": elapsed,
"video_wait_decode_s": video_wait_decode_s,
"parquet_wait_s": parquet_wait_s,
}
_remote_decoder_local = threading.local()
def _remote_decoder_cache() -> VideoDecoderCache:
cache = getattr(_remote_decoder_local, "cache", None)
if cache is None:
cache = VideoDecoderCache(max_size=None)
_remote_decoder_local.cache = cache
return cache
def _decode_remote_source(
meta: LeRobotDatasetMetadata,
data_root: str,
episode_index: int,
camera_key: str,
timestamps: list[float],
):
video_path = _root_join(data_root, str(meta.get_video_file_path(episode_index, camera_key)))
return decode_video_frames_torchcodec(
video_path,
timestamps,
tolerance_s=1.0 / float(meta.fps),
decoder_cache=_remote_decoder_cache(),
return_uint8=True,
)
def run_remote_decoder(
meta: LeRobotDatasetMetadata,
data_root: str,
episodes: Sequence[int],
timestamps: dict[tuple[int, str], list[float]],
*,
frames_per_episode: int,
decode_workers: int,
parquet_reader: EpisodeParquetReader,
) -> dict[str, float]:
items = [
(ep, camera_key, timestamps[(ep, camera_key)]) for ep in episodes for camera_key in meta.video_keys
]
start = time.perf_counter()
for ep, camera_key, ts in items:
if camera_key == meta.video_keys[0]:
parquet_reader.read_episode(ep)
_decode_remote_source(meta, data_root, ep, camera_key, ts)
sequential_s = time.perf_counter() - start
start = time.perf_counter()
if decode_workers <= 1:
for ep, camera_key, ts in items:
if camera_key == meta.video_keys[0]:
parquet_reader.read_episode(ep)
_decode_remote_source(meta, data_root, ep, camera_key, ts)
else:
with ThreadPoolExecutor(max_workers=decode_workers) as pool:
parquet_futures = [pool.submit(parquet_reader.read_episode, ep) for ep in episodes]
futures = [
pool.submit(_decode_remote_source, meta, data_root, ep, camera_key, ts)
for ep, camera_key, ts in items
]
for future in parquet_futures:
future.result()
for future in futures:
future.result()
parallel_s = time.perf_counter() - start
return {
"sequential_samples_s": _samples_per_s(sequential_s, episodes, frames_per_episode),
"parallel_samples_s": _samples_per_s(parallel_s, episodes, frames_per_episode),
}
def _print_range_timing_summary(fetch_pool: dict[str, float]) -> None:
range_jobs = fetch_pool.get("range_jobs", 0.0)
if range_jobs <= 0:
return
print()
print("| Range Read Stage | avg ms/range |")
print("|---|---:|")
for key, label in (
("range_open_s", "fsspec handle open/lookup"),
("range_seek_s", "fsspec seek"),
("range_read_s", "fsspec read"),
("range_resolve_s", "http URL resolve"),
("range_header_s", "http response headers"),
("range_first_byte_s", "http first body byte"),
("range_body_s", "http body drain"),
("range_retry_sleep_s", "http retry sleep"),
):
value = fetch_pool.get(key)
if value is not None:
print(f"| {label} | {value * 1000 / range_jobs:.3f} |")
if "range_retry_attempts" in fetch_pool:
print(f"| http retries | {fetch_pool['range_retry_attempts'] / range_jobs:.3f} |")
if fetch_pool.get("range_failed_requests"):
print(f"| http failed requests | {fetch_pool['range_failed_requests']:.0f} |")
print(f"| range reads | {range_jobs:.0f} |")
print(f"| avg MiB/range | {fetch_pool.get('range_bytes', 0.0) / range_jobs / 1024**2:.1f} |")
def run_indexed_strategy(
meta: LeRobotDatasetMetadata,
data_root: str,
args: argparse.Namespace,
parquet_reader: EpisodeParquetReader,
*,
range_backend: str = "fsspec",
label: str = "indexed",
sidecar_path: str | None = None,
) -> None:
_log(f"starting_strategy: {label}")
memory_start = _memory_snapshot()
manifest_start = time.perf_counter()
dataset_episode_count = int(meta.total_episodes)
manifest_episode_count = args.manifest_episodes or dataset_episode_count
manifest_episode_count = min(manifest_episode_count, dataset_episode_count, args.num_episodes)
manifest = EpisodeVideoManifest.build(
meta,
data_root,
episode_indices=range(manifest_episode_count),
range_backend=range_backend,
workers=args.workers,
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
sidecar_path=sidecar_path,
)
manifest_s = time.perf_counter() - manifest_start
_log(f"{label}: manifest_build_s={manifest_s:.2f}")
benchmark_episode_count = min(dataset_episode_count, args.num_episodes)
episodes = _episode_pool(dataset_episode_count, args.num_episodes, args.pool_size, args.seed)
byte_budget = int(args.byte_budget_gb * 1024**3)
byte_count = _bytes_for(manifest, episodes)
_log(
f"{label}: planned_video_fetch={byte_count / 1024**3:.2f} GiB per fetch track "
f"({byte_count / len(episodes) / 1024**2:.1f} MiB/episode)"
)
_log(f"{label}: filling episode byte cache with {args.workers} workers")
fetch_pool = run_fetch_pool(manifest, data_root, episodes, byte_budget, args.workers, range_backend, args)
estimated_dataset_s = dataset_episode_count / fetch_pool["fetch_episodes_s"]
estimated_benchmark_s = benchmark_episode_count / fetch_pool["fetch_episodes_s"]
print(f"manifest_build_s: {manifest_s:.2f}")
print(f"strategy: {label}")
print(f"range_backend: {range_backend}")
print(f"mp4_sidecar: {sidecar_path or 'none'}")
print(f"data_root: {data_root}")
print(f"dataset_episodes: {dataset_episode_count}")
print(f"benchmark_episodes: {benchmark_episode_count}")
print(f"pool_episodes: {len(episodes)}")
print(f"sampled_episodes: {episodes}")
print(f"cameras: {manifest.video_keys}")
print()
print(
"| Track | fetch MB/s | fetch eps/s | wall s | est benchmark | est full dataset | avg MB/camera | notes |"
)
print("|---|---:|---:|---:|---:|---:|---:|---|")
print(
f"| EPISODE POOL FETCH | {fetch_pool['fetch_mbps']:.1f} | "
f"{fetch_pool['fetch_episodes_s']:.2f} | {fetch_pool['fetch_s']:.2f} | "
f"{_format_duration(estimated_benchmark_s)} | {_format_duration(estimated_dataset_s)} | "
f"{fetch_pool['avg_mb_miss']:.1f} | {args.workers} workers, no decoder open/frame decode |"
)
print()
print("| Camera Job Stage | avg ms/job |")
print("|---|---:|")
print(f"| manifest lookup | {fetch_pool['lookup_ms']:.3f} |")
print(f"| remote byte-range fetch | {fetch_pool['range_fetch_ms']:.3f} |")
print(f"| synthesize mini-MP4 | {fetch_pool['synthesize_ms']:.3f} |")
print(f"| store in shared cache | {fetch_pool['store_ms']:.3f} |")
print(f"| camera jobs | {fetch_pool['jobs']:.0f} |")
_print_range_timing_summary(fetch_pool)
_print_memory_summary(memory_start, _memory_snapshot())
if args.include_decode:
timestamps = _timestamps(manifest, episodes, args.frames_per_episode, args.seed + 1)
_log(f"{label}: running parallel video fetch + decode-only")
parallel = run_parallel(
manifest,
data_root,
episodes,
timestamps,
byte_budget,
args.workers,
args.decode_workers,
args.frames_per_episode,
parquet_reader,
range_backend,
)
_log(f"{label}: running overlapped end-to-end")
overlapped = run_overlapped(
manifest,
data_root,
episodes,
timestamps,
byte_budget,
args.workers,
args.decode_workers,
args.frames_per_episode,
args.prefetch_ahead,
parquet_reader,
range_backend,
)
print(
f"| DECODE COMPARISON | {parallel['fetch_mbps']:.1f} | {parallel['fetch_episodes_s']:.2f} | "
f"{parallel['fetch_s']:.2f} | "
f"{_format_duration(benchmark_episode_count / parallel['fetch_episodes_s'])} | "
f"{_format_duration(dataset_episode_count / parallel['fetch_episodes_s'])} | "
f"{fetch_pool['avg_mb_miss']:.1f} | "
f"decoder open {parallel['decoder_ms_miss']:.1f} ms/miss, "
f"decode {parallel['decode_samples_s']:.1f} samples/s, parquet {parallel['parquet_s']:.2f}s |"
)
print(
f"| OVERLAPPED E2E | - | - | {overlapped['wall_s']:.2f} | - | - | "
f"{fetch_pool['avg_mb_miss']:.1f} | "
f"{overlapped['samples_s']:.1f} samples/s; video+decode "
f"{overlapped['video_wait_decode_s']:.2f}s, parquet {overlapped['parquet_wait_s']:.2f}s |"
)
def run_remote_strategy(
meta: LeRobotDatasetMetadata,
data_root: str,
args: argparse.Namespace,
parquet_reader: EpisodeParquetReader,
) -> None:
_log("starting_strategy: remote-decoder")
episodes = _episode_pool(int(meta.total_episodes), args.num_episodes, args.pool_size, args.seed)
timestamps = _timestamps_from_meta(meta, episodes, args.frames_per_episode, args.seed + 1)
_log("remote-decoder: running direct source MP4 decoder")
result = run_remote_decoder(
meta,
data_root,
episodes,
timestamps,
frames_per_episode=args.frames_per_episode,
decode_workers=args.decode_workers,
parquet_reader=parquet_reader,
)
print("strategy: remote-decoder")
print(f"data_root: {data_root}")
print(f"episodes: {episodes}")
print(f"cameras: {list(meta.video_keys)}")
print()
print("| Track | samples/s | notes |")
print("|---|---:|---|")
print(f"| REMOTE SEQUENTIAL | {result['sequential_samples_s']:.1f} | direct source MP4 decoder |")
print(
f"| REMOTE PARALLEL | {result['parallel_samples_s']:.1f} | "
f"direct source MP4 decoder, {args.decode_workers} workers |"
)
def main() -> None:
args = parse_args()
if args.strategy == "full":
args.strategy = "both"
if args.strategy == "native-http":
args.range_backend = "native-http"
data_root = args.data_root
if data_root.startswith("hf://") and not args.no_hub_branch_assert:
assert_hf_hub_range_cache_branch()
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
meta.ensure_readable()
parquet_reader = EpisodeParquetReader(meta, data_root)
manifest_episode_count = args.manifest_episodes or int(meta.total_episodes)
manifest_episode_count = min(manifest_episode_count, int(meta.total_episodes), args.num_episodes)
sidecar_path = _find_or_download_sidecar(data_root, manifest_episode_count)
if sidecar_path is not None:
print(f"using_mp4_sidecar: {sidecar_path}")
if sidecar_path is not None and args.strategy == "both":
if args.include_decode:
run_remote_strategy(meta, data_root, args, parquet_reader)
print()
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend=args.range_backend,
label=f"indexed-sidecar-{args.range_backend}",
sidecar_path=str(sidecar_path),
)
return
if sidecar_path is not None and args.strategy == "indexed":
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend=args.range_backend,
label=f"indexed-sidecar-{args.range_backend}",
sidecar_path=str(sidecar_path),
)
return
if sidecar_path is not None and args.strategy == "native-http":
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend="native-http",
label="indexed-sidecar-native-http",
sidecar_path=str(sidecar_path),
)
return
if args.strategy == "both":
expected_sidecar = SIDECAR_CACHE_DIR / FULL_SIDECAR_NAME
expected_remote = _root_join(data_root, f"meta/mp4-sidecars/{FULL_SIDECAR_NAME}")
print(f"mp4_sidecar_missing_local: {expected_sidecar}")
print(f"mp4_sidecar_missing_remote: {expected_remote}")
print(
"build_mp4_sidecar: "
"uv run --no-sync python scripts/build_mp4_sidecar.py "
f"--workers {args.workers} --range-backend native-http --output {expected_sidecar}"
)
print("running_without_mp4_sidecar: indexed variants will build MP4 indexes online")
print()
if args.strategy in ("both", "indexed"):
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend="fsspec",
label="indexed",
sidecar_path=None,
)
if args.strategy == "both":
print()
if args.strategy == "remote-decoder" or (args.strategy == "both" and args.include_decode):
run_remote_strategy(meta, data_root, args, parquet_reader)
if args.strategy == "both" and args.include_decode:
print()
if args.strategy in ("both", "native-http"):
run_indexed_strategy(
meta,
data_root,
args,
parquet_reader,
range_backend="native-http",
label="indexed-native-http",
sidecar_path=None,
)
if __name__ == "__main__":
main()
+93
View File
@@ -0,0 +1,93 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import argparse
import time
from pathlib import Path
import fsspec
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.episode_video_streaming import EpisodeVideoManifest, assert_hf_hub_range_cache_branch
DEFAULT_REPO = "allenai/MolmoAct2-BimanualYAM-Dataset"
DEFAULT_REVISION = "e9f21ae15074330839f2ac25ed4b49d76dfa1f9c"
DEFAULT_DATA_ROOT = "hf://buckets/pepijn223/MolmoAct2-BimanualYAM-Dataset-bucket"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Build a reusable MP4 byte-index sidecar for streaming.")
parser.add_argument("--repo-id", default=DEFAULT_REPO)
parser.add_argument("--revision", default=DEFAULT_REVISION)
parser.add_argument("--data-root", default=DEFAULT_DATA_ROOT)
parser.add_argument("--output", required=True)
parser.add_argument("--episodes", type=int, default=None)
parser.add_argument("--workers", type=int, default=8)
parser.add_argument("--range-backend", choices=("fsspec", "native-http"), default="native-http")
parser.add_argument("--max-probe-mb", type=int, default=64)
parser.add_argument(
"--no-push", action="store_true", help="Do not upload the sidecar to data_root/meta/mp4-sidecars."
)
parser.add_argument("--no-hub-branch-assert", action="store_true")
return parser.parse_args()
def push_sidecar(local_path: str, data_root: str) -> list[str]:
if not data_root.startswith("hf://"):
return []
local = Path(local_path)
fs = fsspec.filesystem("hf")
remote_dir = f"{data_root.rstrip('/')}/meta/mp4-sidecars"
remote_paths = [f"{remote_dir}/{local.name}"]
for remote in remote_paths:
fs.put(str(local), remote)
return remote_paths
def main() -> None:
args = parse_args()
if args.data_root.startswith("hf://") and not args.no_hub_branch_assert:
assert_hf_hub_range_cache_branch()
meta = LeRobotDatasetMetadata(args.repo_id, revision=args.revision)
meta.ensure_readable()
total = (
int(meta.total_episodes) if args.episodes is None else min(args.episodes, int(meta.total_episodes))
)
rel_paths = sorted(
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in range(total) for key in meta.video_keys}
)
start = time.perf_counter()
EpisodeVideoManifest.write_file_sidecar(
args.output,
rel_paths,
args.data_root,
range_backend=args.range_backend,
workers=args.workers,
max_probe_bytes=args.max_probe_mb * 1024 * 1024,
)
elapsed = time.perf_counter() - start
print(f"wrote {args.output}")
print(f"episodes={total} files={len(rel_paths)} elapsed_s={elapsed:.2f}")
if args.no_push:
print("push_skipped: --no-push")
else:
pushed = push_sidecar(args.output, args.data_root)
for remote in pushed:
print(f"pushed {remote}")
if __name__ == "__main__":
main()
@@ -54,7 +54,6 @@ from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
from lerobot.datasets.language import (
EVENT_ONLY_STYLES,
LANGUAGE_EVENTS,
@@ -275,11 +274,12 @@ class LanguageColumnsWriter:
new_table = self._materialize_table(
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
)
# Re-emit one row group per episode (a bulk pq.write_table would collapse
# them into one). Write to a sibling tmp path and atomically rename so a
# crash mid-write can't leave a half-written shard.
# Atomic replace: write to a sibling tmp path and rename so a crash
# mid-write can't leave a half-written shard that ``pq.read_table``
# would then fail to open. ``Path.replace`` is atomic on POSIX +
# Windows when source and target sit on the same filesystem.
tmp_path = path.with_suffix(path.suffix + ".tmp")
write_table_one_row_group_per_episode(new_table, tmp_path)
pq.write_table(new_table, tmp_path)
tmp_path.replace(path)
def _materialize_table(
-9
View File
@@ -32,7 +32,6 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_one_row_group_per_episode,
to_parquet_with_hf_images,
write_info,
write_stats,
@@ -552,7 +551,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
aggr_root=dst_meta.root,
hf_features=hf_features,
concatenate=concatenate_data,
one_row_group_per_episode=True,
)
# Record the mapping from source to actual destination
@@ -630,7 +628,6 @@ def append_or_create_parquet_file(
aggr_root: Path = None,
hf_features: datasets.Features | None = None,
concatenate: bool = True,
one_row_group_per_episode: bool = False,
) -> tuple[dict[str, int], tuple[int, int]]:
"""Appends data to an existing parquet file or creates a new one based on size constraints.
@@ -648,8 +645,6 @@ def append_or_create_parquet_file(
aggr_root: Root path for the aggregated dataset.
hf_features: Optional HuggingFace Features schema for proper image typing.
concatenate: When False, always rotate to a new file instead of appending to the current one.
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
the episodes-metadata parquet (already one row per episode).
Returns:
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
@@ -662,8 +657,6 @@ def append_or_create_parquet_file(
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(df, dst_path)
else:
df.to_parquet(dst_path)
return idx, (dst_chunk, dst_file)
@@ -690,8 +683,6 @@ def append_or_create_parquet_file(
if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
elif one_row_group_per_episode:
to_parquet_one_row_group_per_episode(final_df, target_path)
else:
final_df.to_parquet(target_path)
@@ -0,0 +1,890 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import contextlib
import io
import json
import threading
import time
from collections import OrderedDict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from importlib import metadata
from pathlib import Path
from typing import Any
from urllib.parse import quote, urljoin, urlparse
import fsspec
import httpx
import numpy as np
from huggingface_hub import HfApi, HfFileSystem, constants
from huggingface_hub.utils import hf_raise_for_status
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.mp4 import Mp4Index, Mp4SampleSlice, fetch_mp4_index, synthesize_mp4
@dataclass(frozen=True)
class EpisodeVideoSpan:
file_id: int
mdat_offset: int
mdat_length: int
first_pts: float
last_pts: float
frame_count: int
sample_lo: int
sample_hi: int
source_start_pts: float
@dataclass(frozen=True)
class VideoFileRecord:
file_path: str
file_size: int
mp4: Mp4Index
class ThreadLocalRangeFetcher:
"""Range reader that gives each worker thread independent file handles."""
def __init__(self, data_root: str | Path, *, block_size: int = 2**20, cache_type: str = "none"):
self.data_root = str(data_root).rstrip("/")
protocol = "hf" if self.data_root.startswith("hf://") else "file"
self.fs = fsspec.filesystem(protocol)
self.block_size = block_size
self.cache_type = cache_type
self._local = threading.local()
self._timing_lock = threading.Lock()
self._timing_totals = {
"range_jobs": 0.0,
"range_bytes": 0.0,
"range_open_s": 0.0,
"range_seek_s": 0.0,
"range_read_s": 0.0,
}
def _url(self, relative_path: str) -> str:
if self.data_root.startswith("hf://"):
return f"{self.data_root}/{relative_path}"
return str(Path(self.data_root) / relative_path)
def _handle(self, relative_path: str):
handles = getattr(self._local, "handles", None)
if handles is None:
handles = {}
self._local.handles = handles
handle = handles.get(relative_path)
if handle is None or getattr(handle, "closed", False):
handle = self.fs.open(
self._url(relative_path), "rb", block_size=self.block_size, cache_type=self.cache_type
)
handles[relative_path] = handle
return handle
def info_size(self, relative_path: str) -> int:
return int(self.fs.info(self._url(relative_path))["size"])
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
open_start = time.perf_counter()
handle = self._handle(relative_path)
open_s = time.perf_counter() - open_start
seek_start = time.perf_counter()
handle.seek(offset)
seek_s = time.perf_counter() - seek_start
read_start = time.perf_counter()
data = handle.read(length)
read_s = time.perf_counter() - read_start
self._record_timing(
range_jobs=1.0,
range_bytes=float(len(data)),
range_open_s=open_s,
range_seek_s=seek_s,
range_read_s=read_s,
)
return data
def _record_timing(self, **kwargs: float) -> None:
with self._timing_lock:
for key, value in kwargs.items():
self._timing_totals[key] += value
def timing_summary(self) -> dict[str, float]:
with self._timing_lock:
return dict(self._timing_totals)
def close(self) -> None:
handles = getattr(self._local, "handles", None)
if handles is None:
return
for handle in handles.values():
with contextlib.suppress(Exception):
handle.close()
handles.clear()
class NativeHTTPRangeFetcher:
"""Direct pooled HTTP range reader for hf:// paths."""
_GLOBAL_SOURCE_URLS: dict[tuple[str, str], str] = {}
_GLOBAL_RESOLVED_URLS: dict[tuple[str, str], str] = {}
_GLOBAL_SIZES: dict[tuple[str, str], int] = {}
_GLOBAL_LOCK = threading.Lock()
_RETRYABLE_EXCEPTIONS = (
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
httpx.PoolTimeout,
)
def __init__(
self,
data_root: str | Path,
*,
max_connections: int = 32,
timeout: float = 60.0,
max_retries: int = 4,
):
self.data_root = str(data_root).rstrip("/")
if not self.data_root.startswith("hf://"):
raise ValueError("NativeHTTPRangeFetcher only supports hf:// roots")
self.max_retries = max_retries
self.api = HfApi()
self.fs: HfFileSystem | None = None
self._bucket_id: str | None = None
self._bucket_prefix = ""
if self.data_root.startswith("hf://buckets/"):
bucket_root = self.data_root.removeprefix("hf://buckets/")
parts = bucket_root.split("/", 2)
if len(parts) < 2:
raise ValueError(f"Invalid bucket root: {self.data_root}")
self._bucket_id = f"{parts[0]}/{parts[1]}"
self._bucket_prefix = parts[2].strip("/") if len(parts) == 3 else ""
else:
self.fs = HfFileSystem()
self.client = httpx.Client(
timeout=timeout,
limits=httpx.Limits(max_connections=max_connections, max_keepalive_connections=max_connections),
follow_redirects=False,
)
self._resolved_urls: dict[str, str] = {}
self._source_urls: dict[str, str] = {}
self._sizes: dict[str, int] = {}
self._lock = threading.Lock()
self._timing_lock = threading.Lock()
self._timing_totals = {
"range_jobs": 0.0,
"range_bytes": 0.0,
"range_resolve_s": 0.0,
"range_header_s": 0.0,
"range_first_byte_s": 0.0,
"range_body_s": 0.0,
"range_retry_attempts": 0.0,
"range_retry_sleep_s": 0.0,
"range_failed_requests": 0.0,
}
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
return self.client.request(method, url, **kwargs)
except self._RETRYABLE_EXCEPTIONS as exc:
last_exc = exc
if attempt >= self.max_retries:
break
time.sleep(min(0.5 * 2**attempt, 5.0))
if last_exc is None:
raise RuntimeError("HTTP request failed without an exception")
raise last_exc
def _cache_key(self, relative_path: str) -> tuple[str, str]:
return self.data_root, relative_path
def _path(self, relative_path: str) -> str:
return f"{self.data_root}/{relative_path}"
def _bucket_path(self, relative_path: str) -> str:
if self._bucket_prefix:
return f"{self._bucket_prefix}/{relative_path}"
return relative_path
def _headers_for(self, request_url: str, source_url: str) -> dict[str, str]:
headers = self.api._build_hf_headers()
if urlparse(request_url).netloc != urlparse(source_url).netloc:
headers.pop("authorization", None)
headers.pop("Authorization", None)
return headers
def _source_url(self, relative_path: str) -> str:
with self._lock:
source = self._source_urls.get(relative_path)
if source is not None:
return source
key = self._cache_key(relative_path)
with self._GLOBAL_LOCK:
source = self._GLOBAL_SOURCE_URLS.get(key)
if source is None:
if self._bucket_id is not None:
source = (
f"{constants.ENDPOINT}/buckets/{self._bucket_id}/resolve/"
f"{quote(self._bucket_path(relative_path))}"
)
else:
if self.fs is None:
raise RuntimeError("HfFileSystem fallback was not initialized")
source = self.fs.url(self._path(relative_path))
with self._GLOBAL_LOCK:
self._GLOBAL_SOURCE_URLS[key] = source
with self._lock:
self._source_urls[relative_path] = source
return source
def _resolve_url(self, relative_path: str, *, refresh: bool = False) -> str:
with self._lock:
if not refresh and relative_path in self._resolved_urls:
return self._resolved_urls[relative_path]
key = self._cache_key(relative_path)
if not refresh:
with self._GLOBAL_LOCK:
resolved = self._GLOBAL_RESOLVED_URLS.get(key)
size = self._GLOBAL_SIZES.get(key)
if resolved is not None:
with self._lock:
self._resolved_urls[relative_path] = resolved
if size is not None:
self._sizes[relative_path] = size
return resolved
source = self._source_url(relative_path)
response = self._request("HEAD", source, headers=self.api._build_hf_headers(), follow_redirects=False)
try:
hf_raise_for_status(response)
location = response.headers.get("Location")
resolved = urljoin(source, location) if location else source
with self._lock:
self._resolved_urls[relative_path] = resolved
if "Content-Length" in response.headers:
self._sizes[relative_path] = int(response.headers["Content-Length"])
with self._GLOBAL_LOCK:
self._GLOBAL_RESOLVED_URLS[key] = resolved
if "Content-Length" in response.headers:
self._GLOBAL_SIZES[key] = int(response.headers["Content-Length"])
return resolved
finally:
response.close()
def info_size(self, relative_path: str) -> int:
with self._lock:
size = self._sizes.get(relative_path)
if size is not None:
return size
key = self._cache_key(relative_path)
with self._GLOBAL_LOCK:
size = self._GLOBAL_SIZES.get(key)
if size is not None:
with self._lock:
self._sizes[relative_path] = size
return size
resolved = self._resolve_url(relative_path)
source = self._source_url(relative_path)
response = self._request(
"HEAD", resolved, headers=self._headers_for(resolved, source), follow_redirects=True
)
try:
hf_raise_for_status(response)
size = int(response.headers["Content-Length"])
with self._lock:
self._sizes[relative_path] = size
with self._GLOBAL_LOCK:
self._GLOBAL_SIZES[key] = size
return size
finally:
response.close()
def read_range(self, relative_path: str, offset: int, length: int) -> bytes:
resolve_start = time.perf_counter()
resolved = self._resolve_url(relative_path)
source = self._source_url(relative_path)
resolve_s = time.perf_counter() - resolve_start
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
payload, status_code, timings = self._read_range_response(resolved, headers)
if status_code == 403:
refresh_start = time.perf_counter()
resolved = self._resolve_url(relative_path, refresh=True)
resolve_s += time.perf_counter() - refresh_start
headers = self._headers_for(resolved, source)
headers["Range"] = f"bytes={offset}-{offset + length - 1}"
payload, status_code, retry_timings = self._read_range_response(resolved, headers)
for key, value in retry_timings.items():
timings[key] += value
if status_code == 403:
raise PermissionError(f"HTTP range request returned 403 after URL refresh: {relative_path}")
self._record_timing(
range_jobs=1.0,
range_bytes=float(len(payload)),
range_resolve_s=resolve_s,
**timings,
)
return payload
def _read_range_response(self, url: str, headers: dict[str, str]) -> tuple[bytes, int, dict[str, float]]:
last_exc: Exception | None = None
retry_attempts = 0.0
retry_sleep_s = 0.0
for attempt in range(self.max_retries + 1):
try:
payload, status_code, timings = self._read_range_response_once(url, headers)
timings["range_retry_attempts"] = retry_attempts
timings["range_retry_sleep_s"] = retry_sleep_s
return payload, status_code, timings
except self._RETRYABLE_EXCEPTIONS as exc:
last_exc = exc
if attempt >= self.max_retries:
break
retry_attempts += 1.0
sleep_s = min(0.5 * 2**attempt, 5.0)
retry_sleep_s += sleep_s
time.sleep(sleep_s)
self._record_timing(
range_failed_requests=1.0,
range_retry_attempts=retry_attempts,
range_retry_sleep_s=retry_sleep_s,
)
if last_exc is None:
raise RuntimeError("HTTP range request failed without an exception")
raise last_exc
def _read_range_response_once(
self, url: str, headers: dict[str, str]
) -> tuple[bytes, int, dict[str, float]]:
header_start = time.perf_counter()
with self.client.stream("GET", url, headers=headers) as response:
header_s = time.perf_counter() - header_start
if response.status_code == 403:
return (
b"",
response.status_code,
{
"range_header_s": header_s,
"range_first_byte_s": 0.0,
"range_body_s": 0.0,
},
)
hf_raise_for_status(response)
chunks = []
first_byte_s = 0.0
first_chunk = True
body_start = time.perf_counter()
for chunk in response.iter_bytes():
if first_chunk:
first_byte_s = time.perf_counter() - body_start
first_chunk = False
chunks.append(chunk)
body_s = time.perf_counter() - body_start
return (
b"".join(chunks),
response.status_code,
{
"range_header_s": header_s,
"range_first_byte_s": first_byte_s,
"range_body_s": body_s,
},
)
def _record_timing(self, **kwargs: float) -> None:
with self._timing_lock:
for key, value in kwargs.items():
self._timing_totals[key] += value
def timing_summary(self) -> dict[str, float]:
with self._timing_lock:
return dict(self._timing_totals)
def close(self) -> None:
self.client.close()
def make_range_fetcher(
data_root: str | Path,
*,
range_backend: str,
workers: int,
native_http_connections: int | None = None,
native_http_timeout: float = 60.0,
native_http_retries: int = 4,
):
if range_backend == "fsspec":
return ThreadLocalRangeFetcher(data_root)
if range_backend == "native-http":
max_connections = native_http_connections or max(8, workers)
return NativeHTTPRangeFetcher(
data_root,
max_connections=max_connections,
timeout=native_http_timeout,
max_retries=native_http_retries,
)
raise ValueError(f"Unknown range backend: {range_backend}")
class EpisodeVideoManifest:
_FILE_SIDECAR_CACHE: dict[str, dict[str, VideoFileRecord]] = {}
_FILE_SIDECAR_CACHE_LOCK = threading.Lock()
def __init__(
self,
*,
video_keys: list[str],
files: list[VideoFileRecord],
spans: dict[str, np.ndarray],
):
self.video_keys = list(video_keys)
self._camera_to_id = {key: idx for idx, key in enumerate(self.video_keys)}
self.files = files
self.spans = spans
@classmethod
def build(
cls,
meta: LeRobotDatasetMetadata,
data_root: str | Path,
*,
episode_indices: list[int] | range | None = None,
range_backend: str = "fsspec",
workers: int = 8,
header_probe_bytes: int = 4 * 1024 * 1024,
max_probe_bytes: int = 64 * 1024 * 1024,
keyframe_pad_s: float = 0.1,
keyframe_pad_fraction: float = 0.05,
sidecar_path: str | Path | None = None,
) -> EpisodeVideoManifest:
meta.ensure_readable()
video_keys = list(meta.video_keys)
if episode_indices is None:
episode_indices = range(int(meta.total_episodes))
rel_paths = sorted(
{str(meta.get_video_file_path(ep_idx, key)) for ep_idx in episode_indices for key in video_keys}
)
path_to_id = {path: idx for idx, path in enumerate(rel_paths)}
if sidecar_path is None:
files = cls._build_file_records(
rel_paths,
data_root,
range_backend=range_backend,
workers=workers,
header_probe_bytes=header_probe_bytes,
max_probe_bytes=max_probe_bytes,
)
else:
records = cls.load_file_sidecar(sidecar_path)
missing = [path for path in rel_paths if path not in records]
if missing:
raise ValueError(
f"Sidecar {sidecar_path} is missing {len(missing)} files, first: {missing[0]}"
)
files = [records[path] for path in rel_paths]
total = int(meta.total_episodes)
num_cameras = len(video_keys)
spans: dict[str, np.ndarray] = {
"file_id": np.zeros((total, num_cameras), dtype=np.int32),
"mdat_offset": np.zeros((total, num_cameras), dtype=np.int64),
"mdat_length": np.zeros((total, num_cameras), dtype=np.int64),
"first_pts": np.zeros((total, num_cameras), dtype=np.float64),
"last_pts": np.zeros((total, num_cameras), dtype=np.float64),
"frame_count": np.zeros((total, num_cameras), dtype=np.int32),
"sample_lo": np.zeros((total, num_cameras), dtype=np.int32),
"sample_hi": np.zeros((total, num_cameras), dtype=np.int32),
"source_start_pts": np.zeros((total, num_cameras), dtype=np.float64),
}
for ep_idx in episode_indices:
ep = meta.episodes[ep_idx]
for cam_idx, key in enumerate(video_keys):
rel_path = str(meta.get_video_file_path(ep_idx, key))
file_id = path_to_id[rel_path]
mp4 = files[file_id].mp4
from_ts = float(ep[f"videos/{key}/from_timestamp"])
to_ts = float(ep[f"videos/{key}/to_timestamp"])
sample_slice = mp4.sample_slice(
from_ts,
to_ts,
keyframe_pad_s=keyframe_pad_s,
keyframe_pad_fraction=keyframe_pad_fraction,
file_size=files[file_id].file_size,
)
spans["file_id"][ep_idx, cam_idx] = file_id
spans["mdat_offset"][ep_idx, cam_idx] = sample_slice.byte_offset
spans["mdat_length"][ep_idx, cam_idx] = sample_slice.byte_length
spans["first_pts"][ep_idx, cam_idx] = from_ts
spans["last_pts"][ep_idx, cam_idx] = to_ts
spans["frame_count"][ep_idx, cam_idx] = sample_slice.sample_hi - sample_slice.sample_lo + 1
spans["sample_lo"][ep_idx, cam_idx] = sample_slice.sample_lo
spans["sample_hi"][ep_idx, cam_idx] = sample_slice.sample_hi
spans["source_start_pts"][ep_idx, cam_idx] = sample_slice.source_start_pts
return cls(video_keys=video_keys, files=files, spans=spans)
@staticmethod
def _build_file_records(
rel_paths: list[str],
data_root: str | Path,
*,
range_backend: str,
workers: int,
header_probe_bytes: int,
max_probe_bytes: int,
) -> list[VideoFileRecord]:
fetcher = make_range_fetcher(data_root, range_backend=range_backend, workers=workers)
def build_file(path: str) -> VideoFileRecord:
file_size = fetcher.info_size(path)
mp4 = fetch_mp4_index(
path,
fetcher.read_range,
file_size=file_size,
header_probe_bytes=header_probe_bytes,
max_probe_bytes=max_probe_bytes,
)
return VideoFileRecord(path, file_size, mp4)
try:
with ThreadPoolExecutor(max_workers=workers) as pool:
return list(pool.map(build_file, rel_paths))
finally:
fetcher.close()
@classmethod
def write_file_sidecar(
cls,
sidecar_path: str | Path,
rel_paths: list[str],
data_root: str | Path,
*,
range_backend: str = "native-http",
workers: int = 8,
header_probe_bytes: int = 4 * 1024 * 1024,
max_probe_bytes: int = 64 * 1024 * 1024,
) -> None:
records = cls._build_file_records(
sorted(set(rel_paths)),
data_root,
range_backend=range_backend,
workers=workers,
header_probe_bytes=header_probe_bytes,
max_probe_bytes=max_probe_bytes,
)
cls.save_file_sidecar(sidecar_path, records)
@staticmethod
def save_file_sidecar(sidecar_path: str | Path, records: list[VideoFileRecord]) -> None:
sidecar_path = Path(sidecar_path)
sidecar_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"version": 1,
"files": [
{"file_path": record.file_path, "file_size": record.file_size, "mp4": record.mp4.to_dict()}
for record in records
],
}
arrays = {}
for file_idx, record in enumerate(records):
arrays[f"{file_idx}/sample_pts"] = record.mp4.sample_pts
arrays[f"{file_idx}/sample_durations"] = record.mp4.sample_durations
arrays[f"{file_idx}/sample_sizes"] = record.mp4.sample_sizes
arrays[f"{file_idx}/sample_offsets"] = record.mp4.sample_offsets
arrays[f"{file_idx}/sync_samples"] = record.mp4.sync_samples
np.savez_compressed(sidecar_path, manifest_json=json.dumps(payload).encode("utf-8"), **arrays)
@staticmethod
def load_file_sidecar(sidecar_path: str | Path) -> dict[str, VideoFileRecord]:
cache_key = str(Path(sidecar_path).expanduser())
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
cached = EpisodeVideoManifest._FILE_SIDECAR_CACHE.get(cache_key)
if cached is not None:
return cached
with np.load(sidecar_path, allow_pickle=False) as data:
payload = json.loads(bytes(data["manifest_json"]).decode("utf-8"))
records = {}
for file_idx, item in enumerate(payload["files"]):
arrays = {
name: data[f"{file_idx}/{name}"]
for name in [
"sample_pts",
"sample_durations",
"sample_sizes",
"sample_offsets",
"sync_samples",
]
}
mp4 = Mp4Index.from_dict(item["mp4"], arrays)
records[item["file_path"]] = VideoFileRecord(item["file_path"], int(item["file_size"]), mp4)
with EpisodeVideoManifest._FILE_SIDECAR_CACHE_LOCK:
EpisodeVideoManifest._FILE_SIDECAR_CACHE[cache_key] = records
return records
def camera_id(self, camera_key: str) -> int:
return self._camera_to_id[camera_key]
def lookup(self, episode_index: int, camera_key: str) -> EpisodeVideoSpan:
cam = self.camera_id(camera_key)
return EpisodeVideoSpan(
file_id=int(self.spans["file_id"][episode_index, cam]),
mdat_offset=int(self.spans["mdat_offset"][episode_index, cam]),
mdat_length=int(self.spans["mdat_length"][episode_index, cam]),
first_pts=float(self.spans["first_pts"][episode_index, cam]),
last_pts=float(self.spans["last_pts"][episode_index, cam]),
frame_count=int(self.spans["frame_count"][episode_index, cam]),
sample_lo=int(self.spans["sample_lo"][episode_index, cam]),
sample_hi=int(self.spans["sample_hi"][episode_index, cam]),
source_start_pts=float(self.spans["source_start_pts"][episode_index, cam]),
)
def file_lookup(self, file_id: int) -> VideoFileRecord:
return self.files[file_id]
def mp4_index(self, episode_index: int, camera_key: str) -> Mp4Index:
return self.files[self.lookup(episode_index, camera_key).file_id].mp4
def sample_slice(self, episode_index: int, camera_key: str) -> Mp4SampleSlice:
span = self.lookup(episode_index, camera_key)
return Mp4SampleSlice(
sample_lo=span.sample_lo,
sample_hi=span.sample_hi,
byte_offset=span.mdat_offset,
byte_length=span.mdat_length,
source_start_pts=span.source_start_pts,
)
class EpisodeByteCache:
def __init__(
self,
manifest: EpisodeVideoManifest,
data_root: str | Path,
*,
byte_budget: int = 80 * 1024**3,
workers: int = 8,
range_backend: str = "fsspec",
native_http_connections: int | None = None,
native_http_timeout: float = 60.0,
native_http_retries: int = 4,
open_decoders: bool = True,
):
self.manifest = manifest
self.fetcher = make_range_fetcher(
data_root,
range_backend=range_backend,
workers=workers,
native_http_connections=native_http_connections,
native_http_timeout=native_http_timeout,
native_http_retries=native_http_retries,
)
self.byte_budget = byte_budget
self.open_decoders = open_decoders
self._pool = ThreadPoolExecutor(max_workers=workers)
self._cache: OrderedDict[tuple[int, str], dict[str, Any]] = OrderedDict()
self._futures: dict[tuple[int, str], Future[dict[str, Any]]] = {}
self._bytes = 0
self._lock = threading.Lock()
self._timing_totals = {
"lookup_s": 0.0,
"fetch_s": 0.0,
"synthesize_s": 0.0,
"store_s": 0.0,
"jobs": 0.0,
}
def close(self) -> None:
self._pool.shutdown(wait=True)
with self._lock:
self._cache.clear()
self._futures.clear()
self._bytes = 0
self.fetcher.close()
def __enter__(self) -> EpisodeByteCache:
return self
def __exit__(self, *_exc) -> None:
self.close()
def submit_prefetch(self, episode_index: int) -> None:
for camera_key in self.manifest.video_keys:
self._submit(episode_index, camera_key)
def ensure_ready(self, episode_index: int) -> None:
for camera_key in self.manifest.video_keys:
self.get_bytes(episode_index, camera_key)
def get_bytes(self, episode_index: int, camera_key: str) -> bytes:
return self._get_entry(episode_index, camera_key)["bytes"]
def get_decoder(self, episode_index: int, camera_key: str):
entry = self._get_entry(episode_index, camera_key)
decoder = entry.get("decoder")
if decoder is None:
decoder = open_video_decoder(io.BytesIO(entry["bytes"]))
entry["decoder"] = decoder
return decoder
def get_frames(self, episode_index: int, camera_key: str, timestamps: list[float]):
span = self.manifest.lookup(episode_index, camera_key)
local_ts = [ts - span.source_start_pts for ts in timestamps]
decoder = self.get_decoder(episode_index, camera_key)
if hasattr(decoder, "get_frames_played_at"):
return decoder.get_frames_played_at(local_ts).data
metadata = decoder.metadata
fps = getattr(metadata, "average_fps", None)
if fps is None:
duration = max(getattr(metadata, "end_stream_seconds", 0.0), 1e-9)
fps = metadata.num_frames / duration
return decoder.get_frames_at(indices=[round(ts * fps) for ts in local_ts]).data
def timing_summary(self) -> dict[str, float]:
with self._lock:
summary = dict(self._timing_totals)
fetcher_summary = getattr(self.fetcher, "timing_summary", None)
if fetcher_summary is not None:
summary.update(fetcher_summary())
return summary
def _submit(self, episode_index: int, camera_key: str) -> Future[dict[str, Any]]:
key = (episode_index, camera_key)
with self._lock:
if key in self._cache:
future: Future[dict[str, Any]] = Future()
future.set_result(self._cache[key])
return future
future = self._futures.get(key)
if future is None:
future = self._pool.submit(self._fetch_and_synthesize, episode_index, camera_key)
self._futures[key] = future
return future
def _get_entry(self, episode_index: int, camera_key: str) -> dict[str, Any]:
key = (episode_index, camera_key)
with self._lock:
entry = self._cache.get(key)
if entry is not None:
self._cache.move_to_end(key)
return entry
future = self._submit(episode_index, camera_key)
entry = future.result()
store_start = time.perf_counter()
with self._lock:
self._futures.pop(key, None)
existing = self._cache.get(key)
if existing is not None:
self._cache.move_to_end(key)
return existing
self._cache[key] = entry
self._bytes += len(entry["bytes"])
self._evict_locked()
timings = entry.pop("_timings", None)
if timings is not None:
self._timing_totals["lookup_s"] += timings["lookup_s"]
self._timing_totals["fetch_s"] += timings["fetch_s"]
self._timing_totals["synthesize_s"] += timings["synthesize_s"]
self._timing_totals["store_s"] += time.perf_counter() - store_start
self._timing_totals["jobs"] += 1
return entry
def _evict_locked(self) -> None:
while self._bytes > self.byte_budget and self._cache:
_key, entry = self._cache.popitem(last=False)
self._bytes -= len(entry["bytes"])
def _fetch_and_synthesize(self, episode_index: int, camera_key: str) -> dict[str, Any]:
lookup_start = time.perf_counter()
span = self.manifest.lookup(episode_index, camera_key)
file_record = self.manifest.file_lookup(span.file_id)
sample_slice = Mp4SampleSlice(
sample_lo=span.sample_lo,
sample_hi=span.sample_hi,
byte_offset=span.mdat_offset,
byte_length=span.mdat_length,
source_start_pts=span.source_start_pts,
)
lookup_s = time.perf_counter() - lookup_start
fetch_start = time.perf_counter()
payload = self.fetcher.read_range(file_record.file_path, span.mdat_offset, span.mdat_length)
fetch_s = time.perf_counter() - fetch_start
if len(payload) != span.mdat_length:
raise OSError(
f"Short read for {file_record.file_path}: expected {span.mdat_length}, got {len(payload)}"
)
synthesize_start = time.perf_counter()
mp4_bytes = synthesize_mp4(file_record.mp4, sample_slice, payload)
synthesize_s = time.perf_counter() - synthesize_start
entry: dict[str, Any] = {
"bytes": mp4_bytes,
"decoder": None,
"_timings": {
"lookup_s": lookup_s,
"fetch_s": fetch_s,
"synthesize_s": synthesize_s,
},
}
if self.open_decoders:
entry["decoder"] = open_video_decoder(io.BytesIO(mp4_bytes))
return entry
def open_video_decoder(file_like_or_bytesio, frame_mappings=None):
if frame_mappings is not None:
raise ValueError("Synthesized episode videos use a local timeline; pass frame_mappings=None.")
from torchcodec.decoders import VideoDecoder
return VideoDecoder(file_like_or_bytesio, seek_mode="approximate")
def assert_hf_hub_range_cache_branch() -> None:
"""Fail unless huggingface_hub was installed from the required range-cache branch."""
try:
dist = metadata.distribution("huggingface_hub")
except metadata.PackageNotFoundError as exc:
raise AssertionError("huggingface_hub is not installed") from exc
candidates = []
direct_url = dist.read_text("direct_url.json")
if direct_url:
candidates.append(direct_url)
with contextlib.suppress(json.JSONDecodeError):
parsed = json.loads(direct_url)
candidates.append(str(parsed.get("url", "")))
candidates.append(str(parsed.get("vcs_info", {}).get("requested_revision", "")))
candidates.append(str(parsed.get("vcs_info", {}).get("commit_id", "")))
text = "\n".join(candidates)
if "feat/hffs-cache-cdn-range-reads" not in text:
raise AssertionError(
"huggingface_hub must be installed from "
"git+https://github.com/huggingface/huggingface_hub.git@feat/hffs-cache-cdn-range-reads"
)
@dataclass
class StageTimer:
fetch_ms: float = 0.0
decode_ms: float = 0.0
bytes_read: int = 0
misses: int = 0
def record_fetch(self, start: float, byte_count: int) -> None:
self.fetch_ms += (time.perf_counter() - start) * 1000
self.bytes_read += byte_count
self.misses += 1
+9 -38
View File
@@ -20,7 +20,6 @@ import datasets
import numpy as np
import pandas
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
@@ -271,49 +270,21 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
return items_dict
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
"""Write ``table`` with one parquet row group per episode (in episode order).
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
mirroring the recording writer. ``table`` must carry a contiguous
``episode_index`` column.
"""
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
try:
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
finally:
writer.close()
def to_parquet_with_hf_images(
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
) -> None:
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formatted images are returned.
Images are embedded into the arrow table first (``ParquetWriter.write_table``
does not embed external image files like ``Dataset.to_parquet`` does).
``features`` types image columns as ``Image()`` in the parquet schema.
Args:
df: DataFrame to write to parquet.
path: Path to write the parquet file.
features: Optional HuggingFace Features schema. If provided, ensures image columns
are properly typed as Image() in the parquet schema.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
ds = embed_images(ds)
table = ds.with_format("arrow")[:]
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
# No episode boundaries to align row groups to — keep a single write.
pq.write_table(table, str(path))
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
table = pa.Table.from_pandas(df, preserve_index=False)
if "episode_index" in table.column_names:
write_table_one_row_group_per_episode(table, path)
else:
pq.write_table(table, str(path))
ds.to_parquet(path)
def item_to_torch(item: dict) -> dict:
+666
View File
@@ -0,0 +1,666 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import struct
from collections.abc import Callable, Iterable
from dataclasses import dataclass
import numpy as np
@dataclass(frozen=True)
class Box:
type: bytes
start: int
header_size: int
end: int
@property
def payload_start(self) -> int:
return self.start + self.header_size
@property
def size(self) -> int:
return self.end - self.start
@dataclass(frozen=True)
class Mp4SampleSlice:
sample_lo: int
sample_hi: int
byte_offset: int
byte_length: int
source_start_pts: float
@dataclass(frozen=True)
class Mp4Index:
file_path: str
file_size: int
ftyp: bytes
moov_offset: int
mdat_offset: int
mdat_payload_offset: int
mdat_payload_size: int
faststart: bool
codec: str
timescale: int
duration: int
track_id: int
width: int
height: int
stsd_body: bytes
sample_pts: np.ndarray
sample_durations: np.ndarray
sample_sizes: np.ndarray
sample_offsets: np.ndarray
sync_samples: np.ndarray
def sample_slice(
self,
from_ts: float,
to_ts: float,
*,
keyframe_pad_s: float = 0.1,
keyframe_pad_fraction: float = 0.05,
file_size: int | None = None,
) -> Mp4SampleSlice:
if to_ts < from_ts:
raise ValueError(f"Invalid timestamp span: {from_ts=} {to_ts=}")
if len(self.sample_pts) == 0:
raise ValueError(f"{self.file_path} contains no indexed samples")
pad = max(keyframe_pad_s, (to_ts - from_ts) * keyframe_pad_fraction)
lo_ts = max(0.0, from_ts - pad)
hi_ts = to_ts + pad
lo = int(np.searchsorted(self.sample_pts, lo_ts, side="left"))
hi = int(np.searchsorted(self.sample_pts, hi_ts, side="right")) - 1
lo = min(max(lo, 0), len(self.sample_pts) - 1)
hi = min(max(hi, lo), len(self.sample_pts) - 1)
if len(self.sync_samples):
prev_sync = self.sync_samples[self.sync_samples <= lo]
if len(prev_sync):
lo = int(prev_sync[-1])
else:
lo = int(self.sync_samples[0])
if lo > hi:
hi = lo
offsets = self.sample_offsets[lo : hi + 1]
sizes = self.sample_sizes[lo : hi + 1]
slice_lo = int(offsets.min())
slice_hi = int((offsets + sizes).max())
if file_size is not None:
slice_hi = min(slice_hi, int(file_size))
return Mp4SampleSlice(
sample_lo=lo,
sample_hi=hi,
byte_offset=slice_lo,
byte_length=slice_hi - slice_lo,
source_start_pts=float(self.sample_pts[lo]),
)
def to_dict(self) -> dict:
return {
"file_path": self.file_path,
"file_size": self.file_size,
"ftyp": self.ftyp.hex(),
"moov_offset": self.moov_offset,
"mdat_offset": self.mdat_offset,
"mdat_payload_offset": self.mdat_payload_offset,
"mdat_payload_size": self.mdat_payload_size,
"faststart": self.faststart,
"codec": self.codec,
"timescale": self.timescale,
"duration": self.duration,
"track_id": self.track_id,
"width": self.width,
"height": self.height,
"stsd_body": self.stsd_body.hex(),
}
@classmethod
def from_dict(cls, data: dict, arrays: dict[str, np.ndarray]) -> Mp4Index:
return cls(
file_path=data["file_path"],
file_size=int(data["file_size"]),
ftyp=bytes.fromhex(data["ftyp"]),
moov_offset=int(data["moov_offset"]),
mdat_offset=int(data["mdat_offset"]),
mdat_payload_offset=int(data["mdat_payload_offset"]),
mdat_payload_size=int(data["mdat_payload_size"]),
faststart=bool(data["faststart"]),
codec=data["codec"],
timescale=int(data["timescale"]),
duration=int(data["duration"]),
track_id=int(data["track_id"]),
width=int(data["width"]),
height=int(data["height"]),
stsd_body=bytes.fromhex(data["stsd_body"]),
sample_pts=arrays["sample_pts"],
sample_durations=arrays["sample_durations"],
sample_sizes=arrays["sample_sizes"],
sample_offsets=arrays["sample_offsets"],
sync_samples=arrays["sync_samples"],
)
def fetch_mp4_index(
path: str,
read_range: Callable[[str, int, int], bytes],
*,
file_size: int,
header_probe_bytes: int = 4 * 1024 * 1024,
max_probe_bytes: int = 64 * 1024 * 1024,
) -> Mp4Index:
probe_size = min(header_probe_bytes, file_size)
while True:
data = read_range(path, 0, probe_size)
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
has_mdat = any(box.type == b"mdat" for box in top)
has_moov = any(box.type == b"moov" and box.end <= len(data) for box in top)
if has_mdat and has_moov:
return parse_mp4_index(path, data, file_size=file_size)
if probe_size >= min(max_probe_bytes, file_size):
if has_mdat and not has_moov:
tail_index = _fetch_tail_moov_index(path, read_range, data, top, file_size, max_probe_bytes)
if tail_index is not None:
return tail_index
missing = []
if not has_mdat:
missing.append("mdat")
if not has_moov:
missing.append("moov")
raise ValueError(
f"Could not find complete {'/'.join(missing)} in first {probe_size} bytes of {path}"
)
probe_size = min(probe_size * 2, max_probe_bytes, file_size)
def _fetch_tail_moov_index(
path: str,
read_range: Callable[[str, int, int], bytes],
prefix: bytes,
top_boxes: list[Box],
file_size: int,
max_probe_bytes: int,
) -> Mp4Index | None:
mdat_box = _one(top_boxes, b"mdat")
if mdat_box is None or mdat_box.end >= file_size:
return None
tail_offset = mdat_box.end
tail_length = min(max_probe_bytes, file_size - tail_offset)
tail = read_range(path, tail_offset, tail_length)
tail_boxes = list(iter_boxes(tail, 0, len(tail), absolute_base=tail_offset, allow_truncated=True))
moov_box = next(
(box for box in tail_boxes if box.type == b"moov" and box.end <= tail_offset + len(tail)), None
)
if moov_box is None:
return None
ftyp_box = _one(top_boxes, b"ftyp", required=False)
ftyp = (
prefix[ftyp_box.start : ftyp_box.end]
if ftyp_box is not None
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
)
moov_start = moov_box.payload_start - tail_offset
moov_end = moov_box.end - tail_offset
return _parse_mp4_index_from_layout(
path,
file_size=file_size,
ftyp=ftyp,
moov_offset=moov_box.start,
moov=tail[moov_start:moov_end],
mdat_box=mdat_box,
)
def parse_mp4_index(path: str, data: bytes, *, file_size: int | None = None) -> Mp4Index:
if file_size is None:
file_size = len(data)
top = list(iter_boxes(data, 0, len(data), absolute_base=0, allow_truncated=True))
ftyp_box = _one(top, b"ftyp", required=False)
moov_box = _one(top, b"moov")
mdat_box = _one(top, b"mdat")
if moov_box.end > len(data):
raise ValueError(f"{path}: moov box is truncated")
moov = data[moov_box.payload_start : moov_box.end]
ftyp = (
data[ftyp_box.start : ftyp_box.end]
if ftyp_box is not None
else _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
)
return _parse_mp4_index_from_layout(
path,
file_size=file_size,
ftyp=ftyp,
moov_offset=moov_box.start,
moov=moov,
mdat_box=mdat_box,
)
def _parse_mp4_index_from_layout(
path: str,
*,
file_size: int,
ftyp: bytes,
moov_offset: int,
moov: bytes,
mdat_box: Box,
) -> Mp4Index:
mvhd_timescale, mvhd_duration = _parse_mvhd(_find_descendant(moov, [b"mvhd"]))
trak_box, trak_payload = _find_video_trak(moov)
_ = trak_box
tkhd = _parse_tkhd(_find_descendant(trak_payload, [b"tkhd"]))
mdhd_timescale, mdhd_duration = _parse_mdhd(_find_descendant(trak_payload, [b"mdia", b"mdhd"]))
stbl = _find_descendant(trak_payload, [b"mdia", b"minf", b"stbl"])
stsd = _find_child(stbl, b"stsd")
stsd_body = stbl[stsd.payload_start : stsd.end]
codec = _parse_stsd_codec(stsd_body)
stts = _parse_stts(_payload(stbl, b"stts"))
sample_sizes = _parse_stsz(_payload(stbl, b"stsz"))
stsc = _parse_stsc(_payload(stbl, b"stsc"))
chunk_offsets = _parse_chunk_offsets(stbl)
sync_samples = _parse_stss(stbl, len(sample_sizes))
sample_durations = _expand_stts(stts, len(sample_sizes))
sample_pts_units = np.empty(len(sample_durations), dtype=np.int64)
if len(sample_durations):
sample_pts_units[0] = 0
if len(sample_durations) > 1:
sample_pts_units[1:] = np.cumsum(sample_durations[:-1], dtype=np.int64)
sample_pts = sample_pts_units.astype(np.float64) / float(mdhd_timescale)
sample_offsets = _sample_offsets(stsc, chunk_offsets, sample_sizes)
return Mp4Index(
file_path=path,
file_size=file_size,
ftyp=ftyp,
moov_offset=moov_offset,
mdat_offset=mdat_box.start,
mdat_payload_offset=mdat_box.payload_start,
mdat_payload_size=mdat_box.end - mdat_box.payload_start
if mdat_box.end <= file_size
else file_size - mdat_box.payload_start,
faststart=moov_offset < mdat_box.start,
codec=codec,
timescale=mdhd_timescale,
duration=mdhd_duration or mvhd_duration,
track_id=tkhd["track_id"],
width=tkhd["width"],
height=tkhd["height"],
stsd_body=stsd_body,
sample_pts=sample_pts,
sample_durations=sample_durations,
sample_sizes=sample_sizes,
sample_offsets=sample_offsets,
sync_samples=sync_samples,
)
def synthesize_mp4(index: Mp4Index, sample_slice: Mp4SampleSlice, mdat_payload: bytes) -> bytes:
lo = sample_slice.sample_lo
hi = sample_slice.sample_hi + 1
if lo < 0 or hi > len(index.sample_sizes) or lo >= hi:
raise ValueError(f"Invalid sample range [{lo}, {hi}) for {index.file_path}")
offsets = index.sample_offsets[lo:hi]
sizes = index.sample_sizes[lo:hi]
rel_offsets = offsets - sample_slice.byte_offset
if int(rel_offsets.min()) != 0:
raise ValueError("Sample slice must start at the minimum referenced sample offset")
if int((rel_offsets + sizes).max()) > len(mdat_payload):
raise ValueError("Sample slice does not cover all referenced samples")
durations = index.sample_durations[lo:hi]
sync = index.sync_samples[(index.sync_samples >= lo) & (index.sync_samples < hi)] - lo + 1
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=0)
header_size = len(index.ftyp) + len(moov)
moov = _make_moov(index, durations, sizes, rel_offsets, sync, mdat_data_offset=header_size + 8)
return index.ftyp + moov + _box(b"mdat", mdat_payload)
def iter_boxes(
data: bytes,
start: int,
end: int,
*,
absolute_base: int = 0,
allow_truncated: bool = False,
) -> Iterable[Box]:
pos = start
while pos + 8 <= end:
size = struct.unpack_from(">I", data, pos)[0]
typ = data[pos + 4 : pos + 8]
header_size = 8
if size == 1:
if pos + 16 > end:
break
size = struct.unpack_from(">Q", data, pos + 8)[0]
header_size = 16
elif size == 0:
size = end - pos
if size < header_size:
break
box_end = pos + size
if box_end > end and not allow_truncated:
break
yield Box(typ, absolute_base + pos, header_size, absolute_base + box_end)
pos = box_end
def _find_video_trak(moov: bytes) -> tuple[Box, bytes]:
for trak in _children(moov, 0, len(moov)):
if trak.type != b"trak":
continue
payload = moov[trak.payload_start : trak.end]
hdlr = _find_descendant(payload, [b"mdia", b"hdlr"])
if hdlr[8:12] == b"vide":
return trak, payload
raise ValueError("No video track found")
def _find_descendant(data: bytes, path: list[bytes]) -> bytes:
current = data
for typ in path:
box = _find_child(current, typ)
current = current[box.payload_start : box.end]
return current
def _find_child(data: bytes, typ: bytes) -> Box:
for box in _children(data, 0, len(data)):
if box.type == typ:
return box
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
def _children(data: bytes, start: int, end: int) -> Iterable[Box]:
return iter_boxes(data, start, end, absolute_base=0)
def _one(boxes: list[Box], typ: bytes, *, required: bool = True) -> Box | None:
matches = [box for box in boxes if box.type == typ]
if not matches and required:
raise ValueError(f"Missing MP4 box {typ.decode('latin1')}")
return matches[0] if matches else None
def _payload(parent: bytes, typ: bytes) -> bytes:
box = _find_child(parent, typ)
return parent[box.payload_start : box.end]
def _parse_mvhd(payload: bytes) -> tuple[int, int]:
version = payload[0]
if version == 1:
return struct.unpack_from(">IQ", payload, 20)
return struct.unpack_from(">II", payload, 12)
def _parse_mdhd(payload: bytes) -> tuple[int, int]:
version = payload[0]
if version == 1:
return struct.unpack_from(">IQ", payload, 20)
return struct.unpack_from(">II", payload, 12)
def _parse_tkhd(payload: bytes) -> dict[str, int]:
version = payload[0]
if version == 1:
track_id = struct.unpack_from(">I", payload, 20)[0]
duration = struct.unpack_from(">Q", payload, 28)[0]
width, height = struct.unpack_from(">II", payload, 88)
else:
track_id = struct.unpack_from(">I", payload, 12)[0]
duration = struct.unpack_from(">I", payload, 20)[0]
width, height = struct.unpack_from(">II", payload, 76)
return {"track_id": track_id, "duration": duration, "width": width >> 16, "height": height >> 16}
def _parse_stsd_codec(stsd_body: bytes) -> str:
if len(stsd_body) < 16:
return "unknown"
return stsd_body[12:16].decode("latin1")
def _parse_stts(payload: bytes) -> list[tuple[int, int]]:
count = struct.unpack_from(">I", payload, 4)[0]
out = []
offset = 8
for _ in range(count):
out.append(struct.unpack_from(">II", payload, offset))
offset += 8
return out
def _expand_stts(entries: list[tuple[int, int]], sample_count: int) -> np.ndarray:
values = np.empty(sample_count, dtype=np.int64)
pos = 0
for count, delta in entries:
values[pos : pos + count] = delta
pos += count
if pos != sample_count:
raise ValueError(f"stts describes {pos} samples, stsz describes {sample_count}")
return values
def _parse_stsz(payload: bytes) -> np.ndarray:
sample_size, sample_count = struct.unpack_from(">II", payload, 4)
if sample_size:
return np.full(sample_count, sample_size, dtype=np.int64)
offset = 12
values = np.empty(sample_count, dtype=np.int64)
for idx in range(sample_count):
values[idx] = struct.unpack_from(">I", payload, offset)[0]
offset += 4
return values
def _parse_stsc(payload: bytes) -> list[tuple[int, int, int]]:
count = struct.unpack_from(">I", payload, 4)[0]
out = []
offset = 8
for _ in range(count):
out.append(struct.unpack_from(">III", payload, offset))
offset += 12
return out
def _parse_chunk_offsets(stbl: bytes) -> np.ndarray:
with_stco = None
with_co64 = None
for box in _children(stbl, 0, len(stbl)):
if box.type == b"stco":
with_stco = stbl[box.payload_start : box.end]
elif box.type == b"co64":
with_co64 = stbl[box.payload_start : box.end]
if with_co64 is not None:
count = struct.unpack_from(">I", with_co64, 4)[0]
return np.array(
[struct.unpack_from(">Q", with_co64, 8 + idx * 8)[0] for idx in range(count)], dtype=np.int64
)
if with_stco is None:
raise ValueError("Missing stco/co64 chunk offsets")
count = struct.unpack_from(">I", with_stco, 4)[0]
return np.array(
[struct.unpack_from(">I", with_stco, 8 + idx * 4)[0] for idx in range(count)], dtype=np.int64
)
def _parse_stss(stbl: bytes, sample_count: int) -> np.ndarray:
for box in _children(stbl, 0, len(stbl)):
if box.type == b"stss":
payload = stbl[box.payload_start : box.end]
count = struct.unpack_from(">I", payload, 4)[0]
return np.array(
[struct.unpack_from(">I", payload, 8 + idx * 4)[0] - 1 for idx in range(count)],
dtype=np.int64,
)
return np.arange(sample_count, dtype=np.int64)
def _sample_offsets(
stsc: list[tuple[int, int, int]], chunk_offsets: np.ndarray, sample_sizes: np.ndarray
) -> np.ndarray:
if not stsc:
raise ValueError("stsc is empty")
offsets = np.empty(len(sample_sizes), dtype=np.int64)
sample_idx = 0
for entry_idx, (first_chunk, samples_per_chunk, _desc_idx) in enumerate(stsc):
next_first = stsc[entry_idx + 1][0] if entry_idx + 1 < len(stsc) else len(chunk_offsets) + 1
for chunk_number in range(first_chunk, next_first):
if chunk_number < 1 or chunk_number > len(chunk_offsets):
raise ValueError("stsc references a chunk outside stco/co64")
chunk_pos = int(chunk_offsets[chunk_number - 1])
for _ in range(samples_per_chunk):
if sample_idx >= len(sample_sizes):
return offsets
offsets[sample_idx] = chunk_pos
chunk_pos += int(sample_sizes[sample_idx])
sample_idx += 1
if sample_idx != len(sample_sizes):
raise ValueError(f"stsc describes {sample_idx} samples, stsz describes {len(sample_sizes)}")
return offsets
def _make_moov(
index: Mp4Index,
durations: np.ndarray,
sizes: np.ndarray,
rel_offsets: np.ndarray,
sync_samples: np.ndarray,
*,
mdat_data_offset: int,
) -> bytes:
duration = int(durations.sum())
stco_values = [int(mdat_data_offset + value) for value in rel_offsets]
if any(value > 0xFFFFFFFF for value in stco_values):
offset_box = _co64(stco_values)
else:
offset_box = _stco(stco_values)
stbl = _box(
b"stbl",
_box(b"stsd", index.stsd_body)
+ _stts(durations)
+ _stsc_one_sample_per_chunk(len(sizes))
+ _stsz(sizes)
+ offset_box
+ (_stss(sync_samples) if len(sync_samples) else b""),
)
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
mdia = _box(b"mdia", _mdhd(index.timescale, duration) + _hdlr() + minf)
trak = _box(b"trak", _tkhd(index.track_id, duration, index.width, index.height) + mdia)
return _box(b"moov", _mvhd(index.timescale, duration, index.track_id + 1) + trak)
def _full_box(typ: bytes, version: int, flags: int, payload: bytes = b"") -> bytes:
return _box(typ, bytes([version]) + flags.to_bytes(3, "big") + payload)
def _box(typ: bytes, payload: bytes) -> bytes:
size = len(payload) + 8
if size <= 0xFFFFFFFF:
return struct.pack(">I4s", size, typ) + payload
return struct.pack(">I4sQ", 1, typ, size + 8) + payload
def _mvhd(timescale: int, duration: int, next_track_id: int) -> bytes:
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
payload = (
struct.pack(">IIII", 0, 0, timescale, duration)
+ struct.pack(">IHH", 0x00010000, 0x0100, 0)
+ b"\0" * 8
+ matrix
+ b"\0" * 24
+ struct.pack(">I", next_track_id)
)
return _full_box(b"mvhd", 0, 0, payload)
def _tkhd(track_id: int, duration: int, width: int, height: int) -> bytes:
matrix = struct.pack(">9I", 0x00010000, 0, 0, 0, 0x00010000, 0, 0, 0, 0x40000000)
payload = (
struct.pack(">IIIII", 0, 0, track_id, 0, duration)
+ b"\0" * 8
+ struct.pack(">hhhh", 0, 0, 0, 0)
+ matrix
+ struct.pack(">II", width << 16, height << 16)
)
return _full_box(b"tkhd", 0, 7, payload)
def _mdhd(timescale: int, duration: int) -> bytes:
return _full_box(b"mdhd", 0, 0, struct.pack(">IIIIH", 0, 0, timescale, duration, 0x55C4) + b"\0\0")
def _hdlr() -> bytes:
return _full_box(b"hdlr", 0, 0, b"\0" * 4 + b"vide" + b"\0" * 12 + b"VideoHandler\0")
def _vmhd() -> bytes:
return _full_box(b"vmhd", 0, 1, struct.pack(">HHHH", 0, 0, 0, 0))
def _dinf() -> bytes:
url = _full_box(b"url ", 0, 1)
dref = _full_box(b"dref", 0, 0, struct.pack(">I", 1) + url)
return _box(b"dinf", dref)
def _stts(durations: np.ndarray) -> bytes:
runs = []
for duration in durations.tolist():
if runs and runs[-1][1] == int(duration):
runs[-1][0] += 1
else:
runs.append([1, int(duration)])
payload = struct.pack(">I", len(runs)) + b"".join(
struct.pack(">II", count, delta) for count, delta in runs
)
return _full_box(b"stts", 0, 0, payload)
def _stsc_one_sample_per_chunk(sample_count: int) -> bytes:
return _full_box(b"stsc", 0, 0, struct.pack(">IIII", 1, 1, 1, 1))
def _stsz(sizes: np.ndarray) -> bytes:
return _full_box(
b"stsz",
0,
0,
struct.pack(">II", 0, len(sizes)) + b"".join(struct.pack(">I", int(size)) for size in sizes.tolist()),
)
def _stco(values: list[int]) -> bytes:
return _full_box(
b"stco", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">I", v) for v in values)
)
def _co64(values: list[int]) -> bytes:
return _full_box(
b"co64", 0, 0, struct.pack(">I", len(values)) + b"".join(struct.pack(">Q", v) for v in values)
)
def _stss(values: np.ndarray) -> bytes:
return _full_box(
b"stss",
0,
0,
struct.pack(">I", len(values)) + b"".join(struct.pack(">I", int(value)) for value in values.tolist()),
)
@@ -68,6 +68,6 @@ class UnitreeG1Config(RobotConfig):
# Compensates for gravity on the unitree's arms using the arm ik solver
gravity_compensation: bool = False
# Locomotion controller class name, e.g. "GrootLocomotionController",
# "HolosomaLocomotionController", or "SonicWholeBodyController". None disables it.
# Lower-body controller class name, e.g. "GrootLocomotionController" or
# "HolosomaLocomotionController". None disables it.
controller: str | None = None
@@ -1,8 +0,0 @@
"""Unitree G1 locomotion controllers (Groot, Holosoma, SONIC)."""
__all__ = [
"GrootLocomotionController",
"HolosomaLocomotionController",
"SonicWholeBodyController",
"SonicRuntime",
]
@@ -1,941 +0,0 @@
"""SONIC planner pipeline: ONNX enc/dec/planner, movement state, and input helpers."""
import math
import queue
import select
import struct
import sys
import termios
import threading
import time
import tty
from dataclasses import dataclass
from enum import IntEnum
import numpy as np
import onnxruntime as ort
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
# ── Constants ────────────────────────────────────────────────────────────────
DEFAULT_ANGLES = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
0.0, 0.0, 0.0,
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
], dtype=np.float32)
NATURAL_FREQ = 10.0 * 2.0 * np.pi
ARMATURE = {"5020": 0.003609725, "7520_14": 0.010177520, "7520_22": 0.025101925, "4010": 0.00425}
EFFORT = {"5020": 25.0, "7520_14": 88.0, "7520_22": 139.0, "4010": 5.0}
def _action_scale(k):
return 0.25 * EFFORT[k] / (ARMATURE[k] * NATURAL_FREQ**2)
_J = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
["7520_14","5020","5020"] + \
["5020","5020","5020","5020","5020","4010","4010"] * 2
ACTION_SCALE = np.array([_action_scale(k) for k in _J], dtype=np.float32)
CONTROL_DT = 0.02
DEFAULT_HEIGHT = 0.788740
TOKEN_DIM = 64
ENCODER_UPDATE_EVERY = 5
DEBUG_PRINT_EVERY = 100
MOTION_LOOK_AHEAD_STEPS = 2
INITIAL_RANDOM_SEED = 1234
MIN_TOKENS, MAX_TOKENS = 6, 16
K = MAX_TOKENS - MIN_TOKENS + 1
DEADZONE = 0.05
BLEND_FRAMES = 8
REPLAN_INTERVAL = {
"running": 0.1, "crawling": 0.2, "boxing": 1.0, "default": 1.0
}
ISAACLAB_TO_MUJOCO = np.array([
0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8,
11, 15, 19, 21, 23, 25, 27, 12, 16, 20, 22, 24, 26, 28
], dtype=np.int32)
MUJOCO_TO_ISAACLAB = np.array([
0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10,
16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28
], dtype=np.int32)
def _to_mujoco(a): return a[MUJOCO_TO_ISAACLAB]
def _to_runtime(a): r = np.zeros(29, np.float32); r[MUJOCO_TO_ISAACLAB] = a; return r
DEFAULT_ANGLES_MUJOCO = _to_mujoco(DEFAULT_ANGLES)
ENCODER_STANDING_REF = DEFAULT_ANGLES.copy()
LOWER_BODY_IL = np.array([0,3,6,9,13,17,1,4,7,10,14,18], dtype=np.int32)
WRIST_IL = np.array([23,24,25,26,27,28], dtype=np.int32)
VR_TARGET_DEF = np.zeros(9, dtype=np.float32)
VR_ORN_DEF = np.array([1,0,0,0,1,0,0,0,1,0,0,0], dtype=np.float32)
SMPL_DEF = np.zeros(720, dtype=np.float32)
# ── PD gains ─────────────────────────────────────────────────────────────────
def compute_kp_kd():
s = lambda k: ARMATURE[k] * NATURAL_FREQ**2
d = lambda k: 2.0 * 2.0 * ARMATURE[k] * NATURAL_FREQ
_kp_keys = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
["7520_14","5020","5020"] + \
["5020","5020","5020","5020","5020","4010","4010"] * 2
_kd_keys = _kp_keys
_double = {4,5,10,11,13,14} # ankle + waist indices with factor 2
kp = np.array([2*s(k) if i in _double else s(k) for i,k in enumerate(_kp_keys)], dtype=np.float32)
kd = np.array([2*d(k) if i in _double else d(k) for i,k in enumerate(_kd_keys)], dtype=np.float32)
return kp, kd
_kp_kd = compute_kp_kd # backward-compatible alias
def lowstate_to_obs(lowstate) -> dict:
"""Build a robot observation dict from Unitree lowstate."""
obs: dict = {}
for motor in G1_29_JointIndex:
idx = motor.value
obs[f"{motor.name}.q"] = float(lowstate.motor_state[idx].q)
obs[f"{motor.name}.dq"] = float(lowstate.motor_state[idx].dq)
quat = lowstate.imu_state.quaternion
obs["imu.quat.w"] = float(quat[0])
obs["imu.quat.x"] = float(quat[1])
obs["imu.quat.y"] = float(quat[2])
obs["imu.quat.z"] = float(quat[3])
gyro = lowstate.imu_state.gyroscope
obs["imu.gyro.x"] = float(gyro[0])
obs["imu.gyro.y"] = float(gyro[1])
obs["imu.gyro.z"] = float(gyro[2])
wr = getattr(lowstate, "wireless_remote", None)
if wr is not None:
obs["wireless_remote"] = bytes(wr) if not isinstance(wr, (bytes, bytearray)) else wr
return obs
# ── Quaternion helpers ────────────────────────────────────────────────────────
def quat_conj(q):
return np.array([q[0], -q[1], -q[2], -q[3]], dtype=np.float32)
def quat_mul(q1, q2):
w1,x1,y1,z1 = q1; w2,x2,y2,z2 = q2
return np.array([
w1*w2 - x1*x2 - y1*y2 - z1*z2,
w1*x2 + x1*w2 + y1*z2 - z1*y2,
w1*y2 - x1*z2 + y1*w2 + z1*x2,
w1*z2 + x1*y2 - y1*x2 + z1*w2,
], dtype=np.float32)
def gravity_dir(q):
q = q / (np.linalg.norm(q) + 1e-8)
qv = np.array([0, 0, 0, -1], dtype=np.float32)
return quat_mul(quat_mul(quat_conj(q), qv), q)[1:]
def quat_to_6d(q):
w,x,y,z = q
return np.array([
1-2*(y*y+z*z), 2*(x*y-z*w),
2*(x*y+z*w), 1-2*(x*x+z*z),
2*(x*z-y*w), 2*(y*z+x*w),
], dtype=np.float32)
def calc_heading(q):
w,x,y,z = q
return float(np.arctan2(2*(x*y + w*z), 1-2*(y*y+z*z)))
def heading_quat(q, sign=1.0):
a = sign * calc_heading(q) / 2.0
return np.array([np.cos(a), 0, 0, np.sin(a)], dtype=np.float64)
heading_quat_inv = lambda q: heading_quat(q, -1.0)
def quat_slerp(q0, q1, t):
q0 = q0 / (np.linalg.norm(q0)+1e-12); q1 = q1 / (np.linalg.norm(q1)+1e-12)
dot = float(np.dot(q0, q1))
if dot < 0: q1, dot = -q1, -dot
dot = min(dot, 1.0)
if dot > 0.9995:
r = q0 + t*(q1-q0); return r/(np.linalg.norm(r)+1e-12)
th = np.arccos(dot); st = np.sin(th)
return (np.sin((1-t)*th)/st)*q0 + (np.sin(t*th)/st)*q1
def quat_slerp_batch(q0, q1, t):
q0 = q0 / (np.linalg.norm(q0,axis=1,keepdims=True)+1e-12)
q1 = q1 / (np.linalg.norm(q1,axis=1,keepdims=True)+1e-12)
dot = np.sum(q0*q1, axis=1); neg = dot<0
q1=q1.copy(); q1[neg]=-q1[neg]; dot[neg]=-dot[neg]; dot=np.clip(dot,-1,1)
lin = dot>0.9995; th=np.arccos(dot); st=np.where(np.sin(th)==0,1,np.sin(th))
c0=np.sin((1-t)*th)/st; c1=np.sin(t*th)/st
c0[lin]=1-t[lin]; c1[lin]=t[lin]
r = c0[:,None]*q0 + c1[:,None]*q1
return r / (np.linalg.norm(r,axis=1,keepdims=True)+1e-12)
# ── Locomotion modes ──────────────────────────────────────────────────────────
class LocomotionMode(IntEnum):
IDLE=0; SLOW_WALK=1; WALK=2; RUN=3; SQUAT=4; KNEEL_TWO_LEGS=5; KNEEL=6
LYING_FACE_DOWN=7; CRAWLING=8; IDLE_BOXING=9; WALK_BOXING=10
LEFT_PUNCH=11; RIGHT_PUNCH=12; RANDOM_PUNCH=13; ELBOW_CRAWLING=14
LEFT_HOOK=15; RIGHT_HOOK=16; FORWARD_JUMP=17; STEALTH_WALK=18
INJURED_WALK=19; LEDGE_WALKING=20; OBJECT_CARRYING=21; STEALTH_WALK_2=22
HAPPY_DANCE_WALK=23; ZOMBIE_WALK=24; GUN_WALK=25; SCARE_WALK=26
LM = LocomotionMode
MOTION_SETS = [
("Standing", [LM.SLOW_WALK, LM.WALK, LM.RUN, LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK]),
("Squat / Low", [LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.CRAWLING, LM.ELBOW_CRAWLING]),
("Boxing", [LM.IDLE_BOXING, LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK]),
("Styled Walks", [LM.LEDGE_WALKING, LM.OBJECT_CARRYING, LM.STEALTH_WALK_2,
LM.HAPPY_DANCE_WALK, LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK]),
]
STATIC_MODES = {LM.IDLE, LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.LYING_FACE_DOWN, LM.IDLE_BOXING}
STANDING_MODES = {LM.IDLE, LM.SLOW_WALK, LM.WALK, LM.RUN, LM.IDLE_BOXING, LM.WALK_BOXING,
LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK,
LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK, LM.LEDGE_WALKING,
LM.OBJECT_CARRYING, LM.STEALTH_WALK_2, LM.HAPPY_DANCE_WALK,
LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK}
BOXING_MODES = {LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}
SPEED_RANGES = {LM.SLOW_WALK:(0.2,0.8), LM.WALK:(0.8,1.5), LM.RUN:(1.5,3.0),
LM.CRAWLING:(0.4,1.0), LM.ELBOW_CRAWLING:(0.7,1.0)}
def clamp_mode_params(ms):
m = LM(ms.mode)
ms.height = -1.0 if m in STANDING_MODES else max(0.1, min(0.8, ms.height if ms.height>=0 else 0.2))
if m in STATIC_MODES:
ms.speed = -1.0
elif m in SPEED_RANGES:
lo, hi = SPEED_RANGES[m]
ms.speed = max(lo, min(hi, ms.speed if ms.speed>=0 else lo))
elif m in BOXING_MODES:
ms.speed = max(0.7, min(1.5, ms.speed if ms.speed>=0 else 0.7))
else:
ms.speed = -1.0
def replan_interval(mode):
m = LM(mode)
if m == LM.RUN: return REPLAN_INTERVAL["running"]
if m == LM.CRAWLING: return REPLAN_INTERVAL["crawling"]
if m in {LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}:
return REPLAN_INTERVAL["boxing"]
return REPLAN_INTERVAL["default"]
def _ort_providers(force_cpu: bool = False) -> list[str]:
"""Prefer CUDA for enc/dec/planner (matches deploy when onnxruntime-gpu is installed)."""
avail = ort.get_available_providers()
if not force_cpu and "CUDAExecutionProvider" in avail:
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
# ── Movement state ────────────────────────────────────────────────────────────
@dataclass
class MovementState:
mode: int = LM.SLOW_WALK # not IDLE — walking modes respond to WASD
speed: float = -1.0
height: float = -1.0
facing_angle: float = 0.0
movement_angle: float = 0.0
has_movement: bool = False
motion_set_idx: int = 0
needs_replan: bool = False
@property
def movement_direction(self):
if not self.has_movement: return (0.0, 0.0, 0.0)
return (math.cos(self.movement_angle), math.sin(self.movement_angle), 0.0)
@property
def facing_direction(self):
return (math.cos(self.facing_angle), math.sin(self.facing_angle), 0.0)
def status_line(self):
return (f"[{MOTION_SETS[self.motion_set_idx][0]}] mode={self.mode}({LM(self.mode).name}) "
f"spd={'default' if self.speed<0 else f'{self.speed:.1f}'} "
f"hgt={'default' if self.height<0 else f'{self.height:.2f}'} "
f"facing={math.degrees(self.facing_angle):.0f}° "
f"{'moving' if self.has_movement else 'still'}")
@dataclass
class MovementSnapshot:
mode: int = 0
speed: float = -1.0
height: float = -1.0
movement: tuple[float, float, float] = (0.0, 0.0, 0.0)
facing: tuple[float, float, float] = (1.0, 0.0, 0.0)
def _snapshot_ms(ms: MovementState) -> MovementSnapshot:
md, fd = ms.movement_direction, ms.facing_direction
return MovementSnapshot(ms.mode, ms.speed, ms.height, (md[0], md[1], md[2]), (fd[0], fd[1], fd[2]))
def should_replan_request(ms: MovementState, last: MovementSnapshot, replan_timer: float, step: int) -> bool:
"""Match C++ G1Deploy::Planner replan triggers (g1_deploy_onnx_ref.cpp)."""
if step <= 0:
return False
if ms.needs_replan:
return True
md, fd = ms.movement_direction, ms.facing_direction
facing_changed = fd != last.facing
height_changed = ms.height != last.height
mode_changed = ms.mode != last.mode
speed_changed = ms.speed != last.speed
dir_changed = md != last.movement
is_static = LM(ms.mode) in STATIC_MODES
if mode_changed or facing_changed or height_changed:
return True
time_to_replan = replan_timer >= replan_interval(ms.mode)
if not is_static and (speed_changed or dir_changed or (time_to_replan and ms.speed != 0)):
return True
return False
# ── Encoder / Decoder ─────────────────────────────────────────────────────────
class StandingEncoderDecoder:
def __init__(self, encoder, decoder):
self.encoder, self.decoder = encoder, decoder
self.encoder_input = encoder.get_inputs()[0].name
self.decoder_input = decoder.get_inputs()[0].name
enc_dim = int(encoder.get_inputs()[0].shape[1])
dec_dim = int(decoder.get_inputs()[0].shape[1])
if enc_dim != 1762 or dec_dim != 994:
raise RuntimeError(f"Unexpected dims encoder={enc_dim}, decoder={dec_dim}")
self.token = np.zeros(TOKEN_DIM, np.float32)
self.last_action_mj = np.zeros(29, np.float32)
self.h_q_mj = [np.zeros(29, np.float32)] * 10
self.h_dq_mj = [np.zeros(29, np.float32)] * 10
self.h_ang = [np.zeros(3, np.float32)] * 10
self.h_act_mj = [np.zeros(29, np.float32)] * 10
self.h_quat = [np.array([1,0,0,0], np.float32)] * 10
self.init_base_quat = np.array([1,0,0,0], np.float32)
self.init_ref_quat = np.array([1,0,0,0], np.float32)
self._heading_init = False
self.encode_mode = 0
self.vr_3point_local_target = VR_TARGET_DEF.copy()
self.vr_3point_local_orn_target = VR_ORN_DEF.copy()
self.smpl_joints_10frame_step1 = SMPL_DEF.copy()
self.set_zero_reference()
def update_history(self, q, dq, ang, quat):
quat = quat / (np.linalg.norm(quat)+1e-8)
q_mj = _to_mujoco(q); dq_mj = _to_mujoco(dq)
self.h_q_mj = [q_mj - DEFAULT_ANGLES_MUJOCO] + self.h_q_mj[:-1]
self.h_dq_mj = [dq_mj] + self.h_dq_mj[:-1]
self.h_ang = [ang.copy()] + self.h_ang[:-1]
self.h_act_mj = [self.last_action_mj.copy()] + self.h_act_mj[:-1]
self.h_quat = [quat.copy()] + self.h_quat[:-1]
if not self._heading_init:
self.init_base_quat = quat.copy(); self._heading_init = True
def _heading_quat(self, q):
h = calc_heading(q) / 2.0
return np.array([np.cos(h), 0, 0, np.sin(h)], np.float32)
def _heading_quat_inv(self, q):
h = calc_heading(q) / 2.0
return np.array([np.cos(-h), 0, 0, np.sin(-h)], np.float32)
def _anchor_6d(self, base_quat, ref_quat=None):
if ref_quat is None: ref_quat = self.init_ref_quat
delta = quat_mul(self._heading_quat(self.init_base_quat), self._heading_quat_inv(self.init_ref_quat))
new_ref = quat_mul(delta, ref_quat)
return quat_to_6d(quat_mul(quat_conj(base_quat), new_ref))
def set_zero_reference(self):
self.motion_joint_positions = [ENCODER_STANDING_REF.copy()]
self.motion_joint_velocities = [np.zeros(29, np.float32)]
self.motion_body_quats = [np.array([1,0,0,0], np.float32)]
self.motion_body_z = [DEFAULT_HEIGHT]
self.motion_timesteps = 1
self.freeze_ref_frame = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
def build_encoder_obs(self):
obs = np.zeros(1762, np.float32)
obs[0] = float(self.encode_mode)
rf = min(self.freeze_ref_frame, self.motion_timesteps - 1)
ref_pos, ref_quat = self.motion_joint_positions[rf], self.motion_body_quats[rf]
if self.encode_mode == 0:
for f in range(10):
obs[4+29*f:4+29*(f+1)] = ref_pos
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
elif self.encode_mode == 1:
ref_lower = ref_pos[LOWER_BODY_IL]
for f in range(10):
obs[661+12*f:661+12*(f+1)] = ref_lower
obs[901:910] = self.vr_3point_local_target
obs[910:922] = self.vr_3point_local_orn_target
obs[595:601] = self._anchor_6d(self.h_quat[0], ref_quat)
elif self.encode_mode == 2:
obs[922:1642] = self.smpl_joints_10frame_step1
for f in range(10):
obs[1642+6*f:1642+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
obs[1702+6*f:1702+6*(f+1)] = ref_pos[WRIST_IL]
else:
raise RuntimeError(f"Unsupported encoder mode: {self.encode_mode}")
return obs
def build_decoder_obs(self):
obs = np.zeros(994, np.float32); off = 0
obs[off:off+64] = self.token; off += 64
for h, sz in [(list(reversed(self.h_ang)),3), (list(reversed(self.h_q_mj)),29),
(list(reversed(self.h_dq_mj)),29), (list(reversed(self.h_act_mj)),29)]:
for f in range(10): obs[off:off+sz] = h[f]; off += sz
for q in reversed(self.h_quat):
obs[off:off+3] = gravity_dir(q); off += 3
assert off == 994, f"Decoder obs mismatch: {off}"
return obs
def run_encoder(self):
return self.encoder.run(None, {self.encoder_input: self.build_encoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
def step(self, robot_obs, update_encoder, debug=False):
jnames = [m.name for m in G1_29_JointIndex]
q = np.array([robot_obs.get(f"{n}.q", DEFAULT_ANGLES[m.value]) for m,n in zip(G1_29_JointIndex,jnames)], np.float32)
dq = np.array([robot_obs.get(f"{n}.dq", 0.0) for n in jnames], np.float32)
quat = np.array([robot_obs.get("imu.quat.w",1), robot_obs.get("imu.quat.x",0),
robot_obs.get("imu.quat.y",0), robot_obs.get("imu.quat.z",0)], np.float32)
ang = np.array([robot_obs.get(f"imu.gyro.{a}",0) for a in "xyz"], np.float32)
self.update_history(q, dq, ang, quat)
if update_encoder: self.token = self.run_encoder()
action_mj = self.decoder.run(None, {self.decoder_input: self.build_decoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
self.last_action_mj = action_mj.copy()
target = DEFAULT_ANGLES + action_mj[ISAACLAB_TO_MUJOCO] * ACTION_SCALE
if debug:
delta = target - q
print(f"token_norm={np.linalg.norm(self.token):.4f} action_norm={np.linalg.norm(action_mj):.4f} "
f"delta_max={np.max(np.abs(delta)):.4f} delta_rms={np.sqrt(np.mean(delta**2)):.4f}")
return {f"{m.name}.q": float(target[m.value]) for m in G1_29_JointIndex}
def print_input_diagnostics(self):
print("\n[Diag] Standing reference checks")
names = {0:"g1", 1:"teleop", 2:"smpl"}
print(f" encoder mode: {self.encode_mode} ({names.get(self.encode_mode,'unknown')})")
print(f" DEFAULT_ANGLES range: [{DEFAULT_ANGLES.min():+.4f}, {DEFAULT_ANGLES.max():+.4f}]")
print(f" anchor_6d(identity): {self._anchor_6d(np.array([1,0,0,0],np.float32), np.array([1,0,0,0],np.float32))}")
print(f" gravity(identity): {gravity_dir(np.array([1,0,0,0],np.float32))} (expect [0,0,-1])")
dec0 = self.build_decoder_obs()
print(f" decoder q-delta max: {np.max(np.abs(dec0[94:384])):.6f}")
print(f" decoder dq max: {np.max(np.abs(dec0[384:674])):.6f}")
# ── Planner motion buffer ─────────────────────────────────────────────────────
class PlannerMotion:
def __init__(self, max_frames=1500):
self.timesteps = 0
self.joint_positions = np.zeros((max_frames, 29), np.float64)
self.joint_velocities = np.zeros((max_frames, 29), np.float64)
self.body_positions = np.zeros((max_frames, 3), np.float64)
self.body_quaternions = np.zeros((max_frames, 4), np.float64)
self.body_quaternions[:, 0] = 1.0
# ── Subprocess planner ────────────────────────────────────────────────────────
def _resample_30_to_50(qpos, n30):
t50 = int(np.floor(n30 / 30.0 * 50))
f30 = np.arange(t50) / 50.0 * 30.0
f0 = np.floor(f30).astype(int)
f1 = np.minimum(f0+1, n30-1)
frac, w0 = (f30-f0).astype(np.float64), None
w0 = 1.0 - frac
jp = (w0[:,None]*qpos[f0,7:36] + frac[:,None]*qpos[f1,7:36])[:,MUJOCO_TO_ISAACLAB]
jv = np.zeros_like(jp)
if t50 >= 2: jv[:t50-1] = (jp[1:] - jp[:-1]) * 50.0; jv[-1] = jv[-2]
return {
"timesteps": t50,
"joint_positions": jp,
"joint_velocities": jv,
"body_positions": w0[:,None]*qpos[f0,:3] + frac[:,None]*qpos[f1,:3],
"body_quaternions": quat_slerp_batch(qpos[f0,3:7], qpos[f1,3:7], frac),
}
def _build_planner_inputs(ctx, ms_dict, version, seed):
inp = {
"context_mujoco_qpos": ctx.astype(np.float32).reshape(1,4,36),
"target_vel": np.array([ms_dict["speed"]], np.float32),
"mode": np.array([ms_dict["mode"]], np.int64),
"movement_direction": np.array(ms_dict["movement_direction"], np.float32).reshape(1,3),
"facing_direction": np.array(ms_dict["facing_direction"], np.float32).reshape(1,3),
"random_seed": np.array([seed], np.int64),
}
if version >= 1:
# TensorRT deploy: allow 911 prediction tokens only (indices 35 for MIN_TOKENS=6).
allowed = np.zeros((1, K), np.int64)
if K >= 6:
allowed[0, 3:6] = 1
inp.update({
"height": np.array([ms_dict["height"]], np.float32),
"has_specific_target": np.array([[0]], np.int64),
"specific_target_positions": np.zeros((1,4,3), np.float32),
"specific_target_headings": np.zeros((1,4), np.float32),
"allowed_pred_num_tokens": allowed,
})
return inp
def _planner_worker(path, req_q, res_q, stop_evt, version, seed, use_gpu):
so = ort.SessionOptions(); so.log_severity_level = 3
providers = _ort_providers(force_cpu=not use_gpu)
sess = ort.InferenceSession(path, sess_options=so, providers=providers)
while not stop_evt.is_set():
try: ctx, gf, ms_dict = req_q.get(timeout=0.05)
except Exception: continue
try:
inp = _build_planner_inputs(ctx, ms_dict, version, seed)
t0 = time.time()
qpos_out, num_pred = sess.run(None, inp)
t_inf = time.time()
n = int(num_pred.flat[0])
qpos = qpos_out[0,:n]
if np.any(np.isnan(qpos)): continue
motion = _resample_30_to_50(qpos, n)
motion["gen_frame"] = gf
print(f"[Planner] inf={1000*(t_inf-t0):.1f}ms total={1000*(time.time()-t0):.1f}ms frames={n}", flush=True)
while not res_q.empty():
try: res_q.get_nowait()
except queue.Empty: break
res_q.put(motion)
except Exception as e:
print(f"[Planner] Error: {e}", flush=True)
# ── SonicPlanner ──────────────────────────────────────────────────────────────
class SonicPlanner:
def __init__(self, session, planner_path):
self.session = session
self.planner_path = planner_path
self.gen_frame = 0
self.random_seed = INITIAL_RANDOM_SEED
self.version = 1 if len(session.get_inputs()) >= 11 else 0
self.motion_50hz = PlannerMotion()
self._snapshot = PlannerMotion()
self._req_q = self._res_q = self._stop_evt = self._planner_thread = None
self._ctrl = None
def _build_inputs(self, ctx, ms):
return _build_planner_inputs(
ctx,
{"mode": ms.mode, "speed": ms.speed, "height": ms.height,
"movement_direction": list(ms.movement_direction),
"facing_direction": list(ms.facing_direction)},
self.version, self.random_seed)
@staticmethod
def build_initial_context(joint_positions):
ctx = np.zeros((4, 36), np.float32)
jp_mj = joint_positions.astype(np.float32)[ISAACLAB_TO_MUJOCO]
for n in range(4):
ctx[n, 2] = DEFAULT_HEIGHT
ctx[n, 3] = 1.0
ctx[n, 7:36] = jp_mj
return ctx
def _context_from_controller(self, current_frame):
ctrl = self._ctrl
gen_frame = current_frame + MOTION_LOOK_AHEAD_STEPS
t_arr = gen_frame / 50.0 + np.arange(4) / 30.0
f50 = t_arr * 50.0
with ctrl.motion_lock:
ts = ctrl.motion_timesteps
if ts <= 0:
return self.build_initial_context(DEFAULT_ANGLES)
bp, bq, jp = ctrl.motion_body_pos, ctrl.motion_body_quats, ctrl.motion_joint_positions
f0 = np.minimum(np.floor(f50).astype(int), ts - 1)
f1 = np.minimum(f0 + 1, ts - 1)
frac = f50 - f0
w0 = 1.0 - frac
ctx = np.zeros((4, 36), np.float32)
ctx[:, 0:3] = w0[:, None] * bp[f0] + frac[:, None] * bp[f1]
ctx[:, 3:7] = quat_slerp_batch(bq[f0], bq[f1], frac)
ij = w0[:, None] * jp[f0] + frac[:, None] * jp[f1]
ctx[:, 7:36] = ij[:, ISAACLAB_TO_MUJOCO]
self.gen_frame = gen_frame
return ctx
def _load_motion_in_place(self, qpos, n30, target=None):
if target is None: target = self.motion_50hz
r = _resample_30_to_50(qpos, n30)
n = r["timesteps"]; target.timesteps = n
target.joint_positions[:n] = r["joint_positions"]
target.joint_velocities[:n] = r["joint_velocities"]
target.body_positions[:n] = r["body_positions"]
target.body_quaternions[:n] = r["body_quaternions"]
return target
def initialize(self, joint_positions, ms):
ctx = self.build_initial_context(joint_positions)
qpos_out, num_pred = self.session.run(None, self._build_inputs(ctx, ms))
n = int(num_pred.flat[0]); qpos = qpos_out[0,:n]
if np.any(np.isnan(qpos)): raise RuntimeError("Planner initial output contains NaN")
print(f"[Planner] Init: {n} frames @ 30 Hz")
self._load_motion_in_place(qpos, n)
print(f"[Planner] Resampled to {self.motion_50hz.timesteps} frames @ 50 Hz")
return self.motion_50hz
def request_replan(self, cursor, ms):
if self._req_q is None: return
ctx = self._context_from_controller(cursor)
ms_dict = {"mode": ms.mode, "speed": ms.speed, "height": ms.height,
"movement_direction": list(ms.movement_direction),
"facing_direction": list(ms.facing_direction)}
while not self._req_q.empty():
try: self._req_q.get_nowait()
except queue.Empty: break
self._req_q.put((ctx, self.gen_frame, ms_dict))
def try_get_new_motion(self):
if self._res_q is None: return None
result = None
while not self._res_q.empty():
try: result = self._res_q.get_nowait()
except queue.Empty: break
if result is None: return None
n, gf = result["timesteps"], result["gen_frame"]
s = self._snapshot; s.timesteps = n
s.joint_positions[:n] = result["joint_positions"]
s.joint_velocities[:n] = result["joint_velocities"]
s.body_positions[:n] = result["body_positions"]
s.body_quaternions[:n] = result["body_quaternions"]
return s, gf
def start_subprocess(self, controller, use_gpu: bool = False):
"""Run planner ONNX in a background thread (avoids mp spawn/fork + CUDA/MuJoCo issues)."""
self._ctrl = controller
self._req_q = queue.Queue()
self._res_q = queue.Queue()
self._stop_evt = threading.Event()
self._planner_thread = threading.Thread(
target=_planner_worker,
args=(self.planner_path, self._req_q, self._res_q,
self._stop_evt, self.version, self.random_seed, use_gpu),
daemon=True,
name="sonic-planner",
)
self._planner_thread.start()
print(f"[Planner] Background thread started ({'GPU' if use_gpu else 'CPU'})")
def stop_subprocess(self):
if self._stop_evt:
self._stop_evt.set()
if self._planner_thread is not None:
self._planner_thread.join(timeout=3.0)
print("[Planner] Background thread stopped")
self._planner_thread = None
self._req_q = self._res_q = self._stop_evt = None
# ── PlannerController ─────────────────────────────────────────────────────────
class PlannerController(StandingEncoderDecoder):
def __init__(self, planner, encoder, decoder):
super().__init__(encoder, decoder)
self.planner = planner
self.ref_cursor = 0
self.motion_timesteps = 0
self.motion_joint_positions = np.zeros((1500,29), np.float64)
self.motion_joint_velocities = np.zeros((1500,29), np.float64)
self.motion_body_quats = np.zeros((1500,4), np.float64); self.motion_body_quats[:,0] = 1.0
self.motion_body_pos = np.zeros((1500,3), np.float64)
self.init_ref_quat = np.array([1,0,0,0], np.float64)
self.heading_init_base_quat = np.array([1,0,0,0], np.float64)
self.delta_heading = 0.0
self.reinit_heading = False
self.playing = self.first_motion = False
self.motion_lock = threading.Lock()
def load_initial_motion(self, motion):
with self.motion_lock:
n = motion.timesteps
self.motion_timesteps = n
self.motion_joint_positions[:n] = motion.joint_positions[:n]
self.motion_joint_velocities[:n] = motion.joint_velocities[:n]
self.motion_body_quats[:n] = motion.body_quaternions[:n]
self.motion_body_pos[:n] = motion.body_positions[:n]
self.init_ref_quat = motion.body_quaternions[0].copy()
self.ref_cursor = 0; self.first_motion = True
self.playing = True; self.delta_heading = 0.0
def blend_new_motion(self, new_motion, gen_frame):
"""Blend like C++ CurrentFrameAdvancement: 8-frame cross-fade, then copy tail."""
with self.motion_lock:
cur = self.ref_cursor
new_len = gen_frame - cur + new_motion.timesteps
if new_len <= 0:
return
if self.motion_timesteps == 0:
n = new_motion.timesteps
self.motion_joint_positions[:n] = new_motion.joint_positions[:n]
self.motion_joint_velocities[:n] = new_motion.joint_velocities[:n]
self.motion_body_pos[:n] = new_motion.body_positions[:n]
self.motion_body_quats[:n] = new_motion.body_quaternions[:n]
self.motion_timesteps = n
self.ref_cursor = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
self.first_motion = False
return
blend_start = max(0, gen_frame - cur)
blend_end = min(new_len, blend_start + BLEND_FRAMES)
for f in range(blend_end):
f_old = min(f + cur, self.motion_timesteps - 1)
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
w_new = min(1.0, max(0.0, (f - blend_start) / BLEND_FRAMES))
w_old = 1.0 - w_new
self.motion_joint_positions[f] = (
w_old * self.motion_joint_positions[f_old]
+ w_new * new_motion.joint_positions[f_new]
)
self.motion_joint_velocities[f] = (
w_old * self.motion_joint_velocities[f_old]
+ w_new * new_motion.joint_velocities[f_new]
)
self.motion_body_pos[f] = (
w_old * self.motion_body_pos[f_old]
+ w_new * new_motion.body_positions[f_new]
)
self.motion_body_quats[f] = quat_slerp(
self.motion_body_quats[f_old], new_motion.body_quaternions[f_new], w_new
)
for f in range(blend_end, new_len):
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
self.motion_joint_positions[f] = new_motion.joint_positions[f_new]
self.motion_joint_velocities[f] = new_motion.joint_velocities[f_new]
self.motion_body_pos[f] = new_motion.body_positions[f_new]
self.motion_body_quats[f] = new_motion.body_quaternions[f_new].copy()
self.motion_timesteps = new_len
self.first_motion = False
self.ref_cursor = 0
self.init_ref_quat = self.motion_body_quats[0].copy()
def _heading_apply_delta(self):
delta = quat_mul(heading_quat(self.heading_init_base_quat).astype(np.float32),
heading_quat_inv(self.init_ref_quat).astype(np.float32))
if self.delta_heading:
h = self.delta_heading / 2.0
delta = quat_mul(np.array([np.cos(h),0,0,np.sin(h)], np.float32), delta)
return delta
def _anchor_6d(self, base_quat, ref_quat=None):
if ref_quat is None: ref_quat = self.init_ref_quat
new_ref = quat_mul(self._heading_apply_delta(), ref_quat.astype(np.float32))
return quat_to_6d(quat_mul(quat_conj(base_quat.astype(np.float32)), new_ref))
def build_encoder_obs(self):
obs = np.zeros(1762, np.float32); obs[0] = float(self.encode_mode)
with self.motion_lock:
if self.encode_mode == 2:
# SMPL whole-body imitation: the 720-dim SMPL window carries the
# target pose; the planner reference frame supplies anchor + wrist.
rf = min(self.ref_cursor, self.motion_timesteps - 1)
ref_pos = self.motion_joint_positions[rf].astype(np.float32)
ref_quat = self.motion_body_quats[rf].astype(np.float32)
anchor = self._anchor_6d(self.h_quat[0], ref_quat)
wrist = ref_pos[WRIST_IL]
obs[922:1642] = self.smpl_joints_10frame_step1
for f in range(10):
obs[1642+6*f:1642+6*(f+1)] = anchor
obs[1702+6*f:1702+6*(f+1)] = wrist
return obs
for f in range(10):
tf = min(self.ref_cursor + f*5 if self.playing else self.ref_cursor,
self.motion_timesteps - 1)
obs[4+29*f:4+29*(f+1)] = self.motion_joint_positions[tf].astype(np.float32)
if self.playing:
obs[294+29*f:294+29*(f+1)] = self.motion_joint_velocities[tf].astype(np.float32)
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(
self.h_quat[0], self.motion_body_quats[tf].astype(np.float32))
return obs
def step(self, robot_obs, update_encoder, debug=False):
if robot_obs and (self.first_motion or self.reinit_heading):
q = None
if "imu.quat.w" in robot_obs:
q = np.array([
robot_obs["imu.quat.w"], robot_obs["imu.quat.x"],
robot_obs["imu.quat.y"], robot_obs["imu.quat.z"],
], np.float64)
else:
q = robot_obs.get("imu.quaternion")
if q is not None:
q = np.array(q, np.float64)
if q is not None:
self.heading_init_base_quat = np.array(q, np.float64)
with self.motion_lock:
rf = min(self.ref_cursor, self.motion_timesteps - 1)
self.init_ref_quat = self.motion_body_quats[rf].copy()
self.delta_heading = 0.0
self.first_motion = False
self.reinit_heading = False
print(f"[Heading] init quat: {self.heading_init_base_quat}")
return super().step(robot_obs, update_encoder=update_encoder, debug=debug)
def advance_cursor(self):
"""Advance one frame per 50 Hz tick (C++ current_frame_ += 1), no wall-clock catch-up."""
if not self.playing:
return
with self.motion_lock:
if self.motion_timesteps > 0:
self.ref_cursor = min(self.ref_cursor + 1, self.motion_timesteps - 1)
# ── Keyboard ──────────────────────────────────────────────────────────────────
class RawKeyboard:
def __init__(self):
self.fd = sys.stdin.fileno()
self.old = termios.tcgetattr(self.fd)
def __enter__(self): tty.setcbreak(self.fd); return self
def __exit__(self, *_): termios.tcsetattr(self.fd, termios.TCSADRAIN, self.old)
def get_key(self):
return sys.stdin.read(1) if select.select([sys.stdin],[],[],0)[0] else None
def drain_keyboard(kb, ms, controller=None) -> bool:
"""Process all pending terminal keys this frame (return True to quit)."""
quit_requested = False
while True:
key = kb.get_key()
if key is None:
break
if process_keyboard(key, ms, controller):
quit_requested = True
return quit_requested
def process_keyboard(key, ms, controller=None):
if key is None: return False
if key == '\x1b': return True
if key == ' ':
ms.mode = LM.IDLE; ms.speed = ms.height = -1.0
ms.has_movement = False; ms.needs_replan = True
if controller: controller.playing = False; controller.reinit_heading = True
print("\n >> EMERGENCY STOP -> IDLE"); return False
if key in ('r','R'):
ms.needs_replan = True; print("\n >> Manual replan"); return False
if key in ('m','M'):
if controller is not None and getattr(controller, "smpl_motion", None) is not None:
if controller.encode_mode == 2:
controller.encode_mode = 0
controller.playing = True; controller.reinit_heading = True
ms.needs_replan = True
print("\n >> Motion OFF -> locomotion (mode 0). WASD to drive.")
else:
controller.encode_mode = 2
controller.reinit_heading = True
controller.smpl_motion.reset()
print("\n >> Motion ON -> SMPL playback (mode 2)")
else:
print("\n >> No motion loaded (start with --motion-file)")
return False
if key in ('n','N','p','P'):
ms.motion_set_idx = (ms.motion_set_idx + (1 if key in ('n','N') else -1)) % len(MOTION_SETS)
name, modes = MOTION_SETS[ms.motion_set_idx]
print(f"\n >> Motion set: {name}")
[print(f" {i+1}: {m.name}") for i,m in enumerate(modes)]
return False
if key.isdigit() and key not in ('9','0'):
idx = int(key) - 1; modes = MOTION_SETS[ms.motion_set_idx][1]
if 0 <= idx < len(modes):
ms.mode = modes[idx]; ms.needs_replan = True
if controller: controller.playing = True; controller.reinit_heading = True
print(f"\n >> Mode: {LM(ms.mode).name} ({ms.mode}) [replanning...]")
return False
if key == '9':
ms.speed = max(0.0, (ms.speed if ms.speed>=0 else 1.0) - 0.1)
print(f"\n >> Speed: {ms.speed:.1f}"); return False
if key == '0':
ms.speed = min(5.0, (ms.speed if ms.speed>=0 else 1.0) + 0.1)
print(f"\n >> Speed: {ms.speed:.1f}"); return False
if key == '-':
ms.height = max(0.2, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) - 0.02)
print(f"\n >> Height: {ms.height:.2f}"); return False
if key == '=':
ms.height = min(1.0, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) + 0.02)
print(f"\n >> Height: {ms.height:.2f}"); return False
if key.lower() == 'w': ms.movement_angle = ms.facing_angle
elif key.lower() == 's': ms.movement_angle = ms.facing_angle + math.pi
elif key.lower() == 'a': ms.movement_angle = ms.facing_angle + math.pi/2
elif key.lower() == 'd': ms.movement_angle = ms.facing_angle - math.pi/2
if key.lower() in ('w','s','a','d'):
ms.has_movement = ms.needs_replan = True
if controller:
controller.playing = True
print(f"\n >> Move {key.upper()} (replanning...)")
elif key.lower() == 'q':
ms.facing_angle += 0.1
if controller: controller.delta_heading += 0.1
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
elif key.lower() == 'e':
ms.facing_angle -= 0.1
if controller: controller.delta_heading -= 0.1
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
return False
_joy_prev_active = False
def _parse_wireless(wr):
"""Parse wireless_remote (bytes or int-array) into (lx, ly, rx, ry)."""
import struct as _st
if not isinstance(wr, (bytes, bytearray)):
wr = bytes(wr)
if len(wr) < 24:
return None
lx = _st.unpack("f", wr[4:8])[0]
rx = _st.unpack("f", wr[8:12])[0]
ry = _st.unpack("f", wr[12:16])[0]
ly = _st.unpack("f", wr[20:24])[0]
return lx, ly, rx, ry
def process_joystick(obs, ms, controller=None):
"""Joystick mirrors keyboard: left stick=WASD, right stick X=Q/E, right stick Y=height."""
global _joy_prev_active
wr = obs.get("wireless_remote")
if wr is None:
return
parsed = _parse_wireless(wr)
if parsed is None:
return
lx, ly, rx, ry = parsed
# Dead zone + negate both Y axes (bridge already flips them once)
lx = 0.0 if abs(lx) < DEADZONE else lx
ly = 0.0 if abs(ly) < DEADZONE else -ly
rx = 0.0 if abs(rx) < DEADZONE else rx
ry = 0.0 if abs(ry) < DEADZONE else -ry
left_active = abs(lx) > 0 or abs(ly) > 0
# Left stick → WASD (movement direction relative to facing)
if left_active:
ms.movement_angle = ms.facing_angle + math.atan2(-lx, -ly)
ms.has_movement = True
if not _joy_prev_active:
ms.needs_replan = True
_joy_prev_active = True
elif _joy_prev_active and not (abs(rx) > 0 or abs(ry) > 0):
_joy_prev_active = False
ms.has_movement = False
# Right stick X → Q/E (facing rotation, ~1 rad/s at full deflection)
if abs(rx) > 0:
delta = -0.02 * rx
ms.facing_angle += delta
if controller:
controller.delta_heading += delta
# Right stick Y → -/= (height adjustment, ~0.25/s at full deflection)
if abs(ry) > 0:
step = -0.005 * ry
ms.height = max(0.1, min(1.0, (ms.height if ms.height >= 0 else DEFAULT_HEIGHT) + step))
@@ -1,152 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SONIC full-body controller for Unitree G1."""
import logging
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
CONTROL_DT,
DEBUG_PRINT_EVERY,
DEFAULT_ANGLES,
ENCODER_UPDATE_EVERY,
LM,
MOTION_SETS,
MovementState,
PlannerController,
SonicPlanner,
clamp_mode_params,
compute_kp_kd,
lowstate_to_obs,
process_joystick,
should_replan_request,
_ort_providers,
_snapshot_ms,
)
logger = logging.getLogger(__name__)
class SonicRuntime:
"""Shared SONIC control loop state (standalone demo + locomotion controller)."""
def __init__(self, force_cpu: bool = False):
planner_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="planner_sonic.onnx")
encoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_encoder.onnx")
decoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_decoder.onnx")
providers = _ort_providers(force_cpu=force_cpu)
self.use_gpu = providers[0] == "CUDAExecutionProvider"
so = ort.SessionOptions()
so.log_severity_level = 3
planner_sess = ort.InferenceSession(planner_path, sess_options=so, providers=providers)
encoder_sess = ort.InferenceSession(encoder_path, sess_options=so, providers=providers)
decoder_sess = ort.InferenceSession(decoder_path, sess_options=so, providers=providers)
self.kp, self.kd = compute_kp_kd()
self.ms = MovementState()
self.planner = SonicPlanner(planner_sess, planner_path)
self.controller = PlannerController(self.planner, encoder_sess, decoder_sess)
motion = self.planner.initialize(DEFAULT_ANGLES, self.ms)
self.controller.load_initial_motion(motion)
self.planner.start_subprocess(self.controller, use_gpu=self.use_gpu)
self.step = 0
self.replan_timer = 0.0
self.last_ms = _snapshot_ms(self.ms)
@property
def pipeline(self):
return self.controller
def tick(self, obs: dict, *, debug: bool | None = None, use_joystick: bool = True) -> dict:
if not obs:
self.step += 1
return {}
if use_joystick:
process_joystick(obs, self.ms, self.controller)
clamp_mode_params(self.ms)
if self.step > 0:
self.replan_timer += CONTROL_DT
if should_replan_request(self.ms, self.last_ms, self.replan_timer, self.step):
self.planner.request_replan(self.controller.ref_cursor, self.ms)
self.replan_timer = 0.0
self.ms.needs_replan = False
self.last_ms = _snapshot_ms(self.ms)
do_enc = self.step % ENCODER_UPDATE_EVERY == 0
if debug is None:
debug = self.step % DEBUG_PRINT_EVERY == 0
action = self.controller.step(obs, update_encoder=do_enc, debug=debug)
result = self.planner.try_get_new_motion()
if result:
self.controller.blend_new_motion(*result)
self.controller.advance_cursor()
self.step += 1
return action
def reset(self):
self.ms = MovementState()
self.controller.reinit_heading = True
self.controller.playing = True
self.step = 0
self.replan_timer = 0.0
self.last_ms = _snapshot_ms(self.ms)
def shutdown(self):
self.planner.stop_subprocess()
class SonicWholeBodyController:
"""Full-body SONIC controller for UnitreeG1's background controller thread."""
control_dt = CONTROL_DT
full_body = True
def __init__(self, force_cpu: bool = False):
logger.info("Loading SONIC whole-body controller...")
self._runtime = SonicRuntime(force_cpu=force_cpu)
self.kp = self._runtime.kp
self.kd = self._runtime.kd
self.controller = self._runtime.controller
self.ms = self._runtime.ms
logger.info(
"SONIC ready: %s (default mode: %s)",
MOTION_SETS[0][0],
LM(self.ms.mode).name,
)
def run_step(self, action: dict, lowstate) -> dict:
if lowstate is None:
return {}
obs = lowstate_to_obs(lowstate)
return self._runtime.tick(obs, debug=False)
def reset(self):
self._runtime.reset()
def shutdown(self):
self._runtime.shutdown()
+2 -3
View File
@@ -68,9 +68,8 @@ def make_locomotion_controller(name: str | None):
if name is None:
return None
controllers = {
"GrootLocomotionController": "lerobot.robots.unitree_g1.controllers.gr00t_locomotion",
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.controllers.holosoma_locomotion",
"SonicWholeBodyController": "lerobot.robots.unitree_g1.controllers.sonic_whole_body",
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
}
module_path = controllers.get(name)
if module_path is None:
@@ -21,7 +21,7 @@ import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.g1_utils import (
from .g1_utils import (
REMOTE_AXES,
REMOTE_BUTTONS,
G1_29_JointIndex,
@@ -22,7 +22,7 @@ import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.g1_utils import (
from .g1_utils import (
REMOTE_AXES,
G1_29_JointArmIndex,
G1_29_JointIndex,
+1 -9
View File
@@ -338,9 +338,6 @@ class UnitreeG1(Robot):
self.kp = np.array(self.config.kp, dtype=np.float32)
self.kd = np.array(self.config.kd, dtype=np.float32)
if self.controller is not None and hasattr(self.controller, "kp"):
self.kp = np.array(self.controller.kp, dtype=np.float32)
self.kd = np.array(self.controller.kd, dtype=np.float32)
for joint in G1_29_JointIndex:
self.msg.motor_cmd[joint].mode = 1
@@ -377,9 +374,6 @@ class UnitreeG1(Robot):
# Signal thread to stop and unblock any waits
self._shutdown_event.set()
if self.controller is not None and hasattr(self.controller, "shutdown"):
self.controller.shutdown()
# Wait for subscribe thread to finish
if self.subscribe_thread is not None:
self.subscribe_thread.join(timeout=2.0)
@@ -471,11 +465,9 @@ class UnitreeG1(Robot):
def send_action(self, action: RobotAction) -> RobotAction:
action_to_publish = action
if self.controller is not None:
self._update_controller_action(action)
if getattr(self.controller, "full_body", False):
return action
# Controller thread owns legs/waist. Here we only update joystick inputs
# and publish arm targets from the teleoperator.
self._update_controller_action(action)
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
action_to_publish = {
key: value
-73
View File
@@ -28,7 +28,6 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
import pandas as pd # noqa: E402
import pyarrow.parquet as pq # noqa: E402
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
@@ -345,78 +344,6 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
assert len(dataset) == 24
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
from lerobot.datasets.io_utils import write_tasks
from lerobot.utils.io_utils import write_json
data_dir = root / "data" / "chunk-000"
data_dir.mkdir(parents=True, exist_ok=True)
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
for ep, length in enumerate(episode_lengths):
episode_index += [ep] * length
frame_index += list(range(length))
timestamp += [round(i / fps, 6) for i in range(length)]
task_index += [0] * length
subtask_index += [0] * length # legacy column the writer must drop
pd.DataFrame(
{
"episode_index": episode_index,
"frame_index": frame_index,
"timestamp": timestamp,
"task_index": task_index,
"subtask_index": subtask_index,
}
).to_parquet(data_dir / "file-000.parquet", index=False)
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
write_tasks(tasks_df, root)
write_json(
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
root / "meta" / "info.json",
)
return root
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
"""Rewriting a packed shard must keep one row group per episode, not collapse
every episode into a single giant row group."""
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
shard = root / "data" / "chunk-000" / "file-000.parquet"
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
staging_dir = tmp_path / "stage"
for ep in range(len(episode_lengths)):
_stage_episode(
staging_dir,
ep,
plan=[
{
"role": "assistant",
"content": f"subtask for ep {ep}",
"style": "subtask",
"timestamp": 0.0,
"tool_calls": None,
}
],
)
records = list(iter_episodes(root))
LanguageColumnsWriter().write_all(records, staging_dir, root)
# One row group per episode, with row counts matching the episode lengths.
md = pq.ParquetFile(shard).metadata
assert md.num_row_groups == len(episode_lengths)
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
# Language columns are still present after the per-episode rewrite.
table = pq.read_table(shard)
assert "language_persistent" in table.column_names
assert "language_events" in table.column_names
def test_speech_atom_shape_matches_plan_spec() -> None:
atom = speech_atom(2.5, "I'm cleaning up!")
assert atom["role"] == "assistant"
-55
View File
@@ -32,26 +32,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def assert_data_shards_one_row_group_per_episode(root):
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
import pyarrow.parquet as pq
shards = sorted((root / "data").rglob("*.parquet"))
assert shards, f"no data shards found under {root}/data"
n_episodes = 0
for shard in shards:
pf = pq.ParquetFile(shard)
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
assert pf.metadata.num_row_groups == len(set(episodes)), shard
for i in range(pf.metadata.num_row_groups):
rg_episodes = set(
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
)
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
n_episodes += len(set(episodes))
return n_episodes
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
@@ -586,41 +566,6 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
)
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
Covers both the non-image (``df.to_parquet``) and image
(``to_parquet_with_hf_images``) write branches, including the merge-into-
existing-file branch via a low file-size threshold that forces packing.
"""
ds_0 = lerobot_dataset_factory(
root=tmp_path / "rg_0",
repo_id=f"{DUMMY_REPO_ID}_rg_0",
total_episodes=3,
total_frames=60,
use_videos=use_videos,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "rg_1",
repo_id=f"{DUMMY_REPO_ID}_rg_1",
total_episodes=4,
total_frames=80,
use_videos=use_videos,
)
aggr_root = tmp_path / "rg_aggr"
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
aggr_root=aggr_root,
)
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
@@ -0,0 +1,121 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
import json
import struct
import numpy as np
import pytest
from lerobot.datasets.episode_video_streaming import assert_hf_hub_range_cache_branch
from lerobot.datasets.mp4 import (
_box,
_co64,
_dinf,
_hdlr,
_mdhd,
_mvhd,
_stco,
_stsc_one_sample_per_chunk,
_stss,
_stsz,
_stts,
_tkhd,
_vmhd,
parse_mp4_index,
synthesize_mp4,
)
def _minimal_mp4(sample_offsets: list[int], *, use_co64: bool = False) -> bytes:
ftyp = _box(b"ftyp", b"isom\0\0\2\0isomiso2mp41")
sizes = np.array([10, 10, 10], dtype=np.int64)
durations = np.array([1000, 1000, 1000], dtype=np.int64)
stsd_body = struct.pack(">II", 0, 1) + struct.pack(">I4s", 16, b"avc1") + b"\0" * 8
offsets = _co64(sample_offsets) if use_co64 else _stco(sample_offsets)
stbl = _box(
b"stbl",
_box(b"stsd", stsd_body)
+ _stts(durations)
+ _stsc_one_sample_per_chunk(len(sizes))
+ _stsz(sizes)
+ offsets
+ _stss(np.array([1], dtype=np.int64)),
)
minf = _box(b"minf", _vmhd() + _dinf() + stbl)
mdia = _box(b"mdia", _mdhd(1000, 3000) + _hdlr() + minf)
trak = _box(b"trak", _tkhd(1, 3000, 64, 48) + mdia)
moov = _box(b"moov", _mvhd(1000, 3000, 2) + trak)
mdat_payload_start = 10_000
free_size = mdat_payload_start - 8 - len(ftyp) - len(moov)
assert free_size >= 8
free = _box(b"free", b"\0" * (free_size - 8))
return ftyp + moov + free + _box(b"mdat", b"x" * 128)
def test_episode_slice_uses_min_max_sample_offsets_for_reordered_chunks():
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
assert sample_slice.byte_offset == 10_000
assert sample_slice.byte_length == 60
assert sample_slice.sample_lo == 0
assert sample_slice.sample_hi == 2
def test_synthesized_mp4_rebases_one_chunk_per_sample_offsets():
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025]))
sample_slice = mp4.sample_slice(0.0, 2.0, keyframe_pad_s=0, keyframe_pad_fraction=0)
mini = synthesize_mp4(mp4, sample_slice, b"x" * sample_slice.byte_length)
mini_index = parse_mp4_index("mini.mp4", mini)
expected = np.array([0, 50, 25], dtype=np.int64) + mini_index.mdat_payload_offset
np.testing.assert_array_equal(mini_index.sample_offsets, expected)
np.testing.assert_array_equal(mini_index.sample_sizes, np.array([10, 10, 10]))
def test_parser_accepts_co64_chunk_offsets():
mp4 = parse_mp4_index("test.mp4", _minimal_mp4([10_000, 10_050, 10_025], use_co64=True))
np.testing.assert_array_equal(mp4.sample_offsets, np.array([10_000, 10_050, 10_025]))
def test_hf_hub_branch_assertion_accepts_requested_revision(monkeypatch):
class FakeDist:
def read_text(self, name):
assert name == "direct_url.json"
return json.dumps(
{
"url": "https://github.com/huggingface/huggingface_hub.git",
"vcs_info": {"requested_revision": "feat/hffs-cache-cdn-range-reads"},
}
)
monkeypatch.setattr(
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
)
assert_hf_hub_range_cache_branch()
def test_hf_hub_branch_assertion_rejects_plain_install(monkeypatch):
class FakeDist:
def read_text(self, name):
assert name == "direct_url.json"
return json.dumps({"url": "https://github.com/huggingface/huggingface_hub.git"})
monkeypatch.setattr(
"lerobot.datasets.episode_video_streaming.metadata.distribution", lambda _: FakeDist()
)
with pytest.raises(AssertionError):
assert_hf_hub_range_cache_branch()
Generated
+8 -16
View File
@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.12"
resolution-markers = [
"(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')",
@@ -1089,8 +1089,8 @@ wheels = [
[[package]]
name = "datasets"
version = "4.8.5"
source = { registry = "https://pypi.org/simple" }
version = "5.0.1.dev0"
source = { git = "https://github.com/huggingface/datasets.git?branch=main#06fcc085fcdd22fc5cc741954f6187dd879543b6" }
dependencies = [
{ name = "dill" },
{ name = "filelock" },
@@ -1107,10 +1107,6 @@ dependencies = [
{ name = "tqdm" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
]
[[package]]
name = "debugpy"
@@ -1147,7 +1143,7 @@ name = "decord"
version = "0.6.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "numpy", marker = "(platform_machine != 'arm64' and platform_machine != 's390x' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" },
@@ -2050,8 +2046,8 @@ wheels = [
[[package]]
name = "huggingface-hub"
version = "1.19.0"
source = { registry = "https://pypi.org/simple" }
version = "1.20.0.dev0"
source = { git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads#5319b287faa73239bb40df16d69c39e5d6daf0f7" }
dependencies = [
{ name = "click" },
{ name = "filelock" },
@@ -2064,10 +2060,6 @@ dependencies = [
{ name = "typer" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/88/27/629cfe58c582f92ded066c4a07d1a057ff617118ab7973200f770bd853cb/huggingface_hub-1.19.0.tar.gz", hash = "sha256:fd771622182d40977272a923953ee3b1b13538f9f8a7f5d78398f10af0f1c0bd", size = 824721, upload-time = "2026-06-11T12:33:18.665Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b2/a5/558da89f66464d8d0229ff497e8b8666977de2d8cf48c28a2862ecf1250f/huggingface_hub-1.19.0-py3-none-any.whl", hash = "sha256:1dc72e1f6b4d6df6b30eb72e57d00514ef453d660f04af2b87f0e67267f31ee0", size = 693398, upload-time = "2026-06-11T12:33:16.695Z" },
]
[[package]]
name = "hydra-core"
@@ -3187,7 +3179,7 @@ requires-dist = [
{ name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.7.0,<5.0.0" },
{ name = "datasets", marker = "extra == 'dataset'", git = "https://github.com/huggingface/datasets.git?branch=main" },
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
{ name = "deepdiff", marker = "extra == 'deepdiff-dep'", specifier = ">=7.0.1,<9.0.0" },
@@ -3210,7 +3202,7 @@ requires-dist = [
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
{ name = "huggingface-hub", git = "https://github.com/huggingface/huggingface_hub.git?branch=feat%2Fhffs-cache-cdn-range-reads" },
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
{ name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
{ name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" },