mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ee24f64ae5 | |||
| 123b9f7851 |
@@ -0,0 +1,454 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
WBT (Whole Body Tracking) Dance Policy for Unitree G1
|
||||
|
||||
Uses ONNX model with motion data baked in.
|
||||
Pattern matches gr00t_locomotion.py - uses UnitreeG1 robot class.
|
||||
|
||||
Usage:
|
||||
python examples/unitree_g1/dance.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import pinocchio as pin
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
|
||||
CONTROL_DT = 0.02 # 50 Hz
|
||||
NUM_DOFS = 29
|
||||
|
||||
# Default joint positions (holosoma training defaults)
|
||||
DEFAULT_DOF_POS = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Left leg (6)
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # Right leg (6)
|
||||
0.0, 0.0, 0.0, # Waist (3)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Left arm (7)
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # Right arm (7)
|
||||
], dtype=np.float32)
|
||||
|
||||
# Stiff hold KP/KD (for initialization)
|
||||
STIFF_KP = np.array([
|
||||
150, 150, 200, 200, 40, 40,
|
||||
150, 150, 200, 200, 40, 40,
|
||||
200, 200, 100,
|
||||
100, 100, 100, 100, 50, 50, 50,
|
||||
100, 100, 100, 100, 50, 50, 50,
|
||||
], dtype=np.float32)
|
||||
|
||||
STIFF_KD = np.array([
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
5.0, 5.0, 5.0,
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
|
||||
], dtype=np.float32)
|
||||
|
||||
# Joints to freeze at 0 with high KP
|
||||
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
|
||||
FROZEN_KP = 500.0
|
||||
FROZEN_KD = 5.0
|
||||
|
||||
# =============================================================================
|
||||
# QUATERNION UTILITIES
|
||||
# =============================================================================
|
||||
|
||||
def quat_inverse(q):
|
||||
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
|
||||
|
||||
def quat_mul(a, b):
|
||||
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
|
||||
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
||||
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
||||
ww = (z1 + x1) * (x2 + y2)
|
||||
yy = (w1 - y1) * (w2 + z2)
|
||||
zz = (w1 + y1) * (w2 - z2)
|
||||
xx = ww + yy + zz
|
||||
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
|
||||
w = qq - ww + (z1 - y1) * (y2 - z2)
|
||||
x = qq - xx + (x1 + w1) * (x2 + w2)
|
||||
y = qq - yy + (w1 - x1) * (y2 + z2)
|
||||
z = qq - zz + (z1 + y1) * (w2 - x2)
|
||||
return np.stack([w, x, y, z]).T.reshape(a.shape)
|
||||
|
||||
def subtract_frame_transforms(q01, q02):
|
||||
return quat_mul(quat_inverse(q01), q02)
|
||||
|
||||
def matrix_from_quat(q):
|
||||
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||||
two_s = 2.0 / (q * q).sum(-1)
|
||||
o = np.stack((
|
||||
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
||||
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
||||
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
||||
), -1)
|
||||
return o.reshape(q.shape[:-1] + (3, 3))
|
||||
|
||||
def xyzw_to_wxyz(xyzw):
|
||||
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
|
||||
|
||||
def quat_to_rpy(q):
|
||||
w, x, y, z = q
|
||||
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
|
||||
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
|
||||
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
|
||||
return roll, pitch, yaw
|
||||
|
||||
def rpy_to_quat(rpy):
|
||||
roll, pitch, yaw = rpy
|
||||
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
|
||||
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
|
||||
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
|
||||
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
|
||||
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
|
||||
|
||||
# =============================================================================
|
||||
# PINOCCHIO FK
|
||||
# =============================================================================
|
||||
|
||||
DOF_NAMES = (
|
||||
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
|
||||
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
|
||||
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
|
||||
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
|
||||
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
||||
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
|
||||
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
|
||||
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
|
||||
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
|
||||
)
|
||||
|
||||
|
||||
class PinocchioFK:
|
||||
"""Pinocchio forward kinematics for torso_link orientation."""
|
||||
|
||||
def __init__(self, urdf_text: str):
|
||||
root = ElementTree.fromstring(urdf_text)
|
||||
for parent in root.iter():
|
||||
for child in list(parent):
|
||||
if child.tag.split("}")[-1] in {"visual", "collision"}:
|
||||
parent.remove(child)
|
||||
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
|
||||
|
||||
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
|
||||
self.data = self.model.createData()
|
||||
|
||||
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
|
||||
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
|
||||
self.ref_frame_id = self.model.getFrameId("torso_link")
|
||||
logger.info(f"Pinocchio FK: {len(pin_names)} joints, torso_link frame={self.ref_frame_id}")
|
||||
|
||||
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
|
||||
"""Get torso_link orientation in world frame."""
|
||||
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
|
||||
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
|
||||
pin.framesForwardKinematics(self.model, self.data, config)
|
||||
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
|
||||
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DANCE CONTROLLER
|
||||
# =============================================================================
|
||||
|
||||
class DanceController:
|
||||
"""
|
||||
Handles WBT dance policy for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- 29-joint observation processing
|
||||
- Pinocchio FK for torso orientation
|
||||
- Policy inference with motion data from ONNX
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.pinocchio_fk = pinocchio_fk
|
||||
self.motor_kp = motor_kp
|
||||
self.motor_kd = motor_kd
|
||||
self.action_scale = action_scale
|
||||
|
||||
self.obs_dim = policy.get_inputs()[0].shape[1]
|
||||
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
|
||||
self.motion_command = None
|
||||
self.ref_quat_xyzw = None
|
||||
self.timestep = 0
|
||||
self.yaw_offset = 0.0
|
||||
|
||||
# Get initial motion data from ONNX
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
self.motion_start_pose = outs[0].flatten()
|
||||
|
||||
# Thread management
|
||||
self.dance_running = False
|
||||
self.dance_thread = None
|
||||
|
||||
logger.info(f"DanceController: obs_dim={self.obs_dim}, action_scale={action_scale}")
|
||||
|
||||
def capture_yaw_offset(self):
|
||||
"""Capture robot's current yaw for relative tracking."""
|
||||
robot_state = self.robot.lowstate_buffer.get_data()
|
||||
if robot_state and self.pinocchio_fk:
|
||||
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
|
||||
dof = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
|
||||
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
|
||||
logger.info(f"Captured yaw offset: {np.degrees(self.yaw_offset):.1f}°")
|
||||
|
||||
def _remove_yaw_offset(self, quat_wxyz):
|
||||
"""Remove stored yaw offset from orientation."""
|
||||
if abs(self.yaw_offset) < 1e-6:
|
||||
return quat_wxyz
|
||||
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
|
||||
return quat_mul(yaw_q, quat_wxyz)
|
||||
|
||||
def run_step(self):
|
||||
"""Single dance step - reads state, runs policy, sends commands."""
|
||||
robot_state = self.robot.lowstate_buffer.get_data()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Read robot state
|
||||
quat = np.array(robot_state.imu_state.quaternion, dtype=np.float32)
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
dof_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
dof_vel = np.array([robot_state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
|
||||
# Compute motion_ref_ori_b using FK
|
||||
if self.pinocchio_fk:
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
|
||||
torso_q = self._remove_yaw_offset(torso_q)
|
||||
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
|
||||
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
|
||||
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
|
||||
else:
|
||||
ori_b = np.zeros((1, 6), dtype=np.float32)
|
||||
|
||||
dof_rel = (dof_pos - DEFAULT_DOF_POS).reshape(1, -1)
|
||||
|
||||
# Build observation (alphabetical order)
|
||||
obs_dict = {
|
||||
"actions": self.last_action,
|
||||
"base_ang_vel": ang_vel.reshape(1, 3),
|
||||
"dof_pos": dof_rel,
|
||||
"dof_vel": dof_vel.reshape(1, -1),
|
||||
"motion_command": self.motion_command,
|
||||
"motion_ref_ori_b": ori_b,
|
||||
}
|
||||
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
|
||||
obs = np.clip(obs, -100, 100)
|
||||
|
||||
# Run policy
|
||||
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||
|
||||
action = np.clip(outs[0], -100, 100)
|
||||
self.motion_command = np.concatenate(outs[1:3], axis=1)
|
||||
self.ref_quat_xyzw = outs[3]
|
||||
self.last_action = action.copy()
|
||||
|
||||
# Compute target positions
|
||||
target_pos = DEFAULT_DOF_POS + action.flatten() * self.action_scale
|
||||
|
||||
# Send commands
|
||||
for i in range(NUM_DOFS):
|
||||
if i in FROZEN_JOINTS:
|
||||
self.robot.msg.motor_cmd[i].q = 0.0
|
||||
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||
else:
|
||||
self.robot.msg.motor_cmd[i].q = float(target_pos[i])
|
||||
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
|
||||
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
|
||||
self.robot.msg.motor_cmd[i].qd = 0
|
||||
self.robot.msg.motor_cmd[i].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
self.timestep += 1
|
||||
|
||||
def _dance_thread_loop(self):
|
||||
"""Background thread that runs the dance policy."""
|
||||
logger.info("Dance thread started")
|
||||
while self.dance_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.run_step()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dance loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Dance thread stopped")
|
||||
|
||||
def start_dance_thread(self):
|
||||
"""Start the dance control thread."""
|
||||
if self.dance_running:
|
||||
logger.warning("Dance thread already running")
|
||||
return
|
||||
|
||||
# Reset state for fresh start
|
||||
self.timestep = 0
|
||||
self.last_action.fill(0)
|
||||
|
||||
# Re-get initial motion data
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
|
||||
self.capture_yaw_offset()
|
||||
|
||||
logger.info("Starting dance control thread...")
|
||||
self.dance_running = True
|
||||
self.dance_thread = threading.Thread(target=self._dance_thread_loop, daemon=True)
|
||||
self.dance_thread.start()
|
||||
|
||||
def stop_dance_thread(self):
|
||||
"""Stop the dance control thread."""
|
||||
if not self.dance_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping dance control thread...")
|
||||
self.dance_running = False
|
||||
if self.dance_thread:
|
||||
self.dance_thread.join(timeout=2.0)
|
||||
logger.info("Dance control thread stopped")
|
||||
|
||||
def reset_to_motion_pose(self, duration: float = 3.0):
|
||||
"""Move robot to initial motion pose over given duration."""
|
||||
logger.info(f"Moving to dance start pose ({duration}s)...")
|
||||
|
||||
robot_state = self.robot.lowstate_buffer.get_data()
|
||||
init_pos = np.array([robot_state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
target_pos = self.motion_start_pose
|
||||
|
||||
num_steps = int(duration / CONTROL_DT)
|
||||
for step in range(num_steps):
|
||||
alpha = step / num_steps
|
||||
interp = init_pos * (1 - alpha) + target_pos * alpha
|
||||
|
||||
for i in range(NUM_DOFS):
|
||||
if i in FROZEN_JOINTS:
|
||||
self.robot.msg.motor_cmd[i].q = 0.0
|
||||
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||
else:
|
||||
self.robot.msg.motor_cmd[i].q = float(interp[i])
|
||||
self.robot.msg.motor_cmd[i].kp = STIFF_KP[i]
|
||||
self.robot.msg.motor_cmd[i].kd = STIFF_KD[i]
|
||||
self.robot.msg.motor_cmd[i].qd = 0
|
||||
self.robot.msg.motor_cmd[i].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(CONTROL_DT)
|
||||
|
||||
logger.info("At dance start pose!")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN
|
||||
# =============================================================================
|
||||
|
||||
def load_dance_policy(onnx_path: str):
|
||||
"""Load dance policy and extract metadata."""
|
||||
logger.info(f"Loading dance policy: {onnx_path}")
|
||||
|
||||
policy = ort.InferenceSession(onnx_path)
|
||||
model = onnx.load(onnx_path)
|
||||
metadata = {p.key: json.loads(p.value) for p in model.metadata_props}
|
||||
|
||||
motor_kp = np.array(metadata.get("kp", STIFF_KP), dtype=np.float32)
|
||||
motor_kd = np.array(metadata.get("kd", STIFF_KD), dtype=np.float32)
|
||||
action_scale = float(metadata.get("action_scale", 1.0))
|
||||
urdf_text = metadata.get("robot_urdf", None)
|
||||
|
||||
logger.info(f" Obs dim: {policy.get_inputs()[0].shape[1]}")
|
||||
logger.info(f" Action scale: {action_scale}")
|
||||
logger.info(f" KP range: [{motor_kp.min():.1f}, {motor_kp.max():.1f}]")
|
||||
|
||||
# Build Pinocchio FK if URDF available
|
||||
pinocchio_fk = None
|
||||
if urdf_text:
|
||||
logger.info(" Building Pinocchio FK from URDF...")
|
||||
pinocchio_fk = PinocchioFK(urdf_text)
|
||||
else:
|
||||
logger.warning(" No URDF in metadata - FK will not work!")
|
||||
|
||||
return policy, pinocchio_fk, motor_kp, motor_kd, action_scale
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="WBT Dance Policy for Unitree G1")
|
||||
parser.add_argument("--onnx", type=str, default=DANCE_ONNX_PATH, help="Path to dance ONNX model")
|
||||
parser.add_argument("--sim", action="store_true", help="Run in simulation mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("💃 WBT DANCE POLICY")
|
||||
print("=" * 70)
|
||||
|
||||
# Load policy
|
||||
policy, pinocchio_fk, motor_kp, motor_kd, action_scale = load_dance_policy(args.onnx)
|
||||
|
||||
# Initialize robot
|
||||
logger.info("Initializing robot...")
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
logger.info("Robot connected!")
|
||||
|
||||
# Create controller
|
||||
controller = DanceController(policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale)
|
||||
|
||||
try:
|
||||
# Move to start pose
|
||||
controller.reset_to_motion_pose(duration=3.0)
|
||||
|
||||
# Start dancing
|
||||
controller.start_dance_thread()
|
||||
|
||||
logger.info("Dancing! Press Ctrl+C to stop.")
|
||||
print("-" * 70)
|
||||
|
||||
# Log status periodically
|
||||
while True:
|
||||
time.sleep(2.0)
|
||||
logger.info(f"timestep={controller.timestep}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
controller.stop_dance_thread()
|
||||
robot.disconnect()
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
@@ -0,0 +1,479 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: Holosoma Whole-Body Locomotion (23-DOF and 29-DOF)
|
||||
|
||||
This example demonstrates loading Holosoma whole-body locomotion policies
|
||||
and running them on the Unitree G1 robot.
|
||||
|
||||
Supports both:
|
||||
- 23-DOF native policies (82D observations, 23D actions)
|
||||
- 29-DOF policies (100D observations, 29D actions)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# 29-DOF Configuration
|
||||
# =============================================================================
|
||||
# fmt: off
|
||||
HOLOSOMA_29DOF_DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||
0.0, 0.0, 0.0, # waist (yaw, roll, pitch)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_29DOF_KP = np.array([
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||
40.179238471, 28.501246196, 28.501246196, # waist
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # left arm
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_29DOF_KD = np.array([
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||
2.557889765, 1.814445687, 1.814445687, # waist
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # left arm
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
# =============================================================================
|
||||
# 23-DOF Configuration (native G1-23: no waist_roll/pitch, no wrist_pitch/yaw)
|
||||
# Derived from 29-DOF Holosoma values
|
||||
# =============================================================================
|
||||
# Joint order: 6 left leg, 6 right leg, 1 waist_yaw, 5 left arm, 5 right arm
|
||||
HOLOSOMA_23DOF_DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg (from 29-DOF)
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg (from 29-DOF)
|
||||
0.0, # waist_yaw only (from 29-DOF)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, # left arm first 5 joints (from 29-DOF)
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, # right arm first 5 joints (from 29-DOF)
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_23DOF_KP = np.array([
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||
40.179238471, # waist_yaw
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # left arm
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_23DOF_KD = np.array([
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||
2.557889765, # waist_yaw
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # left arm
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
# Maps 23-DOF policy index → 29-DOF motor index
|
||||
# 23-DOF: legs(0-11), waist_yaw(12), L_arm(13-17), R_arm(18-22)
|
||||
# 29-DOF: legs(0-11), waist(12-14), L_arm(15-21), R_arm(22-28)
|
||||
DOF_23_TO_MOTOR_MAP = [
|
||||
0, 1, 2, 3, 4, 5, # left leg → motor 0-5
|
||||
6, 7, 8, 9, 10, 11, # right leg → motor 6-11
|
||||
12, # waist_yaw → motor 12
|
||||
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw) → motor 15-19
|
||||
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw) → motor 22-26
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# Control parameters
|
||||
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 1.0
|
||||
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
|
||||
|
||||
def load_holosoma_policy(
|
||||
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
|
||||
policy_name: str = "fastsac",
|
||||
local_path: str | None = None,
|
||||
) -> tuple[ort.InferenceSession, int]:
|
||||
"""Load Holosoma policy and detect observation dimension.
|
||||
|
||||
Returns:
|
||||
(policy, obs_dim) tuple where obs_dim is 82 (23-DOF) or 100 (29-DOF)
|
||||
"""
|
||||
if local_path is not None:
|
||||
logger.info(f"Loading policy from local path: {local_path}")
|
||||
policy_path = local_path
|
||||
else:
|
||||
logger.info(f"Loading policy from Hugging Face Hub: {repo_id}")
|
||||
policy_path = hf_hub_download(repo_id=repo_id, filename=f"{policy_name}_g1_29dof.onnx")
|
||||
|
||||
policy = ort.InferenceSession(policy_path)
|
||||
|
||||
# Detect observation dimension from model input shape
|
||||
input_shape = policy.get_inputs()[0].shape
|
||||
obs_dim = input_shape[1] if len(input_shape) > 1 else input_shape[0]
|
||||
|
||||
logger.info(f"Policy loaded successfully")
|
||||
logger.info(f" Input: {policy.get_inputs()[0].name}, shape: {input_shape} → obs_dim={obs_dim}")
|
||||
logger.info(f" Output: {policy.get_outputs()[0].name}, shape: {policy.get_outputs()[0].shape}")
|
||||
|
||||
return policy, obs_dim
|
||||
|
||||
|
||||
class HolosomaLocomotionController:
|
||||
"""
|
||||
Handles Holosoma whole-body locomotion for Unitree G1.
|
||||
Supports both 23-DOF (82D obs) and 29-DOF (100D obs) policies.
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, config, obs_dim: int = 100):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
self.obs_dim = obs_dim
|
||||
|
||||
# Detect policy type from observation dimension
|
||||
self.is_23dof = (obs_dim == 82)
|
||||
self.num_dof = 23 if self.is_23dof else 29
|
||||
|
||||
# Velocity commands
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# State variables sized for policy type
|
||||
self.qj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.locomotion_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.locomotion_obs = np.zeros(obs_dim, dtype=np.float32)
|
||||
self.last_unscaled_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
|
||||
# Select config based on DOF
|
||||
if self.is_23dof:
|
||||
self.default_angles = HOLOSOMA_23DOF_DEFAULT_ANGLES
|
||||
self.kp = HOLOSOMA_23DOF_KP
|
||||
self.kd = HOLOSOMA_23DOF_KD
|
||||
self.motor_map = DOF_23_TO_MOTOR_MAP
|
||||
else:
|
||||
self.default_angles = HOLOSOMA_29DOF_DEFAULT_ANGLES
|
||||
self.kp = HOLOSOMA_29DOF_KP
|
||||
self.kd = HOLOSOMA_29DOF_KD
|
||||
self.motor_map = list(range(29)) # Identity map for 29-DOF
|
||||
|
||||
# Phase state for gait
|
||||
self.phase = np.zeros((1, 2), dtype=np.float32)
|
||||
self.phase[0, 0] = 0.0
|
||||
self.phase[0, 1] = np.pi
|
||||
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
|
||||
self.is_standing = False
|
||||
|
||||
self.counter = 0
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info(f"HolosomaLocomotionController initialized")
|
||||
logger.info(f" Mode: {'23-DOF (82D obs)' if self.is_23dof else '29-DOF (100D obs)'}")
|
||||
logger.info(f" Action dim: {self.num_dof}")
|
||||
|
||||
def holosoma_locomotion_run(self):
|
||||
"""Main locomotion loop - handles both 23-DOF and 29-DOF."""
|
||||
self.counter += 1
|
||||
|
||||
if self.counter == 1:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"🚀 RUNNING HOLOSOMA {self.num_dof}-DOF LOCOMOTION POLICY")
|
||||
print(f" {self.obs_dim}D observations → {self.num_dof}D actions")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
robot_state = self.robot.get_observation()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
# Deadzone
|
||||
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||
|
||||
self.locomotion_cmd[0] = ly
|
||||
self.locomotion_cmd[1] = -lx
|
||||
self.locomotion_cmd[2] = -rx
|
||||
|
||||
# Read joint states using motor map
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||
|
||||
# IMU
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale observations
|
||||
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Phase update
|
||||
cmd_norm = np.linalg.norm(self.locomotion_cmd[:2])
|
||||
ang_cmd_norm = np.abs(self.locomotion_cmd[2])
|
||||
|
||||
if cmd_norm < 0.01 and ang_cmd_norm < 0.01:
|
||||
self.phase[0, :] = np.pi * np.ones(2)
|
||||
self.is_standing = True
|
||||
elif self.is_standing:
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = False
|
||||
else:
|
||||
phase_tp1 = self.phase + self.phase_dt
|
||||
self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
sin_phase = np.sin(self.phase[0, :])
|
||||
cos_phase = np.cos(self.phase[0, :])
|
||||
|
||||
# Build observation (format depends on DOF)
|
||||
if self.is_23dof:
|
||||
# 82D: [23 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 23 pos, 23 vel, 3 grav, 2 sin]
|
||||
self.locomotion_obs[0:23] = self.last_unscaled_action
|
||||
self.locomotion_obs[23:26] = ang_vel_scaled
|
||||
self.locomotion_obs[26] = self.locomotion_cmd[2]
|
||||
self.locomotion_obs[27:29] = self.locomotion_cmd[:2]
|
||||
self.locomotion_obs[29:31] = cos_phase
|
||||
self.locomotion_obs[31:54] = qj_obs
|
||||
self.locomotion_obs[54:77] = dqj_obs
|
||||
self.locomotion_obs[77:80] = gravity_orientation
|
||||
self.locomotion_obs[80:82] = sin_phase
|
||||
else:
|
||||
# 100D: [29 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 29 pos, 29 vel, 3 grav, 2 sin]
|
||||
self.locomotion_obs[0:29] = self.last_unscaled_action
|
||||
self.locomotion_obs[29:32] = ang_vel_scaled
|
||||
self.locomotion_obs[32] = self.locomotion_cmd[2]
|
||||
self.locomotion_obs[33:35] = self.locomotion_cmd[:2]
|
||||
self.locomotion_obs[35:37] = cos_phase
|
||||
self.locomotion_obs[37:66] = qj_obs
|
||||
self.locomotion_obs[66:95] = dqj_obs
|
||||
self.locomotion_obs[95:98] = gravity_orientation
|
||||
self.locomotion_obs[98:100] = sin_phase
|
||||
|
||||
# Policy inference
|
||||
obs_input = self.locomotion_obs.reshape(1, -1).astype(np.float32)
|
||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_input}
|
||||
ort_outs = self.policy.run(None, ort_inputs)
|
||||
|
||||
raw_action = ort_outs[0].squeeze()
|
||||
clipped_action = np.clip(raw_action, -100.0, 100.0)
|
||||
|
||||
self.last_unscaled_action = clipped_action.copy()
|
||||
self.locomotion_action = clipped_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# Debug
|
||||
if self.counter <= 3:
|
||||
print(f"\n[Holosoma Debug #{self.counter}]")
|
||||
print(f" Phase: ({self.phase[0, 0]:.3f}, {self.phase[0, 1]:.3f})")
|
||||
print(f" Cmd: ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||
print(f" Action range: [{raw_action.min():.3f}, {raw_action.max():.3f}]")
|
||||
|
||||
# Compute target positions
|
||||
target_dof_pos = self.default_angles + self.locomotion_action
|
||||
|
||||
# Send commands to motors via motor map
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# For 23-DOF: zero out missing joints (waist_roll/pitch, wrist_pitch/yaw)
|
||||
if self.is_23dof:
|
||||
missing_motors = [13, 14, 20, 21, 27, 28] # waist_roll, waist_pitch, wrist_pitch/yaw
|
||||
for motor_idx in missing_motors:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.holosoma_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move joints to default position."""
|
||||
logger.info(f"Moving {self.num_dof} joints to default position...")
|
||||
|
||||
total_time = 3.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Record current positions
|
||||
init_dof_pos = np.zeros(self.num_dof, dtype=np.float32)
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
init_dof_pos[i] = robot_state.motor_state[motor_idx].q
|
||||
|
||||
# Interpolate to target
|
||||
for step in range(num_step):
|
||||
alpha = step / num_step
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
target = self.default_angles[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[i] * (1 - alpha) + target * alpha
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Zero missing joints for 23-DOF
|
||||
if self.is_23dof:
|
||||
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info(f"Reached default position ({self.num_dof} joints)")
|
||||
|
||||
# Hold for 2 seconds
|
||||
logger.info("Holding default position for 2 seconds...")
|
||||
hold_steps = int(2.0 / self.robot.control_dt)
|
||||
for _ in range(hold_steps):
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.default_angles[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
if self.is_23dof:
|
||||
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Ready to start locomotion!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
||||
parser.add_argument("--repo-id", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
|
||||
parser.add_argument("--policy", type=str, default="fastsac", choices=["fastsac", "ppo"])
|
||||
parser.add_argument("--local-path", type=str, default=None, help="Path to local ONNX file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy and detect dimensions
|
||||
policy, obs_dim = load_holosoma_policy(
|
||||
repo_id=args.repo_id,
|
||||
policy_name=args.policy,
|
||||
local_path=args.local_path,
|
||||
)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# Initialize controller with detected obs_dim
|
||||
controller = HolosomaLocomotionController(
|
||||
policy=policy,
|
||||
robot=robot,
|
||||
config=config,
|
||||
obs_dim=obs_dim,
|
||||
)
|
||||
|
||||
try:
|
||||
#controller.reset_robot()
|
||||
controller.start_locomotion_thread()
|
||||
|
||||
logger.info(f"Robot initialized with Holosoma {'23-DOF' if obs_dim == 82 else '29-DOF'} policy")
|
||||
logger.info("Use remote controller: LY=fwd/back, LX=left/right, RX=rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
@@ -0,0 +1,607 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Locomotion ↔ Dance Toggle for Unitree G1
|
||||
|
||||
Press Enter to instantly switch between locomotion and dance modes.
|
||||
- Starts in LOCOMOTION mode (joystick control)
|
||||
- Press Enter → DANCE mode (resets to frame 0)
|
||||
- Press Enter → LOCOMOTION mode
|
||||
- Repeat...
|
||||
|
||||
Auto-recovery feature:
|
||||
- If robot tilts beyond threshold during dance, auto-switches to locomotion
|
||||
- When robot recovers (tilt below recovery threshold), resumes dance from where it left off
|
||||
|
||||
Usage:
|
||||
python examples/unitree_g1/locomotion_to_dance.py
|
||||
python examples/unitree_g1/locomotion_to_dance.py --tilt-threshold 25 --recovery-threshold 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import select
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import pinocchio as pin
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
NUM_DOFS = 29
|
||||
CONTROL_DT = 0.02 # 50Hz
|
||||
|
||||
# Locomotion config
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 1.0
|
||||
|
||||
# Dance config
|
||||
DANCE_ONNX_PATH = "examples/unitree_g1/fastsac_g1_29dof_dancing.onnx"
|
||||
FROZEN_JOINTS = [13, 14, 20, 21, 27, 28]
|
||||
FROZEN_KP = 500.0
|
||||
FROZEN_KD = 5.0
|
||||
|
||||
# fmt: off
|
||||
# 29-DOF defaults (holosoma training)
|
||||
DEFAULT_29DOF_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||
0.0, 0.0, 0.0, # waist
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_29DOF_KP = np.array([
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179, 28.501, 28.501,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251, 16.778, 16.778,
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_29DOF_KD = np.array([
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558, 1.814, 1.814,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907, 1.068, 1.068,
|
||||
], dtype=np.float32)
|
||||
|
||||
# 23-DOF config (no waist_roll/pitch, no wrist_pitch/yaw)
|
||||
DEFAULT_23DOF_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||
0.0, # waist_yaw only
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, # left arm (5 joints)
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, # right arm (5 joints)
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_23DOF_KP = np.array([
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179, 99.098, 40.179, 99.098, 28.501, 28.501,
|
||||
40.179,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251,
|
||||
14.251, 14.251, 14.251, 14.251, 14.251,
|
||||
], dtype=np.float32)
|
||||
|
||||
DEFAULT_23DOF_KD = np.array([
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558, 6.309, 2.558, 6.309, 1.814, 1.814,
|
||||
2.558,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907,
|
||||
0.907, 0.907, 0.907, 0.907, 0.907,
|
||||
], dtype=np.float32)
|
||||
|
||||
# 23-DOF policy index → 29-DOF motor index
|
||||
DOF_23_TO_MOTOR = [
|
||||
0, 1, 2, 3, 4, 5, # left leg
|
||||
6, 7, 8, 9, 10, 11, # right leg
|
||||
12, # waist_yaw
|
||||
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw)
|
||||
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw)
|
||||
]
|
||||
MISSING_23DOF_MOTORS = [13, 14, 20, 21, 27, 28]
|
||||
# fmt: on
|
||||
|
||||
# =============================================================================
|
||||
# QUATERNION UTILITIES
|
||||
# =============================================================================
|
||||
|
||||
def quat_inverse(q):
|
||||
return np.concatenate((q[:, 0:1], -q[:, 1:]), axis=1)
|
||||
|
||||
def quat_mul(a, b):
|
||||
a, b = a.reshape(-1, 4), b.reshape(-1, 4)
|
||||
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
|
||||
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
|
||||
ww = (z1 + x1) * (x2 + y2)
|
||||
yy = (w1 - y1) * (w2 + z2)
|
||||
zz = (w1 + y1) * (w2 - z2)
|
||||
xx = ww + yy + zz
|
||||
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
|
||||
w = qq - ww + (z1 - y1) * (y2 - z2)
|
||||
x = qq - xx + (x1 + w1) * (x2 + w2)
|
||||
y = qq - yy + (w1 - x1) * (y2 + z2)
|
||||
z = qq - zz + (z1 + y1) * (w2 - x2)
|
||||
return np.stack([w, x, y, z]).T.reshape(a.shape)
|
||||
|
||||
def subtract_frame_transforms(q01, q02):
|
||||
return quat_mul(quat_inverse(q01), q02)
|
||||
|
||||
def matrix_from_quat(q):
|
||||
r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
||||
two_s = 2.0 / (q * q).sum(-1)
|
||||
o = np.stack((
|
||||
1 - two_s * (j*j + k*k), two_s * (i*j - k*r), two_s * (i*k + j*r),
|
||||
two_s * (i*j + k*r), 1 - two_s * (i*i + k*k), two_s * (j*k - i*r),
|
||||
two_s * (i*k - j*r), two_s * (j*k + i*r), 1 - two_s * (i*i + j*j),
|
||||
), -1)
|
||||
return o.reshape(q.shape[:-1] + (3, 3))
|
||||
|
||||
def xyzw_to_wxyz(xyzw):
|
||||
return np.concatenate([xyzw[:, -1:], xyzw[:, :3]], axis=1)
|
||||
|
||||
def quat_to_rpy(q):
|
||||
w, x, y, z = q
|
||||
roll = np.arctan2(2*(w*x + y*z), 1 - 2*(x**2 + y**2))
|
||||
pitch = np.arcsin(np.clip(2*(w*y - z*x), -1, 1))
|
||||
yaw = np.arctan2(2*(w*z + x*y), 1 - 2*(y**2 + z**2))
|
||||
return roll, pitch, yaw
|
||||
|
||||
def rpy_to_quat(rpy):
|
||||
roll, pitch, yaw = rpy
|
||||
cy, sy = np.cos(yaw*0.5), np.sin(yaw*0.5)
|
||||
cp, sp = np.cos(pitch*0.5), np.sin(pitch*0.5)
|
||||
cr, sr = np.cos(roll*0.5), np.sin(roll*0.5)
|
||||
return np.array([cr*cp*cy + sr*sp*sy, sr*cp*cy - cr*sp*sy,
|
||||
cr*sp*cy + sr*cp*sy, cr*cp*sy - sr*sp*cy])
|
||||
|
||||
# =============================================================================
|
||||
# PINOCCHIO FK
|
||||
# =============================================================================
|
||||
|
||||
DOF_NAMES = (
|
||||
"left_hip_pitch_joint", "left_hip_roll_joint", "left_hip_yaw_joint",
|
||||
"left_knee_joint", "left_ankle_pitch_joint", "left_ankle_roll_joint",
|
||||
"right_hip_pitch_joint", "right_hip_roll_joint", "right_hip_yaw_joint",
|
||||
"right_knee_joint", "right_ankle_pitch_joint", "right_ankle_roll_joint",
|
||||
"waist_yaw_joint", "waist_roll_joint", "waist_pitch_joint",
|
||||
"left_shoulder_pitch_joint", "left_shoulder_roll_joint", "left_shoulder_yaw_joint", "left_elbow_joint",
|
||||
"left_wrist_roll_joint", "left_wrist_pitch_joint", "left_wrist_yaw_joint",
|
||||
"right_shoulder_pitch_joint", "right_shoulder_roll_joint", "right_shoulder_yaw_joint", "right_elbow_joint",
|
||||
"right_wrist_roll_joint", "right_wrist_pitch_joint", "right_wrist_yaw_joint",
|
||||
)
|
||||
|
||||
|
||||
class PinocchioFK:
|
||||
def __init__(self, urdf_text: str):
|
||||
root = ElementTree.fromstring(urdf_text)
|
||||
for parent in root.iter():
|
||||
for child in list(parent):
|
||||
if child.tag.split("}")[-1] in {"visual", "collision"}:
|
||||
parent.remove(child)
|
||||
xml_text = '<?xml version="1.0"?>\n' + ElementTree.tostring(root, encoding="unicode")
|
||||
self.model = pin.buildModelFromXML(xml_text, pin.JointModelFreeFlyer())
|
||||
self.data = self.model.createData()
|
||||
pin_names = [n for n in self.model.names if n not in ["universe", "root_joint"]]
|
||||
self.idx_map = np.array([DOF_NAMES.index(n) for n in pin_names])
|
||||
self.ref_frame_id = self.model.getFrameId("torso_link")
|
||||
|
||||
def get_torso_quat(self, pos, quat_wxyz, dof_pos):
|
||||
quat_xyzw = np.array([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]])
|
||||
config = np.concatenate([pos, quat_xyzw, dof_pos[self.idx_map]])
|
||||
pin.framesForwardKinematics(self.model, self.data, config)
|
||||
coeffs = pin.Quaternion(self.data.oMf[self.ref_frame_id].rotation).coeffs()
|
||||
return np.array([coeffs[3], coeffs[0], coeffs[1], coeffs[2]]).reshape(1, 4)
|
||||
|
||||
def get_torso_tilt(self, pos, quat_wxyz, dof_pos):
|
||||
"""Get torso tilt angle from upright (degrees). Uses roll and pitch."""
|
||||
torso_q = self.get_torso_quat(pos, quat_wxyz, dof_pos)
|
||||
roll, pitch, _ = quat_to_rpy(torso_q.flatten())
|
||||
# Tilt is the angle from vertical - combine roll and pitch
|
||||
tilt_rad = np.sqrt(roll**2 + pitch**2)
|
||||
return np.degrees(tilt_rad), np.degrees(roll), np.degrees(pitch)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOCOMOTION CONTROLLER
|
||||
# =============================================================================
|
||||
|
||||
class LocomotionController:
|
||||
"""Holosoma whole-body locomotion (23-DOF or 29-DOF)."""
|
||||
|
||||
def __init__(self, policy, robot, obs_dim: int):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.obs_dim = obs_dim
|
||||
|
||||
# Detect DOF mode
|
||||
self.is_23dof = (obs_dim == 82)
|
||||
self.num_dof = 23 if self.is_23dof else 29
|
||||
|
||||
if self.is_23dof:
|
||||
self.default_angles = DEFAULT_23DOF_ANGLES
|
||||
self.kp = DEFAULT_23DOF_KP
|
||||
self.kd = DEFAULT_23DOF_KD
|
||||
self.motor_map = DOF_23_TO_MOTOR
|
||||
logger.info("Locomotion: 23-DOF (82D obs)")
|
||||
else:
|
||||
self.default_angles = DEFAULT_29DOF_ANGLES
|
||||
self.kp = DEFAULT_29DOF_KP
|
||||
self.kd = DEFAULT_29DOF_KD
|
||||
self.motor_map = list(range(29))
|
||||
logger.info("Locomotion: 29-DOF (100D obs)")
|
||||
|
||||
self.cmd = np.zeros(3, dtype=np.float32)
|
||||
self.qj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.obs = np.zeros(obs_dim, dtype=np.float32)
|
||||
self.last_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
|
||||
self.is_standing = True
|
||||
|
||||
def run_step(self):
|
||||
"""Single locomotion step."""
|
||||
state = self.robot.lowstate_buffer.get_data()
|
||||
if state is None:
|
||||
return
|
||||
|
||||
# Joystick
|
||||
if state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(state.wireless_remote)
|
||||
|
||||
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||
self.cmd[0], self.cmd[1], self.cmd[2] = ly, -lx, -rx
|
||||
|
||||
# Read joints via motor map
|
||||
for i in range(self.num_dof):
|
||||
self.qj[i] = state.motor_state[self.motor_map[i]].q
|
||||
self.dqj[i] = state.motor_state[self.motor_map[i]].dq
|
||||
|
||||
# IMU
|
||||
quat = state.imu_state.quaternion
|
||||
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale
|
||||
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_s = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Phase
|
||||
cmd_mag = np.linalg.norm(self.cmd[:2])
|
||||
ang_mag = abs(self.cmd[2])
|
||||
if cmd_mag < 0.01 and ang_mag < 0.01:
|
||||
self.phase[0, :] = np.pi
|
||||
self.is_standing = True
|
||||
elif self.is_standing:
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = False
|
||||
else:
|
||||
self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2*np.pi) - np.pi
|
||||
|
||||
sin_ph, cos_ph = np.sin(self.phase[0]), np.cos(self.phase[0])
|
||||
|
||||
# Build obs
|
||||
if self.is_23dof:
|
||||
self.obs[0:23] = self.last_action
|
||||
self.obs[23:26] = ang_vel_s
|
||||
self.obs[26] = self.cmd[2]
|
||||
self.obs[27:29] = self.cmd[:2]
|
||||
self.obs[29:31] = cos_ph
|
||||
self.obs[31:54] = qj_obs
|
||||
self.obs[54:77] = dqj_obs
|
||||
self.obs[77:80] = gravity
|
||||
self.obs[80:82] = sin_ph
|
||||
else:
|
||||
self.obs[0:29] = self.last_action
|
||||
self.obs[29:32] = ang_vel_s
|
||||
self.obs[32] = self.cmd[2]
|
||||
self.obs[33:35] = self.cmd[:2]
|
||||
self.obs[35:37] = cos_ph
|
||||
self.obs[37:66] = qj_obs
|
||||
self.obs[66:95] = dqj_obs
|
||||
self.obs[95:98] = gravity
|
||||
self.obs[98:100] = sin_ph
|
||||
|
||||
# Inference
|
||||
obs_in = self.obs.reshape(1, -1).astype(np.float32)
|
||||
ort_in = {self.policy.get_inputs()[0].name: obs_in}
|
||||
raw_action = self.policy.run(None, ort_in)[0].squeeze()
|
||||
clipped = np.clip(raw_action, -100.0, 100.0)
|
||||
self.last_action = clipped.copy()
|
||||
scaled = clipped * LOCOMOTION_ACTION_SCALE
|
||||
target = self.default_angles + scaled
|
||||
|
||||
# Send commands
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = float(target[i])
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Zero missing joints for 23-DOF
|
||||
if self.is_23dof:
|
||||
for idx in MISSING_23DOF_MOTORS:
|
||||
self.robot.msg.motor_cmd[idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[idx].qd = 0
|
||||
self.robot.msg.motor_cmd[idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[idx].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def reset(self):
|
||||
"""Reset state for fresh start."""
|
||||
self.last_action.fill(0)
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DANCE CONTROLLER
|
||||
# =============================================================================
|
||||
|
||||
class DanceController:
|
||||
"""WBT dance policy with FK for torso tracking."""
|
||||
|
||||
def __init__(self, policy, robot, pinocchio_fk, motor_kp, motor_kd, action_scale):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.pinocchio_fk = pinocchio_fk
|
||||
self.motor_kp = motor_kp
|
||||
self.motor_kd = motor_kd
|
||||
self.action_scale = action_scale
|
||||
|
||||
self.obs_dim = policy.get_inputs()[0].shape[1]
|
||||
self.last_action = np.zeros((1, NUM_DOFS), dtype=np.float32)
|
||||
self.motion_command = None
|
||||
self.ref_quat_xyzw = None
|
||||
self.timestep = 0
|
||||
self.yaw_offset = 0.0
|
||||
|
||||
logger.info(f"Dance: obs_dim={self.obs_dim}, action_scale={action_scale}")
|
||||
|
||||
def initialize(self, reset_to_frame_0: bool = True):
|
||||
"""Initialize dance. If reset_to_frame_0=True, starts from frame 0. Otherwise resumes."""
|
||||
if reset_to_frame_0:
|
||||
self.timestep = 0
|
||||
self.last_action.fill(0)
|
||||
|
||||
# Get initial motion data at frame 0
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[0]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
logger.info("Dance: reset to frame 0")
|
||||
else:
|
||||
# Resume from current timestep - just update motion command for current frame
|
||||
dummy = np.zeros((1, self.obs_dim), dtype=np.float32)
|
||||
outs = self.policy.run(["joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": dummy, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||
self.motion_command = np.concatenate(outs[0:2], axis=1)
|
||||
self.ref_quat_xyzw = outs[2]
|
||||
logger.info(f"Dance: resuming from frame {self.timestep}")
|
||||
|
||||
# Capture yaw offset
|
||||
state = self.robot.lowstate_buffer.get_data()
|
||||
if state and self.pinocchio_fk:
|
||||
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
|
||||
dof = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof)
|
||||
_, _, self.yaw_offset = quat_to_rpy(torso_q.flatten())
|
||||
logger.info(f"Dance yaw offset: {np.degrees(self.yaw_offset):.1f}°")
|
||||
|
||||
def _remove_yaw_offset(self, quat_wxyz):
|
||||
if abs(self.yaw_offset) < 1e-6:
|
||||
return quat_wxyz
|
||||
yaw_q = rpy_to_quat((0, 0, -self.yaw_offset)).reshape(1, 4)
|
||||
return quat_mul(yaw_q, quat_wxyz)
|
||||
|
||||
def run_step(self):
|
||||
"""Single dance step."""
|
||||
state = self.robot.lowstate_buffer.get_data()
|
||||
if state is None:
|
||||
return
|
||||
|
||||
quat = np.array(state.imu_state.quaternion, dtype=np.float32)
|
||||
ang_vel = np.array(state.imu_state.gyroscope, dtype=np.float32)
|
||||
dof_pos = np.array([state.motor_state[i].q for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
dof_vel = np.array([state.motor_state[i].dq for i in range(NUM_DOFS)], dtype=np.float32)
|
||||
|
||||
# FK for torso orientation
|
||||
if self.pinocchio_fk:
|
||||
torso_q = self.pinocchio_fk.get_torso_quat(np.zeros(3), quat, dof_pos)
|
||||
torso_q = self._remove_yaw_offset(torso_q)
|
||||
motion_ori = xyzw_to_wxyz(self.ref_quat_xyzw)
|
||||
rel_quat = subtract_frame_transforms(torso_q, motion_ori)
|
||||
ori_b = matrix_from_quat(rel_quat)[..., :2].reshape(1, -1)
|
||||
else:
|
||||
ori_b = np.zeros((1, 6), dtype=np.float32)
|
||||
|
||||
dof_rel = (dof_pos - DEFAULT_29DOF_ANGLES).reshape(1, -1)
|
||||
|
||||
# Build obs (alphabetical)
|
||||
obs_dict = {
|
||||
"actions": self.last_action,
|
||||
"base_ang_vel": ang_vel.reshape(1, 3),
|
||||
"dof_pos": dof_rel,
|
||||
"dof_vel": dof_vel.reshape(1, -1),
|
||||
"motion_command": self.motion_command,
|
||||
"motion_ref_ori_b": ori_b,
|
||||
}
|
||||
obs = np.concatenate([obs_dict[k].astype(np.float32) for k in sorted(obs_dict.keys())], axis=1)
|
||||
obs = np.clip(obs, -100, 100)
|
||||
|
||||
# Inference
|
||||
outs = self.policy.run(["actions", "joint_pos", "joint_vel", "ref_quat_xyzw"],
|
||||
{"obs": obs, "time_step": np.array([[self.timestep]], dtype=np.float32)})
|
||||
action = np.clip(outs[0], -100, 100)
|
||||
self.motion_command = np.concatenate(outs[1:3], axis=1)
|
||||
self.ref_quat_xyzw = outs[3]
|
||||
self.last_action = action.copy()
|
||||
|
||||
target = DEFAULT_29DOF_ANGLES + action.flatten() * self.action_scale
|
||||
|
||||
# Send commands
|
||||
for i in range(NUM_DOFS):
|
||||
if i in FROZEN_JOINTS:
|
||||
self.robot.msg.motor_cmd[i].q = 0.0
|
||||
self.robot.msg.motor_cmd[i].kp = FROZEN_KP
|
||||
self.robot.msg.motor_cmd[i].kd = FROZEN_KD
|
||||
else:
|
||||
self.robot.msg.motor_cmd[i].q = float(target[i])
|
||||
self.robot.msg.motor_cmd[i].kp = self.motor_kp[i]
|
||||
self.robot.msg.motor_cmd[i].kd = self.motor_kd[i]
|
||||
self.robot.msg.motor_cmd[i].qd = 0
|
||||
self.robot.msg.motor_cmd[i].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
self.timestep += 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Locomotion ↔ Dance Toggle")
|
||||
parser.add_argument("--loco-repo", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
|
||||
parser.add_argument("--dance-onnx", type=str, default=DANCE_ONNX_PATH)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 70)
|
||||
print("🚶 LOCOMOTION ↔ 💃 DANCE")
|
||||
print("=" * 70)
|
||||
print("Press ENTER to toggle between modes")
|
||||
print("=" * 70)
|
||||
|
||||
# Load locomotion policy
|
||||
logger.info("Loading locomotion policy...")
|
||||
loco_path = hf_hub_download(repo_id=args.loco_repo, filename="fastsac_g1_29dof.onnx")
|
||||
loco_policy = ort.InferenceSession(loco_path)
|
||||
loco_obs_dim = loco_policy.get_inputs()[0].shape[1]
|
||||
logger.info(f"Locomotion: {loco_obs_dim}D obs")
|
||||
|
||||
# Load dance policy
|
||||
logger.info("Loading dance policy...")
|
||||
dance_policy = ort.InferenceSession(args.dance_onnx)
|
||||
dance_model = onnx.load(args.dance_onnx)
|
||||
dance_meta = {p.key: json.loads(p.value) for p in dance_model.metadata_props}
|
||||
dance_kp = np.array(dance_meta.get("kp", DEFAULT_29DOF_KP), dtype=np.float32)
|
||||
dance_kd = np.array(dance_meta.get("kd", DEFAULT_29DOF_KD), dtype=np.float32)
|
||||
dance_action_scale = float(dance_meta.get("action_scale", 1.0))
|
||||
logger.info(f"Dance: {dance_policy.get_inputs()[0].shape[1]}D obs, scale={dance_action_scale}")
|
||||
|
||||
# Build Pinocchio FK
|
||||
pinocchio_fk = None
|
||||
if "robot_urdf" in dance_meta:
|
||||
logger.info("Building Pinocchio FK...")
|
||||
pinocchio_fk = PinocchioFK(dance_meta["robot_urdf"])
|
||||
|
||||
# Initialize robot
|
||||
logger.info("Initializing robot...")
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
logger.info("Robot connected!")
|
||||
|
||||
# Create controllers
|
||||
loco_ctrl = LocomotionController(loco_policy, robot, loco_obs_dim)
|
||||
dance_ctrl = DanceController(dance_policy, robot, pinocchio_fk, dance_kp, dance_kd, dance_action_scale)
|
||||
|
||||
# State
|
||||
mode = "locomotion"
|
||||
toggle_event = threading.Event()
|
||||
shutdown = threading.Event()
|
||||
|
||||
# Input thread
|
||||
def input_loop():
|
||||
while not shutdown.is_set():
|
||||
if select.select([sys.stdin], [], [], 0.1)[0]:
|
||||
sys.stdin.readline()
|
||||
toggle_event.set()
|
||||
|
||||
input_thread = threading.Thread(target=input_loop, daemon=True)
|
||||
input_thread.start()
|
||||
|
||||
print("\n🚶 LOCOMOTION MODE - Use joystick to walk")
|
||||
print(" Press ENTER to switch to DANCE")
|
||||
print("-" * 70)
|
||||
|
||||
step = 0
|
||||
try:
|
||||
while not shutdown.is_set():
|
||||
t0 = time.time()
|
||||
|
||||
# Check toggle
|
||||
if toggle_event.is_set():
|
||||
toggle_event.clear()
|
||||
if mode == "locomotion":
|
||||
mode = "dance"
|
||||
dance_ctrl.initialize()
|
||||
print("\n" + "=" * 70)
|
||||
print("💃 DANCE MODE (frame 0)")
|
||||
print(" Press ENTER to switch to LOCOMOTION")
|
||||
print("=" * 70)
|
||||
else:
|
||||
mode = "locomotion"
|
||||
loco_ctrl.reset()
|
||||
print("\n" + "=" * 70)
|
||||
print("🚶 LOCOMOTION MODE")
|
||||
print(" Press ENTER to switch to DANCE")
|
||||
print("=" * 70)
|
||||
|
||||
# Run controller
|
||||
if mode == "locomotion":
|
||||
loco_ctrl.run_step()
|
||||
else:
|
||||
dance_ctrl.run_step()
|
||||
|
||||
# Log
|
||||
if step % 100 == 0:
|
||||
if mode == "locomotion":
|
||||
print(f"[LOCO ] step={step:5d} cmd=[{loco_ctrl.cmd[0]:.2f},{loco_ctrl.cmd[1]:.2f},{loco_ctrl.cmd[2]:.2f}]")
|
||||
else:
|
||||
print(f"[DANCE] step={step:5d} timestep={dance_ctrl.timestep}")
|
||||
|
||||
step += 1
|
||||
elapsed = time.time() - t0
|
||||
if elapsed < CONTROL_DT:
|
||||
time.sleep(CONTROL_DT - elapsed)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopping...")
|
||||
finally:
|
||||
shutdown.set()
|
||||
robot.disconnect()
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,447 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: Unitree RL 12-DOF Legs-Only Locomotion (TorchScript)
|
||||
|
||||
This example demonstrates loading a 12-DOF legs-only locomotion policy
|
||||
(TorchScript .pt format) and running it on the Unitree G1 robot.
|
||||
|
||||
Key characteristics:
|
||||
- Single TorchScript policy (.pt)
|
||||
- 47D observations, 12D actions (legs only)
|
||||
- Phase-based gait timing
|
||||
- Arms and waist held at fixed positions
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 12-DOF leg joint configuration
|
||||
# Joint order: [L_hip_pitch, L_hip_roll, L_hip_yaw, L_knee, L_ankle_pitch, L_ankle_roll,
|
||||
# R_hip_pitch, R_hip_roll, R_hip_yaw, R_knee, R_ankle_pitch, R_ankle_roll]
|
||||
LEG_JOINT_INDICES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||
|
||||
# Default leg angles for standing
|
||||
DEFAULT_LEG_ANGLES = np.array([
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
|
||||
], dtype=np.float32)
|
||||
|
||||
# KP/KD for leg joints
|
||||
LEG_KPS = np.array([150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40], dtype=np.float32)
|
||||
LEG_KDS = np.array([6, 6, 6, 4, 2, 2, 6, 6, 6, 4, 2, 2], dtype=np.float32)
|
||||
|
||||
# Waist configuration (held at zero)
|
||||
WAIST_JOINT_INDICES = [12, 13, 14] # yaw, roll, pitch
|
||||
WAIST_KPS = np.array([250, 250, 250], dtype=np.float32)
|
||||
WAIST_KDS = np.array([5, 5, 5], dtype=np.float32)
|
||||
|
||||
# Arm configuration (indices 15-28, held at initial position)
|
||||
ARM_JOINT_INDICES = list(range(15, 29))
|
||||
ARM_KPS = np.array([80, 80, 80, 80, 40, 40, 40, # left arm (shoulder + wrist)
|
||||
80, 80, 80, 80, 40, 40, 40], dtype=np.float32) # right arm
|
||||
ARM_KDS = np.array([3, 3, 3, 3, 1.5, 1.5, 1.5,
|
||||
3, 3, 3, 3, 1.5, 1.5, 1.5], dtype=np.float32)
|
||||
|
||||
# Control parameters
|
||||
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz control rate
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
CMD_SCALE = np.array([2.0, 2.0, 0.25], dtype=np.float32)
|
||||
MAX_CMD = np.array([0.8, 0.5, 1.57], dtype=np.float32) # max vx, vy, yaw_rate
|
||||
|
||||
# Gait parameters
|
||||
GAIT_PERIOD = 0.8 # seconds
|
||||
|
||||
DEFAULT_REPO_ID = "nepyope/unitree_rl_locomotion"
|
||||
|
||||
|
||||
def load_torchscript_policy(
|
||||
repo_id: str = DEFAULT_REPO_ID,
|
||||
filename: str = "motion.pt",
|
||||
) -> torch.jit.ScriptModule:
|
||||
"""Load TorchScript locomotion policy from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the policy.
|
||||
filename: Policy filename (default: motion.pt).
|
||||
"""
|
||||
logger.info(f"Loading TorchScript policy from Hugging Face Hub ({repo_id}/{filename})...")
|
||||
|
||||
policy_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
policy = torch.jit.load(policy_path)
|
||||
policy.eval()
|
||||
|
||||
logger.info("TorchScript policy loaded successfully")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
class UnitreeRLLocomotionController:
|
||||
"""
|
||||
Handles 12-DOF legs-only locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Single TorchScript policy
|
||||
- 47D observations (single frame)
|
||||
- 12D action output (legs only)
|
||||
- Arms and waist held at fixed positions
|
||||
- Phase-based gait timing
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, config):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
# Velocity commands (vx, vy, yaw_rate)
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# State variables (12 DOF legs)
|
||||
self.qj = np.zeros(12, dtype=np.float32)
|
||||
self.dqj = np.zeros(12, dtype=np.float32)
|
||||
self.locomotion_action = np.zeros(12, dtype=np.float32)
|
||||
self.locomotion_obs = np.zeros(47, dtype=np.float32)
|
||||
|
||||
# Initial arm positions (captured on reset)
|
||||
self.initial_arm_positions = np.zeros(14, dtype=np.float32)
|
||||
|
||||
# Counter for phase calculation
|
||||
self.counter = 0
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("UnitreeRLLocomotionController initialized")
|
||||
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
|
||||
|
||||
def locomotion_run(self):
|
||||
"""12-DOF legs-only locomotion policy loop."""
|
||||
self.counter += 1
|
||||
|
||||
if self.counter == 1:
|
||||
print("\n" + "=" * 60)
|
||||
print("🚀 RUNNING UNITREE RL 12-DOF LOCOMOTION POLICY")
|
||||
print(" 47D observations → 12D actions (legs only)")
|
||||
print(" Arms and waist held at fixed positions")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right (inverted)
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # yaw (inverted)
|
||||
|
||||
# Get leg joint positions and velocities (12 DOF)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||
|
||||
# Get IMU data
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
|
||||
# Scale observations
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Calculate phase
|
||||
count = self.counter * LOCOMOTION_CONTROL_DT
|
||||
phase = (count % GAIT_PERIOD) / GAIT_PERIOD
|
||||
sin_phase = np.sin(2 * np.pi * phase)
|
||||
cos_phase = np.cos(2 * np.pi * phase)
|
||||
|
||||
# Build 47D observation vector
|
||||
# [0:3] - angular velocity (scaled)
|
||||
# [3:6] - gravity orientation
|
||||
# [6:9] - velocity command (scaled)
|
||||
# [9:21] - joint positions (12D, relative to default)
|
||||
# [21:33] - joint velocities (12D, scaled)
|
||||
# [33:45] - previous actions (12D)
|
||||
# [45] - sin_phase
|
||||
# [46] - cos_phase
|
||||
self.locomotion_obs[0:3] = ang_vel_scaled
|
||||
self.locomotion_obs[3:6] = gravity_orientation
|
||||
self.locomotion_obs[6:9] = self.locomotion_cmd * CMD_SCALE * MAX_CMD
|
||||
self.locomotion_obs[9:21] = qj_obs
|
||||
self.locomotion_obs[21:33] = dqj_obs
|
||||
self.locomotion_obs[33:45] = self.locomotion_action
|
||||
self.locomotion_obs[45] = sin_phase
|
||||
self.locomotion_obs[46] = cos_phase
|
||||
|
||||
# Run policy inference (TorchScript)
|
||||
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
action_tensor = self.policy(obs_tensor)
|
||||
self.locomotion_action = action_tensor.squeeze().numpy()
|
||||
|
||||
# Transform action to target joint positions
|
||||
target_leg_pos = DEFAULT_LEG_ANGLES + self.locomotion_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# Debug logging (first 3 iterations)
|
||||
if self.counter <= 3:
|
||||
print(f"\n[Unitree RL Debug #{self.counter}]")
|
||||
print(f" Phase: {phase:.3f} (sin={sin_phase:.3f}, cos={cos_phase:.3f})")
|
||||
print(f" Cmd (vx, vy, yaw): ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||
print(f" Action range: [{self.locomotion_action.min():.3f}, {self.locomotion_action.max():.3f}]")
|
||||
|
||||
# Send commands to LEG motors (0-11)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_leg_pos[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold WAIST motors at zero (12, 13, 14)
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold ARM motors at initial position (15-28)
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Send command
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move legs to default standing position over 2 seconds (arms are captured and held)."""
|
||||
logger.info("Moving legs to default position...")
|
||||
|
||||
total_time = 2.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
# Get current state
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Capture initial arm positions (to hold during locomotion)
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.initial_arm_positions[i] = robot_state.motor_state[motor_idx].q
|
||||
logger.info(f"Captured initial arm positions: {self.initial_arm_positions[:4]}...")
|
||||
|
||||
# Record current leg positions
|
||||
init_leg_pos = np.zeros(12, dtype=np.float32)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
init_leg_pos[i] = robot_state.motor_state[motor_idx].q
|
||||
|
||||
# Interpolate legs to default position
|
||||
for step in range(num_step):
|
||||
alpha = step / num_step
|
||||
|
||||
# Interpolate leg positions
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
target_pos = DEFAULT_LEG_ANGLES[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_leg_pos[i] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold waist at zero
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold arms at initial position
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Reached default leg position")
|
||||
|
||||
# Hold position for 2 seconds
|
||||
logger.info("Holding default position for 2 seconds...")
|
||||
hold_time = 2.0
|
||||
num_hold_steps = int(hold_time / self.robot.control_dt)
|
||||
|
||||
for _ in range(num_hold_steps):
|
||||
# Hold legs at default
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = DEFAULT_LEG_ANGLES[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold waist at zero
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold arms at initial position
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Ready to start locomotion!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Unitree RL 12-DOF Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for policy (default: {DEFAULT_REPO_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filename",
|
||||
type=str,
|
||||
default="motion.pt",
|
||||
help="Policy filename (default: motion.pt)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy
|
||||
policy = load_torchscript_policy(repo_id=args.repo_id, filename=args.filename)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# Initialize locomotion controller
|
||||
locomotion_controller = UnitreeRLLocomotionController(
|
||||
policy=policy,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Reset robot and start locomotion thread
|
||||
try:
|
||||
locomotion_controller.reset_robot()
|
||||
locomotion_controller.start_locomotion_thread()
|
||||
|
||||
# Log status
|
||||
logger.info("Robot initialized with Unitree RL locomotion policy")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Use remote controller to command velocity:")
|
||||
logger.info(" Left stick Y: forward/backward")
|
||||
logger.info(" Left stick X: left/right")
|
||||
logger.info(" Right stick X: rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Keep robot alive
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
locomotion_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||
@@ -52,7 +54,10 @@ class UnitreeG1Config(RobotConfig):
|
||||
control_dt: float = 1.0 / 250.0 # 250Hz
|
||||
|
||||
# launch mujoco simulation
|
||||
is_simulation: bool = True
|
||||
is_simulation: bool = False
|
||||
|
||||
# socket config for ZMQ bridge
|
||||
robot_ip: str = "192.168.123.164"
|
||||
robot_ip: str = "172.18.129.215"
|
||||
|
||||
# cameras (optional)
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Standalone keyboard control script for Unitree G1 robot.
|
||||
|
||||
This script provides keyboard-based velocity control for the G1 robot's
|
||||
locomotion system. It can be run alongside the main robot control to
|
||||
provide manual movement commands.
|
||||
|
||||
Usage:
|
||||
python keyboard_control.py [--robot-ip IP] [--simulation]
|
||||
|
||||
Controls:
|
||||
W/S: Forward/Backward
|
||||
A/D: Strafe Left/Right
|
||||
Q/E: Rotate Left/Right
|
||||
R/F: Raise/Lower Height (GR00T policies only)
|
||||
Z: Stop (zero all velocity commands)
|
||||
ESC/Ctrl+C: Exit
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import select
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
# Terminal handling for non-blocking keyboard input
|
||||
try:
|
||||
import termios
|
||||
import tty
|
||||
HAS_TERMIOS = True
|
||||
except ImportError:
|
||||
HAS_TERMIOS = False
|
||||
print("Warning: termios not available. Keyboard controls require Linux/macOS.")
|
||||
|
||||
|
||||
class KeyboardController:
|
||||
"""Handles keyboard input and converts to locomotion commands."""
|
||||
|
||||
def __init__(self, callback=None):
|
||||
"""
|
||||
Initialize keyboard controller.
|
||||
|
||||
Args:
|
||||
callback: Optional function called when commands change.
|
||||
Signature: callback(vx, vy, yaw, height)
|
||||
"""
|
||||
self.callback = callback
|
||||
self.running = False
|
||||
|
||||
# Locomotion commands
|
||||
self.vx = 0.0 # Forward/backward velocity
|
||||
self.vy = 0.0 # Left/right velocity (strafe)
|
||||
self.yaw = 0.0 # Rotation rate
|
||||
self.height = 0.74 # Base height (for GR00T policies)
|
||||
|
||||
# Command limits
|
||||
self.vx_limit = (-0.8, 0.8)
|
||||
self.vy_limit = (-0.5, 0.5)
|
||||
self.yaw_limit = (-1.0, 1.0)
|
||||
self.height_limit = (0.50, 1.00)
|
||||
|
||||
# Increments per keypress
|
||||
self.vx_increment = 0.4
|
||||
self.vy_increment = 0.25
|
||||
self.yaw_increment = 0.5
|
||||
self.height_increment = 0.05
|
||||
|
||||
self._old_terminal_settings = None
|
||||
|
||||
def get_commands(self) -> tuple[float, float, float, float]:
|
||||
"""Get current command values as tuple (vx, vy, yaw, height)."""
|
||||
return (self.vx, self.vy, self.yaw, self.height)
|
||||
|
||||
def get_commands_array(self) -> np.ndarray:
|
||||
"""Get velocity commands as numpy array [vx, vy, yaw]."""
|
||||
return np.array([self.vx, self.vy, self.yaw], dtype=np.float32)
|
||||
|
||||
def reset_commands(self):
|
||||
"""Reset all commands to zero (stop)."""
|
||||
self.vx = 0.0
|
||||
self.vy = 0.0
|
||||
self.yaw = 0.0
|
||||
self._notify_callback()
|
||||
|
||||
def _clamp(self, value: float, limits: tuple[float, float]) -> float:
|
||||
"""Clamp value to limits."""
|
||||
return max(limits[0], min(limits[1], value))
|
||||
|
||||
def _notify_callback(self):
|
||||
"""Call callback with current commands if set."""
|
||||
if self.callback:
|
||||
self.callback(self.vx, self.vy, self.yaw, self.height)
|
||||
|
||||
def process_key(self, key: str) -> bool:
|
||||
"""
|
||||
Process a single key press and update commands.
|
||||
|
||||
Args:
|
||||
key: Single character key that was pressed.
|
||||
|
||||
Returns:
|
||||
True if key was handled, False otherwise.
|
||||
"""
|
||||
key = key.lower()
|
||||
handled = True
|
||||
|
||||
if key == 'w':
|
||||
self.vx = self._clamp(self.vx + self.vx_increment, self.vx_limit)
|
||||
elif key == 's':
|
||||
self.vx = self._clamp(self.vx - self.vx_increment, self.vx_limit)
|
||||
elif key == 'a':
|
||||
self.vy = self._clamp(self.vy + self.vy_increment, self.vy_limit)
|
||||
elif key == 'd':
|
||||
self.vy = self._clamp(self.vy - self.vy_increment, self.vy_limit)
|
||||
elif key == 'q':
|
||||
self.yaw = self._clamp(self.yaw + self.yaw_increment, self.yaw_limit)
|
||||
elif key == 'e':
|
||||
self.yaw = self._clamp(self.yaw - self.yaw_increment, self.yaw_limit)
|
||||
elif key == 'r':
|
||||
self.height = self._clamp(self.height + self.height_increment, self.height_limit)
|
||||
elif key == 'f':
|
||||
self.height = self._clamp(self.height - self.height_increment, self.height_limit)
|
||||
elif key == 'z':
|
||||
self.reset_commands()
|
||||
return True # Already notified in reset_commands
|
||||
else:
|
||||
handled = False
|
||||
|
||||
if handled:
|
||||
self._notify_callback()
|
||||
|
||||
return handled
|
||||
|
||||
def _setup_terminal(self):
|
||||
"""Set terminal to raw mode for single character input."""
|
||||
if HAS_TERMIOS:
|
||||
self._old_terminal_settings = termios.tcgetattr(sys.stdin)
|
||||
tty.setcbreak(sys.stdin.fileno())
|
||||
|
||||
def _restore_terminal(self):
|
||||
"""Restore terminal to original settings."""
|
||||
if HAS_TERMIOS and self._old_terminal_settings is not None:
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, self._old_terminal_settings)
|
||||
self._old_terminal_settings = None
|
||||
|
||||
def run(self):
|
||||
"""Run the keyboard listener loop (blocking)."""
|
||||
if not HAS_TERMIOS:
|
||||
print("Error: Keyboard controls require termios (Linux/macOS)")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._print_controls()
|
||||
|
||||
try:
|
||||
self._setup_terminal()
|
||||
|
||||
while self.running:
|
||||
# Check for keyboard input with timeout
|
||||
if select.select([sys.stdin], [], [], 0.1)[0]:
|
||||
key = sys.stdin.read(1)
|
||||
|
||||
# Handle escape sequences (arrow keys, etc.)
|
||||
if key == '\x1b': # ESC
|
||||
self.running = False
|
||||
break
|
||||
|
||||
if self.process_key(key):
|
||||
self._print_status()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user")
|
||||
finally:
|
||||
self._restore_terminal()
|
||||
print("\nKeyboard controls stopped")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
self.running = False
|
||||
|
||||
def _print_controls(self):
|
||||
"""Print control instructions."""
|
||||
print("\n" + "=" * 60)
|
||||
print("KEYBOARD CONTROLS ACTIVE")
|
||||
print("=" * 60)
|
||||
print(" W/S: Forward/Backward")
|
||||
print(" A/D: Strafe Left/Right")
|
||||
print(" Q/E: Rotate Left/Right")
|
||||
print(" R/F: Raise/Lower Height (±5cm)")
|
||||
print(" Z: Stop (zero all commands)")
|
||||
print(" ESC: Exit")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
def _print_status(self):
|
||||
"""Print current command status."""
|
||||
print(f"[CMD] vx={self.vx:+.2f}, vy={self.vy:+.2f}, yaw={self.yaw:+.2f} | height={self.height:.3f}m")
|
||||
|
||||
|
||||
class RobotKeyboardController(KeyboardController):
|
||||
"""Keyboard controller that directly updates a robot's locomotion commands."""
|
||||
|
||||
def __init__(self, robot):
|
||||
"""
|
||||
Initialize with a UnitreeG1 robot instance.
|
||||
|
||||
Args:
|
||||
robot: UnitreeG1 robot instance with locomotion_cmd attribute.
|
||||
"""
|
||||
super().__init__()
|
||||
self.robot = robot
|
||||
|
||||
# Initialize from robot's current state if available
|
||||
if hasattr(robot, 'locomotion_cmd'):
|
||||
self.vx = robot.locomotion_cmd[0]
|
||||
self.vy = robot.locomotion_cmd[1]
|
||||
self.yaw = robot.locomotion_cmd[2]
|
||||
|
||||
if hasattr(robot, 'groot_height_cmd'):
|
||||
self.height = robot.groot_height_cmd
|
||||
|
||||
def _notify_callback(self):
|
||||
"""Update robot's locomotion commands directly."""
|
||||
if hasattr(self.robot, 'locomotion_cmd'):
|
||||
self.robot.locomotion_cmd[0] = self.vx
|
||||
self.robot.locomotion_cmd[1] = self.vy
|
||||
self.robot.locomotion_cmd[2] = self.yaw
|
||||
|
||||
if hasattr(self.robot, 'groot_height_cmd'):
|
||||
self.robot.groot_height_cmd = self.height
|
||||
|
||||
|
||||
def start_keyboard_control_thread(robot) -> tuple:
|
||||
"""
|
||||
Start keyboard controls for a robot in a background thread.
|
||||
|
||||
Args:
|
||||
robot: UnitreeG1 robot instance.
|
||||
|
||||
Returns:
|
||||
Tuple of (controller, thread) for later stopping.
|
||||
"""
|
||||
import threading
|
||||
|
||||
controller = RobotKeyboardController(robot)
|
||||
thread = threading.Thread(target=controller.run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
return controller, thread
|
||||
|
||||
|
||||
def stop_keyboard_control_thread(controller, thread, timeout: float = 2.0):
|
||||
"""
|
||||
Stop the keyboard control thread.
|
||||
|
||||
Args:
|
||||
controller: KeyboardController instance.
|
||||
thread: Thread running the controller.
|
||||
timeout: Max time to wait for thread to stop.
|
||||
"""
|
||||
controller.stop()
|
||||
thread.join(timeout=timeout)
|
||||
|
||||
|
||||
def main():
|
||||
"""Standalone keyboard control with optional robot connection."""
|
||||
parser = argparse.ArgumentParser(description="Keyboard control for Unitree G1")
|
||||
parser.add_argument("--standalone", action="store_true",
|
||||
help="Run in standalone mode (just print commands, no robot)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.standalone:
|
||||
# Standalone mode - just demonstrate keyboard input
|
||||
def print_callback(vx, vy, yaw, height):
|
||||
print(f" → Would send: vx={vx:+.2f}, vy={vy:+.2f}, yaw={yaw:+.2f}, height={height:.3f}")
|
||||
|
||||
controller = KeyboardController(callback=print_callback)
|
||||
print("Running in STANDALONE mode (no robot connection)")
|
||||
controller.run()
|
||||
else:
|
||||
print("To use with a robot, import and use RobotKeyboardController:")
|
||||
print("")
|
||||
print(" from lerobot.robots.unitree_g1.keyboard_control import (")
|
||||
print(" RobotKeyboardController,")
|
||||
print(" start_keyboard_control_thread,")
|
||||
print(" stop_keyboard_control_thread")
|
||||
print(" )")
|
||||
print("")
|
||||
print(" # Start keyboard controls")
|
||||
print(" controller, thread = start_keyboard_control_thread(robot)")
|
||||
print("")
|
||||
print(" # ... robot runs ...")
|
||||
print("")
|
||||
print(" # Stop keyboard controls")
|
||||
print(" stop_keyboard_control_thread(controller, thread)")
|
||||
print("")
|
||||
print("Or run with --standalone to test keyboard input without a robot.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -101,6 +101,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.unitree_g1 import config_unitree_g1 # noqa: F401
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
@@ -197,9 +198,8 @@ class RecordConfig:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
if self.teleop is None and self.policy is None:
|
||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||
# Note: teleop and policy can both be None for robots with built-in control (e.g. unitree_g1)
|
||||
# This is validated in record() after the robot is instantiated
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
@@ -340,6 +340,13 @@ def record_loop(
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
elif policy is None and teleop is None and dataset is not None:
|
||||
# Observation-only recording (robot controls itself, e.g. unitree_g1)
|
||||
# Record observations, extract action-relevant values (positions) from obs
|
||||
# Filter obs_processed to only include keys that match action_features
|
||||
action_keys = set(robot.action_features.keys())
|
||||
action_values = {k: v for k, v in obs_processed.items() if k in action_keys}
|
||||
robot_action_to_send = None
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
@@ -352,15 +359,17 @@ def record_loop(
|
||||
if policy is not None and act_processed_policy is not None:
|
||||
action_values = act_processed_policy
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
else:
|
||||
elif teleop is not None:
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
# else: observation-only mode, action_values already set above
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
# Send action to robot (skip if observation-only mode)
|
||||
if robot_action_to_send is not None:
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
|
||||
Reference in New Issue
Block a user