mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
download policy from the hub in examples/unitree_g1/gr00t_locomotion
This commit is contained in:
@@ -7,6 +7,7 @@ This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||
and passing them to the robot class.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -15,6 +16,7 @@ from collections import deque
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
@@ -72,15 +74,32 @@ DOF_VEL_SCALE: float = 0.05
|
||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||
|
||||
|
||||
def load_groot_policies() -> tuple:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from ONNX files."""
|
||||
logger.info("Loading GR00T dual-policy system...")
|
||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
|
||||
def load_groot_policies(
|
||||
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||
"""
|
||||
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||
|
||||
# Download ONNX policies from Hugging Face Hub
|
||||
balance_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Balance.onnx",
|
||||
)
|
||||
walk_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Walk.onnx",
|
||||
)
|
||||
|
||||
# Load ONNX policies
|
||||
policy_balance = ort.InferenceSession(
|
||||
"examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx"
|
||||
)
|
||||
policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx")
|
||||
policy_balance = ort.InferenceSession(balance_path)
|
||||
policy_walk = ort.InferenceSession(walk_path)
|
||||
|
||||
logger.info("GR00T policies loaded successfully")
|
||||
|
||||
@@ -99,7 +118,6 @@ class GrootLocomotionController:
|
||||
"""
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
@@ -128,7 +146,6 @@ class GrootLocomotionController:
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
|
||||
# get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
@@ -158,7 +175,6 @@ class GrootLocomotionController:
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
|
||||
|
||||
# adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
@@ -178,7 +194,6 @@ class GrootLocomotionController:
|
||||
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
|
||||
# build single frame observation
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||
self.groot_obs_single[3] = self.groot_height_cmd
|
||||
@@ -298,7 +313,9 @@ class GrootLocomotionController:
|
||||
alpha = i / num_step
|
||||
for motor_idx in range(dof_size):
|
||||
target_pos = default_pos[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
@@ -308,10 +325,19 @@ class GrootLocomotionController:
|
||||
time.sleep(self.robot.control_dt)
|
||||
logger.info("Reached default position (legs only)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_GROOT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# load policies
|
||||
policy_balance, policy_walk = load_groot_policies()
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||
|
||||
# initialize robot
|
||||
config = UnitreeG1Config()
|
||||
|
||||
@@ -15,9 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -27,27 +24,85 @@ from ..config import RobotConfig
|
||||
class UnitreeG1Config(RobotConfig):
|
||||
# id: str = "unitree_g1"
|
||||
|
||||
kp: list = field(default_factory=lambda: [
|
||||
150, 150, 150, 300, 40, 40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
150, 150, 150, 300, 40, 40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
250, 250, 250, # Waist yaw, roll, pitch
|
||||
80, 80, 80, 80, # Left shoulder pitch, roll, yaw, elbow (kp_low)
|
||||
40, 40, 40, # Left wrist roll, pitch, yaw (kp_wrist)
|
||||
80, 80, 80, 80, # Right shoulder pitch, roll, yaw, elbow (kp_low)
|
||||
40, 40, 40, # Right wrist roll, pitch, yaw (kp_wrist)
|
||||
80, 80, 80, 80, 80, 80, # Other
|
||||
])
|
||||
kp: list = field(
|
||||
default_factory=lambda: [
|
||||
150,
|
||||
150,
|
||||
150,
|
||||
300,
|
||||
40,
|
||||
40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
150,
|
||||
150,
|
||||
150,
|
||||
300,
|
||||
40,
|
||||
40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
250,
|
||||
250,
|
||||
250, # Waist yaw, roll, pitch
|
||||
80,
|
||||
80,
|
||||
80,
|
||||
80, # Left shoulder pitch, roll, yaw, elbow (kp_low)
|
||||
40,
|
||||
40,
|
||||
40, # Left wrist roll, pitch, yaw (kp_wrist)
|
||||
80,
|
||||
80,
|
||||
80,
|
||||
80, # Right shoulder pitch, roll, yaw, elbow (kp_low)
|
||||
40,
|
||||
40,
|
||||
40, # Right wrist roll, pitch, yaw (kp_wrist)
|
||||
80,
|
||||
80,
|
||||
80,
|
||||
80,
|
||||
80,
|
||||
80, # Other
|
||||
]
|
||||
)
|
||||
|
||||
kd: list = field(default_factory=lambda: [
|
||||
2, 2, 2, 4, 2, 2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
2, 2, 2, 4, 2, 2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
5, 5, 5, # Waist yaw, roll, pitch
|
||||
3, 3, 3, 3, # Left shoulder pitch, roll, yaw, elbow (kd_low)
|
||||
1.5, 1.5, 1.5, # Left wrist roll, pitch, yaw (kd_wrist)
|
||||
3, 3, 3, 3, # Right shoulder pitch, roll, yaw, elbow (kd_low)
|
||||
1.5, 1.5, 1.5, # Right wrist roll, pitch, yaw (kd_wrist)
|
||||
3, 3, 3, 3, 3, 3, # Other
|
||||
])
|
||||
kd: list = field(
|
||||
default_factory=lambda: [
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
4,
|
||||
2,
|
||||
2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
4,
|
||||
2,
|
||||
2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||
5,
|
||||
5,
|
||||
5, # Waist yaw, roll, pitch
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3, # Left shoulder pitch, roll, yaw, elbow (kd_low)
|
||||
1.5,
|
||||
1.5,
|
||||
1.5, # Left wrist roll, pitch, yaw (kd_wrist)
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3, # Right shoulder pitch, roll, yaw, elbow (kd_low)
|
||||
1.5,
|
||||
1.5,
|
||||
1.5, # Right wrist roll, pitch, yaw (kd_wrist)
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
3, # Other
|
||||
]
|
||||
)
|
||||
|
||||
control_dt = 1.0 / 250.0 # 250Hz
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
@@ -19,8 +20,8 @@ class G1_29_JointArmIndex(IntEnum):
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
# Left leg
|
||||
kLeftHipPitch = 0
|
||||
kLeftHipRoll = 1
|
||||
|
||||
@@ -16,7 +16,9 @@ LOWCMD_PORT = 6000
|
||||
LOWSTATE_PORT = 6001
|
||||
|
||||
|
||||
def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read observation from DDS and send to server
|
||||
def state_forward_loop(
|
||||
lowstate_sub, lowstate_sock, state_period: float
|
||||
): # read observation from DDS and send to server
|
||||
last_state_time = 0.0
|
||||
|
||||
while True:
|
||||
@@ -38,7 +40,6 @@ def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read o
|
||||
|
||||
|
||||
def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC): # send action to robot
|
||||
|
||||
while True:
|
||||
payload = lowcmd_sock.recv()
|
||||
topic, cmd = pickle.loads(payload)
|
||||
@@ -52,7 +53,6 @@ def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC):#send action to ro
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
# initialize DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
@@ -52,7 +51,6 @@ H1_2_Num_Motors = 35
|
||||
H1_Num_Motors = 20
|
||||
|
||||
|
||||
|
||||
class MotorState:
|
||||
def __init__(self):
|
||||
self.q = None # position
|
||||
@@ -69,6 +67,7 @@ class IMUState:
|
||||
self.rpy = None # [roll, pitch, yaw] (rad)
|
||||
self.temperature = None # IMU temperature
|
||||
|
||||
|
||||
# g1 observation class
|
||||
class G1_29_LowState:
|
||||
def __init__(self):
|
||||
@@ -97,7 +96,6 @@ class UnitreeG1(Robot):
|
||||
|
||||
# unitree remote controller
|
||||
class RemoteController:
|
||||
|
||||
def __init__(self):
|
||||
self.lx = 0
|
||||
self.ly = 0
|
||||
@@ -256,7 +254,9 @@ class UnitreeG1(Robot):
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
|
||||
def transform_imu_data(self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega):#transform imu data from torso to pelvis frame
|
||||
def transform_imu_data(
|
||||
self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega
|
||||
): # transform imu data from torso to pelvis frame
|
||||
"""Transform IMU data from torso to pelvis frame."""
|
||||
RzWaist = R.from_euler("z", waist_yaw).as_matrix()
|
||||
R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pickle
|
||||
|
||||
import zmq
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
@@ -12,8 +13,7 @@ LOWSTATE_PORT = 6001
|
||||
|
||||
|
||||
def ChannelFactoryInitialize(*args, **kwargs): # DDS to socket bridge
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock\
|
||||
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||
# read socket config
|
||||
config = UnitreeG1Config()
|
||||
robot_ip = config.robot_ip
|
||||
|
||||
Reference in New Issue
Block a user