From 36ed02adfa703b827edb5cf1c4d0fecf89eee297 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 27 Nov 2025 10:23:02 +0100 Subject: [PATCH] download policy from the hub in `examples/unitree_g1/gr00t_locomotion` --- examples/unitree_g1/gr00t_locomotion.py | 84 +++++++++----- .../robots/unitree_g1/config_unitree_g1.py | 105 +++++++++++++----- src/lerobot/robots/unitree_g1/g1_utils.py | 3 +- src/lerobot/robots/unitree_g1/robot_server.py | 16 +-- src/lerobot/robots/unitree_g1/unitree_g1.py | 24 ++-- .../robots/unitree_g1/unitree_sdk2_socket.py | 14 +-- 6 files changed, 164 insertions(+), 82 deletions(-) diff --git a/examples/unitree_g1/gr00t_locomotion.py b/examples/unitree_g1/gr00t_locomotion.py index 3406539c2..b8510d694 100644 --- a/examples/unitree_g1/gr00t_locomotion.py +++ b/examples/unitree_g1/gr00t_locomotion.py @@ -7,6 +7,7 @@ This example demonstrates the NEW pattern for loading GR00T policies externally and passing them to the robot class. """ +import argparse import logging import threading import time @@ -15,6 +16,7 @@ from collections import deque import numpy as np import onnxruntime as ort import torch +from huggingface_hub import hf_hub_download from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 @@ -72,15 +74,32 @@ DOF_VEL_SCALE: float = 0.05 CMD_SCALE: list = [2.0, 2.0, 0.25] -def load_groot_policies() -> tuple: - """Load GR00T dual-policy system (Balance + Walk) from ONNX files.""" - logger.info("Loading GR00T dual-policy system...") +DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1" + + +def load_groot_policies( + repo_id: str = DEFAULT_GROOT_REPO_ID, +) -> tuple[ort.InferenceSession, ort.InferenceSession]: + """Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub. + + Args: + repo_id: Hugging Face Hub repository ID containing the ONNX policies. + """ + logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...") + + # Download ONNX policies from Hugging Face Hub + balance_path = hf_hub_download( + repo_id=repo_id, + filename="GR00T-WholeBodyControl-Balance.onnx", + ) + walk_path = hf_hub_download( + repo_id=repo_id, + filename="GR00T-WholeBodyControl-Walk.onnx", + ) # Load ONNX policies - policy_balance = ort.InferenceSession( - "examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx" - ) - policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx") + policy_balance = ort.InferenceSession(balance_path) + policy_walk = ort.InferenceSession(walk_path) logger.info("GR00T policies loaded successfully") @@ -99,7 +118,6 @@ class GrootLocomotionController: """ def __init__(self, policy_balance, policy_walk, robot, config): - self.policy_balance = policy_balance self.policy_walk = policy_walk self.robot = robot @@ -128,7 +146,6 @@ class GrootLocomotionController: logger.info("GrootLocomotionController initialized") def groot_locomotion_run(self): - # get current observation robot_state = self.robot.get_observation() @@ -150,15 +167,14 @@ class GrootLocomotionController: self.robot.remote_controller.rx = 0.0 self.robot.remote_controller.ry = 0.0 - self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward - self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right - self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate + self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward + self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right + self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate for i in range(29): self.groot_qj_all[i] = robot_state.motor_state[i].q self.groot_dqj_all[i] = robot_state.motor_state[i].dq - # adapt observation for g1_23dof for idx in MISSING_JOINTS: self.groot_qj_all[idx] = 0.0 @@ -173,12 +189,11 @@ class GrootLocomotionController: ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) gravity_orientation = self.robot.get_gravity_orientation(quat) - #scale joint positions and velocities before policy inference + # scale joint positions and velocities before policy inference qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE dqj_obs = dqj_obs * DOF_VEL_SCALE ang_vel_scaled = ang_vel * ANG_VEL_SCALE - # build single frame observation self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE) self.groot_obs_single[3] = self.groot_height_cmd @@ -202,7 +217,7 @@ class GrootLocomotionController: obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0) cmd_magnitude = np.linalg.norm(self.locomotion_cmd) - + if cmd_magnitude < 0.05: # balance/standing policy for small commands selected_policy = self.policy_balance @@ -218,7 +233,7 @@ class GrootLocomotionController: # transform action back to target joint positions target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE - # command motors + # command motors for i in range(15): motor_idx = i self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i] @@ -235,7 +250,7 @@ class GrootLocomotionController: self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx] self.robot.msg.motor_cmd[joint_idx].tau = 0 - #send action to robot + # send action to robot self.robot.send_action(self.robot.msg) def _locomotion_thread_loop(self): @@ -298,7 +313,9 @@ class GrootLocomotionController: alpha = i / num_step for motor_idx in range(dof_size): target_pos = default_pos[motor_idx] - self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha + self.robot.msg.motor_cmd[motor_idx].q = ( + init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha + ) self.robot.msg.motor_cmd[motor_idx].qd = 0 self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx] self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx] @@ -308,16 +325,25 @@ class GrootLocomotionController: time.sleep(self.robot.control_dt) logger.info("Reached default position (legs only)") -if __name__ == "__main__": - - #load policies - policy_balance, policy_walk = load_groot_policies() - #initialize robot +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1") + parser.add_argument( + "--repo-id", + type=str, + default=DEFAULT_GROOT_REPO_ID, + help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})", + ) + args = parser.parse_args() + + # load policies + policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id) + + # initialize robot config = UnitreeG1Config() robot = UnitreeG1(config) - #initialize gr00t locomotion controller + # initialize gr00t locomotion controller groot_controller = GrootLocomotionController( policy_balance=policy_balance, policy_walk=policy_walk, @@ -325,20 +351,20 @@ if __name__ == "__main__": config=config, ) - #reset legs and start locomotion thread + # reset legs and start locomotion thread groot_controller.reset_robot() groot_controller.start_locomotion_thread() - #log status + # log status logger.info("Robot initialized with GR00T locomotion policies") logger.info("Locomotion controller running in background thread") logger.info("Press Ctrl+C to stop") - #keep robot alive + # keep robot alive try: while True: time.sleep(1.0) except KeyboardInterrupt: print("\nStopping locomotion...") groot_controller.stop_locomotion_thread() - print("Done!") \ No newline at end of file + print("Done!") diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index f43d4ce79..7b06d186a 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -15,9 +15,6 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any - -from lerobot.cameras import CameraConfig from ..config import RobotConfig @@ -27,29 +24,87 @@ from ..config import RobotConfig class UnitreeG1Config(RobotConfig): # id: str = "unitree_g1" - kp: list = field(default_factory=lambda: [ - 150, 150, 150, 300, 40, 40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll - 150, 150, 150, 300, 40, 40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll - 250, 250, 250, # Waist yaw, roll, pitch - 80, 80, 80, 80, # Left shoulder pitch, roll, yaw, elbow (kp_low) - 40, 40, 40, # Left wrist roll, pitch, yaw (kp_wrist) - 80, 80, 80, 80, # Right shoulder pitch, roll, yaw, elbow (kp_low) - 40, 40, 40, # Right wrist roll, pitch, yaw (kp_wrist) - 80, 80, 80, 80, 80, 80, # Other - ]) - - kd: list = field(default_factory=lambda: [ - 2, 2, 2, 4, 2, 2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll - 2, 2, 2, 4, 2, 2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll - 5, 5, 5, # Waist yaw, roll, pitch - 3, 3, 3, 3, # Left shoulder pitch, roll, yaw, elbow (kd_low) - 1.5, 1.5, 1.5, # Left wrist roll, pitch, yaw (kd_wrist) - 3, 3, 3, 3, # Right shoulder pitch, roll, yaw, elbow (kd_low) - 1.5, 1.5, 1.5, # Right wrist roll, pitch, yaw (kd_wrist) - 3, 3, 3, 3, 3, 3, # Other - ]) + kp: list = field( + default_factory=lambda: [ + 150, + 150, + 150, + 300, + 40, + 40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll + 150, + 150, + 150, + 300, + 40, + 40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll + 250, + 250, + 250, # Waist yaw, roll, pitch + 80, + 80, + 80, + 80, # Left shoulder pitch, roll, yaw, elbow (kp_low) + 40, + 40, + 40, # Left wrist roll, pitch, yaw (kp_wrist) + 80, + 80, + 80, + 80, # Right shoulder pitch, roll, yaw, elbow (kp_low) + 40, + 40, + 40, # Right wrist roll, pitch, yaw (kp_wrist) + 80, + 80, + 80, + 80, + 80, + 80, # Other + ] + ) - control_dt = 1.0 / 250.0 # 250Hz + kd: list = field( + default_factory=lambda: [ + 2, + 2, + 2, + 4, + 2, + 2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll + 2, + 2, + 2, + 4, + 2, + 2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll + 5, + 5, + 5, # Waist yaw, roll, pitch + 3, + 3, + 3, + 3, # Left shoulder pitch, roll, yaw, elbow (kd_low) + 1.5, + 1.5, + 1.5, # Left wrist roll, pitch, yaw (kd_wrist) + 3, + 3, + 3, + 3, # Right shoulder pitch, roll, yaw, elbow (kd_low) + 1.5, + 1.5, + 1.5, # Right wrist roll, pitch, yaw (kd_wrist) + 3, + 3, + 3, + 3, + 3, + 3, # Other + ] + ) + + control_dt = 1.0 / 250.0 # 250Hz # socket config for ZMQ bridge robot_ip: str = "172.18.129.215" diff --git a/src/lerobot/robots/unitree_g1/g1_utils.py b/src/lerobot/robots/unitree_g1/g1_utils.py index f2e39d042..91485fe6e 100644 --- a/src/lerobot/robots/unitree_g1/g1_utils.py +++ b/src/lerobot/robots/unitree_g1/g1_utils.py @@ -1,5 +1,6 @@ from enum import IntEnum + class G1_29_JointArmIndex(IntEnum): # Left arm kLeftShoulderPitch = 15 @@ -19,8 +20,8 @@ class G1_29_JointArmIndex(IntEnum): kRightWristPitch = 27 kRightWristYaw = 28 -class G1_29_JointIndex(IntEnum): +class G1_29_JointIndex(IntEnum): # Left leg kLeftHipPitch = 0 kLeftHipRoll = 1 diff --git a/src/lerobot/robots/unitree_g1/robot_server.py b/src/lerobot/robots/unitree_g1/robot_server.py index 2e8680eda..8eef81b9e 100644 --- a/src/lerobot/robots/unitree_g1/robot_server.py +++ b/src/lerobot/robots/unitree_g1/robot_server.py @@ -9,14 +9,16 @@ from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublish from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState from unitree_sdk2py.utils.crc import CRC -kTopicLowCommand_Debug = "rt/lowcmd" #action to robot -kTopicLowState = "rt/lowstate" #observation from robot +kTopicLowCommand_Debug = "rt/lowcmd" # action to robot +kTopicLowState = "rt/lowstate" # observation from robot LOWCMD_PORT = 6000 LOWSTATE_PORT = 6001 -def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read observation from DDS and send to server +def state_forward_loop( + lowstate_sub, lowstate_sock, state_period: float +): # read observation from DDS and send to server last_state_time = 0.0 while True: @@ -27,7 +29,7 @@ def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read o now = time.time() # optional downsampling (if robot dds rate > state_period) - if now - last_state_time >= state_period: + if now - last_state_time >= state_period: payload = pickle.dumps((kTopicLowState, msg), protocol=pickle.HIGHEST_PROTOCOL) try: lowstate_sock.send(payload, zmq.NOBLOCK) @@ -37,8 +39,7 @@ def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read o last_state_time = now -def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC):#send action to robot - +def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC): # send action to robot while True: payload = lowcmd_sock.recv() topic, cmd = pickle.loads(payload) @@ -50,7 +51,6 @@ def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC):#send action to ro lowcmd_pub_debug.Write(cmd) else: pass - def main(): @@ -73,7 +73,7 @@ def main(): # initialize DDS publisher lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd) lowcmd_pub_debug.Init() - + # initialize DDS subscriber lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState) lowstate_sub.Init() diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index bac603309..0a68833f4 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import struct import threading @@ -52,11 +51,10 @@ H1_2_Num_Motors = 35 H1_Num_Motors = 20 - class MotorState: def __init__(self): - self.q = None # position - self.dq = None # velocity + self.q = None # position + self.dq = None # velocity self.tau_est = None # estimated torque self.temperature = None # motor temperature @@ -69,7 +67,8 @@ class IMUState: self.rpy = None # [roll, pitch, yaw] (rad) self.temperature = None # IMU temperature -#g1 observation class + +# g1 observation class class G1_29_LowState: def __init__(self): self.motor_state = [MotorState() for _ in range(G1_29_Num_Motors)] @@ -95,9 +94,8 @@ class UnitreeG1(Robot): config_class = UnitreeG1Config name = "unitree_g1" - #unitree remote controller + # unitree remote controller class RemoteController: - def __init__(self): self.lx = 0 self.ly = 0 @@ -165,7 +163,7 @@ class UnitreeG1(Robot): # Initialize remote controller self.remote_controller = self.RemoteController() - def _subscribe_motor_state(self): #polls robot state @ 250Hz + def _subscribe_motor_state(self): # polls robot state @ 250Hz while True: start_time = time.time() msg = self.lowstate_subscriber.Read() @@ -200,13 +198,13 @@ class UnitreeG1(Robot): def action_features(self) -> dict[str, type]: return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex} - def calibrate(self) -> None:#robot is already calibrated + def calibrate(self) -> None: # robot is already calibrated pass def configure(self) -> None: pass - def connect(self, calibrate: bool = True) -> None: #connect to DDS + def connect(self, calibrate: bool = True) -> None: # connect to DDS ChannelFactoryInitialize(0) def disconnect(self): @@ -243,7 +241,7 @@ class UnitreeG1(Robot): self.msg.crc = self.crc.Crc(action) self.lowcmd_publisher.Write(action) - def get_gravity_orientation(self, quaternion):#get gravity orientation from quaternion + def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion """Get gravity orientation from quaternion.""" qw = quaternion[0] qx = quaternion[1] @@ -256,7 +254,9 @@ class UnitreeG1(Robot): gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz) return gravity_orientation - def transform_imu_data(self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega):#transform imu data from torso to pelvis frame + def transform_imu_data( + self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega + ): # transform imu data from torso to pelvis frame """Transform IMU data from torso to pelvis frame.""" RzWaist = R.from_euler("z", waist_yaw).as_matrix() R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix() diff --git a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py index e69abd520..40bcdca34 100644 --- a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py +++ b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py @@ -1,4 +1,5 @@ import pickle + import zmq from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config @@ -11,18 +12,17 @@ LOWCMD_PORT = 6000 LOWSTATE_PORT = 6001 -def ChannelFactoryInitialize(*args, **kwargs):#DDS to socket bridge - global _ctx, _lowcmd_sock, _lowstate_sock\ - +def ChannelFactoryInitialize(*args, **kwargs): # DDS to socket bridge + global _ctx, _lowcmd_sock, _lowstate_sock # read socket config config = UnitreeG1Config() robot_ip = config.robot_ip - + _ctx = zmq.Context.instance() # lowcmd: robot action _lowcmd_sock = _ctx.socket(zmq.PUSH) - _lowcmd_sock.setsockopt(zmq.CONFLATE, 1)#keep only last message + _lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message _lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}") # lowstate: robot observation @@ -32,7 +32,7 @@ def ChannelFactoryInitialize(*args, **kwargs):#DDS to socket bridge _lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "") -class ChannelPublisher: #send action to robot +class ChannelPublisher: # send action to robot def __init__(self, topic, msg_type): self.topic = topic self.msg_type = msg_type @@ -44,7 +44,7 @@ class ChannelPublisher: #send action to robot _lowcmd_sock.send(pickle.dumps((self.topic, msg))) -class ChannelSubscriber: #read observation from robot +class ChannelSubscriber: # read observation from robot def __init__(self, topic, msg_type): self.topic = topic self.msg_type = msg_type