mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ee24f64ae5 | |||
| 123b9f7851 |
@@ -0,0 +1,454 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
WBT (Whole Body Tracking) Dance Policy for Unitree G1
|
||||||
|
|
||||||
|
Uses ONNX model with motion data baked in.
|
||||||
|
Pattern matches gr00t_locomotion.py - uses UnitreeG1 robot class.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/unitree_g1/dance.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from xml.etree import ElementTree
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
import pinocchio as pin
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
|
||||||
|
CONTROL_DT = 0.02 # 50 Hz
|
||||||
|
NUM_DOFS = 29
|
||||||
|
|
||||||
|
# Default joint positions (holosoma training defaults)
|
||||||
|
DEFAULT_DOF_POS = np.array([
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Left leg (6)
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Right leg (6)
|
||||||
|
0.0, 0.0, 0.0, # Waist (3)
|
||||||
|
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Left arm (7)
|
||||||
|
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Right arm (7)
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# Stiff hold KP/KD (for initialization)
|
||||||
|
STIFF_KP = np.array([
|
||||||
|
150, 150, 200, 200, 40, 40,
|
||||||
|
150, 150, 200, 200, 40, 40,
|
||||||
|
200, 200, 100,
|
||||||
|
100, 100, 100, 100, 50, 50, 50,
|
||||||
|
100, 100, 100, 100, 50, 50, 50,
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
STIFF_KD = np.array([
|
||||||
|
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||||
|
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||||
|
5.0, 5.0, 5.0,
|
||||||
|
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||||
|
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# Joints to freeze at 0 with high KP
|
||||||
|
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
|
||||||
|
FROZEN_KP = 500.0
|
||||||
|
FROZEN_KD = 5.0
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# QUATERNION UTILITIES
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def quat_inverse(q):
|
||||||
|
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
|
||||||
|
|
||||||
|
def quat_mul(a, b):
|
||||||
|
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
|
||||||
|
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
||||||
|
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
||||||
|
ww = (z1 + x1) * (x2 + y2)
|
||||||
|
yy = (w1 - y1) * (w2 + z2)
|
||||||
|
zz = (w1 + y1) * (w2 - z2)
|
||||||
|
xx = ww + yy + zz
|
||||||
|
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
|
||||||
|
w = qq - ww + (z1 - y1) * (y2 - z2)
|
||||||
|
x = qq - xx + (x1 + w1) * (x2 + w2)
|
||||||
|
y = qq - yy + (w1 - x1) * (y2 + z2)
|
||||||
|
z = qq - zz + (z1 + y1) * (w2 - x2)
|
||||||
|
return np.stack([w, x, y, z]).T.reshape(a.shape)
|
||||||
|
|
||||||
|
def subtract_frame_transforms(q01, q02):
|
||||||
|
return quat_mul(quat_inverse(q01), q02)
|
||||||
|
|
||||||
|
def matrix_from_quat(q):
|
||||||
|
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||||||
|
two_s = 2.0 / (q * q).sum(-1)
|
||||||
|
o = np.stack((
|
||||||
|
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
||||||
|
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
||||||
|
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
||||||
|
), -1)
|
||||||
|
return o.reshape(q.shape[:-1] + (3, 3))
|
||||||
|
|
||||||
|
def xyzw_to_wxyz(xyzw):
|
||||||
|
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
|
||||||
|
|
||||||
|
def quat_to_rpy(q):
|
||||||
|
w, x, y, z = q
|
||||||
|
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
|
||||||
|
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
|
||||||
|
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
|
||||||
|
return roll, pitch, yaw
|
||||||
|
|
||||||
|
def rpy_to_quat(rpy):
|
||||||
|
roll, pitch, yaw = rpy
|
||||||
|
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
|
||||||
|
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
|
||||||
|
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
|
||||||
|
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
|
||||||
|
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PINOCCHIO FK
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DOF_NAMES = (
|
||||||
|
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
|
||||||
|
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
|
||||||
|
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
|
||||||
|
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
|
||||||
|
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
||||||
|
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
|
||||||
|
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
|
||||||
|
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
|
||||||
|
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PinocchioFK:
|
||||||
|
"""Pinocchio forward kinematics for torso_link orientation."""
|
||||||
|
|
||||||
|
def __init__(self, urdf_text: str):
|
||||||
|
root = ElementTree.fromstring(urdf_text)
|
||||||
|
for parent in root.iter():
|
||||||
|
for child in list(parent):
|
||||||
|
if child.tag.split("}")[-1] in {"visual", "collision"}:
|
||||||
|
parent.remove(child)
|
||||||
|
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
|
||||||
|
|
||||||
|
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
|
||||||
|
self.data = self.model.createData()
|
||||||
|
|
||||||
|
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
|
||||||
|
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
|
||||||
|
self.ref_frame_id = self.model.getFrameId("torso_link")
|
||||||
|
logger.info(f"Pinocchio FK: {len(pin_names)} joints, torso_link frame={self.ref_frame_id}")
|
||||||
|
|
||||||
|
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
|
||||||
|
"""Get torso_link orientation in world frame."""
|
||||||
|
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
|
||||||
|
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
|
||||||
|
pin.framesForwardKinematics(self.model, self.data, config)
|
||||||
|
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
|
||||||
|
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DANCE CONTROLLER
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class DanceController:
|
||||||
|
"""
|
||||||
|
Handles WBT dance policy for the Unitree G1 robot.
|
||||||
|
|
||||||
|
This controller manages:
|
||||||
|
- 29-joint observation processing
|
||||||
|
- Pinocchio FK for torso orientation
|
||||||
|
- Policy inference with motion data from ONNX
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
|
||||||
|
self.policy = policy
|
||||||
|
self.robot = robot
|
||||||
|
self.pinocchio_fk = pinocchio_fk
|
||||||
|
self.motor_kp = motor_kp
|
||||||
|
self.motor_kd = motor_kd
|
||||||
|
self.action_scale = action_scale
|
||||||
|
|
||||||
|
self.obs_dim = policy.get_inputs()[0].shape[1]
|
||||||
|
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
|
||||||
|
self.motion_command = None
|
||||||
|
self.ref_quat_xyzw = None
|
||||||
|
self.timestep = 0
|
||||||
|
self.yaw_offset = 0.0
|
||||||
|
|
||||||
|
# Get initial motion data from ONNX
|
||||||
|
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||||
|
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||||
|
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||||
|
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||||
|
self.ref_quat_xyzw = outs[2]
|
||||||
|
self.motion_start_pose = outs[0].flatten()
|
||||||
|
|
||||||
|
# Thread management
|
||||||
|
self.dance_running = False
|
||||||
|
self.dance_thread = None
|
||||||
|
|
||||||
|
logger.info(f"DanceController: obs_dim={self.obs_dim}, action_scale={action_scale}")
|
||||||
|
|
||||||
|
def capture_yaw_offset(self):
|
||||||
|
"""Capture robot's current yaw for relative tracking."""
|
||||||
|
robot_state = self.robot.lowstate_buffer.get_data()
|
||||||
|
if robot_state and self.pinocchio_fk:
|
||||||
|
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
|
||||||
|
dof = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||||
|
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
|
||||||
|
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
|
||||||
|
logger.info(f"Captured yaw offset: {np.degrees(self.yaw_offset):.1f}°")
|
||||||
|
|
||||||
|
def _remove_yaw_offset(self, quat_wxyz):
|
||||||
|
"""Remove stored yaw offset from orientation."""
|
||||||
|
if abs(self.yaw_offset) < 1e-6:
|
||||||
|
return quat_wxyz
|
||||||
|
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
|
||||||
|
return quat_mul(yaw_q, quat_wxyz)
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
"""Single dance step - reads state, runs policy, sends commands."""
|
||||||
|
robot_state = self.robot.lowstate_buffer.get_data()
|
||||||
|
if robot_state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read robot state
|
||||||
|
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
|
||||||
|
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
dof_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||||
|
dof_vel = np.array([robot_state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
|
||||||
|
|
||||||
|
# Compute motion_ref_ori_b using FK
|
||||||
|
if self.pinocchio_fk:
|
||||||
|
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
|
||||||
|
torso_q = self._remove_yaw_offset(torso_q)
|
||||||
|
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
|
||||||
|
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
|
||||||
|
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
|
||||||
|
else:
|
||||||
|
ori_b = np.zeros((1, 6), dtype=np.float32)
|
||||||
|
|
||||||
|
dof_rel = (dof_pos - DEFAULT_DOF_POS).reshape(1, -1)
|
||||||
|
|
||||||
|
# Build observation (alphabetical order)
|
||||||
|
obs_dict = {
|
||||||
|
"actions": self.last_action,
|
||||||
|
"base_ang_vel": ang_vel.reshape(1, 3),
|
||||||
|
"dof_pos": dof_rel,
|
||||||
|
"dof_vel": dof_vel.reshape(1, -1),
|
||||||
|
"motion_command": self.motion_command,
|
||||||
|
"motion_ref_ori_b": ori_b,
|
||||||
|
}
|
||||||
|
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
|
||||||
|
obs = np.clip(obs, -100, 100)
|
||||||
|
|
||||||
|
# Run policy
|
||||||
|
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||||
|
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||||
|
|
||||||
|
action = np.clip(outs[0], -100, 100)
|
||||||
|
self.motion_command = np.concatenate(outs[1:3], axis=1)
|
||||||
|
self.ref_quat_xyzw = outs[3]
|
||||||
|
self.last_action = action.copy()
|
||||||
|
|
||||||
|
# Compute target positions
|
||||||
|
target_pos = DEFAULT_DOF_POS + action.flatten() * self.action_scale
|
||||||
|
|
||||||
|
# Send commands
|
||||||
|
for i in range(NUM_DOFS):
|
||||||
|
if i in FROZEN_JOINTS:
|
||||||
|
self.robot.msg.motor_cmd[i].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||||
|
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||||
|
else:
|
||||||
|
self.robot.msg.motor_cmd[i].q = float(target_pos[i])
|
||||||
|
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
|
||||||
|
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
|
||||||
|
self.robot.msg.motor_cmd[i].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[i].tau = 0
|
||||||
|
|
||||||
|
self.robot.send_action(self.robot.msg)
|
||||||
|
self.timestep += 1
|
||||||
|
|
||||||
|
def _dance_thread_loop(self):
|
||||||
|
"""Background thread that runs the dance policy."""
|
||||||
|
logger.info("Dance thread started")
|
||||||
|
while self.dance_running:
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
self.run_step()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in dance loop: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
logger.info("Dance thread stopped")
|
||||||
|
|
||||||
|
def start_dance_thread(self):
|
||||||
|
"""Start the dance control thread."""
|
||||||
|
if self.dance_running:
|
||||||
|
logger.warning("Dance thread already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reset state for fresh start
|
||||||
|
self.timestep = 0
|
||||||
|
self.last_action.fill(0)
|
||||||
|
|
||||||
|
# Re-get initial motion data
|
||||||
|
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||||
|
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||||
|
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||||
|
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||||
|
self.ref_quat_xyzw = outs[2]
|
||||||
|
|
||||||
|
self.capture_yaw_offset()
|
||||||
|
|
||||||
|
logger.info("Starting dance control thread...")
|
||||||
|
self.dance_running = True
|
||||||
|
self.dance_thread = threading.Thread(target=self._dance_thread_loop, daemon=True)
|
||||||
|
self.dance_thread.start()
|
||||||
|
|
||||||
|
def stop_dance_thread(self):
|
||||||
|
"""Stop the dance control thread."""
|
||||||
|
if not self.dance_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Stopping dance control thread...")
|
||||||
|
self.dance_running = False
|
||||||
|
if self.dance_thread:
|
||||||
|
self.dance_thread.join(timeout=2.0)
|
||||||
|
logger.info("Dance control thread stopped")
|
||||||
|
|
||||||
|
def reset_to_motion_pose(self, duration: float = 3.0):
|
||||||
|
"""Move robot to initial motion pose over given duration."""
|
||||||
|
logger.info(f"Moving to dance start pose ({duration}s)...")
|
||||||
|
|
||||||
|
robot_state = self.robot.lowstate_buffer.get_data()
|
||||||
|
init_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||||
|
target_pos = self.motion_start_pose
|
||||||
|
|
||||||
|
num_steps = int(duration / CONTROL_DT)
|
||||||
|
for step in range(num_steps):
|
||||||
|
alpha = step / num_steps
|
||||||
|
interp = init_pos * (1 - alpha) + target_pos * alpha
|
||||||
|
|
||||||
|
for i in range(NUM_DOFS):
|
||||||
|
if i in FROZEN_JOINTS:
|
||||||
|
self.robot.msg.motor_cmd[i].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||||
|
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||||
|
else:
|
||||||
|
self.robot.msg.motor_cmd[i].q = float(interp[i])
|
||||||
|
self.robot.msg.motor_cmd[i].kp = STIFF_KP[i]
|
||||||
|
self.robot.msg.motor_cmd[i].kd = STIFF_KD[i]
|
||||||
|
self.robot.msg.motor_cmd[i].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[i].tau = 0
|
||||||
|
|
||||||
|
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||||
|
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||||
|
time.sleep(CONTROL_DT)
|
||||||
|
|
||||||
|
logger.info("At dance start pose!")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MAIN
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def load_dance_policy(onnx_path: str):
|
||||||
|
"""Load dance policy and extract metadata."""
|
||||||
|
logger.info(f"Loading dance policy: {onnx_path}")
|
||||||
|
|
||||||
|
policy = ort.InferenceSession(onnx_path)
|
||||||
|
model = onnx.load(onnx_path)
|
||||||
|
metadata = {p.key: json.loads(p.value) for p in model.metadata_props}
|
||||||
|
|
||||||
|
motor_kp = np.array(metadata.get("kp", STIFF_KP), dtype=np.float32)
|
||||||
|
motor_kd = np.array(metadata.get("kd", STIFF_KD), dtype=np.float32)
|
||||||
|
action_scale = float(metadata.get("action_scale", 1.0))
|
||||||
|
urdf_text = metadata.get("robot_urdf", None)
|
||||||
|
|
||||||
|
logger.info(f" Obs dim: {policy.get_inputs()[0].shape[1]}")
|
||||||
|
logger.info(f" Action scale: {action_scale}")
|
||||||
|
logger.info(f" KP range: [{motor_kp.min():.1f}, {motor_kp.max():.1f}]")
|
||||||
|
|
||||||
|
# Build Pinocchio FK if URDF available
|
||||||
|
pinocchio_fk = None
|
||||||
|
if urdf_text:
|
||||||
|
logger.info(" Building Pinocchio FK from URDF...")
|
||||||
|
pinocchio_fk = PinocchioFK(urdf_text)
|
||||||
|
else:
|
||||||
|
logger.warning(" No URDF in metadata - FK will not work!")
|
||||||
|
|
||||||
|
return policy, pinocchio_fk, motor_kp, motor_kd, action_scale
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="WBT Dance Policy for Unitree G1")
|
||||||
|
parser.add_argument("--onnx", type=str, default=DANCE_ONNX_PATH, help="Path to dance ONNX model")
|
||||||
|
parser.add_argument("--sim", action="store_true", help="Run in simulation mode")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("💃 WBT DANCE POLICY")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Load policy
|
||||||
|
policy, pinocchio_fk, motor_kp, motor_kd, action_scale = load_dance_policy(args.onnx)
|
||||||
|
|
||||||
|
# Initialize robot
|
||||||
|
logger.info("Initializing robot...")
|
||||||
|
config = UnitreeG1Config()
|
||||||
|
robot = UnitreeG1(config)
|
||||||
|
logger.info("Robot connected!")
|
||||||
|
|
||||||
|
# Create controller
|
||||||
|
controller = DanceController(policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Move to start pose
|
||||||
|
controller.reset_to_motion_pose(duration=3.0)
|
||||||
|
|
||||||
|
# Start dancing
|
||||||
|
controller.start_dance_thread()
|
||||||
|
|
||||||
|
logger.info("Dancing! Press Ctrl+C to stop.")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
# Log status periodically
|
||||||
|
while True:
|
||||||
|
time.sleep(2.0)
|
||||||
|
logger.info(f"timestep={controller.timestep}")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping...")
|
||||||
|
finally:
|
||||||
|
controller.stop_dance_thread()
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
print("\nDone!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Binary file not shown.
@@ -0,0 +1,479 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Example: Holosoma Whole-Body Locomotion (23-DOF and 29-DOF)
|
||||||
|
|
||||||
|
This example demonstrates loading Holosoma whole-body locomotion policies
|
||||||
|
and running them on the Unitree G1 robot.
|
||||||
|
|
||||||
|
Supports both:
|
||||||
|
- 23-DOF native policies (82D observations, 23D actions)
|
||||||
|
- 29-DOF policies (100D observations, 29D actions)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 29-DOF Configuration
|
||||||
|
# =============================================================================
|
||||||
|
# fmt: off
|
||||||
|
HOLOSOMA_29DOF_DEFAULT_ANGLES = np.array([
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||||
|
0.0, 0.0, 0.0, # waist (yaw, roll, pitch)
|
||||||
|
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
|
||||||
|
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
HOLOSOMA_29DOF_KP = np.array([
|
||||||
|
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||||
|
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||||
|
40.179238471, 28.501246196, 28.501246196, # waist
|
||||||
|
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # left arm
|
||||||
|
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # right arm
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
HOLOSOMA_29DOF_KD = np.array([
|
||||||
|
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||||
|
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||||
|
2.557889765, 1.814445687, 1.814445687, # waist
|
||||||
|
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # left arm
|
||||||
|
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # right arm
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 23-DOF Configuration (native G1-23: no waist_roll/pitch, no wrist_pitch/yaw)
|
||||||
|
# Derived from 29-DOF Holosoma values
|
||||||
|
# =============================================================================
|
||||||
|
# Joint order: 6 left leg, 6 right leg, 1 waist_yaw, 5 left arm, 5 right arm
|
||||||
|
HOLOSOMA_23DOF_DEFAULT_ANGLES = np.array([
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg (from 29-DOF)
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg (from 29-DOF)
|
||||||
|
0.0, # waist_yaw only (from 29-DOF)
|
||||||
|
0.2, 0.2, 0.0, 0.6, 0.0, # left arm first 5 joints (from 29-DOF)
|
||||||
|
0.2, -0.2, 0.0, 0.6, 0.0, # right arm first 5 joints (from 29-DOF)
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
HOLOSOMA_23DOF_KP = np.array([
|
||||||
|
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||||
|
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||||
|
40.179238471, # waist_yaw
|
||||||
|
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # left arm
|
||||||
|
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # right arm
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
HOLOSOMA_23DOF_KD = np.array([
|
||||||
|
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||||
|
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||||
|
2.557889765, # waist_yaw
|
||||||
|
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # left arm
|
||||||
|
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # right arm
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# Maps 23-DOF policy index → 29-DOF motor index
|
||||||
|
# 23-DOF: legs(0-11), waist_yaw(12), L_arm(13-17), R_arm(18-22)
|
||||||
|
# 29-DOF: legs(0-11), waist(12-14), L_arm(15-21), R_arm(22-28)
|
||||||
|
DOF_23_TO_MOTOR_MAP = [
|
||||||
|
0, 1, 2, 3, 4, 5, # left leg → motor 0-5
|
||||||
|
6, 7, 8, 9, 10, 11, # right leg → motor 6-11
|
||||||
|
12, # waist_yaw → motor 12
|
||||||
|
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw) → motor 15-19
|
||||||
|
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw) → motor 22-26
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# Control parameters
|
||||||
|
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz
|
||||||
|
LOCOMOTION_ACTION_SCALE = 0.25
|
||||||
|
ANG_VEL_SCALE = 0.25
|
||||||
|
DOF_POS_SCALE = 1.0
|
||||||
|
DOF_VEL_SCALE = 0.05
|
||||||
|
GAIT_PERIOD = 1.0
|
||||||
|
|
||||||
|
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||||
|
|
||||||
|
|
||||||
|
def load_holosoma_policy(
|
||||||
|
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
|
||||||
|
policy_name: str = "fastsac",
|
||||||
|
local_path: str | None = None,
|
||||||
|
) -> tuple[ort.InferenceSession, int]:
|
||||||
|
"""Load Holosoma policy and detect observation dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(policy, obs_dim) tuple where obs_dim is 82 (23-DOF) or 100 (29-DOF)
|
||||||
|
"""
|
||||||
|
if local_path is not None:
|
||||||
|
logger.info(f"Loading policy from local path: {local_path}")
|
||||||
|
policy_path = local_path
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading policy from Hugging Face Hub: {repo_id}")
|
||||||
|
policy_path = hf_hub_download(repo_id=repo_id, filename=f"{policy_name}_g1_29dof.onnx")
|
||||||
|
|
||||||
|
policy = ort.InferenceSession(policy_path)
|
||||||
|
|
||||||
|
# Detect observation dimension from model input shape
|
||||||
|
input_shape = policy.get_inputs()[0].shape
|
||||||
|
obs_dim = input_shape[1] if len(input_shape) > 1 else input_shape[0]
|
||||||
|
|
||||||
|
logger.info(f"Policy loaded successfully")
|
||||||
|
logger.info(f" Input: {policy.get_inputs()[0].name}, shape: {input_shape} → obs_dim={obs_dim}")
|
||||||
|
logger.info(f" Output: {policy.get_outputs()[0].name}, shape: {policy.get_outputs()[0].shape}")
|
||||||
|
|
||||||
|
return policy, obs_dim
|
||||||
|
|
||||||
|
|
||||||
|
class HolosomaLocomotionController:
|
||||||
|
"""
|
||||||
|
Handles Holosoma whole-body locomotion for Unitree G1.
|
||||||
|
Supports both 23-DOF (82D obs) and 29-DOF (100D obs) policies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, policy, robot, config, obs_dim: int = 100):
|
||||||
|
self.policy = policy
|
||||||
|
self.robot = robot
|
||||||
|
self.config = config
|
||||||
|
self.obs_dim = obs_dim
|
||||||
|
|
||||||
|
# Detect policy type from observation dimension
|
||||||
|
self.is_23dof = (obs_dim == 82)
|
||||||
|
self.num_dof = 23 if self.is_23dof else 29
|
||||||
|
|
||||||
|
# Velocity commands
|
||||||
|
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
|
# State variables sized for policy type
|
||||||
|
self.qj = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
self.locomotion_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
self.locomotion_obs = np.zeros(obs_dim, dtype=np.float32)
|
||||||
|
self.last_unscaled_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
|
||||||
|
# Select config based on DOF
|
||||||
|
if self.is_23dof:
|
||||||
|
self.default_angles = HOLOSOMA_23DOF_DEFAULT_ANGLES
|
||||||
|
self.kp = HOLOSOMA_23DOF_KP
|
||||||
|
self.kd = HOLOSOMA_23DOF_KD
|
||||||
|
self.motor_map = DOF_23_TO_MOTOR_MAP
|
||||||
|
else:
|
||||||
|
self.default_angles = HOLOSOMA_29DOF_DEFAULT_ANGLES
|
||||||
|
self.kp = HOLOSOMA_29DOF_KP
|
||||||
|
self.kd = HOLOSOMA_29DOF_KD
|
||||||
|
self.motor_map = list(range(29)) # Identity map for 29-DOF
|
||||||
|
|
||||||
|
# Phase state for gait
|
||||||
|
self.phase = np.zeros((1, 2), dtype=np.float32)
|
||||||
|
self.phase[0, 0] = 0.0
|
||||||
|
self.phase[0, 1] = np.pi
|
||||||
|
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
|
||||||
|
self.is_standing = False
|
||||||
|
|
||||||
|
self.counter = 0
|
||||||
|
self.locomotion_running = False
|
||||||
|
self.locomotion_thread = None
|
||||||
|
|
||||||
|
logger.info(f"HolosomaLocomotionController initialized")
|
||||||
|
logger.info(f" Mode: {'23-DOF (82D obs)' if self.is_23dof else '29-DOF (100D obs)'}")
|
||||||
|
logger.info(f" Action dim: {self.num_dof}")
|
||||||
|
|
||||||
|
def holosoma_locomotion_run(self):
|
||||||
|
"""Main locomotion loop - handles both 23-DOF and 29-DOF."""
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
if self.counter == 1:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"🚀 RUNNING HOLOSOMA {self.num_dof}-DOF LOCOMOTION POLICY")
|
||||||
|
print(f" {self.obs_dim}D observations → {self.num_dof}D actions")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
|
robot_state = self.robot.get_observation()
|
||||||
|
if robot_state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remote controller
|
||||||
|
if robot_state.wireless_remote is not None:
|
||||||
|
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||||
|
else:
|
||||||
|
self.robot.remote_controller.lx = 0.0
|
||||||
|
self.robot.remote_controller.ly = 0.0
|
||||||
|
self.robot.remote_controller.rx = 0.0
|
||||||
|
self.robot.remote_controller.ry = 0.0
|
||||||
|
|
||||||
|
# Deadzone
|
||||||
|
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||||
|
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||||
|
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||||
|
|
||||||
|
self.locomotion_cmd[0] = ly
|
||||||
|
self.locomotion_cmd[1] = -lx
|
||||||
|
self.locomotion_cmd[2] = -rx
|
||||||
|
|
||||||
|
# Read joint states using motor map
|
||||||
|
for i in range(self.num_dof):
|
||||||
|
motor_idx = self.motor_map[i]
|
||||||
|
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||||
|
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||||
|
|
||||||
|
# IMU
|
||||||
|
quat = robot_state.imu_state.quaternion
|
||||||
|
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||||
|
|
||||||
|
# Scale observations
|
||||||
|
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
|
||||||
|
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||||
|
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||||
|
|
||||||
|
# Phase update
|
||||||
|
cmd_norm = np.linalg.norm(self.locomotion_cmd[:2])
|
||||||
|
ang_cmd_norm = np.abs(self.locomotion_cmd[2])
|
||||||
|
|
||||||
|
if cmd_norm < 0.01 and ang_cmd_norm < 0.01:
|
||||||
|
self.phase[0, :] = np.pi * np.ones(2)
|
||||||
|
self.is_standing = True
|
||||||
|
elif self.is_standing:
|
||||||
|
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||||
|
self.is_standing = False
|
||||||
|
else:
|
||||||
|
phase_tp1 = self.phase + self.phase_dt
|
||||||
|
self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
|
||||||
|
|
||||||
|
sin_phase = np.sin(self.phase[0, :])
|
||||||
|
cos_phase = np.cos(self.phase[0, :])
|
||||||
|
|
||||||
|
# Build observation (format depends on DOF)
|
||||||
|
if self.is_23dof:
|
||||||
|
# 82D: [23 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 23 pos, 23 vel, 3 grav, 2 sin]
|
||||||
|
self.locomotion_obs[0:23] = self.last_unscaled_action
|
||||||
|
self.locomotion_obs[23:26] = ang_vel_scaled
|
||||||
|
self.locomotion_obs[26] = self.locomotion_cmd[2]
|
||||||
|
self.locomotion_obs[27:29] = self.locomotion_cmd[:2]
|
||||||
|
self.locomotion_obs[29:31] = cos_phase
|
||||||
|
self.locomotion_obs[31:54] = qj_obs
|
||||||
|
self.locomotion_obs[54:77] = dqj_obs
|
||||||
|
self.locomotion_obs[77:80] = gravity_orientation
|
||||||
|
self.locomotion_obs[80:82] = sin_phase
|
||||||
|
else:
|
||||||
|
# 100D: [29 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 29 pos, 29 vel, 3 grav, 2 sin]
|
||||||
|
self.locomotion_obs[0:29] = self.last_unscaled_action
|
||||||
|
self.locomotion_obs[29:32] = ang_vel_scaled
|
||||||
|
self.locomotion_obs[32] = self.locomotion_cmd[2]
|
||||||
|
self.locomotion_obs[33:35] = self.locomotion_cmd[:2]
|
||||||
|
self.locomotion_obs[35:37] = cos_phase
|
||||||
|
self.locomotion_obs[37:66] = qj_obs
|
||||||
|
self.locomotion_obs[66:95] = dqj_obs
|
||||||
|
self.locomotion_obs[95:98] = gravity_orientation
|
||||||
|
self.locomotion_obs[98:100] = sin_phase
|
||||||
|
|
||||||
|
# Policy inference
|
||||||
|
obs_input = self.locomotion_obs.reshape(1, -1).astype(np.float32)
|
||||||
|
ort_inputs = {self.policy.get_inputs()[0].name: obs_input}
|
||||||
|
ort_outs = self.policy.run(None, ort_inputs)
|
||||||
|
|
||||||
|
raw_action = ort_outs[0].squeeze()
|
||||||
|
clipped_action = np.clip(raw_action, -100.0, 100.0)
|
||||||
|
|
||||||
|
self.last_unscaled_action = clipped_action.copy()
|
||||||
|
self.locomotion_action = clipped_action * LOCOMOTION_ACTION_SCALE
|
||||||
|
|
||||||
|
# Debug
|
||||||
|
if self.counter <= 3:
|
||||||
|
print(f"\n[Holosoma Debug #{self.counter}]")
|
||||||
|
print(f" Phase: ({self.phase[0, 0]:.3f}, {self.phase[0, 1]:.3f})")
|
||||||
|
print(f" Cmd: ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||||
|
print(f" Action range: [{raw_action.min():.3f}, {raw_action.max():.3f}]")
|
||||||
|
|
||||||
|
# Compute target positions
|
||||||
|
target_dof_pos = self.default_angles + self.locomotion_action
|
||||||
|
|
||||||
|
# Send commands to motors via motor map
|
||||||
|
for i in range(self.num_dof):
|
||||||
|
motor_idx = self.motor_map[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# For 23-DOF: zero out missing joints (waist_roll/pitch, wrist_pitch/yaw)
|
||||||
|
if self.is_23dof:
|
||||||
|
missing_motors = [13, 14, 20, 21, 27, 28] # waist_roll, waist_pitch, wrist_pitch/yaw
|
||||||
|
for motor_idx in missing_motors:
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
self.robot.send_action(self.robot.msg)
|
||||||
|
|
||||||
|
def _locomotion_thread_loop(self):
|
||||||
|
logger.info("Locomotion thread started")
|
||||||
|
while self.locomotion_running:
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
self.holosoma_locomotion_run()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in locomotion loop: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
logger.info("Locomotion thread stopped")
|
||||||
|
|
||||||
|
def start_locomotion_thread(self):
|
||||||
|
if self.locomotion_running:
|
||||||
|
logger.warning("Locomotion thread already running")
|
||||||
|
return
|
||||||
|
logger.info("Starting locomotion control thread...")
|
||||||
|
self.locomotion_running = True
|
||||||
|
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||||
|
self.locomotion_thread.start()
|
||||||
|
logger.info("Locomotion control thread started!")
|
||||||
|
|
||||||
|
def stop_locomotion_thread(self):
|
||||||
|
if not self.locomotion_running:
|
||||||
|
return
|
||||||
|
logger.info("Stopping locomotion control thread...")
|
||||||
|
self.locomotion_running = False
|
||||||
|
if self.locomotion_thread:
|
||||||
|
self.locomotion_thread.join(timeout=2.0)
|
||||||
|
logger.info("Locomotion control thread stopped")
|
||||||
|
|
||||||
|
def reset_robot(self):
|
||||||
|
"""Move joints to default position."""
|
||||||
|
logger.info(f"Moving {self.num_dof} joints to default position...")
|
||||||
|
|
||||||
|
total_time = 3.0
|
||||||
|
num_step = int(total_time / self.robot.control_dt)
|
||||||
|
|
||||||
|
robot_state = self.robot.get_observation()
|
||||||
|
|
||||||
|
# Record current positions
|
||||||
|
init_dof_pos = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
for i in range(self.num_dof):
|
||||||
|
motor_idx = self.motor_map[i]
|
||||||
|
init_dof_pos[i] = robot_state.motor_state[motor_idx].q
|
||||||
|
|
||||||
|
# Interpolate to target
|
||||||
|
for step in range(num_step):
|
||||||
|
alpha = step / num_step
|
||||||
|
for i in range(self.num_dof):
|
||||||
|
motor_idx = self.motor_map[i]
|
||||||
|
target = self.default_angles[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[i] * (1 - alpha) + target * alpha
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Zero missing joints for 23-DOF
|
||||||
|
if self.is_23dof:
|
||||||
|
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||||
|
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||||
|
time.sleep(self.robot.control_dt)
|
||||||
|
|
||||||
|
logger.info(f"Reached default position ({self.num_dof} joints)")
|
||||||
|
|
||||||
|
# Hold for 2 seconds
|
||||||
|
logger.info("Holding default position for 2 seconds...")
|
||||||
|
hold_steps = int(2.0 / self.robot.control_dt)
|
||||||
|
for _ in range(hold_steps):
|
||||||
|
for i in range(self.num_dof):
|
||||||
|
motor_idx = self.motor_map[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = self.default_angles[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
if self.is_23dof:
|
||||||
|
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||||
|
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||||
|
time.sleep(self.robot.control_dt)
|
||||||
|
|
||||||
|
logger.info("Ready to start locomotion!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
||||||
|
parser.add_argument("--repo-id", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
|
||||||
|
parser.add_argument("--policy", type=str, default="fastsac", choices=["fastsac", "ppo"])
|
||||||
|
parser.add_argument("--local-path", type=str, default=None, help="Path to local ONNX file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load policy and detect dimensions
|
||||||
|
policy, obs_dim = load_holosoma_policy(
|
||||||
|
repo_id=args.repo_id,
|
||||||
|
policy_name=args.policy,
|
||||||
|
local_path=args.local_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize robot
|
||||||
|
config = UnitreeG1Config()
|
||||||
|
robot = UnitreeG1(config)
|
||||||
|
|
||||||
|
# Initialize controller with detected obs_dim
|
||||||
|
controller = HolosomaLocomotionController(
|
||||||
|
policy=policy,
|
||||||
|
robot=robot,
|
||||||
|
config=config,
|
||||||
|
obs_dim=obs_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
#controller.reset_robot()
|
||||||
|
controller.start_locomotion_thread()
|
||||||
|
|
||||||
|
logger.info(f"Robot initialized with Holosoma {'23-DOF' if obs_dim == 82 else '29-DOF'} policy")
|
||||||
|
logger.info("Use remote controller: LY=fwd/back, LX=left/right, RX=rotate")
|
||||||
|
logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(1.0)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping locomotion...")
|
||||||
|
controller.stop_locomotion_thread()
|
||||||
|
print("Done!")
|
||||||
@@ -0,0 +1,607 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Locomotion ↔ Dance Toggle for Unitree G1
|
||||||
|
|
||||||
|
Press Enter to instantly switch between locomotion and dance modes.
|
||||||
|
- Starts in LOCOMOTION mode (joystick control)
|
||||||
|
- Press Enter → DANCE mode (resets to frame 0)
|
||||||
|
- Press Enter → LOCOMOTION mode
|
||||||
|
- Repeat...
|
||||||
|
|
||||||
|
Auto-recovery feature:
|
||||||
|
- If robot tilts beyond threshold during dance, auto-switches to locomotion
|
||||||
|
- When robot recovers (tilt below recovery threshold), resumes dance from where it left off
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/unitree_g1/locomotion_to_dance.py
|
||||||
|
python examples/unitree_g1/locomotion_to_dance.py --tilt-threshold 25 --recovery-threshold 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import select
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from xml.etree import ElementTree
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
import pinocchio as pin
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
NUM_DOFS = 29
|
||||||
|
CONTROL_DT = 0.02 # 50Hz
|
||||||
|
|
||||||
|
# Locomotion config
|
||||||
|
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||||
|
LOCOMOTION_ACTION_SCALE = 0.25
|
||||||
|
ANG_VEL_SCALE = 0.25
|
||||||
|
DOF_POS_SCALE = 1.0
|
||||||
|
DOF_VEL_SCALE = 0.05
|
||||||
|
GAIT_PERIOD = 1.0
|
||||||
|
|
||||||
|
# Dance config
|
||||||
|
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
|
||||||
|
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
|
||||||
|
FROZEN_KP = 500.0
|
||||||
|
FROZEN_KD = 5.0
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# 29-DOF defaults (holosoma training)
|
||||||
|
DEFAULT_29DOF_ANGLES = np.array([
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||||
|
0.0, 0.0, 0.0, # waist
|
||||||
|
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
|
||||||
|
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
DEFAULT_29DOF_KP = np.array([
|
||||||
|
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||||
|
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||||
|
40.179, 28.501, 28.501,
|
||||||
|
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
|
||||||
|
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
DEFAULT_29DOF_KD = np.array([
|
||||||
|
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||||
|
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||||
|
2.558, 1.814, 1.814,
|
||||||
|
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
|
||||||
|
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# 23-DOF config (no waist_roll/pitch, no wrist_pitch/yaw)
|
||||||
|
DEFAULT_23DOF_ANGLES = np.array([
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||||
|
0.0, # waist_yaw only
|
||||||
|
0.2, 0.2, 0.0, 0.6, 0.0, # left arm (5 joints)
|
||||||
|
0.2, -0.2, 0.0, 0.6, 0.0, # right arm (5 joints)
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
DEFAULT_23DOF_KP = np.array([
|
||||||
|
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||||
|
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||||
|
40.179,
|
||||||
|
14.251, 14.251, 14.251, 14.251, 14.251,
|
||||||
|
14.251, 14.251, 14.251, 14.251, 14.251,
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
DEFAULT_23DOF_KD = np.array([
|
||||||
|
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||||
|
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||||
|
2.558,
|
||||||
|
0.907, 0.907, 0.907, 0.907, 0.907,
|
||||||
|
0.907, 0.907, 0.907, 0.907, 0.907,
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# 23-DOF policy index → 29-DOF motor index
|
||||||
|
DOF_23_TO_MOTOR = [
|
||||||
|
0, 1, 2, 3, 4, 5, # left leg
|
||||||
|
6, 7, 8, 9, 10, 11, # right leg
|
||||||
|
12, # waist_yaw
|
||||||
|
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw)
|
||||||
|
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw)
|
||||||
|
]
|
||||||
|
MISSING_23DOF_MOTORS = [13, 14, 20, 21, 27, 28]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# QUATERNION UTILITIES
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def quat_inverse(q):
|
||||||
|
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
|
||||||
|
|
||||||
|
def quat_mul(a, b):
|
||||||
|
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
|
||||||
|
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
||||||
|
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
||||||
|
ww = (z1 + x1) * (x2 + y2)
|
||||||
|
yy = (w1 - y1) * (w2 + z2)
|
||||||
|
zz = (w1 + y1) * (w2 - z2)
|
||||||
|
xx = ww + yy + zz
|
||||||
|
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
|
||||||
|
w = qq - ww + (z1 - y1) * (y2 - z2)
|
||||||
|
x = qq - xx + (x1 + w1) * (x2 + w2)
|
||||||
|
y = qq - yy + (w1 - x1) * (y2 + z2)
|
||||||
|
z = qq - zz + (z1 + y1) * (w2 - x2)
|
||||||
|
return np.stack([w, x, y, z]).T.reshape(a.shape)
|
||||||
|
|
||||||
|
def subtract_frame_transforms(q01, q02):
|
||||||
|
return quat_mul(quat_inverse(q01), q02)
|
||||||
|
|
||||||
|
def matrix_from_quat(q):
|
||||||
|
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||||||
|
two_s = 2.0 / (q * q).sum(-1)
|
||||||
|
o = np.stack((
|
||||||
|
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
||||||
|
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
||||||
|
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
||||||
|
), -1)
|
||||||
|
return o.reshape(q.shape[:-1] + (3, 3))
|
||||||
|
|
||||||
|
def xyzw_to_wxyz(xyzw):
|
||||||
|
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
|
||||||
|
|
||||||
|
def quat_to_rpy(q):
|
||||||
|
w, x, y, z = q
|
||||||
|
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
|
||||||
|
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
|
||||||
|
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
|
||||||
|
return roll, pitch, yaw
|
||||||
|
|
||||||
|
def rpy_to_quat(rpy):
|
||||||
|
roll, pitch, yaw = rpy
|
||||||
|
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
|
||||||
|
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
|
||||||
|
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
|
||||||
|
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
|
||||||
|
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PINOCCHIO FK
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DOF_NAMES = (
|
||||||
|
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
|
||||||
|
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
|
||||||
|
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
|
||||||
|
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
|
||||||
|
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
||||||
|
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
|
||||||
|
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
|
||||||
|
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
|
||||||
|
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PinocchioFK:
|
||||||
|
def __init__(self, urdf_text: str):
|
||||||
|
root = ElementTree.fromstring(urdf_text)
|
||||||
|
for parent in root.iter():
|
||||||
|
for child in list(parent):
|
||||||
|
if child.tag.split("}")[-1] in {"visual", "collision"}:
|
||||||
|
parent.remove(child)
|
||||||
|
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
|
||||||
|
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
|
||||||
|
self.data = self.model.createData()
|
||||||
|
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
|
||||||
|
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
|
||||||
|
self.ref_frame_id = self.model.getFrameId("torso_link")
|
||||||
|
|
||||||
|
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
|
||||||
|
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
|
||||||
|
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
|
||||||
|
pin.framesForwardKinematics(self.model, self.data, config)
|
||||||
|
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
|
||||||
|
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
|
||||||
|
|
||||||
|
def get_torso_tilt(self, pos, quat_wxyz, dof_pos):
|
||||||
|
"""Get torso tilt angle from upright (degrees). Uses roll and pitch."""
|
||||||
|
torso_q = self.get_torso_quat(pos, quat_wxyz, dof_pos)
|
||||||
|
roll, pitch, _ = quat_to_rpy(torso_q.flatten())
|
||||||
|
# Tilt is the angle from vertical - combine roll and pitch
|
||||||
|
tilt_rad = np.sqrt(roll**2 + pitch**2)
|
||||||
|
return np.degrees(tilt_rad), np.degrees(roll), np.degrees(pitch)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LOCOMOTION CONTROLLER
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class LocomotionController:
|
||||||
|
"""Holosoma whole-body locomotion (23-DOF or 29-DOF)."""
|
||||||
|
|
||||||
|
def __init__(self, policy, robot, obs_dim: int):
|
||||||
|
self.policy = policy
|
||||||
|
self.robot = robot
|
||||||
|
self.obs_dim = obs_dim
|
||||||
|
|
||||||
|
# Detect DOF mode
|
||||||
|
self.is_23dof = (obs_dim == 82)
|
||||||
|
self.num_dof = 23 if self.is_23dof else 29
|
||||||
|
|
||||||
|
if self.is_23dof:
|
||||||
|
self.default_angles = DEFAULT_23DOF_ANGLES
|
||||||
|
self.kp = DEFAULT_23DOF_KP
|
||||||
|
self.kd = DEFAULT_23DOF_KD
|
||||||
|
self.motor_map = DOF_23_TO_MOTOR
|
||||||
|
logger.info("Locomotion: 23-DOF (82D obs)")
|
||||||
|
else:
|
||||||
|
self.default_angles = DEFAULT_29DOF_ANGLES
|
||||||
|
self.kp = DEFAULT_29DOF_KP
|
||||||
|
self.kd = DEFAULT_29DOF_KD
|
||||||
|
self.motor_map = list(range(29))
|
||||||
|
logger.info("Locomotion: 29-DOF (100D obs)")
|
||||||
|
|
||||||
|
self.cmd = np.zeros(3, dtype=np.float32)
|
||||||
|
self.qj = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
self.obs = np.zeros(obs_dim, dtype=np.float32)
|
||||||
|
self.last_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||||
|
|
||||||
|
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||||
|
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
|
||||||
|
self.is_standing = True
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
"""Single locomotion step."""
|
||||||
|
state = self.robot.lowstate_buffer.get_data()
|
||||||
|
if state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Joystick
|
||||||
|
if state.wireless_remote is not None:
|
||||||
|
self.robot.remote_controller.set(state.wireless_remote)
|
||||||
|
|
||||||
|
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||||
|
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||||
|
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||||
|
self.cmd[0], self.cmd[1], self.cmd[2] = ly, -lx, -rx
|
||||||
|
|
||||||
|
# Read joints via motor map
|
||||||
|
for i in range(self.num_dof):
|
||||||
|
self.qj[i] = state.motor_state[self.motor_map[i]].q
|
||||||
|
self.dqj[i] = state.motor_state[self.motor_map[i]].dq
|
||||||
|
|
||||||
|
# IMU
|
||||||
|
quat = state.imu_state.quaternion
|
||||||
|
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
gravity = self.robot.get_gravity_orientation(quat)
|
||||||
|
|
||||||
|
# Scale
|
||||||
|
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
|
||||||
|
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||||
|
ang_vel_s = ang_vel * ANG_VEL_SCALE
|
||||||
|
|
||||||
|
# Phase
|
||||||
|
cmd_mag = np.linalg.norm(self.cmd[:2])
|
||||||
|
ang_mag = abs(self.cmd[2])
|
||||||
|
if cmd_mag < 0.01 and ang_mag < 0.01:
|
||||||
|
self.phase[0, :] = np.pi
|
||||||
|
self.is_standing = True
|
||||||
|
elif self.is_standing:
|
||||||
|
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||||
|
self.is_standing = False
|
||||||
|
else:
|
||||||
|
self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2*np.pi) - np.pi
|
||||||
|
|
||||||
|
sin_ph, cos_ph = np.sin(self.phase[0]), np.cos(self.phase[0])
|
||||||
|
|
||||||
|
# Build obs
|
||||||
|
if self.is_23dof:
|
||||||
|
self.obs[0:23] = self.last_action
|
||||||
|
self.obs[23:26] = ang_vel_s
|
||||||
|
self.obs[26] = self.cmd[2]
|
||||||
|
self.obs[27:29] = self.cmd[:2]
|
||||||
|
self.obs[29:31] = cos_ph
|
||||||
|
self.obs[31:54] = qj_obs
|
||||||
|
self.obs[54:77] = dqj_obs
|
||||||
|
self.obs[77:80] = gravity
|
||||||
|
self.obs[80:82] = sin_ph
|
||||||
|
else:
|
||||||
|
self.obs[0:29] = self.last_action
|
||||||
|
self.obs[29:32] = ang_vel_s
|
||||||
|
self.obs[32] = self.cmd[2]
|
||||||
|
self.obs[33:35] = self.cmd[:2]
|
||||||
|
self.obs[35:37] = cos_ph
|
||||||
|
self.obs[37:66] = qj_obs
|
||||||
|
self.obs[66:95] = dqj_obs
|
||||||
|
self.obs[95:98] = gravity
|
||||||
|
self.obs[98:100] = sin_ph
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
obs_in = self.obs.reshape(1, -1).astype(np.float32)
|
||||||
|
ort_in = {self.policy.get_inputs()[0].name: obs_in}
|
||||||
|
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
||||||
|
clipped = np.clip(raw_action, -100.0, 100.0)
|
||||||
|
self.last_action = clipped.copy()
|
||||||
|
scaled = clipped * LOCOMOTION_ACTION_SCALE
|
||||||
|
target = self.default_angles + scaled
|
||||||
|
|
||||||
|
# Send commands
|
||||||
|
for i in range(self.num_dof):
|
||||||
|
motor_idx = self.motor_map[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = float(target[i])
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Zero missing joints for 23-DOF
|
||||||
|
if self.is_23dof:
|
||||||
|
for idx in MISSING_23DOF_MOTORS:
|
||||||
|
self.robot.msg.motor_cmd[idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[idx].kp = 40.0
|
||||||
|
self.robot.msg.motor_cmd[idx].kd = 2.0
|
||||||
|
self.robot.msg.motor_cmd[idx].tau = 0
|
||||||
|
|
||||||
|
self.robot.send_action(self.robot.msg)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset state for fresh start."""
|
||||||
|
self.last_action.fill(0)
|
||||||
|
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||||
|
self.is_standing = True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DANCE CONTROLLER
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class DanceController:
|
||||||
|
"""WBT dance policy with FK for torso tracking."""
|
||||||
|
|
||||||
|
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
|
||||||
|
self.policy = policy
|
||||||
|
self.robot = robot
|
||||||
|
self.pinocchio_fk = pinocchio_fk
|
||||||
|
self.motor_kp = motor_kp
|
||||||
|
self.motor_kd = motor_kd
|
||||||
|
self.action_scale = action_scale
|
||||||
|
|
||||||
|
self.obs_dim = policy.get_inputs()[0].shape[1]
|
||||||
|
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
|
||||||
|
self.motion_command = None
|
||||||
|
self.ref_quat_xyzw = None
|
||||||
|
self.timestep = 0
|
||||||
|
self.yaw_offset = 0.0
|
||||||
|
|
||||||
|
logger.info(f"Dance: obs_dim={self.obs_dim}, action_scale={action_scale}")
|
||||||
|
|
||||||
|
def initialize(self, reset_to_frame_0: bool = True):
|
||||||
|
"""Initialize dance. If reset_to_frame_0=True, starts from frame 0. Otherwise resumes."""
|
||||||
|
if reset_to_frame_0:
|
||||||
|
self.timestep = 0
|
||||||
|
self.last_action.fill(0)
|
||||||
|
|
||||||
|
# Get initial motion data at frame 0
|
||||||
|
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||||
|
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||||
|
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||||
|
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||||
|
self.ref_quat_xyzw = outs[2]
|
||||||
|
logger.info("Dance: reset to frame 0")
|
||||||
|
else:
|
||||||
|
# Resume from current timestep - just update motion command for current frame
|
||||||
|
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||||
|
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||||
|
{"obs": dummy, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||||
|
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||||
|
self.ref_quat_xyzw = outs[2]
|
||||||
|
logger.info(f"Dance: resuming from frame {self.timestep}")
|
||||||
|
|
||||||
|
# Capture yaw offset
|
||||||
|
state = self.robot.lowstate_buffer.get_data()
|
||||||
|
if state and self.pinocchio_fk:
|
||||||
|
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
|
||||||
|
dof = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||||
|
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
|
||||||
|
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
|
||||||
|
logger.info(f"Dance yaw offset: {np.degrees(self.yaw_offset):.1f}°")
|
||||||
|
|
||||||
|
def _remove_yaw_offset(self, quat_wxyz):
|
||||||
|
if abs(self.yaw_offset) < 1e-6:
|
||||||
|
return quat_wxyz
|
||||||
|
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
|
||||||
|
return quat_mul(yaw_q, quat_wxyz)
|
||||||
|
|
||||||
|
def run_step(self):
|
||||||
|
"""Single dance step."""
|
||||||
|
state = self.robot.lowstate_buffer.get_data()
|
||||||
|
if state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
|
||||||
|
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
dof_pos = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||||
|
dof_vel = np.array([state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
|
||||||
|
|
||||||
|
# FK for torso orientation
|
||||||
|
if self.pinocchio_fk:
|
||||||
|
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
|
||||||
|
torso_q = self._remove_yaw_offset(torso_q)
|
||||||
|
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
|
||||||
|
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
|
||||||
|
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
|
||||||
|
else:
|
||||||
|
ori_b = np.zeros((1, 6), dtype=np.float32)
|
||||||
|
|
||||||
|
dof_rel = (dof_pos - DEFAULT_29DOF_ANGLES).reshape(1, -1)
|
||||||
|
|
||||||
|
# Build obs (alphabetical)
|
||||||
|
obs_dict = {
|
||||||
|
"actions": self.last_action,
|
||||||
|
"base_ang_vel": ang_vel.reshape(1, 3),
|
||||||
|
"dof_pos": dof_rel,
|
||||||
|
"dof_vel": dof_vel.reshape(1, -1),
|
||||||
|
"motion_command": self.motion_command,
|
||||||
|
"motion_ref_ori_b": ori_b,
|
||||||
|
}
|
||||||
|
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
|
||||||
|
obs = np.clip(obs, -100, 100)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||||
|
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||||
|
action = np.clip(outs[0], -100, 100)
|
||||||
|
self.motion_command = np.concatenate(outs[1:3], axis=1)
|
||||||
|
self.ref_quat_xyzw = outs[3]
|
||||||
|
self.last_action = action.copy()
|
||||||
|
|
||||||
|
target = DEFAULT_29DOF_ANGLES + action.flatten() * self.action_scale
|
||||||
|
|
||||||
|
# Send commands
|
||||||
|
for i in range(NUM_DOFS):
|
||||||
|
if i in FROZEN_JOINTS:
|
||||||
|
self.robot.msg.motor_cmd[i].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||||
|
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||||
|
else:
|
||||||
|
self.robot.msg.motor_cmd[i].q = float(target[i])
|
||||||
|
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
|
||||||
|
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
|
||||||
|
self.robot.msg.motor_cmd[i].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[i].tau = 0
|
||||||
|
|
||||||
|
self.robot.send_action(self.robot.msg)
|
||||||
|
self.timestep += 1
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MAIN
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Locomotion ↔ Dance Toggle")
|
||||||
|
parser.add_argument("--loco-repo", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
|
||||||
|
parser.add_argument("--dance-onnx", type=str, default=DANCE_ONNX_PATH)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("🚶 LOCOMOTION ↔ 💃 DANCE")
|
||||||
|
print("=" * 70)
|
||||||
|
print("Press ENTER to toggle between modes")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Load locomotion policy
|
||||||
|
logger.info("Loading locomotion policy...")
|
||||||
|
loco_path = hf_hub_download(repo_id=args.loco_repo, filename="fastsac_g1_29dof.onnx")
|
||||||
|
loco_policy = ort.InferenceSession(loco_path)
|
||||||
|
loco_obs_dim = loco_policy.get_inputs()[0].shape[1]
|
||||||
|
logger.info(f"Locomotion: {loco_obs_dim}D obs")
|
||||||
|
|
||||||
|
# Load dance policy
|
||||||
|
logger.info("Loading dance policy...")
|
||||||
|
dance_policy = ort.InferenceSession(args.dance_onnx)
|
||||||
|
dance_model = onnx.load(args.dance_onnx)
|
||||||
|
dance_meta = {p.key: json.loads(p.value) for p in dance_model.metadata_props}
|
||||||
|
dance_kp = np.array(dance_meta.get("kp", DEFAULT_29DOF_KP), dtype=np.float32)
|
||||||
|
dance_kd = np.array(dance_meta.get("kd", DEFAULT_29DOF_KD), dtype=np.float32)
|
||||||
|
dance_action_scale = float(dance_meta.get("action_scale", 1.0))
|
||||||
|
logger.info(f"Dance: {dance_policy.get_inputs()[0].shape[1]}D obs, scale={dance_action_scale}")
|
||||||
|
|
||||||
|
# Build Pinocchio FK
|
||||||
|
pinocchio_fk = None
|
||||||
|
if "robot_urdf" in dance_meta:
|
||||||
|
logger.info("Building Pinocchio FK...")
|
||||||
|
pinocchio_fk = PinocchioFK(dance_meta["robot_urdf"])
|
||||||
|
|
||||||
|
# Initialize robot
|
||||||
|
logger.info("Initializing robot...")
|
||||||
|
config = UnitreeG1Config()
|
||||||
|
robot = UnitreeG1(config)
|
||||||
|
logger.info("Robot connected!")
|
||||||
|
|
||||||
|
# Create controllers
|
||||||
|
loco_ctrl = LocomotionController(loco_policy, robot, loco_obs_dim)
|
||||||
|
dance_ctrl = DanceController(dance_policy, robot, pinocchio_fk, dance_kp, dance_kd, dance_action_scale)
|
||||||
|
|
||||||
|
# State
|
||||||
|
mode = "locomotion"
|
||||||
|
toggle_event = threading.Event()
|
||||||
|
shutdown = threading.Event()
|
||||||
|
|
||||||
|
# Input thread
|
||||||
|
def input_loop():
|
||||||
|
while not shutdown.is_set():
|
||||||
|
if select.select([sys.stdin], [], [], 0.1)[0]:
|
||||||
|
sys.stdin.readline()
|
||||||
|
toggle_event.set()
|
||||||
|
|
||||||
|
input_thread = threading.Thread(target=input_loop, daemon=True)
|
||||||
|
input_thread.start()
|
||||||
|
|
||||||
|
print("\n🚶 LOCOMOTION MODE - Use joystick to walk")
|
||||||
|
print(" Press ENTER to switch to DANCE")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
try:
|
||||||
|
while not shutdown.is_set():
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Check toggle
|
||||||
|
if toggle_event.is_set():
|
||||||
|
toggle_event.clear()
|
||||||
|
if mode == "locomotion":
|
||||||
|
mode = "dance"
|
||||||
|
dance_ctrl.initialize()
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("💃 DANCE MODE (frame 0)")
|
||||||
|
print(" Press ENTER to switch to LOCOMOTION")
|
||||||
|
print("=" * 70)
|
||||||
|
else:
|
||||||
|
mode = "locomotion"
|
||||||
|
loco_ctrl.reset()
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("🚶 LOCOMOTION MODE")
|
||||||
|
print(" Press ENTER to switch to DANCE")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Run controller
|
||||||
|
if mode == "locomotion":
|
||||||
|
loco_ctrl.run_step()
|
||||||
|
else:
|
||||||
|
dance_ctrl.run_step()
|
||||||
|
|
||||||
|
# Log
|
||||||
|
if step % 100 == 0:
|
||||||
|
if mode == "locomotion":
|
||||||
|
print(f"[LOCO ] step={step:5d} cmd=[{loco_ctrl.cmd[0]:.2f},{loco_ctrl.cmd[1]:.2f},{loco_ctrl.cmd[2]:.2f}]")
|
||||||
|
else:
|
||||||
|
print(f"[DANCE] step={step:5d} timestep={dance_ctrl.timestep}")
|
||||||
|
|
||||||
|
step += 1
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
if elapsed < CONTROL_DT:
|
||||||
|
time.sleep(CONTROL_DT - elapsed)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nStopping...")
|
||||||
|
finally:
|
||||||
|
shutdown.set()
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,447 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Example: Unitree RL 12-DOF Legs-Only Locomotion (TorchScript)
|
||||||
|
|
||||||
|
This example demonstrates loading a 12-DOF legs-only locomotion policy
|
||||||
|
(TorchScript .pt format) and running it on the Unitree G1 robot.
|
||||||
|
|
||||||
|
Key characteristics:
|
||||||
|
- Single TorchScript policy (.pt)
|
||||||
|
- 47D observations, 12D actions (legs only)
|
||||||
|
- Phase-based gait timing
|
||||||
|
- Arms and waist held at fixed positions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from scipy.spatial.transform import Rotation as R
|
||||||
|
|
||||||
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
|
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 12-DOF leg joint configuration
|
||||||
|
# Joint order: [L_hip_pitch, L_hip_roll, L_hip_yaw, L_knee, L_ankle_pitch, L_ankle_roll,
|
||||||
|
# R_hip_pitch, R_hip_roll, R_hip_yaw, R_knee, R_ankle_pitch, R_ankle_roll]
|
||||||
|
LEG_JOINT_INDICES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||||
|
|
||||||
|
# Default leg angles for standing
|
||||||
|
DEFAULT_LEG_ANGLES = np.array([
|
||||||
|
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
|
||||||
|
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# KP/KD for leg joints
|
||||||
|
LEG_KPS = np.array([150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40], dtype=np.float32)
|
||||||
|
LEG_KDS = np.array([6, 6, 6, 4, 2, 2, 6, 6, 6, 4, 2, 2], dtype=np.float32)
|
||||||
|
|
||||||
|
# Waist configuration (held at zero)
|
||||||
|
WAIST_JOINT_INDICES = [12, 13, 14] # yaw, roll, pitch
|
||||||
|
WAIST_KPS = np.array([250, 250, 250], dtype=np.float32)
|
||||||
|
WAIST_KDS = np.array([5, 5, 5], dtype=np.float32)
|
||||||
|
|
||||||
|
# Arm configuration (indices 15-28, held at initial position)
|
||||||
|
ARM_JOINT_INDICES = list(range(15, 29))
|
||||||
|
ARM_KPS = np.array([80, 80, 80, 80, 40, 40, 40, # left arm (shoulder + wrist)
|
||||||
|
80, 80, 80, 80, 40, 40, 40], dtype=np.float32) # right arm
|
||||||
|
ARM_KDS = np.array([3, 3, 3, 3, 1.5, 1.5, 1.5,
|
||||||
|
3, 3, 3, 3, 1.5, 1.5, 1.5], dtype=np.float32)
|
||||||
|
|
||||||
|
# Control parameters
|
||||||
|
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz control rate
|
||||||
|
LOCOMOTION_ACTION_SCALE = 0.25
|
||||||
|
ANG_VEL_SCALE = 0.25
|
||||||
|
DOF_POS_SCALE = 1.0
|
||||||
|
DOF_VEL_SCALE = 0.05
|
||||||
|
CMD_SCALE = np.array([2.0, 2.0, 0.25], dtype=np.float32)
|
||||||
|
MAX_CMD = np.array([0.8, 0.5, 1.57], dtype=np.float32) # max vx, vy, yaw_rate
|
||||||
|
|
||||||
|
# Gait parameters
|
||||||
|
GAIT_PERIOD = 0.8 # seconds
|
||||||
|
|
||||||
|
DEFAULT_REPO_ID = "nepyope/unitree_rl_locomotion"
|
||||||
|
|
||||||
|
|
||||||
|
def load_torchscript_policy(
|
||||||
|
repo_id: str = DEFAULT_REPO_ID,
|
||||||
|
filename: str = "motion.pt",
|
||||||
|
) -> torch.jit.ScriptModule:
|
||||||
|
"""Load TorchScript locomotion policy from Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: Hugging Face Hub repository ID containing the policy.
|
||||||
|
filename: Policy filename (default: motion.pt).
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading TorchScript policy from Hugging Face Hub ({repo_id}/{filename})...")
|
||||||
|
|
||||||
|
policy_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy = torch.jit.load(policy_path)
|
||||||
|
policy.eval()
|
||||||
|
|
||||||
|
logger.info("TorchScript policy loaded successfully")
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
|
class UnitreeRLLocomotionController:
|
||||||
|
"""
|
||||||
|
Handles 12-DOF legs-only locomotion control for the Unitree G1 robot.
|
||||||
|
|
||||||
|
This controller manages:
|
||||||
|
- Single TorchScript policy
|
||||||
|
- 47D observations (single frame)
|
||||||
|
- 12D action output (legs only)
|
||||||
|
- Arms and waist held at fixed positions
|
||||||
|
- Phase-based gait timing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, policy, robot, config):
|
||||||
|
self.policy = policy
|
||||||
|
self.robot = robot
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Velocity commands (vx, vy, yaw_rate)
|
||||||
|
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
|
# State variables (12 DOF legs)
|
||||||
|
self.qj = np.zeros(12, dtype=np.float32)
|
||||||
|
self.dqj = np.zeros(12, dtype=np.float32)
|
||||||
|
self.locomotion_action = np.zeros(12, dtype=np.float32)
|
||||||
|
self.locomotion_obs = np.zeros(47, dtype=np.float32)
|
||||||
|
|
||||||
|
# Initial arm positions (captured on reset)
|
||||||
|
self.initial_arm_positions = np.zeros(14, dtype=np.float32)
|
||||||
|
|
||||||
|
# Counter for phase calculation
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
# Thread management
|
||||||
|
self.locomotion_running = False
|
||||||
|
self.locomotion_thread = None
|
||||||
|
|
||||||
|
logger.info("UnitreeRLLocomotionController initialized")
|
||||||
|
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
|
||||||
|
|
||||||
|
def locomotion_run(self):
|
||||||
|
"""12-DOF legs-only locomotion policy loop."""
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
if self.counter == 1:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🚀 RUNNING UNITREE RL 12-DOF LOCOMOTION POLICY")
|
||||||
|
print(" 47D observations → 12D actions (legs only)")
|
||||||
|
print(" Arms and waist held at fixed positions")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
|
# Get current observation
|
||||||
|
robot_state = self.robot.get_observation()
|
||||||
|
if robot_state is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get command from remote controller
|
||||||
|
if robot_state.wireless_remote is not None:
|
||||||
|
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||||
|
else:
|
||||||
|
self.robot.remote_controller.lx = 0.0
|
||||||
|
self.robot.remote_controller.ly = 0.0
|
||||||
|
self.robot.remote_controller.rx = 0.0
|
||||||
|
self.robot.remote_controller.ry = 0.0
|
||||||
|
|
||||||
|
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||||
|
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right (inverted)
|
||||||
|
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # yaw (inverted)
|
||||||
|
|
||||||
|
# Get leg joint positions and velocities (12 DOF)
|
||||||
|
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||||
|
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||||
|
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||||
|
|
||||||
|
# Get IMU data
|
||||||
|
quat = robot_state.imu_state.quaternion
|
||||||
|
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
|
||||||
|
# Scale observations
|
||||||
|
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||||
|
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
|
||||||
|
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||||
|
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||||
|
|
||||||
|
# Calculate phase
|
||||||
|
count = self.counter * LOCOMOTION_CONTROL_DT
|
||||||
|
phase = (count % GAIT_PERIOD) / GAIT_PERIOD
|
||||||
|
sin_phase = np.sin(2 * np.pi * phase)
|
||||||
|
cos_phase = np.cos(2 * np.pi * phase)
|
||||||
|
|
||||||
|
# Build 47D observation vector
|
||||||
|
# [0:3] - angular velocity (scaled)
|
||||||
|
# [3:6] - gravity orientation
|
||||||
|
# [6:9] - velocity command (scaled)
|
||||||
|
# [9:21] - joint positions (12D, relative to default)
|
||||||
|
# [21:33] - joint velocities (12D, scaled)
|
||||||
|
# [33:45] - previous actions (12D)
|
||||||
|
# [45] - sin_phase
|
||||||
|
# [46] - cos_phase
|
||||||
|
self.locomotion_obs[0:3] = ang_vel_scaled
|
||||||
|
self.locomotion_obs[3:6] = gravity_orientation
|
||||||
|
self.locomotion_obs[6:9] = self.locomotion_cmd * CMD_SCALE * MAX_CMD
|
||||||
|
self.locomotion_obs[9:21] = qj_obs
|
||||||
|
self.locomotion_obs[21:33] = dqj_obs
|
||||||
|
self.locomotion_obs[33:45] = self.locomotion_action
|
||||||
|
self.locomotion_obs[45] = sin_phase
|
||||||
|
self.locomotion_obs[46] = cos_phase
|
||||||
|
|
||||||
|
# Run policy inference (TorchScript)
|
||||||
|
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0).float()
|
||||||
|
with torch.no_grad():
|
||||||
|
action_tensor = self.policy(obs_tensor)
|
||||||
|
self.locomotion_action = action_tensor.squeeze().numpy()
|
||||||
|
|
||||||
|
# Transform action to target joint positions
|
||||||
|
target_leg_pos = DEFAULT_LEG_ANGLES + self.locomotion_action * LOCOMOTION_ACTION_SCALE
|
||||||
|
|
||||||
|
# Debug logging (first 3 iterations)
|
||||||
|
if self.counter <= 3:
|
||||||
|
print(f"\n[Unitree RL Debug #{self.counter}]")
|
||||||
|
print(f" Phase: {phase:.3f} (sin={sin_phase:.3f}, cos={cos_phase:.3f})")
|
||||||
|
print(f" Cmd (vx, vy, yaw): ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||||
|
print(f" Action range: [{self.locomotion_action.min():.3f}, {self.locomotion_action.max():.3f}]")
|
||||||
|
|
||||||
|
# Send commands to LEG motors (0-11)
|
||||||
|
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = target_leg_pos[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Hold WAIST motors at zero (12, 13, 14)
|
||||||
|
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Hold ARM motors at initial position (15-28)
|
||||||
|
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Send command
|
||||||
|
self.robot.send_action(self.robot.msg)
|
||||||
|
|
||||||
|
def _locomotion_thread_loop(self):
|
||||||
|
"""Background thread that runs the locomotion policy at specified rate."""
|
||||||
|
logger.info("Locomotion thread started")
|
||||||
|
while self.locomotion_running:
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
self.locomotion_run()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in locomotion loop: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Sleep to maintain control rate
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
logger.info("Locomotion thread stopped")
|
||||||
|
|
||||||
|
def start_locomotion_thread(self):
|
||||||
|
if self.locomotion_running:
|
||||||
|
logger.warning("Locomotion thread already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting locomotion control thread...")
|
||||||
|
self.locomotion_running = True
|
||||||
|
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||||
|
self.locomotion_thread.start()
|
||||||
|
|
||||||
|
logger.info("Locomotion control thread started!")
|
||||||
|
|
||||||
|
def stop_locomotion_thread(self):
|
||||||
|
if not self.locomotion_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Stopping locomotion control thread...")
|
||||||
|
self.locomotion_running = False
|
||||||
|
if self.locomotion_thread:
|
||||||
|
self.locomotion_thread.join(timeout=2.0)
|
||||||
|
logger.info("Locomotion control thread stopped")
|
||||||
|
|
||||||
|
def reset_robot(self):
|
||||||
|
"""Move legs to default standing position over 2 seconds (arms are captured and held)."""
|
||||||
|
logger.info("Moving legs to default position...")
|
||||||
|
|
||||||
|
total_time = 2.0
|
||||||
|
num_step = int(total_time / self.robot.control_dt)
|
||||||
|
|
||||||
|
# Get current state
|
||||||
|
robot_state = self.robot.get_observation()
|
||||||
|
|
||||||
|
# Capture initial arm positions (to hold during locomotion)
|
||||||
|
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||||
|
self.initial_arm_positions[i] = robot_state.motor_state[motor_idx].q
|
||||||
|
logger.info(f"Captured initial arm positions: {self.initial_arm_positions[:4]}...")
|
||||||
|
|
||||||
|
# Record current leg positions
|
||||||
|
init_leg_pos = np.zeros(12, dtype=np.float32)
|
||||||
|
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||||
|
init_leg_pos[i] = robot_state.motor_state[motor_idx].q
|
||||||
|
|
||||||
|
# Interpolate legs to default position
|
||||||
|
for step in range(num_step):
|
||||||
|
alpha = step / num_step
|
||||||
|
|
||||||
|
# Interpolate leg positions
|
||||||
|
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||||
|
target_pos = DEFAULT_LEG_ANGLES[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||||
|
init_leg_pos[i] * (1 - alpha) + target_pos * alpha
|
||||||
|
)
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Hold waist at zero
|
||||||
|
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Hold arms at initial position
|
||||||
|
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||||
|
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||||
|
time.sleep(self.robot.control_dt)
|
||||||
|
|
||||||
|
logger.info("Reached default leg position")
|
||||||
|
|
||||||
|
# Hold position for 2 seconds
|
||||||
|
logger.info("Holding default position for 2 seconds...")
|
||||||
|
hold_time = 2.0
|
||||||
|
num_hold_steps = int(hold_time / self.robot.control_dt)
|
||||||
|
|
||||||
|
for _ in range(num_hold_steps):
|
||||||
|
# Hold legs at default
|
||||||
|
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = DEFAULT_LEG_ANGLES[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Hold waist at zero
|
||||||
|
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
# Hold arms at initial position
|
||||||
|
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||||
|
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
|
||||||
|
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||||
|
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||||
|
time.sleep(self.robot.control_dt)
|
||||||
|
|
||||||
|
logger.info("Ready to start locomotion!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Unitree RL 12-DOF Locomotion Controller for Unitree G1")
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_REPO_ID,
|
||||||
|
help=f"Hugging Face Hub repo ID for policy (default: {DEFAULT_REPO_ID})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--filename",
|
||||||
|
type=str,
|
||||||
|
default="motion.pt",
|
||||||
|
help="Policy filename (default: motion.pt)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load policy
|
||||||
|
policy = load_torchscript_policy(repo_id=args.repo_id, filename=args.filename)
|
||||||
|
|
||||||
|
# Initialize robot
|
||||||
|
config = UnitreeG1Config()
|
||||||
|
robot = UnitreeG1(config)
|
||||||
|
|
||||||
|
# Initialize locomotion controller
|
||||||
|
locomotion_controller = UnitreeRLLocomotionController(
|
||||||
|
policy=policy,
|
||||||
|
robot=robot,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset robot and start locomotion thread
|
||||||
|
try:
|
||||||
|
locomotion_controller.reset_robot()
|
||||||
|
locomotion_controller.start_locomotion_thread()
|
||||||
|
|
||||||
|
# Log status
|
||||||
|
logger.info("Robot initialized with Unitree RL locomotion policy")
|
||||||
|
logger.info("Locomotion controller running in background thread")
|
||||||
|
logger.info("Use remote controller to command velocity:")
|
||||||
|
logger.info(" Left stick Y: forward/backward")
|
||||||
|
logger.info(" Left stick X: left/right")
|
||||||
|
logger.info(" Right stick X: rotate")
|
||||||
|
logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
|
# Keep robot alive
|
||||||
|
while True:
|
||||||
|
time.sleep(1.0)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping locomotion...")
|
||||||
|
locomotion_controller.stop_locomotion_thread()
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.cameras import CameraConfig
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
|
|
||||||
_GAINS: dict[str, dict[str, list[float]]] = {
|
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||||
@@ -52,7 +54,10 @@ class UnitreeG1Config(RobotConfig):
|
|||||||
control_dt: float = 1.0 / 250.0 # 250Hz
|
control_dt: float = 1.0 / 250.0 # 250Hz
|
||||||
|
|
||||||
# launch mujoco simulation
|
# launch mujoco simulation
|
||||||
is_simulation: bool = True
|
is_simulation: bool = False
|
||||||
|
|
||||||
# socket config for ZMQ bridge
|
# socket config for ZMQ bridge
|
||||||
robot_ip: str = "192.168.123.164"
|
robot_ip: str = "172.18.129.215"
|
||||||
|
|
||||||
|
# cameras (optional)
|
||||||
|
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||||
|
|||||||
@@ -0,0 +1,302 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Standalone keyboard control script for Unitree G1 robot.
|
||||||
|
|
||||||
|
This script provides keyboard-based velocity control for the G1 robot's
|
||||||
|
locomotion system. It can be run alongside the main robot control to
|
||||||
|
provide manual movement commands.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python keyboard_control.py [--robot-ip IP] [--simulation]
|
||||||
|
|
||||||
|
Controls:
|
||||||
|
W/S: Forward/Backward
|
||||||
|
A/D: Strafe Left/Right
|
||||||
|
Q/E: Rotate Left/Right
|
||||||
|
R/F: Raise/Lower Height (GR00T policies only)
|
||||||
|
Z: Stop (zero all velocity commands)
|
||||||
|
ESC/Ctrl+C: Exit
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import select
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Terminal handling for non-blocking keyboard input
|
||||||
|
try:
|
||||||
|
import termios
|
||||||
|
import tty
|
||||||
|
HAS_TERMIOS = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TERMIOS = False
|
||||||
|
print("Warning: termios not available. Keyboard controls require Linux/macOS.")
|
||||||
|
|
||||||
|
|
||||||
|
class KeyboardController:
|
||||||
|
"""Handles keyboard input and converts to locomotion commands."""
|
||||||
|
|
||||||
|
def __init__(self, callback=None):
|
||||||
|
"""
|
||||||
|
Initialize keyboard controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Optional function called when commands change.
|
||||||
|
Signature: callback(vx, vy, yaw, height)
|
||||||
|
"""
|
||||||
|
self.callback = callback
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Locomotion commands
|
||||||
|
self.vx = 0.0 # Forward/backward velocity
|
||||||
|
self.vy = 0.0 # Left/right velocity (strafe)
|
||||||
|
self.yaw = 0.0 # Rotation rate
|
||||||
|
self.height = 0.74 # Base height (for GR00T policies)
|
||||||
|
|
||||||
|
# Command limits
|
||||||
|
self.vx_limit = (-0.8, 0.8)
|
||||||
|
self.vy_limit = (-0.5, 0.5)
|
||||||
|
self.yaw_limit = (-1.0, 1.0)
|
||||||
|
self.height_limit = (0.50, 1.00)
|
||||||
|
|
||||||
|
# Increments per keypress
|
||||||
|
self.vx_increment = 0.4
|
||||||
|
self.vy_increment = 0.25
|
||||||
|
self.yaw_increment = 0.5
|
||||||
|
self.height_increment = 0.05
|
||||||
|
|
||||||
|
self._old_terminal_settings = None
|
||||||
|
|
||||||
|
def get_commands(self) -> tuple[float, float, float, float]:
|
||||||
|
"""Get current command values as tuple (vx, vy, yaw, height)."""
|
||||||
|
return (self.vx, self.vy, self.yaw, self.height)
|
||||||
|
|
||||||
|
def get_commands_array(self) -> np.ndarray:
|
||||||
|
"""Get velocity commands as numpy array [vx, vy, yaw]."""
|
||||||
|
return np.array([self.vx, self.vy, self.yaw], dtype=np.float32)
|
||||||
|
|
||||||
|
def reset_commands(self):
|
||||||
|
"""Reset all commands to zero (stop)."""
|
||||||
|
self.vx = 0.0
|
||||||
|
self.vy = 0.0
|
||||||
|
self.yaw = 0.0
|
||||||
|
self._notify_callback()
|
||||||
|
|
||||||
|
def _clamp(self, value: float, limits: tuple[float, float]) -> float:
|
||||||
|
"""Clamp value to limits."""
|
||||||
|
return max(limits[0], min(limits[1], value))
|
||||||
|
|
||||||
|
def _notify_callback(self):
|
||||||
|
"""Call callback with current commands if set."""
|
||||||
|
if self.callback:
|
||||||
|
self.callback(self.vx, self.vy, self.yaw, self.height)
|
||||||
|
|
||||||
|
def process_key(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Process a single key press and update commands.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Single character key that was pressed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key was handled, False otherwise.
|
||||||
|
"""
|
||||||
|
key = key.lower()
|
||||||
|
handled = True
|
||||||
|
|
||||||
|
if key == 'w':
|
||||||
|
self.vx = self._clamp(self.vx + self.vx_increment, self.vx_limit)
|
||||||
|
elif key == 's':
|
||||||
|
self.vx = self._clamp(self.vx - self.vx_increment, self.vx_limit)
|
||||||
|
elif key == 'a':
|
||||||
|
self.vy = self._clamp(self.vy + self.vy_increment, self.vy_limit)
|
||||||
|
elif key == 'd':
|
||||||
|
self.vy = self._clamp(self.vy - self.vy_increment, self.vy_limit)
|
||||||
|
elif key == 'q':
|
||||||
|
self.yaw = self._clamp(self.yaw + self.yaw_increment, self.yaw_limit)
|
||||||
|
elif key == 'e':
|
||||||
|
self.yaw = self._clamp(self.yaw - self.yaw_increment, self.yaw_limit)
|
||||||
|
elif key == 'r':
|
||||||
|
self.height = self._clamp(self.height + self.height_increment, self.height_limit)
|
||||||
|
elif key == 'f':
|
||||||
|
self.height = self._clamp(self.height - self.height_increment, self.height_limit)
|
||||||
|
elif key == 'z':
|
||||||
|
self.reset_commands()
|
||||||
|
return True # Already notified in reset_commands
|
||||||
|
else:
|
||||||
|
handled = False
|
||||||
|
|
||||||
|
if handled:
|
||||||
|
self._notify_callback()
|
||||||
|
|
||||||
|
return handled
|
||||||
|
|
||||||
|
def _setup_terminal(self):
|
||||||
|
"""Set terminal to raw mode for single character input."""
|
||||||
|
if HAS_TERMIOS:
|
||||||
|
self._old_terminal_settings = termios.tcgetattr(sys.stdin)
|
||||||
|
tty.setcbreak(sys.stdin.fileno())
|
||||||
|
|
||||||
|
def _restore_terminal(self):
|
||||||
|
"""Restore terminal to original settings."""
|
||||||
|
if HAS_TERMIOS and self._old_terminal_settings is not None:
|
||||||
|
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self._old_terminal_settings)
|
||||||
|
self._old_terminal_settings = None
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Run the keyboard listener loop (blocking)."""
|
||||||
|
if not HAS_TERMIOS:
|
||||||
|
print("Error: Keyboard controls require termios (Linux/macOS)")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self._print_controls()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._setup_terminal()
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
# Check for keyboard input with timeout
|
||||||
|
if select.select([sys.stdin], [], [], 0.1)[0]:
|
||||||
|
key = sys.stdin.read(1)
|
||||||
|
|
||||||
|
# Handle escape sequences (arrow keys, etc.)
|
||||||
|
if key == '\x1b': # ESC
|
||||||
|
self.running = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.process_key(key):
|
||||||
|
self._print_status()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nInterrupted by user")
|
||||||
|
finally:
|
||||||
|
self._restore_terminal()
|
||||||
|
print("\nKeyboard controls stopped")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the keyboard listener."""
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def _print_controls(self):
|
||||||
|
"""Print control instructions."""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("KEYBOARD CONTROLS ACTIVE")
|
||||||
|
print("=" * 60)
|
||||||
|
print(" W/S: Forward/Backward")
|
||||||
|
print(" A/D: Strafe Left/Right")
|
||||||
|
print(" Q/E: Rotate Left/Right")
|
||||||
|
print(" R/F: Raise/Lower Height (±5cm)")
|
||||||
|
print(" Z: Stop (zero all commands)")
|
||||||
|
print(" ESC: Exit")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
|
def _print_status(self):
|
||||||
|
"""Print current command status."""
|
||||||
|
print(f"[CMD] vx={self.vx:+.2f}, vy={self.vy:+.2f}, yaw={self.yaw:+.2f} | height={self.height:.3f}m")
|
||||||
|
|
||||||
|
|
||||||
|
class RobotKeyboardController(KeyboardController):
|
||||||
|
"""Keyboard controller that directly updates a robot's locomotion commands."""
|
||||||
|
|
||||||
|
def __init__(self, robot):
|
||||||
|
"""
|
||||||
|
Initialize with a UnitreeG1 robot instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: UnitreeG1 robot instance with locomotion_cmd attribute.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.robot = robot
|
||||||
|
|
||||||
|
# Initialize from robot's current state if available
|
||||||
|
if hasattr(robot, 'locomotion_cmd'):
|
||||||
|
self.vx = robot.locomotion_cmd[0]
|
||||||
|
self.vy = robot.locomotion_cmd[1]
|
||||||
|
self.yaw = robot.locomotion_cmd[2]
|
||||||
|
|
||||||
|
if hasattr(robot, 'groot_height_cmd'):
|
||||||
|
self.height = robot.groot_height_cmd
|
||||||
|
|
||||||
|
def _notify_callback(self):
|
||||||
|
"""Update robot's locomotion commands directly."""
|
||||||
|
if hasattr(self.robot, 'locomotion_cmd'):
|
||||||
|
self.robot.locomotion_cmd[0] = self.vx
|
||||||
|
self.robot.locomotion_cmd[1] = self.vy
|
||||||
|
self.robot.locomotion_cmd[2] = self.yaw
|
||||||
|
|
||||||
|
if hasattr(self.robot, 'groot_height_cmd'):
|
||||||
|
self.robot.groot_height_cmd = self.height
|
||||||
|
|
||||||
|
|
||||||
|
def start_keyboard_control_thread(robot) -> tuple:
|
||||||
|
"""
|
||||||
|
Start keyboard controls for a robot in a background thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: UnitreeG1 robot instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (controller, thread) for later stopping.
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
controller = RobotKeyboardController(robot)
|
||||||
|
thread = threading.Thread(target=controller.run, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return controller, thread
|
||||||
|
|
||||||
|
|
||||||
|
def stop_keyboard_control_thread(controller, thread, timeout: float = 2.0):
|
||||||
|
"""
|
||||||
|
Stop the keyboard control thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
controller: KeyboardController instance.
|
||||||
|
thread: Thread running the controller.
|
||||||
|
timeout: Max time to wait for thread to stop.
|
||||||
|
"""
|
||||||
|
controller.stop()
|
||||||
|
thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Standalone keyboard control with optional robot connection."""
|
||||||
|
parser = argparse.ArgumentParser(description="Keyboard control for Unitree G1")
|
||||||
|
parser.add_argument("--standalone", action="store_true",
|
||||||
|
help="Run in standalone mode (just print commands, no robot)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.standalone:
|
||||||
|
# Standalone mode - just demonstrate keyboard input
|
||||||
|
def print_callback(vx, vy, yaw, height):
|
||||||
|
print(f" → Would send: vx={vx:+.2f}, vy={vy:+.2f}, yaw={yaw:+.2f}, height={height:.3f}")
|
||||||
|
|
||||||
|
controller = KeyboardController(callback=print_callback)
|
||||||
|
print("Running in STANDALONE mode (no robot connection)")
|
||||||
|
controller.run()
|
||||||
|
else:
|
||||||
|
print("To use with a robot, import and use RobotKeyboardController:")
|
||||||
|
print("")
|
||||||
|
print(" from lerobot.robots.unitree_g1.keyboard_control import (")
|
||||||
|
print(" RobotKeyboardController,")
|
||||||
|
print(" start_keyboard_control_thread,")
|
||||||
|
print(" stop_keyboard_control_thread")
|
||||||
|
print(" )")
|
||||||
|
print("")
|
||||||
|
print(" # Start keyboard controls")
|
||||||
|
print(" controller, thread = start_keyboard_control_thread(robot)")
|
||||||
|
print("")
|
||||||
|
print(" # ... robot runs ...")
|
||||||
|
print("")
|
||||||
|
print(" # Stop keyboard controls")
|
||||||
|
print(" stop_keyboard_control_thread(controller, thread)")
|
||||||
|
print("")
|
||||||
|
print("Or run with --standalone to test keyboard input without a robot.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
@@ -101,6 +101,7 @@ from lerobot.robots import ( # noqa: F401
|
|||||||
so100_follower,
|
so100_follower,
|
||||||
so101_follower,
|
so101_follower,
|
||||||
)
|
)
|
||||||
|
from lerobot.robots.unitree_g1 import config_unitree_g1 # noqa: F401
|
||||||
from lerobot.teleoperators import ( # noqa: F401
|
from lerobot.teleoperators import ( # noqa: F401
|
||||||
Teleoperator,
|
Teleoperator,
|
||||||
TeleoperatorConfig,
|
TeleoperatorConfig,
|
||||||
@@ -197,9 +198,8 @@ class RecordConfig:
|
|||||||
cli_overrides = parser.get_cli_overrides("policy")
|
cli_overrides = parser.get_cli_overrides("policy")
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
# Note: teleop and policy can both be None for robots with built-in control (e.g. unitree_g1)
|
||||||
if self.teleop is None and self.policy is None:
|
# This is validated in record() after the robot is instantiated
|
||||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
@@ -340,6 +340,13 @@ def record_loop(
|
|||||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||||
act_processed_teleop = teleop_action_processor((act, obs))
|
act_processed_teleop = teleop_action_processor((act, obs))
|
||||||
|
elif policy is None and teleop is None and dataset is not None:
|
||||||
|
# Observation-only recording (robot controls itself, e.g. unitree_g1)
|
||||||
|
# Record observations, extract action-relevant values (positions) from obs
|
||||||
|
# Filter obs_processed to only include keys that match action_features
|
||||||
|
action_keys = set(robot.action_features.keys())
|
||||||
|
action_values = {k: v for k, v in obs_processed.items() if k in action_keys}
|
||||||
|
robot_action_to_send = None
|
||||||
else:
|
else:
|
||||||
logging.info(
|
logging.info(
|
||||||
"No policy or teleoperator provided, skipping action generation."
|
"No policy or teleoperator provided, skipping action generation."
|
||||||
@@ -352,15 +359,17 @@ def record_loop(
|
|||||||
if policy is not None and act_processed_policy is not None:
|
if policy is not None and act_processed_policy is not None:
|
||||||
action_values = act_processed_policy
|
action_values = act_processed_policy
|
||||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||||
else:
|
elif teleop is not None:
|
||||||
action_values = act_processed_teleop
|
action_values = act_processed_teleop
|
||||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||||
|
# else: observation-only mode, action_values already set above
|
||||||
|
|
||||||
# Send action to robot
|
# Send action to robot (skip if observation-only mode)
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
if robot_action_to_send is not None:
|
||||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||||
_sent_action = robot.send_action(robot_action_to_send)
|
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||||
|
_sent_action = robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
# Write to dataset
|
# Write to dataset
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user