Files
lerobot/examples/unitree_g1/dance.py
T
2025-12-17 16:00:43 +01:00

455 lines
17 KiB
Python

#!/usr/bin/env python3
"""
WBT (Whole Body Tracking) Dance Policy for Unitree G1
Uses ONNX model with motion data baked in.
Pattern matches gr00t_locomotion.py - uses UnitreeG1 robot class.
Usage:
python examples/unitree_g1/dance.py
"""
import argparse
import json
import logging
import threading
import time
from xml.etree import ElementTree
import numpy as np
import onnx
import onnxruntime as ort
import pinocchio as pin
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
CONTROL_DT = 0.02 # 50 Hz
NUM_DOFS = 29
# Default joint positions (holosoma training defaults)
DEFAULT_DOF_POS = np.array([
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Left leg (6)
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Right leg (6)
0.0, 0.0, 0.0, # Waist (3)
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Left arm (7)
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Right arm (7)
], dtype=np.float32)
# Stiff hold KP/KD (for initialization)
STIFF_KP = np.array([
150, 150, 200, 200, 40, 40,
150, 150, 200, 200, 40, 40,
200, 200, 100,
100, 100, 100, 100, 50, 50, 50,
100, 100, 100, 100, 50, 50, 50,
], dtype=np.float32)
STIFF_KD = np.array([
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
5.0, 5.0, 5.0,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
], dtype=np.float32)
# Joints to freeze at 0 with high KP
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
FROZEN_KP = 500.0
FROZEN_KD = 5.0
# =============================================================================
# QUATERNION UTILITIES
# =============================================================================
def quat_inverse(q):
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
def quat_mul(a, b):
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return np.stack([w, x, y, z]).T.reshape(a.shape)
def subtract_frame_transforms(q01, q02):
return quat_mul(quat_inverse(q01), q02)
def matrix_from_quat(q):
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
two_s = 2.0 / (q * q).sum(-1)
o = np.stack((
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
), -1)
return o.reshape(q.shape[:-1] + (3, 3))
def xyzw_to_wxyz(xyzw):
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
def quat_to_rpy(q):
w, x, y, z = q
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
return roll, pitch, yaw
def rpy_to_quat(rpy):
roll, pitch, yaw = rpy
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
# =============================================================================
# PINOCCHIO FK
# =============================================================================
DOF_NAMES = (
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
)
class PinocchioFK:
"""Pinocchio forward kinematics for torso_link orientation."""
def __init__(self, urdf_text: str):
root = ElementTree.fromstring(urdf_text)
for parent in root.iter():
for child in list(parent):
if child.tag.split("}")[-1] in {"visual", "collision"}:
parent.remove(child)
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
self.data = self.model.createData()
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
self.ref_frame_id = self.model.getFrameId("torso_link")
logger.info(f"Pinocchio FK: {len(pin_names)} joints, torso_link frame={self.ref_frame_id}")
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
"""Get torso_link orientation in world frame."""
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
pin.framesForwardKinematics(self.model, self.data, config)
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
# =============================================================================
# DANCE CONTROLLER
# =============================================================================
class DanceController:
"""
Handles WBT dance policy for the Unitree G1 robot.
This controller manages:
- 29-joint observation processing
- Pinocchio FK for torso orientation
- Policy inference with motion data from ONNX
"""
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
self.policy = policy
self.robot = robot
self.pinocchio_fk = pinocchio_fk
self.motor_kp = motor_kp
self.motor_kd = motor_kd
self.action_scale = action_scale
self.obs_dim = policy.get_inputs()[0].shape[1]
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
self.motion_command = None
self.ref_quat_xyzw = None
self.timestep = 0
self.yaw_offset = 0.0
# Get initial motion data from ONNX
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
self.motion_start_pose = outs[0].flatten()
# Thread management
self.dance_running = False
self.dance_thread = None
logger.info(f"DanceController: obs_dim={self.obs_dim}, action_scale={action_scale}")
def capture_yaw_offset(self):
"""Capture robot's current yaw for relative tracking."""
robot_state = self.robot.lowstate_buffer.get_data()
if robot_state and self.pinocchio_fk:
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
dof = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
logger.info(f"Captured yaw offset: {np.degrees(self.yaw_offset):.1f}°")
def _remove_yaw_offset(self, quat_wxyz):
"""Remove stored yaw offset from orientation."""
if abs(self.yaw_offset) < 1e-6:
return quat_wxyz
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
return quat_mul(yaw_q, quat_wxyz)
def run_step(self):
"""Single dance step - reads state, runs policy, sends commands."""
robot_state = self.robot.lowstate_buffer.get_data()
if robot_state is None:
return
# Read robot state
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
dof_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
dof_vel = np.array([robot_state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
# Compute motion_ref_ori_b using FK
if self.pinocchio_fk:
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
torso_q = self._remove_yaw_offset(torso_q)
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
else:
ori_b = np.zeros((1, 6), dtype=np.float32)
dof_rel = (dof_pos - DEFAULT_DOF_POS).reshape(1, -1)
# Build observation (alphabetical order)
obs_dict = {
"actions": self.last_action,
"base_ang_vel": ang_vel.reshape(1, 3),
"dof_pos": dof_rel,
"dof_vel": dof_vel.reshape(1, -1),
"motion_command": self.motion_command,
"motion_ref_ori_b": ori_b,
}
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
obs = np.clip(obs, -100, 100)
# Run policy
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
action = np.clip(outs[0], -100, 100)
self.motion_command = np.concatenate(outs[1:3], axis=1)
self.ref_quat_xyzw = outs[3]
self.last_action = action.copy()
# Compute target positions
target_pos = DEFAULT_DOF_POS + action.flatten() * self.action_scale
# Send commands
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(target_pos[i])
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.send_action(self.robot.msg)
self.timestep += 1
def _dance_thread_loop(self):
"""Background thread that runs the dance policy."""
logger.info("Dance thread started")
while self.dance_running:
start_time = time.time()
try:
self.run_step()
except Exception as e:
logger.error(f"Error in dance loop: {e}")
import traceback
traceback.print_exc()
elapsed = time.time() - start_time
sleep_time = max(0, CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Dance thread stopped")
def start_dance_thread(self):
"""Start the dance control thread."""
if self.dance_running:
logger.warning("Dance thread already running")
return
# Reset state for fresh start
self.timestep = 0
self.last_action.fill(0)
# Re-get initial motion data
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
self.motion_command = np.concatenate(outs[0:2], axis=1)
self.ref_quat_xyzw = outs[2]
self.capture_yaw_offset()
logger.info("Starting dance control thread...")
self.dance_running = True
self.dance_thread = threading.Thread(target=self._dance_thread_loop, daemon=True)
self.dance_thread.start()
def stop_dance_thread(self):
"""Stop the dance control thread."""
if not self.dance_running:
return
logger.info("Stopping dance control thread...")
self.dance_running = False
if self.dance_thread:
self.dance_thread.join(timeout=2.0)
logger.info("Dance control thread stopped")
def reset_to_motion_pose(self, duration: float = 3.0):
"""Move robot to initial motion pose over given duration."""
logger.info(f"Moving to dance start pose ({duration}s)...")
robot_state = self.robot.lowstate_buffer.get_data()
init_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
target_pos = self.motion_start_pose
num_steps = int(duration / CONTROL_DT)
for step in range(num_steps):
alpha = step / num_steps
interp = init_pos * (1 - alpha) + target_pos * alpha
for i in range(NUM_DOFS):
if i in FROZEN_JOINTS:
self.robot.msg.motor_cmd[i].q = 0.0
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
else:
self.robot.msg.motor_cmd[i].q = float(interp[i])
self.robot.msg.motor_cmd[i].kp = STIFF_KP[i]
self.robot.msg.motor_cmd[i].kd = STIFF_KD[i]
self.robot.msg.motor_cmd[i].qd = 0
self.robot.msg.motor_cmd[i].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(CONTROL_DT)
logger.info("At dance start pose!")
# =============================================================================
# MAIN
# =============================================================================
def load_dance_policy(onnx_path: str):
"""Load dance policy and extract metadata."""
logger.info(f"Loading dance policy: {onnx_path}")
policy = ort.InferenceSession(onnx_path)
model = onnx.load(onnx_path)
metadata = {p.key: json.loads(p.value) for p in model.metadata_props}
motor_kp = np.array(metadata.get("kp", STIFF_KP), dtype=np.float32)
motor_kd = np.array(metadata.get("kd", STIFF_KD), dtype=np.float32)
action_scale = float(metadata.get("action_scale", 1.0))
urdf_text = metadata.get("robot_urdf", None)
logger.info(f" Obs dim: {policy.get_inputs()[0].shape[1]}")
logger.info(f" Action scale: {action_scale}")
logger.info(f" KP range: [{motor_kp.min():.1f}, {motor_kp.max():.1f}]")
# Build Pinocchio FK if URDF available
pinocchio_fk = None
if urdf_text:
logger.info(" Building Pinocchio FK from URDF...")
pinocchio_fk = PinocchioFK(urdf_text)
else:
logger.warning(" No URDF in metadata - FK will not work!")
return policy, pinocchio_fk, motor_kp, motor_kd, action_scale
def main():
parser = argparse.ArgumentParser(description="WBT Dance Policy for Unitree G1")
parser.add_argument("--onnx", type=str, default=DANCE_ONNX_PATH, help="Path to dance ONNX model")
parser.add_argument("--sim", action="store_true", help="Run in simulation mode")
args = parser.parse_args()
print("=" * 70)
print("💃 WBT DANCE POLICY")
print("=" * 70)
# Load policy
policy, pinocchio_fk, motor_kp, motor_kd, action_scale = load_dance_policy(args.onnx)
# Initialize robot
logger.info("Initializing robot...")
config = UnitreeG1Config()
robot = UnitreeG1(config)
logger.info("Robot connected!")
# Create controller
controller = DanceController(policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale)
try:
# Move to start pose
controller.reset_to_motion_pose(duration=3.0)
# Start dancing
controller.start_dance_thread()
logger.info("Dancing! Press Ctrl+C to stop.")
print("-" * 70)
# Log status periodically
while True:
time.sleep(2.0)
logger.info(f"timestep={controller.timestep}")
except KeyboardInterrupt:
print("\n\nStopping...")
finally:
controller.stop_dance_thread()
robot.disconnect()
print("\nDone!")
if __name__ == "__main__":
main()