mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
download policy from the hub in examples/unitree_g1/gr00t_locomotion
This commit is contained in:
@@ -7,6 +7,7 @@ This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||
and passing them to the robot class.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -15,6 +16,7 @@ from collections import deque
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
@@ -72,15 +74,32 @@ DOF_VEL_SCALE: float = 0.05
|
||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||
|
||||
|
||||
def load_groot_policies() -> tuple:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from ONNX files."""
|
||||
logger.info("Loading GR00T dual-policy system...")
|
||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
|
||||
def load_groot_policies(
|
||||
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||
"""
|
||||
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||
|
||||
# Download ONNX policies from Hugging Face Hub
|
||||
balance_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Balance.onnx",
|
||||
)
|
||||
walk_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Walk.onnx",
|
||||
)
|
||||
|
||||
# Load ONNX policies
|
||||
policy_balance = ort.InferenceSession(
|
||||
"examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx"
|
||||
)
|
||||
policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx")
|
||||
policy_balance = ort.InferenceSession(balance_path)
|
||||
policy_walk = ort.InferenceSession(walk_path)
|
||||
|
||||
logger.info("GR00T policies loaded successfully")
|
||||
|
||||
@@ -99,7 +118,6 @@ class GrootLocomotionController:
|
||||
"""
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
@@ -128,7 +146,6 @@ class GrootLocomotionController:
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
|
||||
# get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
@@ -150,15 +167,14 @@ class GrootLocomotionController:
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
|
||||
|
||||
# adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
@@ -173,12 +189,11 @@ class GrootLocomotionController:
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
#scale joint positions and velocities before policy inference
|
||||
# scale joint positions and velocities before policy inference
|
||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
|
||||
# build single frame observation
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||
self.groot_obs_single[3] = self.groot_height_cmd
|
||||
@@ -202,7 +217,7 @@ class GrootLocomotionController:
|
||||
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
||||
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
|
||||
|
||||
if cmd_magnitude < 0.05:
|
||||
# balance/standing policy for small commands
|
||||
selected_policy = self.policy_balance
|
||||
@@ -218,7 +233,7 @@ class GrootLocomotionController:
|
||||
# transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# command motors
|
||||
# command motors
|
||||
for i in range(15):
|
||||
motor_idx = i
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||
@@ -235,7 +250,7 @@ class GrootLocomotionController:
|
||||
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||
|
||||
#send action to robot
|
||||
# send action to robot
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
@@ -298,7 +313,9 @@ class GrootLocomotionController:
|
||||
alpha = i / num_step
|
||||
for motor_idx in range(dof_size):
|
||||
target_pos = default_pos[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
@@ -308,16 +325,25 @@ class GrootLocomotionController:
|
||||
time.sleep(self.robot.control_dt)
|
||||
logger.info("Reached default position (legs only)")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
#load policies
|
||||
policy_balance, policy_walk = load_groot_policies()
|
||||
|
||||
#initialize robot
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_GROOT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||
|
||||
# initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
#initialize gr00t locomotion controller
|
||||
# initialize gr00t locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
@@ -325,20 +351,20 @@ if __name__ == "__main__":
|
||||
config=config,
|
||||
)
|
||||
|
||||
#reset legs and start locomotion thread
|
||||
# reset legs and start locomotion thread
|
||||
groot_controller.reset_robot()
|
||||
groot_controller.start_locomotion_thread()
|
||||
|
||||
#log status
|
||||
# log status
|
||||
logger.info("Robot initialized with GR00T locomotion policies")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
#keep robot alive
|
||||
# keep robot alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
groot_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
print("Done!")
|
||||
|
||||
Reference in New Issue
Block a user