From 3385350f2dafc419453cf84b3566a12f0a9e9cb4 Mon Sep 17 00:00:00 2001 From: Martino Russi Date: Wed, 26 Nov 2025 17:17:02 +0100 Subject: [PATCH] separate groot locomotion logic --- examples/unitree_g1/gr00t_locomotion.py | 345 ++++++++++ .../GR00T-WholeBodyControl-Balance.onnx | Bin .../GR00T-WholeBodyControl-Walk.onnx | Bin .../robots/unitree_g1/config_unitree_g1.py | 44 +- src/lerobot/robots/unitree_g1/g1_utils.py | 6 +- src/lerobot/robots/unitree_g1/unitree_g1.py | 610 +----------------- 6 files changed, 367 insertions(+), 638 deletions(-) create mode 100644 examples/unitree_g1/gr00t_locomotion.py rename {src/lerobot/robots/unitree_g1/assets/g1 => examples/unitree_g1}/locomotion/GR00T-WholeBodyControl-Balance.onnx (100%) rename {src/lerobot/robots/unitree_g1/assets/g1 => examples/unitree_g1}/locomotion/GR00T-WholeBodyControl-Walk.onnx (100%) diff --git a/examples/unitree_g1/gr00t_locomotion.py b/examples/unitree_g1/gr00t_locomotion.py new file mode 100644 index 000000000..7a5d94157 --- /dev/null +++ b/examples/unitree_g1/gr00t_locomotion.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python + +""" +Example: GR00T Locomotion with Pre-loaded Policies + +This example demonstrates the NEW pattern for loading GR00T policies externally +and passing them to the robot class. +""" + +import logging +import threading +import time +from collections import deque + +import numpy as np +import onnxruntime as ort +import torch +from scipy.spatial.transform import Rotation as R + +from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config +from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 + +logger = logging.getLogger(__name__) + + +def load_groot_policies() -> tuple: + """Load GR00T dual-policy system (Balance + Walk) from ONNX files.""" + logger.info("Loading GR00T dual-policy system...") + + # Load ONNX policies + policy_balance = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx") + policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx") + + logger.info("GR00T policies loaded successfully") + logger.info(f" Input shape: {policy_balance.get_inputs()[0].shape}") + logger.info(f" Output shape: {policy_balance.get_outputs()[0].shape}") + + return policy_balance, policy_walk + + +class GrootLocomotionController: + """ + Handles GR00T-style locomotion control for the Unitree G1 robot. + + This controller manages: + - Dual-policy system (Balance + Walk) + - 29-joint observation processing + - 15D action output (legs + waist) + - Policy inference and motor command generation + """ + + # GR00T default angles for all 29 joints + GROOT_DEFAULT_ANGLES = np.array([ + -0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg + -0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg + 0.0, 0.0, 0.0, # waist + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # left arm (zeroed) + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # right arm (zeroed) + ], dtype=np.float32) + + # Joints to zero out in observations and commands + JOINTS_TO_ZERO = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw + PROBLEMATIC_JOINTS = [12, 14, 20, 21, 27, 28] + + def __init__(self, policy_balance, policy_walk, robot, config): + """ + Initialize the GR00T locomotion controller. + + Args: + policy_balance: ONNX InferenceSession for balance/standing policy + policy_walk: ONNX InferenceSession for walking policy + robot: Reference to the UnitreeG1 robot instance + config: UnitreeG1Config object with locomotion parameters + """ + self.policy_balance = policy_balance + self.policy_walk = policy_walk + self.robot = robot + self.config = config + + # Locomotion state + self.locomotion_counter = 0 + self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, yaw_rate + + # GR00T-specific state + self.groot_qj_all = np.zeros(29, dtype=np.float32) + self.groot_dqj_all = np.zeros(29, dtype=np.float32) + self.groot_action = np.zeros(15, dtype=np.float32) + self.groot_obs_single = np.zeros(86, dtype=np.float32) + self.groot_obs_history = deque(maxlen=6) + self.groot_obs_stacked = np.zeros(516, dtype=np.float32) + self.groot_height_cmd = 0.74 # Default base height + self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) + + # Initialize history with zeros + for _ in range(6): + self.groot_obs_history.append(np.zeros(86, dtype=np.float32)) + + # Thread management + self.locomotion_running = False + self.locomotion_thread = None + + logger.info("GrootLocomotionController initialized") + + def groot_locomotion_run(self): + """GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action.""" + self.locomotion_counter += 1 + + # Get current lowstate + lowstate = self.robot.lowstate_buffer.GetData() + if lowstate is None: + return + + # Update remote controller from lowstate + if lowstate.wireless_remote is not None: + self.robot.remote_controller.set(lowstate.wireless_remote) + + # R1/R2 buttons for height control on real robot (button indices 0 and 4) + if self.robot.remote_controller.button[0]: # R1 - raise height + self.groot_height_cmd += 0.001 # Small increment per timestep (~0.05m per second at 50Hz) + self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) + if self.robot.remote_controller.button[4]: # R2 - lower height + self.groot_height_cmd -= 0.001 # Small decrement per timestep + self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) + else: + # Default to zero commands if no remote data + self.robot.remote_controller.lx = 0.0 + self.robot.remote_controller.ly = 0.0 + self.robot.remote_controller.rx = 0.0 + self.robot.remote_controller.ry = 0.0 + + # Get ALL 29 joint positions and velocities + for i in range(29): + self.groot_qj_all[i] = lowstate.motor_state[i].q + self.groot_dqj_all[i] = lowstate.motor_state[i].dq + + # Get IMU data + quat = lowstate.imu_state.quaternion + ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32) + + # Transform IMU if using torso IMU + if self.config.locomotion_imu_type == "torso": + waist_yaw = lowstate.motor_state[12].q # Waist yaw index + waist_yaw_omega = lowstate.motor_state[12].dq + quat, ang_vel_3d = self.robot.locomotion_transform_imu_data( + waist_yaw, waist_yaw_omega, quat, np.array([ang_vel]) + ) + ang_vel = ang_vel_3d.flatten() + + # Create observation + gravity_orientation = self.robot.locomotion_get_gravity_orientation(quat) + + # Zero out specific joints in observation + for idx in self.JOINTS_TO_ZERO: + self.groot_qj_all[idx] = 0.0 + self.groot_dqj_all[idx] = 0.0 + + # Scale joint positions and velocities + qj_obs = self.groot_qj_all.copy() + dqj_obs = self.groot_dqj_all.copy() + + qj_obs = (qj_obs - self.GROOT_DEFAULT_ANGLES) * self.config.dof_pos_scale + dqj_obs = dqj_obs * self.config.dof_vel_scale + ang_vel_scaled = ang_vel * self.config.groot_ang_vel_scale + + # Get velocity commands (keyboard or remote) + if not self.robot.simulation_mode: + self.locomotion_cmd[0] = self.robot.remote_controller.ly + self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 + self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 + + # Build 86D single frame observation (GR00T format) + self.groot_obs_single[:3] = self.locomotion_cmd * np.array(self.config.groot_cmd_scale) + self.groot_obs_single[3] = self.groot_height_cmd + self.groot_obs_single[4:7] = self.groot_orientation_cmd + self.groot_obs_single[7:10] = ang_vel_scaled + self.groot_obs_single[10:13] = gravity_orientation + self.groot_obs_single[13:42] = qj_obs # 29D joint positions + self.groot_obs_single[42:71] = dqj_obs # 29D joint velocities + self.groot_obs_single[71:86] = self.groot_action # 15D previous actions + + # Add to history and stack observations (6 frames × 86D = 516D) + self.groot_obs_history.append(self.groot_obs_single.copy()) + + # Stack all 6 frames into 516D vector + for i, obs_frame in enumerate(self.groot_obs_history): + start_idx = i * 86 + end_idx = start_idx + 86 + self.groot_obs_stacked[start_idx:end_idx] = obs_frame + + # Run policy inference (ONNX) with 516D stacked observation + obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0) + + # Select appropriate policy based on command magnitude (dual-policy system) + cmd_magnitude = np.linalg.norm(self.locomotion_cmd) + if cmd_magnitude < 0.05: + # Use balance/standing policy for small commands + selected_policy = self.policy_balance + else: + # Use walking policy for movement commands + selected_policy = self.policy_walk + + ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()} + ort_outs = selected_policy.run(None, ort_inputs) + self.groot_action = ort_outs[0].squeeze() + + # Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11) + self.groot_action[12] = 0.0 # Waist yaw + self.groot_action[13] = 0.0 # Waist roll + self.groot_action[14] = 0.0 # Waist pitch + + # Transform action to target joint positions (15D: legs + waist) + target_dof_pos_15 = ( + self.GROOT_DEFAULT_ANGLES[:15] + self.groot_action * self.config.locomotion_action_scale + ) + + # Send commands to LEG motors (0-11) + for i in range(12): + motor_idx = i + self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i] + self.robot.msg.motor_cmd[motor_idx].qd = 0 + self.robot.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i] + self.robot.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i] + self.robot.msg.motor_cmd[motor_idx].tau = 0 + + # Send WAIST commands - but SKIP waist yaw (12) and waist pitch (14) + # Only send waist roll (13) + waist_roll_idx = 13 + waist_roll_action_idx = 13 + self.robot.msg.motor_cmd[waist_roll_idx].q = target_dof_pos_15[waist_roll_action_idx] + self.robot.msg.motor_cmd[waist_roll_idx].qd = 0 + self.robot.msg.motor_cmd[waist_roll_idx].kp = self.config.locomotion_arm_waist_kps[1] + self.robot.msg.motor_cmd[waist_roll_idx].kd = self.config.locomotion_arm_waist_kds[1] + self.robot.msg.motor_cmd[waist_roll_idx].tau = 0 + + # Zero out the problematic joints (waist yaw, waist pitch, wrist pitch/yaw) + for joint_idx in self.PROBLEMATIC_JOINTS: + self.robot.msg.motor_cmd[joint_idx].q = 0.0 + self.robot.msg.motor_cmd[joint_idx].qd = 0 + if joint_idx in [12, 14]: # waist + kp_idx = 0 if joint_idx == 12 else 2 # yaw or pitch + self.robot.msg.motor_cmd[joint_idx].kp = self.config.locomotion_arm_waist_kps[kp_idx] + self.robot.msg.motor_cmd[joint_idx].kd = self.config.locomotion_arm_waist_kds[kp_idx] + else: # wrists (20, 21, 27, 28) + self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp_wrist + self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd_wrist + self.robot.msg.motor_cmd[joint_idx].tau = 0 + + # Send command + self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg) + self.robot.lowcmd_publisher.Write(self.robot.msg) + + def _locomotion_thread_loop(self): + """Background thread that runs the locomotion policy at specified rate.""" + logger.info("Locomotion thread started") + while self.locomotion_running: + start_time = time.time() + try: + self.groot_locomotion_run() + except Exception as e: + logger.error(f"Error in locomotion loop: {e}") + + # Sleep to maintain control rate + elapsed = time.time() - start_time + sleep_time = max(0, self.config.locomotion_control_dt - elapsed) + time.sleep(sleep_time) + logger.info("Locomotion thread stopped") + + def start_locomotion_thread(self): + """Start the background locomotion control thread.""" + if self.locomotion_running: + logger.warning("Locomotion thread already running") + return + + logger.info("Starting locomotion control thread...") + self.locomotion_running = True + self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True) + self.locomotion_thread.start() + logger.info("Locomotion control thread started!") + + def stop_locomotion_thread(self): + """Stop the background locomotion control thread.""" + if not self.locomotion_running: + return + + logger.info("Stopping locomotion control thread...") + self.locomotion_running = False + if self.locomotion_thread: + self.locomotion_thread.join(timeout=2.0) + logger.info("Locomotion control thread stopped") + + def init_groot_locomotion(self): + """Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions).""" + logger.info("Starting GR00T locomotion initialization...") + + # Move legs to default position + self.robot.locomotion_move_to_default_pos() + + # Wait 3 seconds + time.sleep(3.0) + + # Hold default leg position for 2 seconds + self.robot.locomotion_default_pos_state() + + # Start locomotion policy thread + logger.info("Starting GR00T locomotion policy control...") + self.start_locomotion_thread() + + logger.info("GR00T locomotion initialization complete! Policy is now running.") + logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)") + + + +if __name__ == "__main__": + # 1. Load policies externally (separate from robot initialization) + policy_balance, policy_walk = load_groot_policies() + + # 2. Create config (no locomotion_control=True since we're using external controller) + config = UnitreeG1Config() + + # 3. Initialize robot + robot = UnitreeG1(config) + + # 4. Create GR00T locomotion controller with loaded policies + groot_controller = GrootLocomotionController( + policy_balance=policy_balance, + policy_walk=policy_walk, + robot=robot, + config=config, + ) + + # 5. Initialize and start locomotion + groot_controller.init_groot_locomotion() + + # Robot is now ready with locomotion control! + print("Robot initialized with GR00T locomotion policies") + print("Locomotion controller running in background thread") + print("Press Ctrl+C to stop") + + try: + while True: + time.sleep(1.0) + except KeyboardInterrupt: + print("\nStopping locomotion...") + groot_controller.stop_locomotion_thread() + print("Done!") \ No newline at end of file diff --git a/src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Balance.onnx b/examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx similarity index 100% rename from src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Balance.onnx rename to examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx diff --git a/src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx b/examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx similarity index 100% rename from src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx rename to examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index 6e711ed53..63f000512 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -15,6 +15,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Any from lerobot.cameras import CameraConfig @@ -24,7 +25,7 @@ from ..config import RobotConfig @RobotConfig.register_subclass("unitree_g1") @dataclass class UnitreeG1Config(RobotConfig): - # id: str = "unitree_g1" + # id: str = "unitree_g1" simulation_mode: bool = False kp_high = 40.0 kd_high = 3.0 @@ -52,49 +53,38 @@ class UnitreeG1Config(RobotConfig): # This robot class ONLY uses sockets to communicate with a bridge on the Orin # Run 'python dds_to_socket.py' on the Orin first, then set this to the Orin's IP # Example: socket_host="192.168.123.164" (Orin's wlan0 IP) - socket_host: str | None = None # = "172.18.129.215" + socket_host: str | None = None# = "172.18.129.215" socket_port: int | None = None # Locomotion control locomotion_control: bool = True - # policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt" - policy_path: str | None = None - + # Pre-loaded policies (preferred method for GR00T locomotion) + policy_walk: Any = None # Pre-loaded walk policy (ONNX InferenceSession) + policy_balance: Any = None # Pre-loaded balance policy (ONNX InferenceSession) + # Locomotion parameters (from g1.yaml) locomotion_control_dt: float = 0.02 - + leg_joint2motor_idx: list = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) - locomotion_kps: list = field( - default_factory=lambda: [150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40] - ) + locomotion_kps: list = field(default_factory=lambda: [150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40]) locomotion_kds: list = field(default_factory=lambda: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2]) - default_leg_angles: list = field( - default_factory=lambda: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0] - ) - - arm_waist_joint2motor_idx: list = field( - default_factory=lambda: [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28] - ) - locomotion_arm_waist_kps: list = field( - default_factory=lambda: [250, 250, 250, 100, 100, 50, 50, 20, 20, 20, 100, 100, 50, 50, 20, 20, 20] - ) - locomotion_arm_waist_kds: list = field( - default_factory=lambda: [5, 5, 5, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1] - ) - locomotion_arm_waist_target: list = field( - default_factory=lambda: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ) + default_leg_angles: list = field(default_factory=lambda: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0]) + arm_waist_joint2motor_idx: list = field(default_factory=lambda: [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]) + locomotion_arm_waist_kps: list = field(default_factory=lambda: [250, 250, 250, 100, 100, 50, 50, 20, 20, 20, 100, 100, 50, 50, 20, 20, 20]) + locomotion_arm_waist_kds: list = field(default_factory=lambda: [5, 5, 5, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1]) + locomotion_arm_waist_target: list = field(default_factory=lambda: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + ang_vel_scale: float = 0.25 dof_pos_scale: float = 1.0 dof_vel_scale: float = 0.05 locomotion_action_scale: float = 0.25 cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25]) - + # GR00T-specific scaling (different from regular locomotion!) groot_ang_vel_scale: float = 0.25 # GR00T uses 0.5, not 0.25 groot_cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25]) # yaw is 0.5 for GR00T num_locomotion_actions: int = 12 num_locomotion_obs: int = 47 max_cmd: list = field(default_factory=lambda: [0.8, 0.5, 1.57]) - locomotion_imu_type: str = "pelvis" # "torso" or "pelvis" + locomotion_imu_type: str = "pelvis" # "torso" or "pelvis" \ No newline at end of file diff --git a/src/lerobot/robots/unitree_g1/g1_utils.py b/src/lerobot/robots/unitree_g1/g1_utils.py index 0e6fa89fb..cfaba98bf 100644 --- a/src/lerobot/robots/unitree_g1/g1_utils.py +++ b/src/lerobot/robots/unitree_g1/g1_utils.py @@ -1,5 +1,6 @@ from enum import IntEnum + class G1_29_JointArmIndex(IntEnum): # Left arm kLeftShoulderPitch = 15 @@ -19,6 +20,7 @@ class G1_29_JointArmIndex(IntEnum): kRightWristPitch = 27 kRightWristYaw = 28 + class G1_29_JointIndex(IntEnum): # Left leg kLeftHipPitch = 0 @@ -36,7 +38,7 @@ class G1_29_JointIndex(IntEnum): kRightAnklePitch = 10 kRightAnkleRoll = 11 - kWaistYaw = 12 #we're c + kWaistYaw = 12 # we're c kWaistRoll = 13 kWaistPitch = 14 @@ -64,4 +66,4 @@ class G1_29_JointIndex(IntEnum): kNotUsedJoint2 = 31 kNotUsedJoint3 = 32 kNotUsedJoint4 = 33 - kNotUsedJoint5 = 34 \ No newline at end of file + kNotUsedJoint5 = 34 diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index cb19acbed..8bc9722d1 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -9,7 +9,6 @@ import time import tty from collections import deque from functools import cached_property -from pathlib import Path from typing import Any import numpy as np @@ -172,12 +171,6 @@ class UnitreeG1(Robot): self.msg.motor_cmd[id].kp = self.kp_high self.msg.motor_cmd[id].kd = self.kd_high self.msg.motor_cmd[id].q = self.all_motor_q[id] - # print current motor q, kp, kd - - logger.warning("Lock OK!\n") # motors are not locked x - # for i in range(10000): - # print(self.get_current_motor_q()) - # time.sleep(0.05) # Initialize control flags BEFORE starting threads self.keyboard_thread = None @@ -185,10 +178,6 @@ class UnitreeG1(Robot): self.locomotion_thread = None self.locomotion_running = False - # Initialize publish thread for arm control - # Note: This thread runs alongside locomotion thread - # - Arm thread: controls arms (indices 15-28) - # - Locomotion thread: controls legs (0-11), waist (12-14) # Both update different parts of self.msg, both call Write() self.publish_thread = None self.ctrl_lock = threading.Lock() @@ -196,102 +185,8 @@ class UnitreeG1(Robot): self.publish_thread.daemon = True self.publish_thread.start() logger.warning("Arm control publish thread started") + self.remote_controller = self.RemoteController() - # Load locomotion policy if enabled - self.policy = None - self.policy_type = None # 'torchscript' or 'onnx' - print(config) - if config.locomotion_control: - if config.policy_path is None: - raise ValueError("locomotion_control is True but policy_path is not set") - - logger.warning(f"Loading locomotion policy from {config.policy_path}") - - # Check file extension and load accordingly - if config.policy_path.endswith(".pt"): - logger.warning("Detected TorchScript (.pt) policy") - self.policy = torch.jit.load(config.policy_path) - self.policy_type = "torchscript" - logger.info("TorchScript policy loaded successfully") - elif config.policy_path.endswith(".onnx"): - logger.warning("Detected ONNX (.onnx) policy") - - # For GR00T-style policies, load both Balance and Walk policies - # Balance policy for standing (low velocity commands) - # Walk policy for locomotion (high velocity commands) - balance_policy_path = config.policy_path.replace("Walk.onnx", "Balance.onnx") - walk_policy_path = config.policy_path - - if Path(balance_policy_path).exists() and Path(walk_policy_path).exists(): - logger.info("Loading dual-policy system (Balance + Walk)") - self.policy_balance = ort.InferenceSession(balance_policy_path) - self.policy_walk = ort.InferenceSession(walk_policy_path) - self.policy = None # Not used when dual policies are loaded - logger.info(f"Balance policy loaded from: {balance_policy_path}") - logger.info(f"Walk policy loaded from: {walk_policy_path}") - logger.info( - f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}" - ) - logger.info( - f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}" - ) - else: - # Fallback to single policy - logger.info("Loading single ONNX policy") - self.policy = ort.InferenceSession(config.policy_path) - self.policy_balance = None - self.policy_walk = None - logger.info("ONNX policy loaded successfully") - logger.info( - f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}" - ) - logger.info( - f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}" - ) - - self.policy_type = "onnx" - else: - raise ValueError( - f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported." - ) - - # Initialize locomotion variables - self.remote_controller = self.RemoteController() - self.locomotion_counter = 0 - self.qj = np.zeros(config.num_locomotion_actions, dtype=np.float32) - self.dqj = np.zeros(config.num_locomotion_actions, dtype=np.float32) - self.locomotion_action = np.zeros(config.num_locomotion_actions, dtype=np.float32) - self.locomotion_obs = np.zeros(config.num_locomotion_obs, dtype=np.float32) - self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) - - # GR00T-specific variables (for ONNX policies with 29 joints) - if self.policy_type == "onnx": - self.groot_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints - self.groot_dqj_all = np.zeros(29, dtype=np.float32) - self.groot_action = np.zeros(15, dtype=np.float32) # 15D action (legs + waist) - self.groot_obs_single = np.zeros(86, dtype=np.float32) # 86D single frame observation - self.groot_obs_history = deque(maxlen=6) # 6-frame history buffer - self.groot_obs_stacked = np.zeros(516, dtype=np.float32) # 86D × 6 = 516D stacked observation - self.groot_height_cmd = 0.74 # Default base height - self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # roll, pitch, yaw - - # Initialize history with zeros - for _ in range(6): - self.groot_obs_history.append(np.zeros(86, dtype=np.float32)) - - # Start keyboard controls if in simulation mode - if self.simulation_mode: - logger.info("Starting keyboard controls for simulation...") - self.start_keyboard_controls() - - # Use different init based on policy type - if self.policy_type == "onnx": - self.init_groot_locomotion() - else: - self.init_locomotion() - elif self.simulation_mode: - # Even without locomotion, provide keyboard feedback in sim - logger.info("Simulation mode active (locomotion disabled)") logger.info("Initialize G1 OK!\n") @@ -766,506 +661,3 @@ class UnitreeG1(Robot): R_pelvis = np.dot(R_torso, RzWaist.T) w = np.dot(RzWaist, imu_omega[0]) - np.array([0, 0, waist_yaw_omega]) return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w - - def locomotion_run(self): - """Main locomotion policy loop - runs policy and sends leg commands.""" - self.locomotion_counter += 1 - - # Get current lowstate - lowstate = self.lowstate_buffer.GetData() - if lowstate is None: - return - - # Update remote controller from lowstate - if lowstate.wireless_remote is not None: - self.remote_controller.set(lowstate.wireless_remote) - else: - # Default to zero commands if no remote data - self.remote_controller.lx = 0.0 - self.remote_controller.ly = 0.0 - self.remote_controller.rx = 0.0 - self.remote_controller.ry = 0.0 - - # Get the current joint position and velocity (LEGS ONLY) - for i in range(len(self.config.leg_joint2motor_idx)): - self.qj[i] = lowstate.motor_state[self.config.leg_joint2motor_idx[i]].q - self.dqj[i] = lowstate.motor_state[self.config.leg_joint2motor_idx[i]].dq - - # Get IMU data - quat = lowstate.imu_state.quaternion - ang_vel = np.array([lowstate.imu_state.gyroscope], dtype=np.float32) - - if self.config.locomotion_imu_type == "torso": - # Transform IMU data from torso to pelvis frame - waist_yaw = lowstate.motor_state[self.config.arm_waist_joint2motor_idx[0]].q - waist_yaw_omega = lowstate.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq - quat, ang_vel = self.locomotion_transform_imu_data(waist_yaw, waist_yaw_omega, quat, ang_vel) - - # Create observation - gravity_orientation = self.locomotion_get_gravity_orientation(quat) - qj_obs = self.qj.copy() - dqj_obs = self.dqj.copy() - qj_obs = (qj_obs - np.array(self.config.default_leg_angles)) * self.config.dof_pos_scale - dqj_obs = dqj_obs * self.config.dof_vel_scale - ang_vel = ang_vel * self.config.ang_vel_scale - - # Calculate phase - period = 0.8 - count = self.locomotion_counter * self.config.locomotion_control_dt - phase = count % period / period - sin_phase = np.sin(2 * np.pi * phase) - cos_phase = np.cos(2 * np.pi * phase) - - # Get velocity commands from remote controller (only if NOT in simulation mode) - # In simulation mode, keyboard controls set self.locomotion_cmd directly - if not self.simulation_mode: - self.locomotion_cmd[0] = self.remote_controller.ly - self.locomotion_cmd[1] = self.remote_controller.lx * -1 - self.locomotion_cmd[2] = self.remote_controller.rx * -1 - - # Debug: print remote controller values every 50 iterations (~1 second at 50Hz) - if self.locomotion_counter % 50 == 0: - logger.debug( - f"Remote controller - lx:{self.remote_controller.lx:.2f}, ly:{self.remote_controller.ly:.2f}, rx:{self.remote_controller.rx:.2f}" - ) - - # Build observation vector - num_actions = self.config.num_locomotion_actions - self.locomotion_obs[:3] = ang_vel - self.locomotion_obs[3:6] = gravity_orientation - self.locomotion_obs[6:9] = ( - self.locomotion_cmd * np.array(self.config.cmd_scale) * np.array(self.config.max_cmd) - ) - self.locomotion_obs[9 : 9 + num_actions] = qj_obs - self.locomotion_obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs - self.locomotion_obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.locomotion_action - self.locomotion_obs[9 + num_actions * 3] = sin_phase - self.locomotion_obs[9 + num_actions * 3 + 1] = cos_phase - - # Get action from policy network - obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0) - - if self.policy_type == "torchscript": - # TorchScript inference - self.locomotion_action = self.policy(obs_tensor).detach().numpy().squeeze() - elif self.policy_type == "onnx": - # ONNX inference - ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()} - ort_outs = self.policy.run(None, ort_inputs) - self.locomotion_action = ort_outs[0].squeeze() - else: - raise ValueError(f"Unknown policy type: {self.policy_type}") - - # Transform action to target joint positions - target_dof_pos = ( - np.array(self.config.default_leg_angles) - + self.locomotion_action * self.config.locomotion_action_scale - ) - - # Send commands to LEG motors only - for i in range(len(self.config.leg_joint2motor_idx)): - motor_idx = self.config.leg_joint2motor_idx[i] - self.msg.motor_cmd[motor_idx].q = target_dof_pos[i] - self.msg.motor_cmd[motor_idx].qd = 0 - self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i] - self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i] - self.msg.motor_cmd[motor_idx].tau = 0 - - # Hold WAIST motors at 0 (indices 12, 13, 14 = WaistYaw, WaistRoll, WaistPitch) - waist_indices = self.config.arm_waist_joint2motor_idx[:3] # First 3 are waist - for i, motor_idx in enumerate(waist_indices): - self.msg.motor_cmd[motor_idx].q = 0.0 - self.msg.motor_cmd[motor_idx].qd = 0 - self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_arm_waist_kps[i] - self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_arm_waist_kds[i] - self.msg.motor_cmd[motor_idx].tau = 0 - - # Send command - self.msg.crc = self.crc.Crc(self.msg) - self.lowcmd_publisher.Write(self.msg) - - def groot_locomotion_run(self): - """GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action.""" - self.locomotion_counter += 1 - - # Get current lowstate - lowstate = self.lowstate_buffer.GetData() - if lowstate is None: - return - - # Update remote controller from lowstate - if lowstate.wireless_remote is not None: - self.remote_controller.set(lowstate.wireless_remote) - - # R1/R2 buttons for height control on real robot (button indices 4 and 5) - if self.remote_controller.button[0]: # R1 - raise height - self.groot_height_cmd += 0.001 # Small increment per timestep (~0.05m per second at 50Hz) - self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - if self.remote_controller.button[4]: # R2 - lower height - self.groot_height_cmd -= 0.001 # Small decrement per timestep - self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - else: - # Default to zero commands if no remote data - self.remote_controller.lx = 0.0 - self.remote_controller.ly = 0.0 - self.remote_controller.rx = 0.0 - self.remote_controller.ry = 0.0 - - # Get ALL 29 joint positions and velocities - for i in range(29): - self.groot_qj_all[i] = lowstate.motor_state[i].q - self.groot_dqj_all[i] = lowstate.motor_state[i].dq - - # Get IMU data - quat = lowstate.imu_state.quaternion - ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32) - - # Transform IMU if using torso IMU - if self.config.locomotion_imu_type == "torso": - waist_yaw = lowstate.motor_state[12].q # Waist yaw index - waist_yaw_omega = lowstate.motor_state[12].dq - quat, ang_vel_3d = self.locomotion_transform_imu_data( - waist_yaw, waist_yaw_omega, quat, np.array([ang_vel]) - ) - ang_vel = ang_vel_3d.flatten() - - # Create observation - gravity_orientation = self.locomotion_get_gravity_orientation(quat) - joints_to_zero_obs = [12, 14, 20, 21, 27, 28] # Note: NOT 13 (waist roll exists) - for idx in joints_to_zero_obs: - self.groot_qj_all[idx] = 0.0 - self.groot_dqj_all[idx] = 0.0 - # Scale joint positions and velocities - qj_obs = self.groot_qj_all.copy() - dqj_obs = self.groot_dqj_all.copy() - - # Subtract default angles for legs + waist (15 joints) - # GR00T default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0, 0.0, 0.0, 0.0] - groot_default_angles = np.array( - [ - -0.1, - 0.0, - 0.0, - 0.3, - -0.2, - 0.0, # left leg - -0.1, - 0.0, - 0.0, - 0.3, - -0.2, - 0.0, # right leg - 0.0, - 0.0, - 0.0, # waist - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, # left arm (zeroed) - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ], - dtype=np.float32, - ) # right arm (zeroed) - - qj_obs = (qj_obs - groot_default_angles) * self.config.dof_pos_scale - dqj_obs = dqj_obs * self.config.dof_vel_scale - ang_vel_scaled = ang_vel * self.config.groot_ang_vel_scale # Use GR00T-specific scaling! - - # Get velocity commands (keyboard or remote) - if not self.simulation_mode: - self.locomotion_cmd[0] = self.remote_controller.ly - self.locomotion_cmd[1] = self.remote_controller.lx * -1 - self.locomotion_cmd[2] = self.remote_controller.rx * -1 - - # Build 86D single frame observation (GR00T format) - self.groot_obs_single[:3] = self.locomotion_cmd * np.array( - self.config.groot_cmd_scale - ) # cmd - use GR00T scaling! - self.groot_obs_single[3] = self.groot_height_cmd # height_cmd - self.groot_obs_single[4:7] = self.groot_orientation_cmd # roll, pitch, yaw cmd - self.groot_obs_single[7:10] = ang_vel_scaled # angular velocity - self.groot_obs_single[10:13] = gravity_orientation # gravity - self.groot_obs_single[13:42] = qj_obs # joint positions (29D) - self.groot_obs_single[42:71] = dqj_obs # joint velocities (29D) - self.groot_obs_single[71:86] = self.groot_action # previous actions (15D) - - # Add to history and stack observations (6 frames × 86D = 516D) - self.groot_obs_history.append(self.groot_obs_single.copy()) - - # Stack all 6 frames into 516D vector - for i, obs_frame in enumerate(self.groot_obs_history): - start_idx = i * 86 - end_idx = start_idx + 86 - self.groot_obs_stacked[start_idx:end_idx] = obs_frame - - # Run policy inference (ONNX) with 516D stacked observation - obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0) - - # Select appropriate policy based on command magnitude (dual-policy system) - if self.policy_balance is not None and self.policy_walk is not None: - # Dual-policy mode: switch between Balance and Walk - cmd_magnitude = np.linalg.norm(self.locomotion_cmd) - if cmd_magnitude < 0.05: - # Use balance/standing policy for small commands - selected_policy = self.policy_balance - else: - # Use walking policy for movement commands - selected_policy = self.policy_walk - else: - # Single policy mode (fallback) - selected_policy = self.policy - - ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()} - ort_outs = selected_policy.run(None, ort_inputs) - self.groot_action = ort_outs[0].squeeze() - - # Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11) - # This ensures action history in observations matches what's actually executed - self.groot_action[12] = 0.0 # Waist yaw - self.groot_action[13] = 0.0 # Waist roll - self.groot_action[14] = 0.0 # Waist pitch - - # Transform action to target joint positions (15D: legs + waist, but waist actions are zeroed) - target_dof_pos_15 = ( - groot_default_angles[:15] + self.groot_action * self.config.locomotion_action_scale - ) - - # Send commands to LEG motors (0-11) - for i in range(12): - motor_idx = i - self.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i] - self.msg.motor_cmd[motor_idx].qd = 0 - self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i] - self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i] - self.msg.motor_cmd[motor_idx].tau = 0 - - # Send WAIST commands - but SKIP waist yaw (12) and waist pitch (14) - # Only send waist roll (13) - waist_roll_idx = 13 - waist_roll_action_idx = 13 # In the 15D action - self.msg.motor_cmd[waist_roll_idx].q = target_dof_pos_15[waist_roll_action_idx] - self.msg.motor_cmd[waist_roll_idx].qd = 0 - self.msg.motor_cmd[waist_roll_idx].kp = self.config.locomotion_arm_waist_kps[ - 1 - ] # index 1 is waist roll - self.msg.motor_cmd[waist_roll_idx].kd = self.config.locomotion_arm_waist_kds[1] - self.msg.motor_cmd[waist_roll_idx].tau = 0 - - # Zero out the problematic joints (waist yaw, waist pitch, wrist pitch/yaw) - problematic_joints = [12, 14, 20, 21, 27, 28] - for joint_idx in problematic_joints: - self.msg.motor_cmd[joint_idx].q = 0.0 - self.msg.motor_cmd[joint_idx].qd = 0 - if joint_idx in [12, 14]: # waist - kp_idx = 0 if joint_idx == 12 else 2 # yaw or pitch - self.msg.motor_cmd[joint_idx].kp = self.config.locomotion_arm_waist_kps[kp_idx] - self.msg.motor_cmd[joint_idx].kd = self.config.locomotion_arm_waist_kds[kp_idx] - else: # wrists (20, 21, 27, 28) - self.msg.motor_cmd[joint_idx].kp = self.kp_wrist - self.msg.motor_cmd[joint_idx].kd = self.kd_wrist - self.msg.motor_cmd[joint_idx].tau = 0 - - # Send command - self.msg.crc = self.crc.Crc(self.msg) - self.lowcmd_publisher.Write(self.msg) - - def _locomotion_thread_loop(self): - """Background thread that runs the locomotion policy at specified rate.""" - logger.info("Locomotion thread started") - while self.locomotion_running: - start_time = time.time() - try: - # Use different run function based on policy type - if self.policy_type == "onnx": - self.groot_locomotion_run() - else: - self.locomotion_run() - except Exception as e: - logger.error(f"Error in locomotion loop: {e}") - - # Sleep to maintain control rate - elapsed = time.time() - start_time - sleep_time = max(0, self.config.locomotion_control_dt - elapsed) - time.sleep(sleep_time) - logger.info("Locomotion thread stopped") - - def start_locomotion_thread(self): - """Start the background locomotion control thread.""" - if not self.config.locomotion_control: - logger.warning("locomotion_control is False, cannot start thread") - return - - if self.locomotion_running: - logger.warning("Locomotion thread already running") - return - - logger.info("Starting locomotion control thread...") - self.locomotion_running = True - self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True) - self.locomotion_thread.start() - logger.info("Locomotion control thread started!") - - def stop_locomotion_thread(self): - """Stop the background locomotion control thread.""" - if not self.locomotion_running: - return - - logger.info("Stopping locomotion control thread...") - self.locomotion_running = False - if self.locomotion_thread: - self.locomotion_thread.join(timeout=2.0) - logger.info("Locomotion control thread stopped") - - # Also stop keyboard thread if running - if self.keyboard_running: - self.stop_keyboard_controls() - - def _keyboard_listener_thread(self): - """Background thread that listens for keyboard input (sim mode only).""" - print("\n" + "=" * 60) - print("KEYBOARD CONTROLS ACTIVE!") - print(" W/S: Forward/Backward") - print(" A/D: Left/Right") - print(" Q/E: Rotate Left/Right") - print(" R/F: Raise/Lower Height (±5cm)") - print(" Z: Stop (zero velocity commands)") - print("=" * 60 + "\n") - - # Save terminal settings - old_settings = None - try: - old_settings = termios.tcgetattr(sys.stdin) - tty.setcbreak(sys.stdin.fileno()) - - while self.keyboard_running: - if select.select([sys.stdin], [], [], 0.1)[0]: - key = sys.stdin.read(1).lower() - - # Velocity commands - if key == "w": - self.locomotion_cmd[0] += 0.4 # Forward - elif key == "s": - self.locomotion_cmd[0] -= 0.4 # Backward - elif key == "a": - self.locomotion_cmd[1] += 0.25 # Left - elif key == "d": - self.locomotion_cmd[1] -= 0.25 # Right - elif key == "q": - self.locomotion_cmd[2] += 0.5 # Rotate left - elif key == "e": - self.locomotion_cmd[2] -= 0.5 # Rotate right - elif key == "z": - self.locomotion_cmd[:] = 0.0 # Stop - - # Height commands (only for GR00T ONNX policies) - elif key == "r": - self.groot_height_cmd += 0.05 # Raise 5cm - elif key == "f": - self.groot_height_cmd -= 0.05 # Lower 5cm - - # Clamp commands to reasonable limits - self.locomotion_cmd[0] = np.clip(self.locomotion_cmd[0], -0.8, 0.8) # vx - self.locomotion_cmd[1] = np.clip(self.locomotion_cmd[1], -0.5, 0.5) # vy - self.locomotion_cmd[2] = np.clip(self.locomotion_cmd[2], -1.0, 1.0) # yaw_rate - - # Clamp height (reasonable range: 0.5m to 1.0m) - if hasattr(self, "groot_height_cmd"): - self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - - # Print current commands - print( - f"[VEL CMD] vx={self.locomotion_cmd[0]:.2f}, vy={self.locomotion_cmd[1]:.2f}, yaw={self.locomotion_cmd[2]:.2f}", - end="", - ) - if hasattr(self, "groot_height_cmd"): - print(f" | [HEIGHT] {self.groot_height_cmd:.3f}m", end="") - print() # Newline - - finally: - # Restore terminal settings - if old_settings is not None: - termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) - print("\nKeyboard controls stopped") - - def start_keyboard_controls(self): - """Start the keyboard control thread (sim mode only).""" - if not self.simulation_mode: - logger.warning("Keyboard controls only available in simulation mode") - return - - if self.keyboard_running: - logger.warning("Keyboard controls already running") - return - - self.keyboard_running = True - self.keyboard_thread = threading.Thread(target=self._keyboard_listener_thread, daemon=True) - self.keyboard_thread.start() - logger.info("Keyboard controls started!") - - def stop_keyboard_controls(self): - """Stop the keyboard control thread.""" - if not self.keyboard_running: - return - - logger.info("Stopping keyboard controls...") - self.keyboard_running = False - if self.keyboard_thread: - self.keyboard_thread.join(timeout=2.0) - logger.info("Keyboard controls stopped") - - def init_locomotion(self): - """Test locomotion control sequence: home arms -> move legs to default -> start policy thread.""" - if not self.config.locomotion_control: - logger.warning("locomotion_control is False, cannot run test sequence") - return - - logger.info("Starting locomotion test sequence...") - - # 2. Move legs to default position - self.locomotion_move_to_default_pos() - - # 3. Wait 3 seconds - time.sleep(3.0) - - # 4. Hold default leg position for 2 seconds - self.locomotion_default_pos_state() - - # 5. Start locomotion policy thread (runs in background) - logger.info("Starting locomotion policy control...") - self.start_locomotion_thread() - - logger.info("Locomotion test sequence complete! Policy is now running in background.") - logger.info("Use robot.stop_locomotion_thread() to stop the policy.") - - def init_groot_locomotion(self): - """Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions).""" - if not self.config.locomotion_control: - logger.warning("locomotion_control is False, cannot run GR00T init") - return - - logger.info("Starting GR00T locomotion initialization...") - - # Move legs to default position (same as regular locomotion) - self.locomotion_move_to_default_pos() - - # Wait 3 seconds - time.sleep(3.0) - - # Hold default leg position for 2 seconds - self.locomotion_default_pos_state() - - # Start locomotion policy thread (will use groot_locomotion_run) - logger.info("Starting GR00T locomotion policy control...") - self.start_locomotion_thread() - - logger.info("GR00T locomotion initialization complete! Policy is now running.") - logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")