mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a130a9db39 | |||
| 4f5e6596be | |||
| afeeeb8982 | |||
| 040c6b3d66 | |||
| 287c823f13 | |||
| acd31c7de2 | |||
| 240393d238 |
@@ -1,174 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Load an SMPL motion clip and expose it in SONIC's encoder format.
|
||||
|
||||
SONIC's whole-body tracking mode (``encode_mode == 2``) consumes a flat
|
||||
720-vector ``smpl_joints_10frame_step1`` = 10 consecutive frames x 24 SMPL
|
||||
joints x 3 (xyz) at 50 Hz.
|
||||
|
||||
IMPORTANT - frame convention: the encoder expects each frame's joints with the
|
||||
body's *root orientation removed* (per-frame canonical), exactly like the live
|
||||
deploy stream's ``smpl_joints_local`` (see ``process_smpl_joints`` in the GEAR
|
||||
PICO teleop and ``smpl_joints_multi_future_local`` in training). The reference
|
||||
``smpl_filtered`` clips instead store **world-frame** joints (heading retained),
|
||||
so feeding them raw makes the robot move but mis-track / never face-forward.
|
||||
This loader therefore canonicalizes on load using the clip's per-frame root
|
||||
orientation (``pose_aa[:, :3]``):
|
||||
|
||||
A = Rx(+90deg) * rotvec(pose_aa[:, :3]) # y-up -> z-up root quat
|
||||
local = base120 * A^-1 * joints # remove root orient
|
||||
|
||||
with ``base120 = quat(0.5,0.5,0.5,0.5)`` (SMPL base rotation). This reproduces
|
||||
the deployed transform (verified: per-frame hip-heading std -> 0).
|
||||
|
||||
Clip is read from a numpy ``.npz``. Expected keys:
|
||||
smpl_joints : (T, 24, 3) float32 -- world-frame joint positions, 50 fps
|
||||
pose_aa : (T, 72) float32 -- SMPL axis-angle (root = [:, :3])
|
||||
transl : (T, 3) float32 -- global root translation (optional)
|
||||
fps : scalar
|
||||
|
||||
Example:
|
||||
python examples/unitree_g1/motion_loader.py \
|
||||
--motion examples/unitree_g1/motions/walk_forward.npz
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
WINDOW = 10 # frames per encoder window (smpl_joints_10frame_step1)
|
||||
N_JOINTS = 24
|
||||
JOINT_DIM = 3
|
||||
SMPL_OBS_DIM = WINDOW * N_JOINTS * JOINT_DIM # 720
|
||||
|
||||
|
||||
def canonicalize_smpl_joints(smpl_joints: np.ndarray, root_aa: np.ndarray) -> np.ndarray:
|
||||
"""Remove per-frame root orientation -> SONIC ``smpl_joints_local`` format.
|
||||
|
||||
Args:
|
||||
smpl_joints: (T, 24, 3) world-frame (z-up) SMPL joint positions.
|
||||
root_aa: (T, 3) SMPL global-orient axis-angle (y-up convention).
|
||||
|
||||
Returns:
|
||||
(T, 24, 3) per-frame root-orientation-removed joints.
|
||||
"""
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
rx90 = R.from_euler("x", 90, degrees=True) # smpl_root_ytoz_up
|
||||
base120 = R.from_quat([0.5, 0.5, 0.5, 0.5]) # remove_smpl_base_rot
|
||||
a = rx90 * R.from_rotvec(root_aa) # z-up root quat (left-mult)
|
||||
b_inv = base120 * a.inv() # inv(remove_smpl_base_rot(a))
|
||||
return np.einsum("tij,tkj->tki", b_inv.as_matrix(), smpl_joints).astype(np.float32)
|
||||
|
||||
|
||||
class SmplMotion:
|
||||
"""A single SMPL clip with SONIC-format windowing."""
|
||||
|
||||
def __init__(self, path: str, loop: bool = True, canonicalize: bool = True):
|
||||
data = np.load(path)
|
||||
smpl_joints = data["smpl_joints"].astype(np.float32) # (T, 24, 3)
|
||||
self.pose_aa = data["pose_aa"].astype(np.float32) if "pose_aa" in data.files else None
|
||||
self.transl = data["transl"].astype(np.float32) if "transl" in data.files else None
|
||||
self.fps = float(data["fps"]) if "fps" in data.files else 50.0
|
||||
self.loop = loop
|
||||
|
||||
if smpl_joints.ndim != 3 or smpl_joints.shape[1:] != (N_JOINTS, JOINT_DIM):
|
||||
raise ValueError(
|
||||
f"Expected smpl_joints (T, {N_JOINTS}, {JOINT_DIM}), got {smpl_joints.shape}"
|
||||
)
|
||||
|
||||
# Reference clips store world-frame joints; the encoder wants per-frame
|
||||
# root-orientation-removed joints. Canonicalize when we have the root pose.
|
||||
self.canonicalized = False
|
||||
if canonicalize and self.pose_aa is not None:
|
||||
smpl_joints = canonicalize_smpl_joints(smpl_joints, self.pose_aa[:, :3])
|
||||
self.canonicalized = True
|
||||
self.smpl_joints = smpl_joints
|
||||
|
||||
self.num_frames = self.smpl_joints.shape[0]
|
||||
self._cursor = 0
|
||||
|
||||
def window(self, start: int) -> np.ndarray:
|
||||
"""Return the 720-vector for the 10-frame window beginning at ``start``.
|
||||
|
||||
Frames are laid out oldest->newest, joint-major within a frame:
|
||||
[f0_j0_xyz, f0_j1_xyz, ..., f9_j23_xyz].
|
||||
"""
|
||||
idx = np.arange(start, start + WINDOW)
|
||||
if self.loop:
|
||||
idx = np.mod(idx, self.num_frames)
|
||||
else:
|
||||
idx = np.clip(idx, 0, self.num_frames - 1)
|
||||
return self.smpl_joints[idx].reshape(-1).astype(np.float32)
|
||||
|
||||
def reset(self):
|
||||
self._cursor = 0
|
||||
|
||||
def step(self) -> np.ndarray:
|
||||
"""Advance one frame and return the current 720-vector window."""
|
||||
w = self.window(self._cursor)
|
||||
self._cursor += 1
|
||||
if self.loop:
|
||||
self._cursor %= self.num_frames
|
||||
return w
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return (not self.loop) and (self._cursor + WINDOW >= self.num_frames)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--motion", required=True, help="Path to motion .npz")
|
||||
parser.add_argument("--no-loop", action="store_true")
|
||||
parser.add_argument("--no-canon", action="store_true",
|
||||
help="Skip canonicalization (feed raw stored joints)")
|
||||
args = parser.parse_args()
|
||||
|
||||
m = SmplMotion(args.motion, loop=not args.no_loop, canonicalize=not args.no_canon)
|
||||
duration = m.num_frames / m.fps
|
||||
print(f"Loaded '{args.motion}'")
|
||||
print(f" frames={m.num_frames} fps={m.fps:.1f} duration={duration:.1f}s")
|
||||
print(f" smpl_joints={m.smpl_joints.shape} canonicalized={m.canonicalized} "
|
||||
f"pose_aa={None if m.pose_aa is None else m.pose_aa.shape} "
|
||||
f"transl={None if m.transl is None else m.transl.shape}")
|
||||
|
||||
# Sanity: after canonicalization the per-frame body heading should be fixed.
|
||||
j = m.smpl_joints
|
||||
v = (j[:, 2, :2] - j[:, 1, :2]) # R_hip - L_hip, horizontal
|
||||
a = np.arctan2(v[:, 1], v[:, 0])
|
||||
rlen = np.clip(np.hypot(np.cos(a).mean(), np.sin(a).mean()), 1e-9, 1.0)
|
||||
circ_std = np.degrees(np.sqrt(-2 * np.log(rlen)))
|
||||
print(f" hip-heading circ-std={circ_std:.1f} deg "
|
||||
f"(~0 => orientation removed; large => world-frame)")
|
||||
|
||||
w0 = m.window(0)
|
||||
print(f" window(0): shape={w0.shape} (expected {SMPL_OBS_DIM}) "
|
||||
f"min={w0.min():.3f} max={w0.max():.3f}")
|
||||
assert w0.shape == (SMPL_OBS_DIM,), "window must be 720-dim for obs[922:1642]"
|
||||
|
||||
# Simulate a few control ticks.
|
||||
print(" stepping 5 ticks:")
|
||||
for t in range(5):
|
||||
w = m.step()
|
||||
print(f" t={t} cursor={m._cursor} window_norm={np.linalg.norm(w):.2f}")
|
||||
|
||||
print("OK: motion loads and yields SONIC-format 720-vec windows.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,88 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert a GEAR-SONIC / BONES-SEED ``smpl_filtered`` clip (.pkl) to .npz.
|
||||
|
||||
The reference clips are zlib-compressed joblib pickles holding a dict with
|
||||
``pose_aa`` (T, 72), ``transl`` (T, 3), ``smpl_joints`` (T, 24, 3), ``fps``.
|
||||
``motion_loader.SmplMotion`` consumes the .npz form so the runtime needs no
|
||||
joblib dependency. Canonicalization (root-orientation removal) happens at load
|
||||
time in ``motion_loader``, so this converter just repackages the raw arrays.
|
||||
|
||||
Run this in an environment that has ``joblib`` (e.g. the sonic teleop venv):
|
||||
|
||||
python examples/unitree_g1/pkl_to_npz.py \
|
||||
--pkl sample_data/smpl_filtered/walk_forward_amateur_001__A001.pkl \
|
||||
--out examples/unitree_g1/motions/walk_forward.npz
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_pkl(path: str) -> dict:
|
||||
try:
|
||||
import joblib
|
||||
|
||||
return joblib.load(path)
|
||||
except Exception:
|
||||
# joblib clips are zlib-compressed pickles; fall back to manual inflate.
|
||||
import pickle
|
||||
import zlib
|
||||
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
try:
|
||||
raw = zlib.decompress(raw)
|
||||
except zlib.error:
|
||||
pass
|
||||
return pickle.loads(raw)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--pkl", required=True, help="Input smpl_filtered .pkl")
|
||||
parser.add_argument("--out", required=True, help="Output .npz path")
|
||||
args = parser.parse_args()
|
||||
|
||||
d = load_pkl(args.pkl)
|
||||
if not isinstance(d, dict) or "smpl_joints" not in d:
|
||||
raise ValueError(f"Unexpected pkl structure; keys={list(d) if isinstance(d, dict) else type(d)}")
|
||||
|
||||
smpl_joints = np.asarray(d["smpl_joints"], np.float32)
|
||||
if smpl_joints.ndim != 3 or smpl_joints.shape[1:] != (24, 3):
|
||||
raise ValueError(f"smpl_joints must be (T,24,3), got {smpl_joints.shape}")
|
||||
|
||||
out = {"smpl_joints": smpl_joints, "fps": np.float32(d.get("fps", 50.0))}
|
||||
if "pose_aa" in d:
|
||||
out["pose_aa"] = np.asarray(d["pose_aa"], np.float32)
|
||||
else:
|
||||
print("[warn] no pose_aa -> loader cannot canonicalize (will feed raw)")
|
||||
if "transl" in d:
|
||||
out["transl"] = np.asarray(d["transl"], np.float32)
|
||||
|
||||
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez_compressed(args.out, **out)
|
||||
dur = smpl_joints.shape[0] / float(out["fps"])
|
||||
print(f"Wrote {args.out}")
|
||||
print(f" frames={smpl_joints.shape[0]} fps={float(out['fps']):.1f} duration={dur:.1f}s "
|
||||
f"keys={sorted(out)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,255 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
SONIC planner with full mode control.
|
||||
|
||||
Keyboard controls:
|
||||
N / P - next / previous motion set
|
||||
1-8 - select mode within current set
|
||||
WASD - movement direction
|
||||
Q / E - rotate facing left / right
|
||||
9 / 0 - decrease / increase speed
|
||||
- / = - decrease / increase height
|
||||
R - force replan
|
||||
M - toggle SMPL motion playback <-> locomotion (needs --motion-file)
|
||||
Space - emergency stop -> IDLE
|
||||
Esc - quit
|
||||
|
||||
Gamepad controls (Unitree wireless controller):
|
||||
Left stick Y - speed (forward = fast, back = stop)
|
||||
Left stick X - movement direction (offset from facing)
|
||||
Right stick X - facing direction (incremental rotation)
|
||||
Right stick Y - height (up = tall 0.8m, down = low 0.1m)
|
||||
Buttons - unused (mode selection is keyboard-only)
|
||||
|
||||
For teleop integration use --robot.controller=SonicWholeBodyController instead.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import faulthandler
|
||||
import gc
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.controllers.sonic_whole_body import SonicRuntime
|
||||
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
|
||||
CONTROL_DT,
|
||||
DEFAULT_ANGLES,
|
||||
LM,
|
||||
MOTION_SETS,
|
||||
RawKeyboard,
|
||||
compute_kp_kd,
|
||||
drain_keyboard,
|
||||
)
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
from motion_loader import SmplMotion
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="SONIC planner with keyboard + gamepad control")
|
||||
parser.add_argument("--ip", type=str, default=None,
|
||||
help="Robot IP for real hardware (e.g. 192.168.123.164). "
|
||||
"Omit for simulation.")
|
||||
parser.add_argument("--log-csv", action="store_true",
|
||||
help="Write /tmp/sonic_pose_log.csv (disabled by default for teleop perf)")
|
||||
parser.add_argument("--cpu", action="store_true",
|
||||
help="Force CPU ONNX Runtime (skip CUDA even if onnxruntime-gpu is installed)")
|
||||
parser.add_argument("--headless", action="store_true",
|
||||
help="Ignored for sim (stock UnitreeG1 uses hub MuJoCo defaults)")
|
||||
parser.add_argument("--gamepad", action="store_true",
|
||||
help="Read Unitree wireless gamepad in sim (default: keyboard-only in sim)")
|
||||
parser.add_argument("--keyboard-only", action="store_true",
|
||||
help="Ignore wireless gamepad (terminal keyboard only)")
|
||||
parser.add_argument("--motion-file", type=str, default=None,
|
||||
help="Play an SMPL motion clip (.npz) via SONIC whole-body mode "
|
||||
"(encode_mode=2) instead of locomotion planning.")
|
||||
parser.add_argument("--no-loop", action="store_true",
|
||||
help="With --motion-file, play once instead of looping")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Surface native crashes (onnxruntime / mujoco) with a real traceback, and
|
||||
# avoid losing buffered diagnostics if the process dies mid-loop.
|
||||
faulthandler.enable()
|
||||
try:
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("=" * 60)
|
||||
print("SONIC planner - full mode control")
|
||||
print(" N/P cycle sets | 1-8 select mode | WASD move")
|
||||
print(" Q/E rotate | 9/0 speed | -/= height")
|
||||
print(" R replan | Space IDLE | Esc quit")
|
||||
if args.ip:
|
||||
print(f" Robot IP: {args.ip}")
|
||||
else:
|
||||
print(" Mode: simulation")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
cfg = UnitreeG1Config(controller=None) # full-body SONIC; standalone loop owns publish
|
||||
if args.ip:
|
||||
cfg.is_simulation = False
|
||||
cfg.robot_ip = args.ip
|
||||
else:
|
||||
cfg.is_simulation = True
|
||||
if args.headless:
|
||||
print("[Note] --headless ignored: sim uses stock UnitreeG1 + hub env")
|
||||
robot = UnitreeG1(cfg)
|
||||
robot.connect()
|
||||
kp, kd = compute_kp_kd()
|
||||
robot.kp = kp.copy()
|
||||
robot.kd = kd.copy()
|
||||
|
||||
runtime = SonicRuntime(force_cpu=args.cpu)
|
||||
controller = runtime.controller
|
||||
ms = runtime.ms
|
||||
|
||||
motion = None
|
||||
if args.motion_file:
|
||||
motion = SmplMotion(args.motion_file, loop=not args.no_loop)
|
||||
controller.smpl_motion = motion # lets 'M' key toggle playback
|
||||
controller.encode_mode = 2 # start in SONIC whole-body SMPL imitation
|
||||
dur = motion.num_frames / motion.fps
|
||||
print(f"\n[Motion] SMPL whole-body playback: {args.motion_file}")
|
||||
print(f" frames={motion.num_frames} fps={motion.fps:.1f} "
|
||||
f"duration={dur:.1f}s loop={not args.no_loop} encode_mode=2")
|
||||
print(" Press 'M' to toggle SMPL playback <-> locomotion at runtime.")
|
||||
|
||||
runtime.controller.print_input_diagnostics()
|
||||
|
||||
print(f"\nStarting: {MOTION_SETS[0][0]} (default mode: {LM(ms.mode).name})")
|
||||
[print(f" {i+1}: {m.name}") for i, m in enumerate(MOTION_SETS[0][1])]
|
||||
print("\n[Ready] Click THIS terminal, then W/A/S/D to move. "
|
||||
"1-6 change mode, 9/0 speed, Esc quit.\n", flush=True)
|
||||
|
||||
# Sim hub publishes wireless_remote bytes that can fight terminal WASD.
|
||||
base_joystick = not args.keyboard_only and (args.gamepad or args.ip is not None)
|
||||
|
||||
with RawKeyboard() as kb:
|
||||
try:
|
||||
gc.disable()
|
||||
gc_timer = 0.0
|
||||
robot.reset(CONTROL_DT, DEFAULT_ANGLES)
|
||||
time.sleep(1.0)
|
||||
|
||||
last_status = time.time() - 2.1
|
||||
loop_t = enc_t = dec_t = obs_t = act_t = []
|
||||
slow_n = blend_n = 0
|
||||
stall_src = ""
|
||||
did_blend = False
|
||||
prev_end = time.time()
|
||||
t_start = time.time()
|
||||
|
||||
log_path = "/tmp/sonic_pose_log.csv"
|
||||
jnames = [m.name for m in G1_29_JointIndex]
|
||||
log_ctx = open(log_path, "w") if args.log_csv else None
|
||||
if log_ctx:
|
||||
log_ctx.write("t,step,cursor,ts,blend,mode," +
|
||||
",".join(f"q{i}" for i in range(29)) + "," +
|
||||
",".join(f"ref{i}" for i in range(29)) + "," +
|
||||
",".join(f"act{i}" for i in range(29)) +
|
||||
",delta_max,action_norm,token_norm\n")
|
||||
|
||||
try:
|
||||
while not robot._shutdown_event.is_set():
|
||||
t0 = time.time()
|
||||
if drain_keyboard(kb, ms, controller):
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
t_obs = time.time()
|
||||
obs_t.append(1000 * (t_obs - t0))
|
||||
if not obs:
|
||||
runtime.tick({}, use_joystick=False)
|
||||
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
|
||||
continue
|
||||
|
||||
# SMPL playback only while in whole-body mode; 'M' toggles it.
|
||||
motion_active = motion is not None and controller.encode_mode == 2
|
||||
if motion_active:
|
||||
controller.smpl_joints_10frame_step1 = motion.step()
|
||||
if motion.done:
|
||||
print("\n[Motion] clip finished")
|
||||
break
|
||||
|
||||
step_before = runtime.step
|
||||
t_step = time.time()
|
||||
action = runtime.tick(obs, use_joystick=base_joystick and not motion_active)
|
||||
step_ms = 1000 * (time.time() - t_step)
|
||||
do_enc = step_before % 5 == 0
|
||||
(enc_t if do_enc else dec_t).append(step_ms)
|
||||
|
||||
t_act = time.time()
|
||||
robot.send_action(action)
|
||||
act_t.append(1000 * (time.time() - t_act))
|
||||
|
||||
if log_ctx and runtime.step % 5 == 0:
|
||||
t_rel = time.time() - t_start
|
||||
q_r = np.array([obs.get(f"{n}.q", 0) for n in jnames])
|
||||
a_v = np.array([action.get(f"{n}.q", 0) for n in jnames])
|
||||
cur, ts = controller.ref_cursor, controller.motion_timesteps
|
||||
q_ref = controller.motion_joint_positions[min(cur, ts - 1)] if ts > 0 else np.zeros(29)
|
||||
log_ctx.write(f"{t_rel:.4f},{runtime.step},{cur},{ts},{int(did_blend)},{ms.mode}," +
|
||||
",".join(f"{v:.6f}" for v in q_r) + "," +
|
||||
",".join(f"{v:.6f}" for v in q_ref) + "," +
|
||||
",".join(f"{v:.6f}" for v in a_v) + "," +
|
||||
f"{np.max(np.abs(a_v - q_r)):.6f},"
|
||||
f"{np.linalg.norm(a_v):.6f},"
|
||||
f"{np.linalg.norm(controller.token):.6f}\n")
|
||||
did_blend = False
|
||||
|
||||
now = time.time()
|
||||
loop_ms = 1000 * (now - t0)
|
||||
if loop_ms > 50:
|
||||
stall_src = (f"[STALL] {loop_ms:.0f}ms: "
|
||||
f"obs={obs_t[-1]:.0f} step={step_ms:.0f} act={act_t[-1]:.0f}")
|
||||
if loop_ms > CONTROL_DT * 1500:
|
||||
slow_n += 1
|
||||
|
||||
if now - last_status > 2.0:
|
||||
def _avg(lst):
|
||||
return sum(lst) / len(lst) if lst else 0
|
||||
|
||||
hz = 1000 / _avg(loop_t) if _avg(loop_t) else 0
|
||||
print(f"\r {ms.status_line()} step={runtime.step} "
|
||||
f"ref={controller.ref_cursor}/{controller.motion_timesteps} "
|
||||
f"loop={_avg(loop_t):.1f}ms(max={max(loop_t, default=0):.1f}) hz={hz:.0f} "
|
||||
f"enc={_avg(enc_t):.1f} dec={_avg(dec_t):.1f} obs={_avg(obs_t):.1f} "
|
||||
f"slow={slow_n} blends={blend_n}", end="", flush=True)
|
||||
if stall_src:
|
||||
print(f"\n {stall_src}")
|
||||
stall_src = ""
|
||||
last_status = now
|
||||
loop_t = enc_t = dec_t = obs_t = act_t = []
|
||||
slow_n = blend_n = 0
|
||||
|
||||
prev_end = time.time()
|
||||
gc_timer += CONTROL_DT
|
||||
if gc_timer >= 10.0:
|
||||
gc.collect()
|
||||
gc_timer = 0.0
|
||||
loop_t.append(loop_ms)
|
||||
time.sleep(max(0.0, CONTROL_DT - (time.time() - t0)))
|
||||
finally:
|
||||
if log_ctx:
|
||||
log_ctx.close()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
gc.enable()
|
||||
if args.log_csv:
|
||||
print(f"\n[Log] Saved to {log_path}")
|
||||
runtime.shutdown()
|
||||
print("\nStopping...")
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -73,8 +73,17 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||
use_async_envs: bool = True
|
||||
# Whether to record eval rollouts as a LeRobot dataset on disk.
|
||||
recording: bool = False
|
||||
# If set, push recorded eval datasets to the Hub under this repo id (one repo per task,
|
||||
# suffixed by task and env index). Requires recording=true.
|
||||
recording_repo_id: str | None = None
|
||||
# Whether the pushed recording repositories should be private.
|
||||
recording_private: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.recording_repo_id is not None and not self.recording:
|
||||
raise ValueError("eval.recording_repo_id requires eval.recording=true.")
|
||||
if self.batch_size == 0:
|
||||
self.batch_size = self._auto_batch_size()
|
||||
if self.batch_size > self.n_episodes:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -709,7 +710,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
|
||||
@@ -27,6 +27,7 @@ import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -1101,7 +1102,9 @@ def _copy_episodes_metadata_and_stats(
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info.features[key]["info"] = deepcopy(
|
||||
src_dataset.meta.info.features[key].get("info", {})
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
|
||||
@@ -68,6 +68,6 @@ class UnitreeG1Config(RobotConfig):
|
||||
# Compensates for gravity on the unitree's arms using the arm ik solver
|
||||
gravity_compensation: bool = False
|
||||
|
||||
# Locomotion controller class name, e.g. "GrootLocomotionController",
|
||||
# "HolosomaLocomotionController", or "SonicWholeBodyController". None disables it.
|
||||
# Lower-body controller class name, e.g. "GrootLocomotionController" or
|
||||
# "HolosomaLocomotionController". None disables it.
|
||||
controller: str | None = None
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Unitree G1 locomotion controllers (Groot, Holosoma, SONIC)."""
|
||||
|
||||
__all__ = [
|
||||
"GrootLocomotionController",
|
||||
"HolosomaLocomotionController",
|
||||
"SonicWholeBodyController",
|
||||
"SonicRuntime",
|
||||
]
|
||||
@@ -1,941 +0,0 @@
|
||||
"""SONIC planner pipeline: ONNX enc/dec/planner, movement state, and input helpers."""
|
||||
|
||||
import math
|
||||
import queue
|
||||
import select
|
||||
import struct
|
||||
import sys
|
||||
import termios
|
||||
import threading
|
||||
import time
|
||||
import tty
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
|
||||
# ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
|
||||
], dtype=np.float32)
|
||||
|
||||
NATURAL_FREQ = 10.0 * 2.0 * np.pi
|
||||
ARMATURE = {"5020": 0.003609725, "7520_14": 0.010177520, "7520_22": 0.025101925, "4010": 0.00425}
|
||||
EFFORT = {"5020": 25.0, "7520_14": 88.0, "7520_22": 139.0, "4010": 5.0}
|
||||
|
||||
def _action_scale(k):
|
||||
return 0.25 * EFFORT[k] / (ARMATURE[k] * NATURAL_FREQ**2)
|
||||
|
||||
_J = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
|
||||
["7520_14","5020","5020"] + \
|
||||
["5020","5020","5020","5020","5020","4010","4010"] * 2
|
||||
ACTION_SCALE = np.array([_action_scale(k) for k in _J], dtype=np.float32)
|
||||
|
||||
CONTROL_DT = 0.02
|
||||
DEFAULT_HEIGHT = 0.788740
|
||||
TOKEN_DIM = 64
|
||||
ENCODER_UPDATE_EVERY = 5
|
||||
DEBUG_PRINT_EVERY = 100
|
||||
MOTION_LOOK_AHEAD_STEPS = 2
|
||||
INITIAL_RANDOM_SEED = 1234
|
||||
MIN_TOKENS, MAX_TOKENS = 6, 16
|
||||
K = MAX_TOKENS - MIN_TOKENS + 1
|
||||
DEADZONE = 0.05
|
||||
BLEND_FRAMES = 8
|
||||
|
||||
REPLAN_INTERVAL = {
|
||||
"running": 0.1, "crawling": 0.2, "boxing": 1.0, "default": 1.0
|
||||
}
|
||||
|
||||
ISAACLAB_TO_MUJOCO = np.array([
|
||||
0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8,
|
||||
11, 15, 19, 21, 23, 25, 27, 12, 16, 20, 22, 24, 26, 28
|
||||
], dtype=np.int32)
|
||||
|
||||
MUJOCO_TO_ISAACLAB = np.array([
|
||||
0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10,
|
||||
16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28
|
||||
], dtype=np.int32)
|
||||
|
||||
def _to_mujoco(a): return a[MUJOCO_TO_ISAACLAB]
|
||||
def _to_runtime(a): r = np.zeros(29, np.float32); r[MUJOCO_TO_ISAACLAB] = a; return r
|
||||
|
||||
DEFAULT_ANGLES_MUJOCO = _to_mujoco(DEFAULT_ANGLES)
|
||||
ENCODER_STANDING_REF = DEFAULT_ANGLES.copy()
|
||||
|
||||
LOWER_BODY_IL = np.array([0,3,6,9,13,17,1,4,7,10,14,18], dtype=np.int32)
|
||||
WRIST_IL = np.array([23,24,25,26,27,28], dtype=np.int32)
|
||||
VR_TARGET_DEF = np.zeros(9, dtype=np.float32)
|
||||
VR_ORN_DEF = np.array([1,0,0,0,1,0,0,0,1,0,0,0], dtype=np.float32)
|
||||
SMPL_DEF = np.zeros(720, dtype=np.float32)
|
||||
|
||||
# ── PD gains ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def compute_kp_kd():
|
||||
s = lambda k: ARMATURE[k] * NATURAL_FREQ**2
|
||||
d = lambda k: 2.0 * 2.0 * ARMATURE[k] * NATURAL_FREQ
|
||||
_kp_keys = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
|
||||
["7520_14","5020","5020"] + \
|
||||
["5020","5020","5020","5020","5020","4010","4010"] * 2
|
||||
_kd_keys = _kp_keys
|
||||
_double = {4,5,10,11,13,14} # ankle + waist indices with factor 2
|
||||
kp = np.array([2*s(k) if i in _double else s(k) for i,k in enumerate(_kp_keys)], dtype=np.float32)
|
||||
kd = np.array([2*d(k) if i in _double else d(k) for i,k in enumerate(_kd_keys)], dtype=np.float32)
|
||||
return kp, kd
|
||||
|
||||
|
||||
_kp_kd = compute_kp_kd # backward-compatible alias
|
||||
|
||||
|
||||
def lowstate_to_obs(lowstate) -> dict:
|
||||
"""Build a robot observation dict from Unitree lowstate."""
|
||||
obs: dict = {}
|
||||
for motor in G1_29_JointIndex:
|
||||
idx = motor.value
|
||||
obs[f"{motor.name}.q"] = float(lowstate.motor_state[idx].q)
|
||||
obs[f"{motor.name}.dq"] = float(lowstate.motor_state[idx].dq)
|
||||
quat = lowstate.imu_state.quaternion
|
||||
obs["imu.quat.w"] = float(quat[0])
|
||||
obs["imu.quat.x"] = float(quat[1])
|
||||
obs["imu.quat.y"] = float(quat[2])
|
||||
obs["imu.quat.z"] = float(quat[3])
|
||||
gyro = lowstate.imu_state.gyroscope
|
||||
obs["imu.gyro.x"] = float(gyro[0])
|
||||
obs["imu.gyro.y"] = float(gyro[1])
|
||||
obs["imu.gyro.z"] = float(gyro[2])
|
||||
wr = getattr(lowstate, "wireless_remote", None)
|
||||
if wr is not None:
|
||||
obs["wireless_remote"] = bytes(wr) if not isinstance(wr, (bytes, bytearray)) else wr
|
||||
return obs
|
||||
|
||||
|
||||
# ── Quaternion helpers ────────────────────────────────────────────────────────
|
||||
|
||||
def quat_conj(q):
|
||||
return np.array([q[0], -q[1], -q[2], -q[3]], dtype=np.float32)
|
||||
|
||||
def quat_mul(q1, q2):
|
||||
w1,x1,y1,z1 = q1; w2,x2,y2,z2 = q2
|
||||
return np.array([
|
||||
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
||||
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
||||
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
||||
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
||||
], dtype=np.float32)
|
||||
|
||||
def gravity_dir(q):
|
||||
q = q / (np.linalg.norm(q) + 1e-8)
|
||||
qv = np.array([0, 0, 0, -1], dtype=np.float32)
|
||||
return quat_mul(quat_mul(quat_conj(q), qv), q)[1:]
|
||||
|
||||
def quat_to_6d(q):
|
||||
w,x,y,z = q
|
||||
return np.array([
|
||||
1-2*(y*y+z*z), 2*(x*y-z*w),
|
||||
2*(x*y+z*w), 1-2*(x*x+z*z),
|
||||
2*(x*z-y*w), 2*(y*z+x*w),
|
||||
], dtype=np.float32)
|
||||
|
||||
def calc_heading(q):
|
||||
w,x,y,z = q
|
||||
return float(np.arctan2(2*(x*y + w*z), 1-2*(y*y+z*z)))
|
||||
|
||||
def heading_quat(q, sign=1.0):
|
||||
a = sign * calc_heading(q) / 2.0
|
||||
return np.array([np.cos(a), 0, 0, np.sin(a)], dtype=np.float64)
|
||||
|
||||
heading_quat_inv = lambda q: heading_quat(q, -1.0)
|
||||
|
||||
def quat_slerp(q0, q1, t):
|
||||
q0 = q0 / (np.linalg.norm(q0)+1e-12); q1 = q1 / (np.linalg.norm(q1)+1e-12)
|
||||
dot = float(np.dot(q0, q1))
|
||||
if dot < 0: q1, dot = -q1, -dot
|
||||
dot = min(dot, 1.0)
|
||||
if dot > 0.9995:
|
||||
r = q0 + t*(q1-q0); return r/(np.linalg.norm(r)+1e-12)
|
||||
th = np.arccos(dot); st = np.sin(th)
|
||||
return (np.sin((1-t)*th)/st)*q0 + (np.sin(t*th)/st)*q1
|
||||
|
||||
def quat_slerp_batch(q0, q1, t):
|
||||
q0 = q0 / (np.linalg.norm(q0,axis=1,keepdims=True)+1e-12)
|
||||
q1 = q1 / (np.linalg.norm(q1,axis=1,keepdims=True)+1e-12)
|
||||
dot = np.sum(q0*q1, axis=1); neg = dot<0
|
||||
q1=q1.copy(); q1[neg]=-q1[neg]; dot[neg]=-dot[neg]; dot=np.clip(dot,-1,1)
|
||||
lin = dot>0.9995; th=np.arccos(dot); st=np.where(np.sin(th)==0,1,np.sin(th))
|
||||
c0=np.sin((1-t)*th)/st; c1=np.sin(t*th)/st
|
||||
c0[lin]=1-t[lin]; c1[lin]=t[lin]
|
||||
r = c0[:,None]*q0 + c1[:,None]*q1
|
||||
return r / (np.linalg.norm(r,axis=1,keepdims=True)+1e-12)
|
||||
|
||||
# ── Locomotion modes ──────────────────────────────────────────────────────────
|
||||
|
||||
class LocomotionMode(IntEnum):
|
||||
IDLE=0; SLOW_WALK=1; WALK=2; RUN=3; SQUAT=4; KNEEL_TWO_LEGS=5; KNEEL=6
|
||||
LYING_FACE_DOWN=7; CRAWLING=8; IDLE_BOXING=9; WALK_BOXING=10
|
||||
LEFT_PUNCH=11; RIGHT_PUNCH=12; RANDOM_PUNCH=13; ELBOW_CRAWLING=14
|
||||
LEFT_HOOK=15; RIGHT_HOOK=16; FORWARD_JUMP=17; STEALTH_WALK=18
|
||||
INJURED_WALK=19; LEDGE_WALKING=20; OBJECT_CARRYING=21; STEALTH_WALK_2=22
|
||||
HAPPY_DANCE_WALK=23; ZOMBIE_WALK=24; GUN_WALK=25; SCARE_WALK=26
|
||||
|
||||
LM = LocomotionMode
|
||||
|
||||
MOTION_SETS = [
|
||||
("Standing", [LM.SLOW_WALK, LM.WALK, LM.RUN, LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK]),
|
||||
("Squat / Low", [LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.CRAWLING, LM.ELBOW_CRAWLING]),
|
||||
("Boxing", [LM.IDLE_BOXING, LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
|
||||
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK]),
|
||||
("Styled Walks", [LM.LEDGE_WALKING, LM.OBJECT_CARRYING, LM.STEALTH_WALK_2,
|
||||
LM.HAPPY_DANCE_WALK, LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK]),
|
||||
]
|
||||
|
||||
STATIC_MODES = {LM.IDLE, LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.LYING_FACE_DOWN, LM.IDLE_BOXING}
|
||||
STANDING_MODES = {LM.IDLE, LM.SLOW_WALK, LM.WALK, LM.RUN, LM.IDLE_BOXING, LM.WALK_BOXING,
|
||||
LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK,
|
||||
LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK, LM.LEDGE_WALKING,
|
||||
LM.OBJECT_CARRYING, LM.STEALTH_WALK_2, LM.HAPPY_DANCE_WALK,
|
||||
LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK}
|
||||
BOXING_MODES = {LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
|
||||
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}
|
||||
SPEED_RANGES = {LM.SLOW_WALK:(0.2,0.8), LM.WALK:(0.8,1.5), LM.RUN:(1.5,3.0),
|
||||
LM.CRAWLING:(0.4,1.0), LM.ELBOW_CRAWLING:(0.7,1.0)}
|
||||
|
||||
def clamp_mode_params(ms):
|
||||
m = LM(ms.mode)
|
||||
ms.height = -1.0 if m in STANDING_MODES else max(0.1, min(0.8, ms.height if ms.height>=0 else 0.2))
|
||||
if m in STATIC_MODES:
|
||||
ms.speed = -1.0
|
||||
elif m in SPEED_RANGES:
|
||||
lo, hi = SPEED_RANGES[m]
|
||||
ms.speed = max(lo, min(hi, ms.speed if ms.speed>=0 else lo))
|
||||
elif m in BOXING_MODES:
|
||||
ms.speed = max(0.7, min(1.5, ms.speed if ms.speed>=0 else 0.7))
|
||||
else:
|
||||
ms.speed = -1.0
|
||||
|
||||
def replan_interval(mode):
|
||||
m = LM(mode)
|
||||
if m == LM.RUN: return REPLAN_INTERVAL["running"]
|
||||
if m == LM.CRAWLING: return REPLAN_INTERVAL["crawling"]
|
||||
if m in {LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}:
|
||||
return REPLAN_INTERVAL["boxing"]
|
||||
return REPLAN_INTERVAL["default"]
|
||||
|
||||
|
||||
def _ort_providers(force_cpu: bool = False) -> list[str]:
|
||||
"""Prefer CUDA for enc/dec/planner (matches deploy when onnxruntime-gpu is installed)."""
|
||||
avail = ort.get_available_providers()
|
||||
if not force_cpu and "CUDAExecutionProvider" in avail:
|
||||
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
return ["CPUExecutionProvider"]
|
||||
|
||||
# ── Movement state ────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class MovementState:
|
||||
mode: int = LM.SLOW_WALK # not IDLE — walking modes respond to WASD
|
||||
speed: float = -1.0
|
||||
height: float = -1.0
|
||||
facing_angle: float = 0.0
|
||||
movement_angle: float = 0.0
|
||||
has_movement: bool = False
|
||||
motion_set_idx: int = 0
|
||||
needs_replan: bool = False
|
||||
|
||||
@property
|
||||
def movement_direction(self):
|
||||
if not self.has_movement: return (0.0, 0.0, 0.0)
|
||||
return (math.cos(self.movement_angle), math.sin(self.movement_angle), 0.0)
|
||||
|
||||
@property
|
||||
def facing_direction(self):
|
||||
return (math.cos(self.facing_angle), math.sin(self.facing_angle), 0.0)
|
||||
|
||||
def status_line(self):
|
||||
return (f"[{MOTION_SETS[self.motion_set_idx][0]}] mode={self.mode}({LM(self.mode).name}) "
|
||||
f"spd={'default' if self.speed<0 else f'{self.speed:.1f}'} "
|
||||
f"hgt={'default' if self.height<0 else f'{self.height:.2f}'} "
|
||||
f"facing={math.degrees(self.facing_angle):.0f}° "
|
||||
f"{'moving' if self.has_movement else 'still'}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MovementSnapshot:
|
||||
mode: int = 0
|
||||
speed: float = -1.0
|
||||
height: float = -1.0
|
||||
movement: tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
facing: tuple[float, float, float] = (1.0, 0.0, 0.0)
|
||||
|
||||
|
||||
def _snapshot_ms(ms: MovementState) -> MovementSnapshot:
|
||||
md, fd = ms.movement_direction, ms.facing_direction
|
||||
return MovementSnapshot(ms.mode, ms.speed, ms.height, (md[0], md[1], md[2]), (fd[0], fd[1], fd[2]))
|
||||
|
||||
|
||||
def should_replan_request(ms: MovementState, last: MovementSnapshot, replan_timer: float, step: int) -> bool:
|
||||
"""Match C++ G1Deploy::Planner replan triggers (g1_deploy_onnx_ref.cpp)."""
|
||||
if step <= 0:
|
||||
return False
|
||||
if ms.needs_replan:
|
||||
return True
|
||||
md, fd = ms.movement_direction, ms.facing_direction
|
||||
facing_changed = fd != last.facing
|
||||
height_changed = ms.height != last.height
|
||||
mode_changed = ms.mode != last.mode
|
||||
speed_changed = ms.speed != last.speed
|
||||
dir_changed = md != last.movement
|
||||
is_static = LM(ms.mode) in STATIC_MODES
|
||||
if mode_changed or facing_changed or height_changed:
|
||||
return True
|
||||
time_to_replan = replan_timer >= replan_interval(ms.mode)
|
||||
if not is_static and (speed_changed or dir_changed or (time_to_replan and ms.speed != 0)):
|
||||
return True
|
||||
return False
|
||||
|
||||
# ── Encoder / Decoder ─────────────────────────────────────────────────────────
|
||||
|
||||
class StandingEncoderDecoder:
|
||||
def __init__(self, encoder, decoder):
|
||||
self.encoder, self.decoder = encoder, decoder
|
||||
self.encoder_input = encoder.get_inputs()[0].name
|
||||
self.decoder_input = decoder.get_inputs()[0].name
|
||||
enc_dim = int(encoder.get_inputs()[0].shape[1])
|
||||
dec_dim = int(decoder.get_inputs()[0].shape[1])
|
||||
if enc_dim != 1762 or dec_dim != 994:
|
||||
raise RuntimeError(f"Unexpected dims encoder={enc_dim}, decoder={dec_dim}")
|
||||
self.token = np.zeros(TOKEN_DIM, np.float32)
|
||||
self.last_action_mj = np.zeros(29, np.float32)
|
||||
self.h_q_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_dq_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_ang = [np.zeros(3, np.float32)] * 10
|
||||
self.h_act_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_quat = [np.array([1,0,0,0], np.float32)] * 10
|
||||
self.init_base_quat = np.array([1,0,0,0], np.float32)
|
||||
self.init_ref_quat = np.array([1,0,0,0], np.float32)
|
||||
self._heading_init = False
|
||||
self.encode_mode = 0
|
||||
self.vr_3point_local_target = VR_TARGET_DEF.copy()
|
||||
self.vr_3point_local_orn_target = VR_ORN_DEF.copy()
|
||||
self.smpl_joints_10frame_step1 = SMPL_DEF.copy()
|
||||
self.set_zero_reference()
|
||||
|
||||
def update_history(self, q, dq, ang, quat):
|
||||
quat = quat / (np.linalg.norm(quat)+1e-8)
|
||||
q_mj = _to_mujoco(q); dq_mj = _to_mujoco(dq)
|
||||
self.h_q_mj = [q_mj - DEFAULT_ANGLES_MUJOCO] + self.h_q_mj[:-1]
|
||||
self.h_dq_mj = [dq_mj] + self.h_dq_mj[:-1]
|
||||
self.h_ang = [ang.copy()] + self.h_ang[:-1]
|
||||
self.h_act_mj = [self.last_action_mj.copy()] + self.h_act_mj[:-1]
|
||||
self.h_quat = [quat.copy()] + self.h_quat[:-1]
|
||||
if not self._heading_init:
|
||||
self.init_base_quat = quat.copy(); self._heading_init = True
|
||||
|
||||
def _heading_quat(self, q):
|
||||
h = calc_heading(q) / 2.0
|
||||
return np.array([np.cos(h), 0, 0, np.sin(h)], np.float32)
|
||||
|
||||
def _heading_quat_inv(self, q):
|
||||
h = calc_heading(q) / 2.0
|
||||
return np.array([np.cos(-h), 0, 0, np.sin(-h)], np.float32)
|
||||
|
||||
def _anchor_6d(self, base_quat, ref_quat=None):
|
||||
if ref_quat is None: ref_quat = self.init_ref_quat
|
||||
delta = quat_mul(self._heading_quat(self.init_base_quat), self._heading_quat_inv(self.init_ref_quat))
|
||||
new_ref = quat_mul(delta, ref_quat)
|
||||
return quat_to_6d(quat_mul(quat_conj(base_quat), new_ref))
|
||||
|
||||
def set_zero_reference(self):
|
||||
self.motion_joint_positions = [ENCODER_STANDING_REF.copy()]
|
||||
self.motion_joint_velocities = [np.zeros(29, np.float32)]
|
||||
self.motion_body_quats = [np.array([1,0,0,0], np.float32)]
|
||||
self.motion_body_z = [DEFAULT_HEIGHT]
|
||||
self.motion_timesteps = 1
|
||||
self.freeze_ref_frame = 0
|
||||
self.init_ref_quat = self.motion_body_quats[0].copy()
|
||||
|
||||
def build_encoder_obs(self):
|
||||
obs = np.zeros(1762, np.float32)
|
||||
obs[0] = float(self.encode_mode)
|
||||
rf = min(self.freeze_ref_frame, self.motion_timesteps - 1)
|
||||
ref_pos, ref_quat = self.motion_joint_positions[rf], self.motion_body_quats[rf]
|
||||
if self.encode_mode == 0:
|
||||
for f in range(10):
|
||||
obs[4+29*f:4+29*(f+1)] = ref_pos
|
||||
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
elif self.encode_mode == 1:
|
||||
ref_lower = ref_pos[LOWER_BODY_IL]
|
||||
for f in range(10):
|
||||
obs[661+12*f:661+12*(f+1)] = ref_lower
|
||||
obs[901:910] = self.vr_3point_local_target
|
||||
obs[910:922] = self.vr_3point_local_orn_target
|
||||
obs[595:601] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
elif self.encode_mode == 2:
|
||||
obs[922:1642] = self.smpl_joints_10frame_step1
|
||||
for f in range(10):
|
||||
obs[1642+6*f:1642+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
obs[1702+6*f:1702+6*(f+1)] = ref_pos[WRIST_IL]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported encoder mode: {self.encode_mode}")
|
||||
return obs
|
||||
|
||||
def build_decoder_obs(self):
|
||||
obs = np.zeros(994, np.float32); off = 0
|
||||
obs[off:off+64] = self.token; off += 64
|
||||
for h, sz in [(list(reversed(self.h_ang)),3), (list(reversed(self.h_q_mj)),29),
|
||||
(list(reversed(self.h_dq_mj)),29), (list(reversed(self.h_act_mj)),29)]:
|
||||
for f in range(10): obs[off:off+sz] = h[f]; off += sz
|
||||
for q in reversed(self.h_quat):
|
||||
obs[off:off+3] = gravity_dir(q); off += 3
|
||||
assert off == 994, f"Decoder obs mismatch: {off}"
|
||||
return obs
|
||||
|
||||
def run_encoder(self):
|
||||
return self.encoder.run(None, {self.encoder_input: self.build_encoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
|
||||
|
||||
def step(self, robot_obs, update_encoder, debug=False):
|
||||
jnames = [m.name for m in G1_29_JointIndex]
|
||||
q = np.array([robot_obs.get(f"{n}.q", DEFAULT_ANGLES[m.value]) for m,n in zip(G1_29_JointIndex,jnames)], np.float32)
|
||||
dq = np.array([robot_obs.get(f"{n}.dq", 0.0) for n in jnames], np.float32)
|
||||
quat = np.array([robot_obs.get("imu.quat.w",1), robot_obs.get("imu.quat.x",0),
|
||||
robot_obs.get("imu.quat.y",0), robot_obs.get("imu.quat.z",0)], np.float32)
|
||||
ang = np.array([robot_obs.get(f"imu.gyro.{a}",0) for a in "xyz"], np.float32)
|
||||
self.update_history(q, dq, ang, quat)
|
||||
if update_encoder: self.token = self.run_encoder()
|
||||
action_mj = self.decoder.run(None, {self.decoder_input: self.build_decoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
|
||||
self.last_action_mj = action_mj.copy()
|
||||
target = DEFAULT_ANGLES + action_mj[ISAACLAB_TO_MUJOCO] * ACTION_SCALE
|
||||
if debug:
|
||||
delta = target - q
|
||||
print(f"token_norm={np.linalg.norm(self.token):.4f} action_norm={np.linalg.norm(action_mj):.4f} "
|
||||
f"delta_max={np.max(np.abs(delta)):.4f} delta_rms={np.sqrt(np.mean(delta**2)):.4f}")
|
||||
return {f"{m.name}.q": float(target[m.value]) for m in G1_29_JointIndex}
|
||||
|
||||
def print_input_diagnostics(self):
|
||||
print("\n[Diag] Standing reference checks")
|
||||
names = {0:"g1", 1:"teleop", 2:"smpl"}
|
||||
print(f" encoder mode: {self.encode_mode} ({names.get(self.encode_mode,'unknown')})")
|
||||
print(f" DEFAULT_ANGLES range: [{DEFAULT_ANGLES.min():+.4f}, {DEFAULT_ANGLES.max():+.4f}]")
|
||||
print(f" anchor_6d(identity): {self._anchor_6d(np.array([1,0,0,0],np.float32), np.array([1,0,0,0],np.float32))}")
|
||||
print(f" gravity(identity): {gravity_dir(np.array([1,0,0,0],np.float32))} (expect [0,0,-1])")
|
||||
dec0 = self.build_decoder_obs()
|
||||
print(f" decoder q-delta max: {np.max(np.abs(dec0[94:384])):.6f}")
|
||||
print(f" decoder dq max: {np.max(np.abs(dec0[384:674])):.6f}")
|
||||
|
||||
# ── Planner motion buffer ─────────────────────────────────────────────────────
|
||||
|
||||
class PlannerMotion:
|
||||
def __init__(self, max_frames=1500):
|
||||
self.timesteps = 0
|
||||
self.joint_positions = np.zeros((max_frames, 29), np.float64)
|
||||
self.joint_velocities = np.zeros((max_frames, 29), np.float64)
|
||||
self.body_positions = np.zeros((max_frames, 3), np.float64)
|
||||
self.body_quaternions = np.zeros((max_frames, 4), np.float64)
|
||||
self.body_quaternions[:, 0] = 1.0
|
||||
|
||||
# ── Subprocess planner ────────────────────────────────────────────────────────
|
||||
|
||||
def _resample_30_to_50(qpos, n30):
|
||||
t50 = int(np.floor(n30 / 30.0 * 50))
|
||||
f30 = np.arange(t50) / 50.0 * 30.0
|
||||
f0 = np.floor(f30).astype(int)
|
||||
f1 = np.minimum(f0+1, n30-1)
|
||||
frac, w0 = (f30-f0).astype(np.float64), None
|
||||
w0 = 1.0 - frac
|
||||
jp = (w0[:,None]*qpos[f0,7:36] + frac[:,None]*qpos[f1,7:36])[:,MUJOCO_TO_ISAACLAB]
|
||||
jv = np.zeros_like(jp)
|
||||
if t50 >= 2: jv[:t50-1] = (jp[1:] - jp[:-1]) * 50.0; jv[-1] = jv[-2]
|
||||
return {
|
||||
"timesteps": t50,
|
||||
"joint_positions": jp,
|
||||
"joint_velocities": jv,
|
||||
"body_positions": w0[:,None]*qpos[f0,:3] + frac[:,None]*qpos[f1,:3],
|
||||
"body_quaternions": quat_slerp_batch(qpos[f0,3:7], qpos[f1,3:7], frac),
|
||||
}
|
||||
|
||||
def _build_planner_inputs(ctx, ms_dict, version, seed):
|
||||
inp = {
|
||||
"context_mujoco_qpos": ctx.astype(np.float32).reshape(1,4,36),
|
||||
"target_vel": np.array([ms_dict["speed"]], np.float32),
|
||||
"mode": np.array([ms_dict["mode"]], np.int64),
|
||||
"movement_direction": np.array(ms_dict["movement_direction"], np.float32).reshape(1,3),
|
||||
"facing_direction": np.array(ms_dict["facing_direction"], np.float32).reshape(1,3),
|
||||
"random_seed": np.array([seed], np.int64),
|
||||
}
|
||||
if version >= 1:
|
||||
# TensorRT deploy: allow 9–11 prediction tokens only (indices 3–5 for MIN_TOKENS=6).
|
||||
allowed = np.zeros((1, K), np.int64)
|
||||
if K >= 6:
|
||||
allowed[0, 3:6] = 1
|
||||
inp.update({
|
||||
"height": np.array([ms_dict["height"]], np.float32),
|
||||
"has_specific_target": np.array([[0]], np.int64),
|
||||
"specific_target_positions": np.zeros((1,4,3), np.float32),
|
||||
"specific_target_headings": np.zeros((1,4), np.float32),
|
||||
"allowed_pred_num_tokens": allowed,
|
||||
})
|
||||
return inp
|
||||
|
||||
def _planner_worker(path, req_q, res_q, stop_evt, version, seed, use_gpu):
|
||||
so = ort.SessionOptions(); so.log_severity_level = 3
|
||||
providers = _ort_providers(force_cpu=not use_gpu)
|
||||
sess = ort.InferenceSession(path, sess_options=so, providers=providers)
|
||||
while not stop_evt.is_set():
|
||||
try: ctx, gf, ms_dict = req_q.get(timeout=0.05)
|
||||
except Exception: continue
|
||||
try:
|
||||
inp = _build_planner_inputs(ctx, ms_dict, version, seed)
|
||||
t0 = time.time()
|
||||
qpos_out, num_pred = sess.run(None, inp)
|
||||
t_inf = time.time()
|
||||
n = int(num_pred.flat[0])
|
||||
qpos = qpos_out[0,:n]
|
||||
if np.any(np.isnan(qpos)): continue
|
||||
motion = _resample_30_to_50(qpos, n)
|
||||
motion["gen_frame"] = gf
|
||||
print(f"[Planner] inf={1000*(t_inf-t0):.1f}ms total={1000*(time.time()-t0):.1f}ms frames={n}", flush=True)
|
||||
while not res_q.empty():
|
||||
try: res_q.get_nowait()
|
||||
except queue.Empty: break
|
||||
res_q.put(motion)
|
||||
except Exception as e:
|
||||
print(f"[Planner] Error: {e}", flush=True)
|
||||
|
||||
# ── SonicPlanner ──────────────────────────────────────────────────────────────
|
||||
|
||||
class SonicPlanner:
|
||||
def __init__(self, session, planner_path):
|
||||
self.session = session
|
||||
self.planner_path = planner_path
|
||||
self.gen_frame = 0
|
||||
self.random_seed = INITIAL_RANDOM_SEED
|
||||
self.version = 1 if len(session.get_inputs()) >= 11 else 0
|
||||
self.motion_50hz = PlannerMotion()
|
||||
self._snapshot = PlannerMotion()
|
||||
self._req_q = self._res_q = self._stop_evt = self._planner_thread = None
|
||||
self._ctrl = None
|
||||
|
||||
def _build_inputs(self, ctx, ms):
|
||||
return _build_planner_inputs(
|
||||
ctx,
|
||||
{"mode": ms.mode, "speed": ms.speed, "height": ms.height,
|
||||
"movement_direction": list(ms.movement_direction),
|
||||
"facing_direction": list(ms.facing_direction)},
|
||||
self.version, self.random_seed)
|
||||
|
||||
@staticmethod
|
||||
def build_initial_context(joint_positions):
|
||||
ctx = np.zeros((4, 36), np.float32)
|
||||
jp_mj = joint_positions.astype(np.float32)[ISAACLAB_TO_MUJOCO]
|
||||
for n in range(4):
|
||||
ctx[n, 2] = DEFAULT_HEIGHT
|
||||
ctx[n, 3] = 1.0
|
||||
ctx[n, 7:36] = jp_mj
|
||||
return ctx
|
||||
|
||||
def _context_from_controller(self, current_frame):
|
||||
ctrl = self._ctrl
|
||||
gen_frame = current_frame + MOTION_LOOK_AHEAD_STEPS
|
||||
t_arr = gen_frame / 50.0 + np.arange(4) / 30.0
|
||||
f50 = t_arr * 50.0
|
||||
with ctrl.motion_lock:
|
||||
ts = ctrl.motion_timesteps
|
||||
if ts <= 0:
|
||||
return self.build_initial_context(DEFAULT_ANGLES)
|
||||
bp, bq, jp = ctrl.motion_body_pos, ctrl.motion_body_quats, ctrl.motion_joint_positions
|
||||
f0 = np.minimum(np.floor(f50).astype(int), ts - 1)
|
||||
f1 = np.minimum(f0 + 1, ts - 1)
|
||||
frac = f50 - f0
|
||||
w0 = 1.0 - frac
|
||||
ctx = np.zeros((4, 36), np.float32)
|
||||
ctx[:, 0:3] = w0[:, None] * bp[f0] + frac[:, None] * bp[f1]
|
||||
ctx[:, 3:7] = quat_slerp_batch(bq[f0], bq[f1], frac)
|
||||
ij = w0[:, None] * jp[f0] + frac[:, None] * jp[f1]
|
||||
ctx[:, 7:36] = ij[:, ISAACLAB_TO_MUJOCO]
|
||||
self.gen_frame = gen_frame
|
||||
return ctx
|
||||
|
||||
def _load_motion_in_place(self, qpos, n30, target=None):
|
||||
if target is None: target = self.motion_50hz
|
||||
r = _resample_30_to_50(qpos, n30)
|
||||
n = r["timesteps"]; target.timesteps = n
|
||||
target.joint_positions[:n] = r["joint_positions"]
|
||||
target.joint_velocities[:n] = r["joint_velocities"]
|
||||
target.body_positions[:n] = r["body_positions"]
|
||||
target.body_quaternions[:n] = r["body_quaternions"]
|
||||
return target
|
||||
|
||||
def initialize(self, joint_positions, ms):
|
||||
ctx = self.build_initial_context(joint_positions)
|
||||
qpos_out, num_pred = self.session.run(None, self._build_inputs(ctx, ms))
|
||||
n = int(num_pred.flat[0]); qpos = qpos_out[0,:n]
|
||||
if np.any(np.isnan(qpos)): raise RuntimeError("Planner initial output contains NaN")
|
||||
print(f"[Planner] Init: {n} frames @ 30 Hz")
|
||||
self._load_motion_in_place(qpos, n)
|
||||
print(f"[Planner] Resampled to {self.motion_50hz.timesteps} frames @ 50 Hz")
|
||||
return self.motion_50hz
|
||||
|
||||
def request_replan(self, cursor, ms):
|
||||
if self._req_q is None: return
|
||||
ctx = self._context_from_controller(cursor)
|
||||
ms_dict = {"mode": ms.mode, "speed": ms.speed, "height": ms.height,
|
||||
"movement_direction": list(ms.movement_direction),
|
||||
"facing_direction": list(ms.facing_direction)}
|
||||
while not self._req_q.empty():
|
||||
try: self._req_q.get_nowait()
|
||||
except queue.Empty: break
|
||||
self._req_q.put((ctx, self.gen_frame, ms_dict))
|
||||
|
||||
def try_get_new_motion(self):
|
||||
if self._res_q is None: return None
|
||||
result = None
|
||||
while not self._res_q.empty():
|
||||
try: result = self._res_q.get_nowait()
|
||||
except queue.Empty: break
|
||||
if result is None: return None
|
||||
n, gf = result["timesteps"], result["gen_frame"]
|
||||
s = self._snapshot; s.timesteps = n
|
||||
s.joint_positions[:n] = result["joint_positions"]
|
||||
s.joint_velocities[:n] = result["joint_velocities"]
|
||||
s.body_positions[:n] = result["body_positions"]
|
||||
s.body_quaternions[:n] = result["body_quaternions"]
|
||||
return s, gf
|
||||
|
||||
def start_subprocess(self, controller, use_gpu: bool = False):
|
||||
"""Run planner ONNX in a background thread (avoids mp spawn/fork + CUDA/MuJoCo issues)."""
|
||||
self._ctrl = controller
|
||||
self._req_q = queue.Queue()
|
||||
self._res_q = queue.Queue()
|
||||
self._stop_evt = threading.Event()
|
||||
self._planner_thread = threading.Thread(
|
||||
target=_planner_worker,
|
||||
args=(self.planner_path, self._req_q, self._res_q,
|
||||
self._stop_evt, self.version, self.random_seed, use_gpu),
|
||||
daemon=True,
|
||||
name="sonic-planner",
|
||||
)
|
||||
self._planner_thread.start()
|
||||
print(f"[Planner] Background thread started ({'GPU' if use_gpu else 'CPU'})")
|
||||
|
||||
def stop_subprocess(self):
|
||||
if self._stop_evt:
|
||||
self._stop_evt.set()
|
||||
if self._planner_thread is not None:
|
||||
self._planner_thread.join(timeout=3.0)
|
||||
print("[Planner] Background thread stopped")
|
||||
self._planner_thread = None
|
||||
self._req_q = self._res_q = self._stop_evt = None
|
||||
|
||||
# ── PlannerController ─────────────────────────────────────────────────────────
|
||||
|
||||
class PlannerController(StandingEncoderDecoder):
|
||||
def __init__(self, planner, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
self.planner = planner
|
||||
self.ref_cursor = 0
|
||||
self.motion_timesteps = 0
|
||||
self.motion_joint_positions = np.zeros((1500,29), np.float64)
|
||||
self.motion_joint_velocities = np.zeros((1500,29), np.float64)
|
||||
self.motion_body_quats = np.zeros((1500,4), np.float64); self.motion_body_quats[:,0] = 1.0
|
||||
self.motion_body_pos = np.zeros((1500,3), np.float64)
|
||||
self.init_ref_quat = np.array([1,0,0,0], np.float64)
|
||||
self.heading_init_base_quat = np.array([1,0,0,0], np.float64)
|
||||
self.delta_heading = 0.0
|
||||
self.reinit_heading = False
|
||||
self.playing = self.first_motion = False
|
||||
self.motion_lock = threading.Lock()
|
||||
|
||||
def load_initial_motion(self, motion):
|
||||
with self.motion_lock:
|
||||
n = motion.timesteps
|
||||
self.motion_timesteps = n
|
||||
self.motion_joint_positions[:n] = motion.joint_positions[:n]
|
||||
self.motion_joint_velocities[:n] = motion.joint_velocities[:n]
|
||||
self.motion_body_quats[:n] = motion.body_quaternions[:n]
|
||||
self.motion_body_pos[:n] = motion.body_positions[:n]
|
||||
self.init_ref_quat = motion.body_quaternions[0].copy()
|
||||
self.ref_cursor = 0; self.first_motion = True
|
||||
self.playing = True; self.delta_heading = 0.0
|
||||
|
||||
def blend_new_motion(self, new_motion, gen_frame):
|
||||
"""Blend like C++ CurrentFrameAdvancement: 8-frame cross-fade, then copy tail."""
|
||||
with self.motion_lock:
|
||||
cur = self.ref_cursor
|
||||
new_len = gen_frame - cur + new_motion.timesteps
|
||||
if new_len <= 0:
|
||||
return
|
||||
if self.motion_timesteps == 0:
|
||||
n = new_motion.timesteps
|
||||
self.motion_joint_positions[:n] = new_motion.joint_positions[:n]
|
||||
self.motion_joint_velocities[:n] = new_motion.joint_velocities[:n]
|
||||
self.motion_body_pos[:n] = new_motion.body_positions[:n]
|
||||
self.motion_body_quats[:n] = new_motion.body_quaternions[:n]
|
||||
self.motion_timesteps = n
|
||||
self.ref_cursor = 0
|
||||
self.init_ref_quat = self.motion_body_quats[0].copy()
|
||||
self.first_motion = False
|
||||
return
|
||||
|
||||
blend_start = max(0, gen_frame - cur)
|
||||
blend_end = min(new_len, blend_start + BLEND_FRAMES)
|
||||
|
||||
for f in range(blend_end):
|
||||
f_old = min(f + cur, self.motion_timesteps - 1)
|
||||
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
|
||||
w_new = min(1.0, max(0.0, (f - blend_start) / BLEND_FRAMES))
|
||||
w_old = 1.0 - w_new
|
||||
self.motion_joint_positions[f] = (
|
||||
w_old * self.motion_joint_positions[f_old]
|
||||
+ w_new * new_motion.joint_positions[f_new]
|
||||
)
|
||||
self.motion_joint_velocities[f] = (
|
||||
w_old * self.motion_joint_velocities[f_old]
|
||||
+ w_new * new_motion.joint_velocities[f_new]
|
||||
)
|
||||
self.motion_body_pos[f] = (
|
||||
w_old * self.motion_body_pos[f_old]
|
||||
+ w_new * new_motion.body_positions[f_new]
|
||||
)
|
||||
self.motion_body_quats[f] = quat_slerp(
|
||||
self.motion_body_quats[f_old], new_motion.body_quaternions[f_new], w_new
|
||||
)
|
||||
|
||||
for f in range(blend_end, new_len):
|
||||
f_new = max(0, min(f + cur - gen_frame, new_motion.timesteps - 1))
|
||||
self.motion_joint_positions[f] = new_motion.joint_positions[f_new]
|
||||
self.motion_joint_velocities[f] = new_motion.joint_velocities[f_new]
|
||||
self.motion_body_pos[f] = new_motion.body_positions[f_new]
|
||||
self.motion_body_quats[f] = new_motion.body_quaternions[f_new].copy()
|
||||
|
||||
self.motion_timesteps = new_len
|
||||
self.first_motion = False
|
||||
self.ref_cursor = 0
|
||||
self.init_ref_quat = self.motion_body_quats[0].copy()
|
||||
|
||||
def _heading_apply_delta(self):
|
||||
delta = quat_mul(heading_quat(self.heading_init_base_quat).astype(np.float32),
|
||||
heading_quat_inv(self.init_ref_quat).astype(np.float32))
|
||||
if self.delta_heading:
|
||||
h = self.delta_heading / 2.0
|
||||
delta = quat_mul(np.array([np.cos(h),0,0,np.sin(h)], np.float32), delta)
|
||||
return delta
|
||||
|
||||
def _anchor_6d(self, base_quat, ref_quat=None):
|
||||
if ref_quat is None: ref_quat = self.init_ref_quat
|
||||
new_ref = quat_mul(self._heading_apply_delta(), ref_quat.astype(np.float32))
|
||||
return quat_to_6d(quat_mul(quat_conj(base_quat.astype(np.float32)), new_ref))
|
||||
|
||||
def build_encoder_obs(self):
|
||||
obs = np.zeros(1762, np.float32); obs[0] = float(self.encode_mode)
|
||||
with self.motion_lock:
|
||||
if self.encode_mode == 2:
|
||||
# SMPL whole-body imitation: the 720-dim SMPL window carries the
|
||||
# target pose; the planner reference frame supplies anchor + wrist.
|
||||
rf = min(self.ref_cursor, self.motion_timesteps - 1)
|
||||
ref_pos = self.motion_joint_positions[rf].astype(np.float32)
|
||||
ref_quat = self.motion_body_quats[rf].astype(np.float32)
|
||||
anchor = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
wrist = ref_pos[WRIST_IL]
|
||||
obs[922:1642] = self.smpl_joints_10frame_step1
|
||||
for f in range(10):
|
||||
obs[1642+6*f:1642+6*(f+1)] = anchor
|
||||
obs[1702+6*f:1702+6*(f+1)] = wrist
|
||||
return obs
|
||||
for f in range(10):
|
||||
tf = min(self.ref_cursor + f*5 if self.playing else self.ref_cursor,
|
||||
self.motion_timesteps - 1)
|
||||
obs[4+29*f:4+29*(f+1)] = self.motion_joint_positions[tf].astype(np.float32)
|
||||
if self.playing:
|
||||
obs[294+29*f:294+29*(f+1)] = self.motion_joint_velocities[tf].astype(np.float32)
|
||||
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(
|
||||
self.h_quat[0], self.motion_body_quats[tf].astype(np.float32))
|
||||
return obs
|
||||
|
||||
def step(self, robot_obs, update_encoder, debug=False):
|
||||
if robot_obs and (self.first_motion or self.reinit_heading):
|
||||
q = None
|
||||
if "imu.quat.w" in robot_obs:
|
||||
q = np.array([
|
||||
robot_obs["imu.quat.w"], robot_obs["imu.quat.x"],
|
||||
robot_obs["imu.quat.y"], robot_obs["imu.quat.z"],
|
||||
], np.float64)
|
||||
else:
|
||||
q = robot_obs.get("imu.quaternion")
|
||||
if q is not None:
|
||||
q = np.array(q, np.float64)
|
||||
if q is not None:
|
||||
self.heading_init_base_quat = np.array(q, np.float64)
|
||||
with self.motion_lock:
|
||||
rf = min(self.ref_cursor, self.motion_timesteps - 1)
|
||||
self.init_ref_quat = self.motion_body_quats[rf].copy()
|
||||
self.delta_heading = 0.0
|
||||
self.first_motion = False
|
||||
self.reinit_heading = False
|
||||
print(f"[Heading] init quat: {self.heading_init_base_quat}")
|
||||
return super().step(robot_obs, update_encoder=update_encoder, debug=debug)
|
||||
|
||||
def advance_cursor(self):
|
||||
"""Advance one frame per 50 Hz tick (C++ current_frame_ += 1), no wall-clock catch-up."""
|
||||
if not self.playing:
|
||||
return
|
||||
with self.motion_lock:
|
||||
if self.motion_timesteps > 0:
|
||||
self.ref_cursor = min(self.ref_cursor + 1, self.motion_timesteps - 1)
|
||||
|
||||
# ── Keyboard ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class RawKeyboard:
|
||||
def __init__(self):
|
||||
self.fd = sys.stdin.fileno()
|
||||
self.old = termios.tcgetattr(self.fd)
|
||||
def __enter__(self): tty.setcbreak(self.fd); return self
|
||||
def __exit__(self, *_): termios.tcsetattr(self.fd, termios.TCSADRAIN, self.old)
|
||||
def get_key(self):
|
||||
return sys.stdin.read(1) if select.select([sys.stdin],[],[],0)[0] else None
|
||||
|
||||
|
||||
def drain_keyboard(kb, ms, controller=None) -> bool:
|
||||
"""Process all pending terminal keys this frame (return True to quit)."""
|
||||
quit_requested = False
|
||||
while True:
|
||||
key = kb.get_key()
|
||||
if key is None:
|
||||
break
|
||||
if process_keyboard(key, ms, controller):
|
||||
quit_requested = True
|
||||
return quit_requested
|
||||
|
||||
def process_keyboard(key, ms, controller=None):
|
||||
if key is None: return False
|
||||
if key == '\x1b': return True
|
||||
if key == ' ':
|
||||
ms.mode = LM.IDLE; ms.speed = ms.height = -1.0
|
||||
ms.has_movement = False; ms.needs_replan = True
|
||||
if controller: controller.playing = False; controller.reinit_heading = True
|
||||
print("\n >> EMERGENCY STOP -> IDLE"); return False
|
||||
if key in ('r','R'):
|
||||
ms.needs_replan = True; print("\n >> Manual replan"); return False
|
||||
if key in ('m','M'):
|
||||
if controller is not None and getattr(controller, "smpl_motion", None) is not None:
|
||||
if controller.encode_mode == 2:
|
||||
controller.encode_mode = 0
|
||||
controller.playing = True; controller.reinit_heading = True
|
||||
ms.needs_replan = True
|
||||
print("\n >> Motion OFF -> locomotion (mode 0). WASD to drive.")
|
||||
else:
|
||||
controller.encode_mode = 2
|
||||
controller.reinit_heading = True
|
||||
controller.smpl_motion.reset()
|
||||
print("\n >> Motion ON -> SMPL playback (mode 2)")
|
||||
else:
|
||||
print("\n >> No motion loaded (start with --motion-file)")
|
||||
return False
|
||||
if key in ('n','N','p','P'):
|
||||
ms.motion_set_idx = (ms.motion_set_idx + (1 if key in ('n','N') else -1)) % len(MOTION_SETS)
|
||||
name, modes = MOTION_SETS[ms.motion_set_idx]
|
||||
print(f"\n >> Motion set: {name}")
|
||||
[print(f" {i+1}: {m.name}") for i,m in enumerate(modes)]
|
||||
return False
|
||||
if key.isdigit() and key not in ('9','0'):
|
||||
idx = int(key) - 1; modes = MOTION_SETS[ms.motion_set_idx][1]
|
||||
if 0 <= idx < len(modes):
|
||||
ms.mode = modes[idx]; ms.needs_replan = True
|
||||
if controller: controller.playing = True; controller.reinit_heading = True
|
||||
print(f"\n >> Mode: {LM(ms.mode).name} ({ms.mode}) [replanning...]")
|
||||
return False
|
||||
if key == '9':
|
||||
ms.speed = max(0.0, (ms.speed if ms.speed>=0 else 1.0) - 0.1)
|
||||
print(f"\n >> Speed: {ms.speed:.1f}"); return False
|
||||
if key == '0':
|
||||
ms.speed = min(5.0, (ms.speed if ms.speed>=0 else 1.0) + 0.1)
|
||||
print(f"\n >> Speed: {ms.speed:.1f}"); return False
|
||||
if key == '-':
|
||||
ms.height = max(0.2, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) - 0.02)
|
||||
print(f"\n >> Height: {ms.height:.2f}"); return False
|
||||
if key == '=':
|
||||
ms.height = min(1.0, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) + 0.02)
|
||||
print(f"\n >> Height: {ms.height:.2f}"); return False
|
||||
if key.lower() == 'w': ms.movement_angle = ms.facing_angle
|
||||
elif key.lower() == 's': ms.movement_angle = ms.facing_angle + math.pi
|
||||
elif key.lower() == 'a': ms.movement_angle = ms.facing_angle + math.pi/2
|
||||
elif key.lower() == 'd': ms.movement_angle = ms.facing_angle - math.pi/2
|
||||
if key.lower() in ('w','s','a','d'):
|
||||
ms.has_movement = ms.needs_replan = True
|
||||
if controller:
|
||||
controller.playing = True
|
||||
print(f"\n >> Move {key.upper()} (replanning...)")
|
||||
elif key.lower() == 'q':
|
||||
ms.facing_angle += 0.1
|
||||
if controller: controller.delta_heading += 0.1
|
||||
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
|
||||
elif key.lower() == 'e':
|
||||
ms.facing_angle -= 0.1
|
||||
if controller: controller.delta_heading -= 0.1
|
||||
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
|
||||
return False
|
||||
|
||||
_joy_prev_active = False
|
||||
|
||||
|
||||
def _parse_wireless(wr):
|
||||
"""Parse wireless_remote (bytes or int-array) into (lx, ly, rx, ry)."""
|
||||
import struct as _st
|
||||
if not isinstance(wr, (bytes, bytearray)):
|
||||
wr = bytes(wr)
|
||||
if len(wr) < 24:
|
||||
return None
|
||||
lx = _st.unpack("f", wr[4:8])[0]
|
||||
rx = _st.unpack("f", wr[8:12])[0]
|
||||
ry = _st.unpack("f", wr[12:16])[0]
|
||||
ly = _st.unpack("f", wr[20:24])[0]
|
||||
return lx, ly, rx, ry
|
||||
|
||||
|
||||
def process_joystick(obs, ms, controller=None):
|
||||
"""Joystick mirrors keyboard: left stick=WASD, right stick X=Q/E, right stick Y=height."""
|
||||
global _joy_prev_active
|
||||
wr = obs.get("wireless_remote")
|
||||
if wr is None:
|
||||
return
|
||||
parsed = _parse_wireless(wr)
|
||||
if parsed is None:
|
||||
return
|
||||
lx, ly, rx, ry = parsed
|
||||
|
||||
# Dead zone + negate both Y axes (bridge already flips them once)
|
||||
lx = 0.0 if abs(lx) < DEADZONE else lx
|
||||
ly = 0.0 if abs(ly) < DEADZONE else -ly
|
||||
rx = 0.0 if abs(rx) < DEADZONE else rx
|
||||
ry = 0.0 if abs(ry) < DEADZONE else -ry
|
||||
|
||||
left_active = abs(lx) > 0 or abs(ly) > 0
|
||||
|
||||
# Left stick → WASD (movement direction relative to facing)
|
||||
if left_active:
|
||||
ms.movement_angle = ms.facing_angle + math.atan2(-lx, -ly)
|
||||
ms.has_movement = True
|
||||
if not _joy_prev_active:
|
||||
ms.needs_replan = True
|
||||
_joy_prev_active = True
|
||||
elif _joy_prev_active and not (abs(rx) > 0 or abs(ry) > 0):
|
||||
_joy_prev_active = False
|
||||
ms.has_movement = False
|
||||
|
||||
# Right stick X → Q/E (facing rotation, ~1 rad/s at full deflection)
|
||||
if abs(rx) > 0:
|
||||
delta = -0.02 * rx
|
||||
ms.facing_angle += delta
|
||||
if controller:
|
||||
controller.delta_heading += delta
|
||||
|
||||
# Right stick Y → -/= (height adjustment, ~0.25/s at full deflection)
|
||||
if abs(ry) > 0:
|
||||
step = -0.005 * ry
|
||||
ms.height = max(0.1, min(1.0, (ms.height if ms.height >= 0 else DEFAULT_HEIGHT) + step))
|
||||
@@ -1,152 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""SONIC full-body controller for Unitree G1."""
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.controllers.sonic_pipeline import (
|
||||
CONTROL_DT,
|
||||
DEBUG_PRINT_EVERY,
|
||||
DEFAULT_ANGLES,
|
||||
ENCODER_UPDATE_EVERY,
|
||||
LM,
|
||||
MOTION_SETS,
|
||||
MovementState,
|
||||
PlannerController,
|
||||
SonicPlanner,
|
||||
clamp_mode_params,
|
||||
compute_kp_kd,
|
||||
lowstate_to_obs,
|
||||
process_joystick,
|
||||
should_replan_request,
|
||||
_ort_providers,
|
||||
_snapshot_ms,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SonicRuntime:
|
||||
"""Shared SONIC control loop state (standalone demo + locomotion controller)."""
|
||||
|
||||
def __init__(self, force_cpu: bool = False):
|
||||
planner_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="planner_sonic.onnx")
|
||||
encoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_encoder.onnx")
|
||||
decoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_decoder.onnx")
|
||||
|
||||
providers = _ort_providers(force_cpu=force_cpu)
|
||||
self.use_gpu = providers[0] == "CUDAExecutionProvider"
|
||||
so = ort.SessionOptions()
|
||||
so.log_severity_level = 3
|
||||
|
||||
planner_sess = ort.InferenceSession(planner_path, sess_options=so, providers=providers)
|
||||
encoder_sess = ort.InferenceSession(encoder_path, sess_options=so, providers=providers)
|
||||
decoder_sess = ort.InferenceSession(decoder_path, sess_options=so, providers=providers)
|
||||
|
||||
self.kp, self.kd = compute_kp_kd()
|
||||
self.ms = MovementState()
|
||||
self.planner = SonicPlanner(planner_sess, planner_path)
|
||||
self.controller = PlannerController(self.planner, encoder_sess, decoder_sess)
|
||||
|
||||
motion = self.planner.initialize(DEFAULT_ANGLES, self.ms)
|
||||
self.controller.load_initial_motion(motion)
|
||||
self.planner.start_subprocess(self.controller, use_gpu=self.use_gpu)
|
||||
|
||||
self.step = 0
|
||||
self.replan_timer = 0.0
|
||||
self.last_ms = _snapshot_ms(self.ms)
|
||||
|
||||
@property
|
||||
def pipeline(self):
|
||||
return self.controller
|
||||
|
||||
def tick(self, obs: dict, *, debug: bool | None = None, use_joystick: bool = True) -> dict:
|
||||
if not obs:
|
||||
self.step += 1
|
||||
return {}
|
||||
|
||||
if use_joystick:
|
||||
process_joystick(obs, self.ms, self.controller)
|
||||
clamp_mode_params(self.ms)
|
||||
|
||||
if self.step > 0:
|
||||
self.replan_timer += CONTROL_DT
|
||||
if should_replan_request(self.ms, self.last_ms, self.replan_timer, self.step):
|
||||
self.planner.request_replan(self.controller.ref_cursor, self.ms)
|
||||
self.replan_timer = 0.0
|
||||
self.ms.needs_replan = False
|
||||
self.last_ms = _snapshot_ms(self.ms)
|
||||
|
||||
do_enc = self.step % ENCODER_UPDATE_EVERY == 0
|
||||
if debug is None:
|
||||
debug = self.step % DEBUG_PRINT_EVERY == 0
|
||||
action = self.controller.step(obs, update_encoder=do_enc, debug=debug)
|
||||
|
||||
result = self.planner.try_get_new_motion()
|
||||
if result:
|
||||
self.controller.blend_new_motion(*result)
|
||||
|
||||
self.controller.advance_cursor()
|
||||
self.step += 1
|
||||
return action
|
||||
|
||||
def reset(self):
|
||||
self.ms = MovementState()
|
||||
self.controller.reinit_heading = True
|
||||
self.controller.playing = True
|
||||
self.step = 0
|
||||
self.replan_timer = 0.0
|
||||
self.last_ms = _snapshot_ms(self.ms)
|
||||
|
||||
def shutdown(self):
|
||||
self.planner.stop_subprocess()
|
||||
|
||||
|
||||
class SonicWholeBodyController:
|
||||
"""Full-body SONIC controller for UnitreeG1's background controller thread."""
|
||||
|
||||
control_dt = CONTROL_DT
|
||||
full_body = True
|
||||
|
||||
def __init__(self, force_cpu: bool = False):
|
||||
logger.info("Loading SONIC whole-body controller...")
|
||||
self._runtime = SonicRuntime(force_cpu=force_cpu)
|
||||
self.kp = self._runtime.kp
|
||||
self.kd = self._runtime.kd
|
||||
self.controller = self._runtime.controller
|
||||
self.ms = self._runtime.ms
|
||||
logger.info(
|
||||
"SONIC ready: %s (default mode: %s)",
|
||||
MOTION_SETS[0][0],
|
||||
LM(self.ms.mode).name,
|
||||
)
|
||||
|
||||
def run_step(self, action: dict, lowstate) -> dict:
|
||||
if lowstate is None:
|
||||
return {}
|
||||
obs = lowstate_to_obs(lowstate)
|
||||
return self._runtime.tick(obs, debug=False)
|
||||
|
||||
def reset(self):
|
||||
self._runtime.reset()
|
||||
|
||||
def shutdown(self):
|
||||
self._runtime.shutdown()
|
||||
@@ -68,9 +68,8 @@ def make_locomotion_controller(name: str | None):
|
||||
if name is None:
|
||||
return None
|
||||
controllers = {
|
||||
"GrootLocomotionController": "lerobot.robots.unitree_g1.controllers.gr00t_locomotion",
|
||||
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.controllers.holosoma_locomotion",
|
||||
"SonicWholeBodyController": "lerobot.robots.unitree_g1.controllers.sonic_whole_body",
|
||||
"GrootLocomotionController": "lerobot.robots.unitree_g1.gr00t_locomotion",
|
||||
"HolosomaLocomotionController": "lerobot.robots.unitree_g1.holosoma_locomotion",
|
||||
}
|
||||
module_path = controllers.get(name)
|
||||
if module_path is None:
|
||||
|
||||
+1
-1
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
from .g1_utils import (
|
||||
REMOTE_AXES,
|
||||
REMOTE_BUTTONS,
|
||||
G1_29_JointIndex,
|
||||
+1
-1
@@ -22,7 +22,7 @@ import onnx
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import (
|
||||
from .g1_utils import (
|
||||
REMOTE_AXES,
|
||||
G1_29_JointArmIndex,
|
||||
G1_29_JointIndex,
|
||||
@@ -338,9 +338,6 @@ class UnitreeG1(Robot):
|
||||
|
||||
self.kp = np.array(self.config.kp, dtype=np.float32)
|
||||
self.kd = np.array(self.config.kd, dtype=np.float32)
|
||||
if self.controller is not None and hasattr(self.controller, "kp"):
|
||||
self.kp = np.array(self.controller.kp, dtype=np.float32)
|
||||
self.kd = np.array(self.controller.kd, dtype=np.float32)
|
||||
|
||||
for joint in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[joint].mode = 1
|
||||
@@ -377,9 +374,6 @@ class UnitreeG1(Robot):
|
||||
# Signal thread to stop and unblock any waits
|
||||
self._shutdown_event.set()
|
||||
|
||||
if self.controller is not None and hasattr(self.controller, "shutdown"):
|
||||
self.controller.shutdown()
|
||||
|
||||
# Wait for subscribe thread to finish
|
||||
if self.subscribe_thread is not None:
|
||||
self.subscribe_thread.join(timeout=2.0)
|
||||
@@ -471,11 +465,9 @@ class UnitreeG1(Robot):
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
action_to_publish = action
|
||||
if self.controller is not None:
|
||||
self._update_controller_action(action)
|
||||
if getattr(self.controller, "full_body", False):
|
||||
return action
|
||||
# Controller thread owns legs/waist. Here we only update joystick inputs
|
||||
# and publish arm targets from the teleoperator.
|
||||
self._update_controller_action(action)
|
||||
arm_prefixes = tuple(j.name for j in G1_29_JointArmIndex)
|
||||
action_to_publish = {
|
||||
key: value
|
||||
|
||||
@@ -72,8 +72,9 @@ from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs import FeatureType, parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs import (
|
||||
check_env_attributes_and_types,
|
||||
close_envs,
|
||||
@@ -84,7 +85,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -95,6 +96,65 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict) -> dict:
|
||||
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
shape = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
return features
|
||||
|
||||
|
||||
def _build_raw_frame(
|
||||
raw_obs: dict,
|
||||
env_idx: int,
|
||||
action: np.ndarray,
|
||||
reward: float,
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if key.startswith("next."):
|
||||
continue
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
val = raw_obs[key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -105,6 +165,10 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -145,6 +209,33 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
recording_datasets: list[LeRobotDataset] | None = None
|
||||
raw_observation = None
|
||||
task_desc = ""
|
||||
if recording_dir is not None and env_features is not None:
|
||||
features = _env_features_to_dataset_features(env_features)
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
recording_datasets = []
|
||||
for i in range(env.num_envs):
|
||||
multi_env = env.num_envs > 1
|
||||
root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir)
|
||||
base_repo_id = recording_repo_id or "eval_recording"
|
||||
repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id
|
||||
recording_datasets.append(
|
||||
LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=True,
|
||||
)
|
||||
)
|
||||
raw_observation = deepcopy(observation)
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
task_desc = ""
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
@@ -217,6 +308,26 @@ def rollout(
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
if recording_datasets is not None and raw_observation is not None:
|
||||
prev_done = done.copy()
|
||||
for env_idx in range(env.num_envs):
|
||||
if prev_done[env_idx]:
|
||||
continue
|
||||
frame = _build_raw_frame(
|
||||
raw_observation,
|
||||
env_idx,
|
||||
action_numpy[env_idx],
|
||||
reward[env_idx],
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_datasets[env_idx].features,
|
||||
)
|
||||
recording_datasets[env_idx].add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_datasets[env_idx].save_episode()
|
||||
raw_observation = deepcopy(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||
@@ -255,6 +366,12 @@ def rollout(
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret[OBS_STR] = stacked_observations
|
||||
|
||||
if recording_datasets is not None:
|
||||
for ds in recording_datasets:
|
||||
ds.finalize()
|
||||
if recording_repo_id is not None:
|
||||
ds.push_to_hub(private=recording_private)
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
policy.use_original_modules()
|
||||
|
||||
@@ -273,6 +390,10 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -361,6 +482,10 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -563,6 +688,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
|
||||
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
@@ -572,10 +701,15 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=False,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
recording_dir=recording_dir,
|
||||
env_features=cfg.env.features if cfg.eval.recording else None,
|
||||
recording_repo_id=cfg.eval.recording_repo_id,
|
||||
recording_private=cfg.eval.recording_private,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
@@ -618,6 +752,10 @@ def eval_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> TaskMetrics:
|
||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||
|
||||
@@ -635,6 +773,10 @@ def eval_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -661,6 +803,10 @@ def run_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
):
|
||||
"""
|
||||
Run eval_one for a single (task_group, task_id, env).
|
||||
@@ -672,7 +818,13 @@ def run_one(
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
|
||||
task_recording_dir = None
|
||||
task_repo_id = None
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
if recording_repo_id is not None:
|
||||
task_repo_id = f"{recording_repo_id}_{task_group}_{task_id}"
|
||||
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
@@ -685,8 +837,12 @@ def run_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=task_recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=task_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
@@ -702,6 +858,10 @@ def eval_policy_all(
|
||||
n_episodes: int,
|
||||
*,
|
||||
max_episodes_rendered: int = 0,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
@@ -761,6 +921,10 @@ def eval_policy_all(
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
|
||||
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
||||
)
|
||||
|
||||
|
||||
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
|
||||
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
|
||||
)
|
||||
dataset_copy = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
|
||||
)
|
||||
|
||||
original_shape = dataset.meta.info.features["state"]["shape"]
|
||||
dataset_copy.meta.info.features["state"]["shape"] = (999,)
|
||||
|
||||
assert dataset.meta.info.features["state"]["shape"] == original_shape
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
Reference in New Issue
Block a user