mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
download policy from the hub in examples/unitree_g1/gr00t_locomotion
This commit is contained in:
@@ -7,6 +7,7 @@ This example demonstrates the NEW pattern for loading GR00T policies externally
|
|||||||
and passing them to the robot class.
|
and passing them to the robot class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -15,6 +16,7 @@ from collections import deque
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||||
@@ -72,15 +74,32 @@ DOF_VEL_SCALE: float = 0.05
|
|||||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||||
|
|
||||||
|
|
||||||
def load_groot_policies() -> tuple:
|
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||||
"""Load GR00T dual-policy system (Balance + Walk) from ONNX files."""
|
|
||||||
logger.info("Loading GR00T dual-policy system...")
|
|
||||||
|
def load_groot_policies(
|
||||||
|
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||||
|
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||||
|
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||||
|
|
||||||
|
# Download ONNX policies from Hugging Face Hub
|
||||||
|
balance_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="GR00T-WholeBodyControl-Balance.onnx",
|
||||||
|
)
|
||||||
|
walk_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="GR00T-WholeBodyControl-Walk.onnx",
|
||||||
|
)
|
||||||
|
|
||||||
# Load ONNX policies
|
# Load ONNX policies
|
||||||
policy_balance = ort.InferenceSession(
|
policy_balance = ort.InferenceSession(balance_path)
|
||||||
"examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx"
|
policy_walk = ort.InferenceSession(walk_path)
|
||||||
)
|
|
||||||
policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx")
|
|
||||||
|
|
||||||
logger.info("GR00T policies loaded successfully")
|
logger.info("GR00T policies loaded successfully")
|
||||||
|
|
||||||
@@ -99,7 +118,6 @@ class GrootLocomotionController:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||||
|
|
||||||
self.policy_balance = policy_balance
|
self.policy_balance = policy_balance
|
||||||
self.policy_walk = policy_walk
|
self.policy_walk = policy_walk
|
||||||
self.robot = robot
|
self.robot = robot
|
||||||
@@ -128,7 +146,6 @@ class GrootLocomotionController:
|
|||||||
logger.info("GrootLocomotionController initialized")
|
logger.info("GrootLocomotionController initialized")
|
||||||
|
|
||||||
def groot_locomotion_run(self):
|
def groot_locomotion_run(self):
|
||||||
|
|
||||||
# get current observation
|
# get current observation
|
||||||
robot_state = self.robot.get_observation()
|
robot_state = self.robot.get_observation()
|
||||||
|
|
||||||
@@ -150,15 +167,14 @@ class GrootLocomotionController:
|
|||||||
self.robot.remote_controller.rx = 0.0
|
self.robot.remote_controller.rx = 0.0
|
||||||
self.robot.remote_controller.ry = 0.0
|
self.robot.remote_controller.ry = 0.0
|
||||||
|
|
||||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||||
|
|
||||||
for i in range(29):
|
for i in range(29):
|
||||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||||
|
|
||||||
|
|
||||||
# adapt observation for g1_23dof
|
# adapt observation for g1_23dof
|
||||||
for idx in MISSING_JOINTS:
|
for idx in MISSING_JOINTS:
|
||||||
self.groot_qj_all[idx] = 0.0
|
self.groot_qj_all[idx] = 0.0
|
||||||
@@ -173,12 +189,11 @@ class GrootLocomotionController:
|
|||||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||||
|
|
||||||
#scale joint positions and velocities before policy inference
|
# scale joint positions and velocities before policy inference
|
||||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||||
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||||
|
|
||||||
|
|
||||||
# build single frame observation
|
# build single frame observation
|
||||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||||
self.groot_obs_single[3] = self.groot_height_cmd
|
self.groot_obs_single[3] = self.groot_height_cmd
|
||||||
@@ -202,7 +217,7 @@ class GrootLocomotionController:
|
|||||||
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
||||||
|
|
||||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||||
|
|
||||||
if cmd_magnitude < 0.05:
|
if cmd_magnitude < 0.05:
|
||||||
# balance/standing policy for small commands
|
# balance/standing policy for small commands
|
||||||
selected_policy = self.policy_balance
|
selected_policy = self.policy_balance
|
||||||
@@ -218,7 +233,7 @@ class GrootLocomotionController:
|
|||||||
# transform action back to target joint positions
|
# transform action back to target joint positions
|
||||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
||||||
|
|
||||||
# command motors
|
# command motors
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
motor_idx = i
|
motor_idx = i
|
||||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||||
@@ -235,7 +250,7 @@ class GrootLocomotionController:
|
|||||||
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
||||||
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||||
|
|
||||||
#send action to robot
|
# send action to robot
|
||||||
self.robot.send_action(self.robot.msg)
|
self.robot.send_action(self.robot.msg)
|
||||||
|
|
||||||
def _locomotion_thread_loop(self):
|
def _locomotion_thread_loop(self):
|
||||||
@@ -298,7 +313,9 @@ class GrootLocomotionController:
|
|||||||
alpha = i / num_step
|
alpha = i / num_step
|
||||||
for motor_idx in range(dof_size):
|
for motor_idx in range(dof_size):
|
||||||
target_pos = default_pos[motor_idx]
|
target_pos = default_pos[motor_idx]
|
||||||
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||||
|
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||||
|
)
|
||||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||||
@@ -308,16 +325,25 @@ class GrootLocomotionController:
|
|||||||
time.sleep(self.robot.control_dt)
|
time.sleep(self.robot.control_dt)
|
||||||
logger.info("Reached default position (legs only)")
|
logger.info("Reached default position (legs only)")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
#load policies
|
|
||||||
policy_balance, policy_walk = load_groot_policies()
|
|
||||||
|
|
||||||
#initialize robot
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_GROOT_REPO_ID,
|
||||||
|
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# load policies
|
||||||
|
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||||
|
|
||||||
|
# initialize robot
|
||||||
config = UnitreeG1Config()
|
config = UnitreeG1Config()
|
||||||
robot = UnitreeG1(config)
|
robot = UnitreeG1(config)
|
||||||
|
|
||||||
#initialize gr00t locomotion controller
|
# initialize gr00t locomotion controller
|
||||||
groot_controller = GrootLocomotionController(
|
groot_controller = GrootLocomotionController(
|
||||||
policy_balance=policy_balance,
|
policy_balance=policy_balance,
|
||||||
policy_walk=policy_walk,
|
policy_walk=policy_walk,
|
||||||
@@ -325,20 +351,20 @@ if __name__ == "__main__":
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
#reset legs and start locomotion thread
|
# reset legs and start locomotion thread
|
||||||
groot_controller.reset_robot()
|
groot_controller.reset_robot()
|
||||||
groot_controller.start_locomotion_thread()
|
groot_controller.start_locomotion_thread()
|
||||||
|
|
||||||
#log status
|
# log status
|
||||||
logger.info("Robot initialized with GR00T locomotion policies")
|
logger.info("Robot initialized with GR00T locomotion policies")
|
||||||
logger.info("Locomotion controller running in background thread")
|
logger.info("Locomotion controller running in background thread")
|
||||||
logger.info("Press Ctrl+C to stop")
|
logger.info("Press Ctrl+C to stop")
|
||||||
|
|
||||||
#keep robot alive
|
# keep robot alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
time.sleep(1.0)
|
time.sleep(1.0)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nStopping locomotion...")
|
print("\nStopping locomotion...")
|
||||||
groot_controller.stop_locomotion_thread()
|
groot_controller.stop_locomotion_thread()
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|||||||
@@ -15,9 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
|
||||||
|
|
||||||
from ..config import RobotConfig
|
from ..config import RobotConfig
|
||||||
|
|
||||||
@@ -27,29 +24,87 @@ from ..config import RobotConfig
|
|||||||
class UnitreeG1Config(RobotConfig):
|
class UnitreeG1Config(RobotConfig):
|
||||||
# id: str = "unitree_g1"
|
# id: str = "unitree_g1"
|
||||||
|
|
||||||
kp: list = field(default_factory=lambda: [
|
kp: list = field(
|
||||||
150, 150, 150, 300, 40, 40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
default_factory=lambda: [
|
||||||
150, 150, 150, 300, 40, 40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
150,
|
||||||
250, 250, 250, # Waist yaw, roll, pitch
|
150,
|
||||||
80, 80, 80, 80, # Left shoulder pitch, roll, yaw, elbow (kp_low)
|
150,
|
||||||
40, 40, 40, # Left wrist roll, pitch, yaw (kp_wrist)
|
300,
|
||||||
80, 80, 80, 80, # Right shoulder pitch, roll, yaw, elbow (kp_low)
|
40,
|
||||||
40, 40, 40, # Right wrist roll, pitch, yaw (kp_wrist)
|
40, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||||
80, 80, 80, 80, 80, 80, # Other
|
150,
|
||||||
])
|
150,
|
||||||
|
150,
|
||||||
kd: list = field(default_factory=lambda: [
|
300,
|
||||||
2, 2, 2, 4, 2, 2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
40,
|
||||||
2, 2, 2, 4, 2, 2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
40, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||||
5, 5, 5, # Waist yaw, roll, pitch
|
250,
|
||||||
3, 3, 3, 3, # Left shoulder pitch, roll, yaw, elbow (kd_low)
|
250,
|
||||||
1.5, 1.5, 1.5, # Left wrist roll, pitch, yaw (kd_wrist)
|
250, # Waist yaw, roll, pitch
|
||||||
3, 3, 3, 3, # Right shoulder pitch, roll, yaw, elbow (kd_low)
|
80,
|
||||||
1.5, 1.5, 1.5, # Right wrist roll, pitch, yaw (kd_wrist)
|
80,
|
||||||
3, 3, 3, 3, 3, 3, # Other
|
80,
|
||||||
])
|
80, # Left shoulder pitch, roll, yaw, elbow (kp_low)
|
||||||
|
40,
|
||||||
|
40,
|
||||||
|
40, # Left wrist roll, pitch, yaw (kp_wrist)
|
||||||
|
80,
|
||||||
|
80,
|
||||||
|
80,
|
||||||
|
80, # Right shoulder pitch, roll, yaw, elbow (kp_low)
|
||||||
|
40,
|
||||||
|
40,
|
||||||
|
40, # Right wrist roll, pitch, yaw (kp_wrist)
|
||||||
|
80,
|
||||||
|
80,
|
||||||
|
80,
|
||||||
|
80,
|
||||||
|
80,
|
||||||
|
80, # Other
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
control_dt = 1.0 / 250.0 # 250Hz
|
kd: list = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2, # Left leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
2,
|
||||||
|
2, # Right leg pitch, roll, yaw, knee, ankle pitch, ankle roll
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
5, # Waist yaw, roll, pitch
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3, # Left shoulder pitch, roll, yaw, elbow (kd_low)
|
||||||
|
1.5,
|
||||||
|
1.5,
|
||||||
|
1.5, # Left wrist roll, pitch, yaw (kd_wrist)
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3, # Right shoulder pitch, roll, yaw, elbow (kd_low)
|
||||||
|
1.5,
|
||||||
|
1.5,
|
||||||
|
1.5, # Right wrist roll, pitch, yaw (kd_wrist)
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3, # Other
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
control_dt = 1.0 / 250.0 # 250Hz
|
||||||
|
|
||||||
# socket config for ZMQ bridge
|
# socket config for ZMQ bridge
|
||||||
robot_ip: str = "172.18.129.215"
|
robot_ip: str = "172.18.129.215"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
class G1_29_JointArmIndex(IntEnum):
|
class G1_29_JointArmIndex(IntEnum):
|
||||||
# Left arm
|
# Left arm
|
||||||
kLeftShoulderPitch = 15
|
kLeftShoulderPitch = 15
|
||||||
@@ -19,8 +20,8 @@ class G1_29_JointArmIndex(IntEnum):
|
|||||||
kRightWristPitch = 27
|
kRightWristPitch = 27
|
||||||
kRightWristYaw = 28
|
kRightWristYaw = 28
|
||||||
|
|
||||||
class G1_29_JointIndex(IntEnum):
|
|
||||||
|
|
||||||
|
class G1_29_JointIndex(IntEnum):
|
||||||
# Left leg
|
# Left leg
|
||||||
kLeftHipPitch = 0
|
kLeftHipPitch = 0
|
||||||
kLeftHipRoll = 1
|
kLeftHipRoll = 1
|
||||||
|
|||||||
@@ -9,14 +9,16 @@ from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublish
|
|||||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||||
from unitree_sdk2py.utils.crc import CRC
|
from unitree_sdk2py.utils.crc import CRC
|
||||||
|
|
||||||
kTopicLowCommand_Debug = "rt/lowcmd" #action to robot
|
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||||
kTopicLowState = "rt/lowstate" #observation from robot
|
kTopicLowState = "rt/lowstate" # observation from robot
|
||||||
|
|
||||||
LOWCMD_PORT = 6000
|
LOWCMD_PORT = 6000
|
||||||
LOWSTATE_PORT = 6001
|
LOWSTATE_PORT = 6001
|
||||||
|
|
||||||
|
|
||||||
def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read observation from DDS and send to server
|
def state_forward_loop(
|
||||||
|
lowstate_sub, lowstate_sock, state_period: float
|
||||||
|
): # read observation from DDS and send to server
|
||||||
last_state_time = 0.0
|
last_state_time = 0.0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -27,7 +29,7 @@ def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read o
|
|||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
# optional downsampling (if robot dds rate > state_period)
|
# optional downsampling (if robot dds rate > state_period)
|
||||||
if now - last_state_time >= state_period:
|
if now - last_state_time >= state_period:
|
||||||
payload = pickle.dumps((kTopicLowState, msg), protocol=pickle.HIGHEST_PROTOCOL)
|
payload = pickle.dumps((kTopicLowState, msg), protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
try:
|
try:
|
||||||
lowstate_sock.send(payload, zmq.NOBLOCK)
|
lowstate_sock.send(payload, zmq.NOBLOCK)
|
||||||
@@ -37,8 +39,7 @@ def state_forward_loop(lowstate_sub, lowstate_sock, state_period: float):#read o
|
|||||||
last_state_time = now
|
last_state_time = now
|
||||||
|
|
||||||
|
|
||||||
def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC):#send action to robot
|
def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC): # send action to robot
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
payload = lowcmd_sock.recv()
|
payload = lowcmd_sock.recv()
|
||||||
topic, cmd = pickle.loads(payload)
|
topic, cmd = pickle.loads(payload)
|
||||||
@@ -50,7 +51,6 @@ def cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc: CRC):#send action to ro
|
|||||||
lowcmd_pub_debug.Write(cmd)
|
lowcmd_pub_debug.Write(cmd)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -73,7 +73,7 @@ def main():
|
|||||||
# initialize DDS publisher
|
# initialize DDS publisher
|
||||||
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||||
lowcmd_pub_debug.Init()
|
lowcmd_pub_debug.Init()
|
||||||
|
|
||||||
# initialize DDS subscriber
|
# initialize DDS subscriber
|
||||||
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
|
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||||
lowstate_sub.Init()
|
lowstate_sub.Init()
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
import threading
|
import threading
|
||||||
@@ -52,11 +51,10 @@ H1_2_Num_Motors = 35
|
|||||||
H1_Num_Motors = 20
|
H1_Num_Motors = 20
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MotorState:
|
class MotorState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.q = None # position
|
self.q = None # position
|
||||||
self.dq = None # velocity
|
self.dq = None # velocity
|
||||||
self.tau_est = None # estimated torque
|
self.tau_est = None # estimated torque
|
||||||
self.temperature = None # motor temperature
|
self.temperature = None # motor temperature
|
||||||
|
|
||||||
@@ -69,7 +67,8 @@ class IMUState:
|
|||||||
self.rpy = None # [roll, pitch, yaw] (rad)
|
self.rpy = None # [roll, pitch, yaw] (rad)
|
||||||
self.temperature = None # IMU temperature
|
self.temperature = None # IMU temperature
|
||||||
|
|
||||||
#g1 observation class
|
|
||||||
|
# g1 observation class
|
||||||
class G1_29_LowState:
|
class G1_29_LowState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.motor_state = [MotorState() for _ in range(G1_29_Num_Motors)]
|
self.motor_state = [MotorState() for _ in range(G1_29_Num_Motors)]
|
||||||
@@ -95,9 +94,8 @@ class UnitreeG1(Robot):
|
|||||||
config_class = UnitreeG1Config
|
config_class = UnitreeG1Config
|
||||||
name = "unitree_g1"
|
name = "unitree_g1"
|
||||||
|
|
||||||
#unitree remote controller
|
# unitree remote controller
|
||||||
class RemoteController:
|
class RemoteController:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.lx = 0
|
self.lx = 0
|
||||||
self.ly = 0
|
self.ly = 0
|
||||||
@@ -165,7 +163,7 @@ class UnitreeG1(Robot):
|
|||||||
# Initialize remote controller
|
# Initialize remote controller
|
||||||
self.remote_controller = self.RemoteController()
|
self.remote_controller = self.RemoteController()
|
||||||
|
|
||||||
def _subscribe_motor_state(self): #polls robot state @ 250Hz
|
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||||
while True:
|
while True:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
msg = self.lowstate_subscriber.Read()
|
msg = self.lowstate_subscriber.Read()
|
||||||
@@ -200,13 +198,13 @@ class UnitreeG1(Robot):
|
|||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||||
|
|
||||||
def calibrate(self) -> None:#robot is already calibrated
|
def calibrate(self) -> None: # robot is already calibrated
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def connect(self, calibrate: bool = True) -> None: #connect to DDS
|
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||||
ChannelFactoryInitialize(0)
|
ChannelFactoryInitialize(0)
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
@@ -243,7 +241,7 @@ class UnitreeG1(Robot):
|
|||||||
self.msg.crc = self.crc.Crc(action)
|
self.msg.crc = self.crc.Crc(action)
|
||||||
self.lowcmd_publisher.Write(action)
|
self.lowcmd_publisher.Write(action)
|
||||||
|
|
||||||
def get_gravity_orientation(self, quaternion):#get gravity orientation from quaternion
|
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
||||||
"""Get gravity orientation from quaternion."""
|
"""Get gravity orientation from quaternion."""
|
||||||
qw = quaternion[0]
|
qw = quaternion[0]
|
||||||
qx = quaternion[1]
|
qx = quaternion[1]
|
||||||
@@ -256,7 +254,9 @@ class UnitreeG1(Robot):
|
|||||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||||
return gravity_orientation
|
return gravity_orientation
|
||||||
|
|
||||||
def transform_imu_data(self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega):#transform imu data from torso to pelvis frame
|
def transform_imu_data(
|
||||||
|
self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega
|
||||||
|
): # transform imu data from torso to pelvis frame
|
||||||
"""Transform IMU data from torso to pelvis frame."""
|
"""Transform IMU data from torso to pelvis frame."""
|
||||||
RzWaist = R.from_euler("z", waist_yaw).as_matrix()
|
RzWaist = R.from_euler("z", waist_yaw).as_matrix()
|
||||||
R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix()
|
R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||||
@@ -11,18 +12,17 @@ LOWCMD_PORT = 6000
|
|||||||
LOWSTATE_PORT = 6001
|
LOWSTATE_PORT = 6001
|
||||||
|
|
||||||
|
|
||||||
def ChannelFactoryInitialize(*args, **kwargs):#DDS to socket bridge
|
def ChannelFactoryInitialize(*args, **kwargs): # DDS to socket bridge
|
||||||
global _ctx, _lowcmd_sock, _lowstate_sock\
|
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||||
|
|
||||||
# read socket config
|
# read socket config
|
||||||
config = UnitreeG1Config()
|
config = UnitreeG1Config()
|
||||||
robot_ip = config.robot_ip
|
robot_ip = config.robot_ip
|
||||||
|
|
||||||
_ctx = zmq.Context.instance()
|
_ctx = zmq.Context.instance()
|
||||||
|
|
||||||
# lowcmd: robot action
|
# lowcmd: robot action
|
||||||
_lowcmd_sock = _ctx.socket(zmq.PUSH)
|
_lowcmd_sock = _ctx.socket(zmq.PUSH)
|
||||||
_lowcmd_sock.setsockopt(zmq.CONFLATE, 1)#keep only last message
|
_lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||||
_lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
|
_lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
|
||||||
|
|
||||||
# lowstate: robot observation
|
# lowstate: robot observation
|
||||||
@@ -32,7 +32,7 @@ def ChannelFactoryInitialize(*args, **kwargs):#DDS to socket bridge
|
|||||||
_lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
_lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
|
|
||||||
|
|
||||||
class ChannelPublisher: #send action to robot
|
class ChannelPublisher: # send action to robot
|
||||||
def __init__(self, topic, msg_type):
|
def __init__(self, topic, msg_type):
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.msg_type = msg_type
|
self.msg_type = msg_type
|
||||||
@@ -44,7 +44,7 @@ class ChannelPublisher: #send action to robot
|
|||||||
_lowcmd_sock.send(pickle.dumps((self.topic, msg)))
|
_lowcmd_sock.send(pickle.dumps((self.topic, msg)))
|
||||||
|
|
||||||
|
|
||||||
class ChannelSubscriber: #read observation from robot
|
class ChannelSubscriber: # read observation from robot
|
||||||
def __init__(self, topic, msg_type):
|
def __init__(self, topic, msg_type):
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.msg_type = msg_type
|
self.msg_type = msg_type
|
||||||
|
|||||||
Reference in New Issue
Block a user