download policy from the hub in examples/unitree_g1/gr00t_locomotion

This commit is contained in:
Michel Aractingi
2025-11-27 10:23:02 +01:00
parent 288cfc7f8e
commit 36ed02adfa
6 changed files with 164 additions and 82 deletions
+39 -13
View File
@@ -7,6 +7,7 @@ This example demonstrates the NEW pattern for loading GR00T policies externally
and passing them to the robot class. and passing them to the robot class.
""" """
import argparse
import logging import logging
import threading import threading
import time import time
@@ -15,6 +16,7 @@ from collections import deque
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch import torch
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
@@ -72,15 +74,32 @@ DOF_VEL_SCALE: float = 0.05
CMD_SCALE: list = [2.0, 2.0, 0.25] CMD_SCALE: list = [2.0, 2.0, 0.25]
def load_groot_policies() -> tuple: DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
"""Load GR00T dual-policy system (Balance + Walk) from ONNX files."""
logger.info("Loading GR00T dual-policy system...")
def load_groot_policies(
repo_id: str = DEFAULT_GROOT_REPO_ID,
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
Args:
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
"""
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
# Download ONNX policies from Hugging Face Hub
balance_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Balance.onnx",
)
walk_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Walk.onnx",
)
# Load ONNX policies # Load ONNX policies
policy_balance = ort.InferenceSession( policy_balance = ort.InferenceSession(balance_path)
"examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx" policy_walk = ort.InferenceSession(walk_path)
)
policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx")
logger.info("GR00T policies loaded successfully") logger.info("GR00T policies loaded successfully")
@@ -99,7 +118,6 @@ class GrootLocomotionController:
""" """
def __init__(self, policy_balance, policy_walk, robot, config): def __init__(self, policy_balance, policy_walk, robot, config):
self.policy_balance = policy_balance self.policy_balance = policy_balance
self.policy_walk = policy_walk self.policy_walk = policy_walk
self.robot = robot self.robot = robot
@@ -128,7 +146,6 @@ class GrootLocomotionController:
logger.info("GrootLocomotionController initialized") logger.info("GrootLocomotionController initialized")
def groot_locomotion_run(self): def groot_locomotion_run(self):
# get current observation # get current observation
robot_state = self.robot.get_observation() robot_state = self.robot.get_observation()
@@ -158,7 +175,6 @@ class GrootLocomotionController:
self.groot_qj_all[i] = robot_state.motor_state[i].q self.groot_qj_all[i] = robot_state.motor_state[i].q
self.groot_dqj_all[i] = robot_state.motor_state[i].dq self.groot_dqj_all[i] = robot_state.motor_state[i].dq
# adapt observation for g1_23dof # adapt observation for g1_23dof
for idx in MISSING_JOINTS: for idx in MISSING_JOINTS:
self.groot_qj_all[idx] = 0.0 self.groot_qj_all[idx] = 0.0
@@ -178,7 +194,6 @@ class GrootLocomotionController:
dqj_obs = dqj_obs * DOF_VEL_SCALE dqj_obs = dqj_obs * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# build single frame observation # build single frame observation
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE) self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
self.groot_obs_single[3] = self.groot_height_cmd self.groot_obs_single[3] = self.groot_height_cmd
@@ -298,7 +313,9 @@ class GrootLocomotionController:
alpha = i / num_step alpha = i / num_step
for motor_idx in range(dof_size): for motor_idx in range(dof_size):
target_pos = default_pos[motor_idx] target_pos = default_pos[motor_idx]
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha self.robot.msg.motor_cmd[motor_idx].q = (
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
)
self.robot.msg.motor_cmd[motor_idx].qd = 0 self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx] self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx] self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
@@ -308,10 +325,19 @@ class GrootLocomotionController:
time.sleep(self.robot.control_dt) time.sleep(self.robot.control_dt)
logger.info("Reached default position (legs only)") logger.info("Reached default position (legs only)")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_GROOT_REPO_ID,
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
)
args = parser.parse_args()
# load policies # load policies
policy_balance, policy_walk = load_groot_policies() policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
# initialize robot # initialize robot
config = UnitreeG1Config() config = UnitreeG1Config()
@@ -15,9 +15,6 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
@@ -27,27 +24,85 @@ from ..config import RobotConfig
class UnitreeG1Config(RobotConfig): class UnitreeG1Config(RobotConfig):
# id: str = "unitree_g1" # id: str = "unitree_g1"
kp: list = field(default_factory=lambda: [ kp: list = field(
150, 150, 150, 300, 40, 40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll default_factory=lambda: [
150, 150, 150, 300, 40, 40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll 150,
250, 250, 250, # Waist yaw, roll, pitch 150,
80, 80, 80, 80, # Left shoulder pitch, roll, yaw, elbow (kp_low) 150,
40, 40, 40, # Left wrist roll, pitch, yaw (kp_wrist) 300,
80, 80, 80, 80, # Right shoulder pitch, roll, yaw, elbow (kp_low) 40,
40, 40, 40, # Right wrist roll, pitch, yaw (kp_wrist) 40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
80, 80, 80, 80, 80, 80, # Other 150,
]) 150,
150,
300,
40,
40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
250,
250,
250, # Waist yaw, roll, pitch
80,
80,
80,
80, # Left shoulder pitch, roll, yaw, elbow (kp_low)
40,
40,
40, # Left wrist roll, pitch, yaw (kp_wrist)
80,
80,
80,
80, # Right shoulder pitch, roll, yaw, elbow (kp_low)
40,
40,
40, # Right wrist roll, pitch, yaw (kp_wrist)
80,
80,
80,
80,
80,
80, # Other
]
)
kd: list = field(default_factory=lambda: [ kd: list = field(
2, 2, 2, 4, 2, 2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll default_factory=lambda: [
2, 2, 2, 4, 2, 2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll 2,
5, 5, 5, # Waist yaw, roll, pitch 2,
3, 3, 3, 3, # Left shoulder pitch, roll, yaw, elbow (kd_low) 2,
1.5, 1.5, 1.5, # Left wrist roll, pitch, yaw (kd_wrist) 4,
3, 3, 3, 3, # Right shoulder pitch, roll, yaw, elbow (kd_low) 2,
1.5, 1.5, 1.5, # Right wrist roll, pitch, yaw (kd_wrist) 2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
3, 3, 3, 3, 3, 3, # Other 2,
]) 2,
2,
4,
2,
2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
5,
5,
5, # Waist yaw, roll, pitch
3,
3,
3,
3, # Left shoulder pitch, roll, yaw, elbow (kd_low)
1.5,
1.5,
1.5, # Left wrist roll, pitch, yaw (kd_wrist)
3,
3,
3,
3, # Right shoulder pitch, roll, yaw, elbow (kd_low)
1.5,
1.5,
1.5, # Right wrist roll, pitch, yaw (kd_wrist)
3,
3,
3,
3,
3,
3, # Other
]
)
control_dt = 1.0 / 250.0 # 250Hz control_dt = 1.0 / 250.0 # 250Hz
+2 -1
View File
@@ -1,5 +1,6 @@
from enum import IntEnum from enum import IntEnum
class G1_29_JointArmIndex(IntEnum): class G1_29_JointArmIndex(IntEnum):
# Left arm # Left arm
kLeftShoulderPitch = 15 kLeftShoulderPitch = 15
@@ -19,8 +20,8 @@ class G1_29_JointArmIndex(IntEnum):
kRightWristPitch = 27 kRightWristPitch = 27
kRightWristYaw = 28 kRightWristYaw = 28
class G1_29_JointIndex(IntEnum):
class G1_29_JointIndex(IntEnum):
# Left leg # Left leg
kLeftHipPitch = 0 kLeftHipPitch = 0
kLeftHipRoll = 1 kLeftHipRoll = 1
@@ -16,7 +16,9 @@ LOWCMD_PORT = 6000
LOWSTATE_PORT = 6001 LOWSTATE_PORT = 6001
def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read observation from DDS and send to server def state_forward_loop(
lowstate_sub, lowstate_sock, state_period: float
): # read observation from DDS and send to server
last_state_time = 0.0 last_state_time = 0.0
while True: while True:
@@ -38,7 +40,6 @@ def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read o
def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC): # send action to robot def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC): # send action to robot
while True: while True:
payload = lowcmd_sock.recv() payload = lowcmd_sock.recv()
topic, cmd = pickle.loads(payload) topic, cmd = pickle.loads(payload)
@@ -52,7 +53,6 @@ def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC):#send action to ro
pass pass
def main(): def main():
# initialize DDS # initialize DDS
ChannelFactoryInitialize(0) ChannelFactoryInitialize(0)
+4 -4
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import struct import struct
import threading import threading
@@ -52,7 +51,6 @@ H1_2_Num_Motors = 35
H1_Num_Motors = 20 H1_Num_Motors = 20
class MotorState: class MotorState:
def __init__(self): def __init__(self):
self.q = None # position self.q = None # position
@@ -69,6 +67,7 @@ class IMUState:
self.rpy = None # [roll, pitch, yaw] (rad) self.rpy = None # [roll, pitch, yaw] (rad)
self.temperature = None # IMU temperature self.temperature = None # IMU temperature
# g1 observation class # g1 observation class
class G1_29_LowState: class G1_29_LowState:
def __init__(self): def __init__(self):
@@ -97,7 +96,6 @@ class UnitreeG1(Robot):
# unitree remote controller # unitree remote controller
class RemoteController: class RemoteController:
def __init__(self): def __init__(self):
self.lx = 0 self.lx = 0
self.ly = 0 self.ly = 0
@@ -256,7 +254,9 @@ class UnitreeG1(Robot):
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz) gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
return gravity_orientation return gravity_orientation
def transform_imu_data(self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega):#transform imu data from torso to pelvis frame def transform_imu_data(
self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega
): # transform imu data from torso to pelvis frame
"""Transform IMU data from torso to pelvis frame.""" """Transform IMU data from torso to pelvis frame."""
RzWaist = R.from_euler("z", waist_yaw).as_matrix() RzWaist = R.from_euler("z", waist_yaw).as_matrix()
R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix() R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix()
@@ -1,4 +1,5 @@
import pickle import pickle
import zmq import zmq
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
@@ -12,8 +13,7 @@ LOWSTATE_PORT = 6001
def ChannelFactoryInitialize(*args, **kwargs): # DDS to socket bridge def ChannelFactoryInitialize(*args, **kwargs): # DDS to socket bridge
global _ctx, _lowcmd_sock, _lowstate_sock\ global _ctx, _lowcmd_sock, _lowstate_sock
# read socket config # read socket config
config = UnitreeG1Config() config = UnitreeG1Config()
robot_ip = config.robot_ip robot_ip = config.robot_ip