mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
separate groot locomotion logic
This commit is contained in:
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Example: GR00T Locomotion with Pre-loaded Policies
|
||||
|
||||
This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||
and passing them to the robot class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_groot_policies() -> tuple:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from ONNX files."""
|
||||
logger.info("Loading GR00T dual-policy system...")
|
||||
|
||||
# Load ONNX policies
|
||||
policy_balance = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx")
|
||||
policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx")
|
||||
|
||||
logger.info("GR00T policies loaded successfully")
|
||||
logger.info(f" Input shape: {policy_balance.get_inputs()[0].shape}")
|
||||
logger.info(f" Output shape: {policy_balance.get_outputs()[0].shape}")
|
||||
|
||||
return policy_balance, policy_walk
|
||||
|
||||
|
||||
class GrootLocomotionController:
|
||||
"""
|
||||
Handles GR00T-style locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Dual-policy system (Balance + Walk)
|
||||
- 29-joint observation processing
|
||||
- 15D action output (legs + waist)
|
||||
- Policy inference and motor command generation
|
||||
"""
|
||||
|
||||
# GR00T default angles for all 29 joints
|
||||
GROOT_DEFAULT_ANGLES = np.array([
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
|
||||
0.0, 0.0, 0.0, # waist
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # left arm (zeroed)
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # right arm (zeroed)
|
||||
], dtype=np.float32)
|
||||
|
||||
# Joints to zero out in observations and commands
|
||||
JOINTS_TO_ZERO = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
|
||||
PROBLEMATIC_JOINTS = [12, 14, 20, 21, 27, 28]
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
"""
|
||||
Initialize the GR00T locomotion controller.
|
||||
|
||||
Args:
|
||||
policy_balance: ONNX InferenceSession for balance/standing policy
|
||||
policy_walk: ONNX InferenceSession for walking policy
|
||||
robot: Reference to the UnitreeG1 robot instance
|
||||
config: UnitreeG1Config object with locomotion parameters
|
||||
"""
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
# Locomotion state
|
||||
self.locomotion_counter = 0
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, yaw_rate
|
||||
|
||||
# GR00T-specific state
|
||||
self.groot_qj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_action = np.zeros(15, dtype=np.float32)
|
||||
self.groot_obs_single = np.zeros(86, dtype=np.float32)
|
||||
self.groot_obs_history = deque(maxlen=6)
|
||||
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
|
||||
self.groot_height_cmd = 0.74 # Default base height
|
||||
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# Initialize history with zeros
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
"""GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action."""
|
||||
self.locomotion_counter += 1
|
||||
|
||||
# Get current lowstate
|
||||
lowstate = self.robot.lowstate_buffer.GetData()
|
||||
if lowstate is None:
|
||||
return
|
||||
|
||||
# Update remote controller from lowstate
|
||||
if lowstate.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(lowstate.wireless_remote)
|
||||
|
||||
# R1/R2 buttons for height control on real robot (button indices 0 and 4)
|
||||
if self.robot.remote_controller.button[0]: # R1 - raise height
|
||||
self.groot_height_cmd += 0.001 # Small increment per timestep (~0.05m per second at 50Hz)
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.robot.remote_controller.button[4]: # R2 - lower height
|
||||
self.groot_height_cmd -= 0.001 # Small decrement per timestep
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
# Default to zero commands if no remote data
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
# Get ALL 29 joint positions and velocities
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = lowstate.motor_state[i].q
|
||||
self.groot_dqj_all[i] = lowstate.motor_state[i].dq
|
||||
|
||||
# Get IMU data
|
||||
quat = lowstate.imu_state.quaternion
|
||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||
|
||||
# Transform IMU if using torso IMU
|
||||
if self.config.locomotion_imu_type == "torso":
|
||||
waist_yaw = lowstate.motor_state[12].q # Waist yaw index
|
||||
waist_yaw_omega = lowstate.motor_state[12].dq
|
||||
quat, ang_vel_3d = self.robot.locomotion_transform_imu_data(
|
||||
waist_yaw, waist_yaw_omega, quat, np.array([ang_vel])
|
||||
)
|
||||
ang_vel = ang_vel_3d.flatten()
|
||||
|
||||
# Create observation
|
||||
gravity_orientation = self.robot.locomotion_get_gravity_orientation(quat)
|
||||
|
||||
# Zero out specific joints in observation
|
||||
for idx in self.JOINTS_TO_ZERO:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
|
||||
# Scale joint positions and velocities
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
qj_obs = (qj_obs - self.GROOT_DEFAULT_ANGLES) * self.config.dof_pos_scale
|
||||
dqj_obs = dqj_obs * self.config.dof_vel_scale
|
||||
ang_vel_scaled = ang_vel * self.config.groot_ang_vel_scale
|
||||
|
||||
# Get velocity commands (keyboard or remote)
|
||||
if not self.robot.simulation_mode:
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1
|
||||
|
||||
# Build 86D single frame observation (GR00T format)
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(self.config.groot_cmd_scale)
|
||||
self.groot_obs_single[3] = self.groot_height_cmd
|
||||
self.groot_obs_single[4:7] = self.groot_orientation_cmd
|
||||
self.groot_obs_single[7:10] = ang_vel_scaled
|
||||
self.groot_obs_single[10:13] = gravity_orientation
|
||||
self.groot_obs_single[13:42] = qj_obs # 29D joint positions
|
||||
self.groot_obs_single[42:71] = dqj_obs # 29D joint velocities
|
||||
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
|
||||
|
||||
# Add to history and stack observations (6 frames × 86D = 516D)
|
||||
self.groot_obs_history.append(self.groot_obs_single.copy())
|
||||
|
||||
# Stack all 6 frames into 516D vector
|
||||
for i, obs_frame in enumerate(self.groot_obs_history):
|
||||
start_idx = i * 86
|
||||
end_idx = start_idx + 86
|
||||
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
|
||||
|
||||
# Run policy inference (ONNX) with 516D stacked observation
|
||||
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
||||
|
||||
# Select appropriate policy based on command magnitude (dual-policy system)
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
if cmd_magnitude < 0.05:
|
||||
# Use balance/standing policy for small commands
|
||||
selected_policy = self.policy_balance
|
||||
else:
|
||||
# Use walking policy for movement commands
|
||||
selected_policy = self.policy_walk
|
||||
|
||||
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
||||
ort_outs = selected_policy.run(None, ort_inputs)
|
||||
self.groot_action = ort_outs[0].squeeze()
|
||||
|
||||
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
|
||||
self.groot_action[12] = 0.0 # Waist yaw
|
||||
self.groot_action[13] = 0.0 # Waist roll
|
||||
self.groot_action[14] = 0.0 # Waist pitch
|
||||
|
||||
# Transform action to target joint positions (15D: legs + waist)
|
||||
target_dof_pos_15 = (
|
||||
self.GROOT_DEFAULT_ANGLES[:15] + self.groot_action * self.config.locomotion_action_scale
|
||||
)
|
||||
|
||||
# Send commands to LEG motors (0-11)
|
||||
for i in range(12):
|
||||
motor_idx = i
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Send WAIST commands - but SKIP waist yaw (12) and waist pitch (14)
|
||||
# Only send waist roll (13)
|
||||
waist_roll_idx = 13
|
||||
waist_roll_action_idx = 13
|
||||
self.robot.msg.motor_cmd[waist_roll_idx].q = target_dof_pos_15[waist_roll_action_idx]
|
||||
self.robot.msg.motor_cmd[waist_roll_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[waist_roll_idx].kp = self.config.locomotion_arm_waist_kps[1]
|
||||
self.robot.msg.motor_cmd[waist_roll_idx].kd = self.config.locomotion_arm_waist_kds[1]
|
||||
self.robot.msg.motor_cmd[waist_roll_idx].tau = 0
|
||||
|
||||
# Zero out the problematic joints (waist yaw, waist pitch, wrist pitch/yaw)
|
||||
for joint_idx in self.PROBLEMATIC_JOINTS:
|
||||
self.robot.msg.motor_cmd[joint_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[joint_idx].qd = 0
|
||||
if joint_idx in [12, 14]: # waist
|
||||
kp_idx = 0 if joint_idx == 12 else 2 # yaw or pitch
|
||||
self.robot.msg.motor_cmd[joint_idx].kp = self.config.locomotion_arm_waist_kps[kp_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].kd = self.config.locomotion_arm_waist_kds[kp_idx]
|
||||
else: # wrists (20, 21, 27, 28)
|
||||
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp_wrist
|
||||
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd_wrist
|
||||
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||
|
||||
# Send command
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.groot_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.config.locomotion_control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
"""Start the background locomotion control thread."""
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
"""Stop the background locomotion control thread."""
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def init_groot_locomotion(self):
|
||||
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
|
||||
logger.info("Starting GR00T locomotion initialization...")
|
||||
|
||||
# Move legs to default position
|
||||
self.robot.locomotion_move_to_default_pos()
|
||||
|
||||
# Wait 3 seconds
|
||||
time.sleep(3.0)
|
||||
|
||||
# Hold default leg position for 2 seconds
|
||||
self.robot.locomotion_default_pos_state()
|
||||
|
||||
# Start locomotion policy thread
|
||||
logger.info("Starting GR00T locomotion policy control...")
|
||||
self.start_locomotion_thread()
|
||||
|
||||
logger.info("GR00T locomotion initialization complete! Policy is now running.")
|
||||
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. Load policies externally (separate from robot initialization)
|
||||
policy_balance, policy_walk = load_groot_policies()
|
||||
|
||||
# 2. Create config (no locomotion_control=True since we're using external controller)
|
||||
config = UnitreeG1Config()
|
||||
|
||||
# 3. Initialize robot
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# 4. Create GR00T locomotion controller with loaded policies
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 5. Initialize and start locomotion
|
||||
groot_controller.init_groot_locomotion()
|
||||
|
||||
# Robot is now ready with locomotion control!
|
||||
print("Robot initialized with GR00T locomotion policies")
|
||||
print("Locomotion controller running in background thread")
|
||||
print("Press Ctrl+C to stop")
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
groot_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
@@ -24,7 +25,7 @@ from ..config import RobotConfig
|
||||
@RobotConfig.register_subclass("unitree_g1")
|
||||
@dataclass
|
||||
class UnitreeG1Config(RobotConfig):
|
||||
# id: str = "unitree_g1"
|
||||
# id: str = "unitree_g1"
|
||||
simulation_mode: bool = False
|
||||
kp_high = 40.0
|
||||
kd_high = 3.0
|
||||
@@ -52,49 +53,38 @@ class UnitreeG1Config(RobotConfig):
|
||||
# This robot class ONLY uses sockets to communicate with a bridge on the Orin
|
||||
# Run 'python dds_to_socket.py' on the Orin first, then set this to the Orin's IP
|
||||
# Example: socket_host="192.168.123.164" (Orin's wlan0 IP)
|
||||
socket_host: str | None = None # = "172.18.129.215"
|
||||
socket_host: str | None = None# = "172.18.129.215"
|
||||
socket_port: int | None = None
|
||||
|
||||
# Locomotion control
|
||||
locomotion_control: bool = True
|
||||
# policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
|
||||
policy_path: str | None = None
|
||||
|
||||
# Pre-loaded policies (preferred method for GR00T locomotion)
|
||||
policy_walk: Any = None # Pre-loaded walk policy (ONNX InferenceSession)
|
||||
policy_balance: Any = None # Pre-loaded balance policy (ONNX InferenceSession)
|
||||
|
||||
# Locomotion parameters (from g1.yaml)
|
||||
locomotion_control_dt: float = 0.02
|
||||
|
||||
|
||||
leg_joint2motor_idx: list = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
||||
locomotion_kps: list = field(
|
||||
default_factory=lambda: [150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40]
|
||||
)
|
||||
locomotion_kps: list = field(default_factory=lambda: [150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40])
|
||||
locomotion_kds: list = field(default_factory=lambda: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2])
|
||||
default_leg_angles: list = field(
|
||||
default_factory=lambda: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
|
||||
)
|
||||
|
||||
arm_waist_joint2motor_idx: list = field(
|
||||
default_factory=lambda: [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]
|
||||
)
|
||||
locomotion_arm_waist_kps: list = field(
|
||||
default_factory=lambda: [250, 250, 250, 100, 100, 50, 50, 20, 20, 20, 100, 100, 50, 50, 20, 20, 20]
|
||||
)
|
||||
locomotion_arm_waist_kds: list = field(
|
||||
default_factory=lambda: [5, 5, 5, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1]
|
||||
)
|
||||
locomotion_arm_waist_target: list = field(
|
||||
default_factory=lambda: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
)
|
||||
default_leg_angles: list = field(default_factory=lambda: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0])
|
||||
|
||||
arm_waist_joint2motor_idx: list = field(default_factory=lambda: [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28])
|
||||
locomotion_arm_waist_kps: list = field(default_factory=lambda: [250, 250, 250, 100, 100, 50, 50, 20, 20, 20, 100, 100, 50, 50, 20, 20, 20])
|
||||
locomotion_arm_waist_kds: list = field(default_factory=lambda: [5, 5, 5, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1])
|
||||
locomotion_arm_waist_target: list = field(default_factory=lambda: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
ang_vel_scale: float = 0.25
|
||||
dof_pos_scale: float = 1.0
|
||||
dof_vel_scale: float = 0.05
|
||||
locomotion_action_scale: float = 0.25
|
||||
cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25])
|
||||
|
||||
|
||||
# GR00T-specific scaling (different from regular locomotion!)
|
||||
groot_ang_vel_scale: float = 0.25 # GR00T uses 0.5, not 0.25
|
||||
groot_cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25]) # yaw is 0.5 for GR00T
|
||||
num_locomotion_actions: int = 12
|
||||
num_locomotion_obs: int = 47
|
||||
max_cmd: list = field(default_factory=lambda: [0.8, 0.5, 1.57])
|
||||
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
|
||||
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
@@ -19,6 +20,7 @@ class G1_29_JointArmIndex(IntEnum):
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
# Left leg
|
||||
kLeftHipPitch = 0
|
||||
@@ -36,7 +38,7 @@ class G1_29_JointIndex(IntEnum):
|
||||
kRightAnklePitch = 10
|
||||
kRightAnkleRoll = 11
|
||||
|
||||
kWaistYaw = 12 #we're c
|
||||
kWaistYaw = 12 # we're c
|
||||
kWaistRoll = 13
|
||||
kWaistPitch = 14
|
||||
|
||||
@@ -64,4 +66,4 @@ class G1_29_JointIndex(IntEnum):
|
||||
kNotUsedJoint2 = 31
|
||||
kNotUsedJoint3 = 32
|
||||
kNotUsedJoint4 = 33
|
||||
kNotUsedJoint5 = 34
|
||||
kNotUsedJoint5 = 34
|
||||
|
||||
@@ -9,7 +9,6 @@ import time
|
||||
import tty
|
||||
from collections import deque
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -172,12 +171,6 @@ class UnitreeG1(Robot):
|
||||
self.msg.motor_cmd[id].kp = self.kp_high
|
||||
self.msg.motor_cmd[id].kd = self.kd_high
|
||||
self.msg.motor_cmd[id].q = self.all_motor_q[id]
|
||||
# print current motor q, kp, kd
|
||||
|
||||
logger.warning("Lock OK!\n") # motors are not locked x
|
||||
# for i in range(10000):
|
||||
# print(self.get_current_motor_q())
|
||||
# time.sleep(0.05)
|
||||
|
||||
# Initialize control flags BEFORE starting threads
|
||||
self.keyboard_thread = None
|
||||
@@ -185,10 +178,6 @@ class UnitreeG1(Robot):
|
||||
self.locomotion_thread = None
|
||||
self.locomotion_running = False
|
||||
|
||||
# Initialize publish thread for arm control
|
||||
# Note: This thread runs alongside locomotion thread
|
||||
# - Arm thread: controls arms (indices 15-28)
|
||||
# - Locomotion thread: controls legs (0-11), waist (12-14)
|
||||
# Both update different parts of self.msg, both call Write()
|
||||
self.publish_thread = None
|
||||
self.ctrl_lock = threading.Lock()
|
||||
@@ -196,102 +185,8 @@ class UnitreeG1(Robot):
|
||||
self.publish_thread.daemon = True
|
||||
self.publish_thread.start()
|
||||
logger.warning("Arm control publish thread started")
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
# Load locomotion policy if enabled
|
||||
self.policy = None
|
||||
self.policy_type = None # 'torchscript' or 'onnx'
|
||||
print(config)
|
||||
if config.locomotion_control:
|
||||
if config.policy_path is None:
|
||||
raise ValueError("locomotion_control is True but policy_path is not set")
|
||||
|
||||
logger.warning(f"Loading locomotion policy from {config.policy_path}")
|
||||
|
||||
# Check file extension and load accordingly
|
||||
if config.policy_path.endswith(".pt"):
|
||||
logger.warning("Detected TorchScript (.pt) policy")
|
||||
self.policy = torch.jit.load(config.policy_path)
|
||||
self.policy_type = "torchscript"
|
||||
logger.info("TorchScript policy loaded successfully")
|
||||
elif config.policy_path.endswith(".onnx"):
|
||||
logger.warning("Detected ONNX (.onnx) policy")
|
||||
|
||||
# For GR00T-style policies, load both Balance and Walk policies
|
||||
# Balance policy for standing (low velocity commands)
|
||||
# Walk policy for locomotion (high velocity commands)
|
||||
balance_policy_path = config.policy_path.replace("Walk.onnx", "Balance.onnx")
|
||||
walk_policy_path = config.policy_path
|
||||
|
||||
if Path(balance_policy_path).exists() and Path(walk_policy_path).exists():
|
||||
logger.info("Loading dual-policy system (Balance + Walk)")
|
||||
self.policy_balance = ort.InferenceSession(balance_policy_path)
|
||||
self.policy_walk = ort.InferenceSession(walk_policy_path)
|
||||
self.policy = None # Not used when dual policies are loaded
|
||||
logger.info(f"Balance policy loaded from: {balance_policy_path}")
|
||||
logger.info(f"Walk policy loaded from: {walk_policy_path}")
|
||||
logger.info(
|
||||
f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}"
|
||||
)
|
||||
logger.info(
|
||||
f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}"
|
||||
)
|
||||
else:
|
||||
# Fallback to single policy
|
||||
logger.info("Loading single ONNX policy")
|
||||
self.policy = ort.InferenceSession(config.policy_path)
|
||||
self.policy_balance = None
|
||||
self.policy_walk = None
|
||||
logger.info("ONNX policy loaded successfully")
|
||||
logger.info(
|
||||
f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}"
|
||||
)
|
||||
logger.info(
|
||||
f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}"
|
||||
)
|
||||
|
||||
self.policy_type = "onnx"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported."
|
||||
)
|
||||
|
||||
# Initialize locomotion variables
|
||||
self.remote_controller = self.RemoteController()
|
||||
self.locomotion_counter = 0
|
||||
self.qj = np.zeros(config.num_locomotion_actions, dtype=np.float32)
|
||||
self.dqj = np.zeros(config.num_locomotion_actions, dtype=np.float32)
|
||||
self.locomotion_action = np.zeros(config.num_locomotion_actions, dtype=np.float32)
|
||||
self.locomotion_obs = np.zeros(config.num_locomotion_obs, dtype=np.float32)
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# GR00T-specific variables (for ONNX policies with 29 joints)
|
||||
if self.policy_type == "onnx":
|
||||
self.groot_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints
|
||||
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_action = np.zeros(15, dtype=np.float32) # 15D action (legs + waist)
|
||||
self.groot_obs_single = np.zeros(86, dtype=np.float32) # 86D single frame observation
|
||||
self.groot_obs_history = deque(maxlen=6) # 6-frame history buffer
|
||||
self.groot_obs_stacked = np.zeros(516, dtype=np.float32) # 86D × 6 = 516D stacked observation
|
||||
self.groot_height_cmd = 0.74 # Default base height
|
||||
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # roll, pitch, yaw
|
||||
|
||||
# Initialize history with zeros
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
|
||||
# Start keyboard controls if in simulation mode
|
||||
if self.simulation_mode:
|
||||
logger.info("Starting keyboard controls for simulation...")
|
||||
self.start_keyboard_controls()
|
||||
|
||||
# Use different init based on policy type
|
||||
if self.policy_type == "onnx":
|
||||
self.init_groot_locomotion()
|
||||
else:
|
||||
self.init_locomotion()
|
||||
elif self.simulation_mode:
|
||||
# Even without locomotion, provide keyboard feedback in sim
|
||||
logger.info("Simulation mode active (locomotion disabled)")
|
||||
|
||||
logger.info("Initialize G1 OK!\n")
|
||||
|
||||
@@ -766,506 +661,3 @@ class UnitreeG1(Robot):
|
||||
R_pelvis = np.dot(R_torso, RzWaist.T)
|
||||
w = np.dot(RzWaist, imu_omega[0]) - np.array([0, 0, waist_yaw_omega])
|
||||
return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w
|
||||
|
||||
def locomotion_run(self):
|
||||
"""Main locomotion policy loop - runs policy and sends leg commands."""
|
||||
self.locomotion_counter += 1
|
||||
|
||||
# Get current lowstate
|
||||
lowstate = self.lowstate_buffer.GetData()
|
||||
if lowstate is None:
|
||||
return
|
||||
|
||||
# Update remote controller from lowstate
|
||||
if lowstate.wireless_remote is not None:
|
||||
self.remote_controller.set(lowstate.wireless_remote)
|
||||
else:
|
||||
# Default to zero commands if no remote data
|
||||
self.remote_controller.lx = 0.0
|
||||
self.remote_controller.ly = 0.0
|
||||
self.remote_controller.rx = 0.0
|
||||
self.remote_controller.ry = 0.0
|
||||
|
||||
# Get the current joint position and velocity (LEGS ONLY)
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
self.qj[i] = lowstate.motor_state[self.config.leg_joint2motor_idx[i]].q
|
||||
self.dqj[i] = lowstate.motor_state[self.config.leg_joint2motor_idx[i]].dq
|
||||
|
||||
# Get IMU data
|
||||
quat = lowstate.imu_state.quaternion
|
||||
ang_vel = np.array([lowstate.imu_state.gyroscope], dtype=np.float32)
|
||||
|
||||
if self.config.locomotion_imu_type == "torso":
|
||||
# Transform IMU data from torso to pelvis frame
|
||||
waist_yaw = lowstate.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
|
||||
waist_yaw_omega = lowstate.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
|
||||
quat, ang_vel = self.locomotion_transform_imu_data(waist_yaw, waist_yaw_omega, quat, ang_vel)
|
||||
|
||||
# Create observation
|
||||
gravity_orientation = self.locomotion_get_gravity_orientation(quat)
|
||||
qj_obs = self.qj.copy()
|
||||
dqj_obs = self.dqj.copy()
|
||||
qj_obs = (qj_obs - np.array(self.config.default_leg_angles)) * self.config.dof_pos_scale
|
||||
dqj_obs = dqj_obs * self.config.dof_vel_scale
|
||||
ang_vel = ang_vel * self.config.ang_vel_scale
|
||||
|
||||
# Calculate phase
|
||||
period = 0.8
|
||||
count = self.locomotion_counter * self.config.locomotion_control_dt
|
||||
phase = count % period / period
|
||||
sin_phase = np.sin(2 * np.pi * phase)
|
||||
cos_phase = np.cos(2 * np.pi * phase)
|
||||
|
||||
# Get velocity commands from remote controller (only if NOT in simulation mode)
|
||||
# In simulation mode, keyboard controls set self.locomotion_cmd directly
|
||||
if not self.simulation_mode:
|
||||
self.locomotion_cmd[0] = self.remote_controller.ly
|
||||
self.locomotion_cmd[1] = self.remote_controller.lx * -1
|
||||
self.locomotion_cmd[2] = self.remote_controller.rx * -1
|
||||
|
||||
# Debug: print remote controller values every 50 iterations (~1 second at 50Hz)
|
||||
if self.locomotion_counter % 50 == 0:
|
||||
logger.debug(
|
||||
f"Remote controller - lx:{self.remote_controller.lx:.2f}, ly:{self.remote_controller.ly:.2f}, rx:{self.remote_controller.rx:.2f}"
|
||||
)
|
||||
|
||||
# Build observation vector
|
||||
num_actions = self.config.num_locomotion_actions
|
||||
self.locomotion_obs[:3] = ang_vel
|
||||
self.locomotion_obs[3:6] = gravity_orientation
|
||||
self.locomotion_obs[6:9] = (
|
||||
self.locomotion_cmd * np.array(self.config.cmd_scale) * np.array(self.config.max_cmd)
|
||||
)
|
||||
self.locomotion_obs[9 : 9 + num_actions] = qj_obs
|
||||
self.locomotion_obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs
|
||||
self.locomotion_obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.locomotion_action
|
||||
self.locomotion_obs[9 + num_actions * 3] = sin_phase
|
||||
self.locomotion_obs[9 + num_actions * 3 + 1] = cos_phase
|
||||
|
||||
# Get action from policy network
|
||||
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0)
|
||||
|
||||
if self.policy_type == "torchscript":
|
||||
# TorchScript inference
|
||||
self.locomotion_action = self.policy(obs_tensor).detach().numpy().squeeze()
|
||||
elif self.policy_type == "onnx":
|
||||
# ONNX inference
|
||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
||||
ort_outs = self.policy.run(None, ort_inputs)
|
||||
self.locomotion_action = ort_outs[0].squeeze()
|
||||
else:
|
||||
raise ValueError(f"Unknown policy type: {self.policy_type}")
|
||||
|
||||
# Transform action to target joint positions
|
||||
target_dof_pos = (
|
||||
np.array(self.config.default_leg_angles)
|
||||
+ self.locomotion_action * self.config.locomotion_action_scale
|
||||
)
|
||||
|
||||
# Send commands to LEG motors only
|
||||
for i in range(len(self.config.leg_joint2motor_idx)):
|
||||
motor_idx = self.config.leg_joint2motor_idx[i]
|
||||
self.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
|
||||
self.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i]
|
||||
self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i]
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold WAIST motors at 0 (indices 12, 13, 14 = WaistYaw, WaistRoll, WaistPitch)
|
||||
waist_indices = self.config.arm_waist_joint2motor_idx[:3] # First 3 are waist
|
||||
for i, motor_idx in enumerate(waist_indices):
|
||||
self.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_arm_waist_kps[i]
|
||||
self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_arm_waist_kds[i]
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Send command
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
"""GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action."""
|
||||
self.locomotion_counter += 1
|
||||
|
||||
# Get current lowstate
|
||||
lowstate = self.lowstate_buffer.GetData()
|
||||
if lowstate is None:
|
||||
return
|
||||
|
||||
# Update remote controller from lowstate
|
||||
if lowstate.wireless_remote is not None:
|
||||
self.remote_controller.set(lowstate.wireless_remote)
|
||||
|
||||
# R1/R2 buttons for height control on real robot (button indices 4 and 5)
|
||||
if self.remote_controller.button[0]: # R1 - raise height
|
||||
self.groot_height_cmd += 0.001 # Small increment per timestep (~0.05m per second at 50Hz)
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.remote_controller.button[4]: # R2 - lower height
|
||||
self.groot_height_cmd -= 0.001 # Small decrement per timestep
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
# Default to zero commands if no remote data
|
||||
self.remote_controller.lx = 0.0
|
||||
self.remote_controller.ly = 0.0
|
||||
self.remote_controller.rx = 0.0
|
||||
self.remote_controller.ry = 0.0
|
||||
|
||||
# Get ALL 29 joint positions and velocities
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = lowstate.motor_state[i].q
|
||||
self.groot_dqj_all[i] = lowstate.motor_state[i].dq
|
||||
|
||||
# Get IMU data
|
||||
quat = lowstate.imu_state.quaternion
|
||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||
|
||||
# Transform IMU if using torso IMU
|
||||
if self.config.locomotion_imu_type == "torso":
|
||||
waist_yaw = lowstate.motor_state[12].q # Waist yaw index
|
||||
waist_yaw_omega = lowstate.motor_state[12].dq
|
||||
quat, ang_vel_3d = self.locomotion_transform_imu_data(
|
||||
waist_yaw, waist_yaw_omega, quat, np.array([ang_vel])
|
||||
)
|
||||
ang_vel = ang_vel_3d.flatten()
|
||||
|
||||
# Create observation
|
||||
gravity_orientation = self.locomotion_get_gravity_orientation(quat)
|
||||
joints_to_zero_obs = [12, 14, 20, 21, 27, 28] # Note: NOT 13 (waist roll exists)
|
||||
for idx in joints_to_zero_obs:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
# Scale joint positions and velocities
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# Subtract default angles for legs + waist (15 joints)
|
||||
# GR00T default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0, 0.0, 0.0, 0.0]
|
||||
groot_default_angles = np.array(
|
||||
[
|
||||
-0.1,
|
||||
0.0,
|
||||
0.0,
|
||||
0.3,
|
||||
-0.2,
|
||||
0.0, # left leg
|
||||
-0.1,
|
||||
0.0,
|
||||
0.0,
|
||||
0.3,
|
||||
-0.2,
|
||||
0.0, # right leg
|
||||
0.0,
|
||||
0.0,
|
||||
0.0, # waist
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0, # left arm (zeroed)
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
],
|
||||
dtype=np.float32,
|
||||
) # right arm (zeroed)
|
||||
|
||||
qj_obs = (qj_obs - groot_default_angles) * self.config.dof_pos_scale
|
||||
dqj_obs = dqj_obs * self.config.dof_vel_scale
|
||||
ang_vel_scaled = ang_vel * self.config.groot_ang_vel_scale # Use GR00T-specific scaling!
|
||||
|
||||
# Get velocity commands (keyboard or remote)
|
||||
if not self.simulation_mode:
|
||||
self.locomotion_cmd[0] = self.remote_controller.ly
|
||||
self.locomotion_cmd[1] = self.remote_controller.lx * -1
|
||||
self.locomotion_cmd[2] = self.remote_controller.rx * -1
|
||||
|
||||
# Build 86D single frame observation (GR00T format)
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(
|
||||
self.config.groot_cmd_scale
|
||||
) # cmd - use GR00T scaling!
|
||||
self.groot_obs_single[3] = self.groot_height_cmd # height_cmd
|
||||
self.groot_obs_single[4:7] = self.groot_orientation_cmd # roll, pitch, yaw cmd
|
||||
self.groot_obs_single[7:10] = ang_vel_scaled # angular velocity
|
||||
self.groot_obs_single[10:13] = gravity_orientation # gravity
|
||||
self.groot_obs_single[13:42] = qj_obs # joint positions (29D)
|
||||
self.groot_obs_single[42:71] = dqj_obs # joint velocities (29D)
|
||||
self.groot_obs_single[71:86] = self.groot_action # previous actions (15D)
|
||||
|
||||
# Add to history and stack observations (6 frames × 86D = 516D)
|
||||
self.groot_obs_history.append(self.groot_obs_single.copy())
|
||||
|
||||
# Stack all 6 frames into 516D vector
|
||||
for i, obs_frame in enumerate(self.groot_obs_history):
|
||||
start_idx = i * 86
|
||||
end_idx = start_idx + 86
|
||||
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
|
||||
|
||||
# Run policy inference (ONNX) with 516D stacked observation
|
||||
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
||||
|
||||
# Select appropriate policy based on command magnitude (dual-policy system)
|
||||
if self.policy_balance is not None and self.policy_walk is not None:
|
||||
# Dual-policy mode: switch between Balance and Walk
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
if cmd_magnitude < 0.05:
|
||||
# Use balance/standing policy for small commands
|
||||
selected_policy = self.policy_balance
|
||||
else:
|
||||
# Use walking policy for movement commands
|
||||
selected_policy = self.policy_walk
|
||||
else:
|
||||
# Single policy mode (fallback)
|
||||
selected_policy = self.policy
|
||||
|
||||
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
||||
ort_outs = selected_policy.run(None, ort_inputs)
|
||||
self.groot_action = ort_outs[0].squeeze()
|
||||
|
||||
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
|
||||
# This ensures action history in observations matches what's actually executed
|
||||
self.groot_action[12] = 0.0 # Waist yaw
|
||||
self.groot_action[13] = 0.0 # Waist roll
|
||||
self.groot_action[14] = 0.0 # Waist pitch
|
||||
|
||||
# Transform action to target joint positions (15D: legs + waist, but waist actions are zeroed)
|
||||
target_dof_pos_15 = (
|
||||
groot_default_angles[:15] + self.groot_action * self.config.locomotion_action_scale
|
||||
)
|
||||
|
||||
# Send commands to LEG motors (0-11)
|
||||
for i in range(12):
|
||||
motor_idx = i
|
||||
self.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||
self.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i]
|
||||
self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i]
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Send WAIST commands - but SKIP waist yaw (12) and waist pitch (14)
|
||||
# Only send waist roll (13)
|
||||
waist_roll_idx = 13
|
||||
waist_roll_action_idx = 13 # In the 15D action
|
||||
self.msg.motor_cmd[waist_roll_idx].q = target_dof_pos_15[waist_roll_action_idx]
|
||||
self.msg.motor_cmd[waist_roll_idx].qd = 0
|
||||
self.msg.motor_cmd[waist_roll_idx].kp = self.config.locomotion_arm_waist_kps[
|
||||
1
|
||||
] # index 1 is waist roll
|
||||
self.msg.motor_cmd[waist_roll_idx].kd = self.config.locomotion_arm_waist_kds[1]
|
||||
self.msg.motor_cmd[waist_roll_idx].tau = 0
|
||||
|
||||
# Zero out the problematic joints (waist yaw, waist pitch, wrist pitch/yaw)
|
||||
problematic_joints = [12, 14, 20, 21, 27, 28]
|
||||
for joint_idx in problematic_joints:
|
||||
self.msg.motor_cmd[joint_idx].q = 0.0
|
||||
self.msg.motor_cmd[joint_idx].qd = 0
|
||||
if joint_idx in [12, 14]: # waist
|
||||
kp_idx = 0 if joint_idx == 12 else 2 # yaw or pitch
|
||||
self.msg.motor_cmd[joint_idx].kp = self.config.locomotion_arm_waist_kps[kp_idx]
|
||||
self.msg.motor_cmd[joint_idx].kd = self.config.locomotion_arm_waist_kds[kp_idx]
|
||||
else: # wrists (20, 21, 27, 28)
|
||||
self.msg.motor_cmd[joint_idx].kp = self.kp_wrist
|
||||
self.msg.motor_cmd[joint_idx].kd = self.kd_wrist
|
||||
self.msg.motor_cmd[joint_idx].tau = 0
|
||||
|
||||
# Send command
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Use different run function based on policy type
|
||||
if self.policy_type == "onnx":
|
||||
self.groot_locomotion_run()
|
||||
else:
|
||||
self.locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.config.locomotion_control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
"""Start the background locomotion control thread."""
|
||||
if not self.config.locomotion_control:
|
||||
logger.warning("locomotion_control is False, cannot start thread")
|
||||
return
|
||||
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
"""Stop the background locomotion control thread."""
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
# Also stop keyboard thread if running
|
||||
if self.keyboard_running:
|
||||
self.stop_keyboard_controls()
|
||||
|
||||
def _keyboard_listener_thread(self):
|
||||
"""Background thread that listens for keyboard input (sim mode only)."""
|
||||
print("\n" + "=" * 60)
|
||||
print("KEYBOARD CONTROLS ACTIVE!")
|
||||
print(" W/S: Forward/Backward")
|
||||
print(" A/D: Left/Right")
|
||||
print(" Q/E: Rotate Left/Right")
|
||||
print(" R/F: Raise/Lower Height (±5cm)")
|
||||
print(" Z: Stop (zero velocity commands)")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Save terminal settings
|
||||
old_settings = None
|
||||
try:
|
||||
old_settings = termios.tcgetattr(sys.stdin)
|
||||
tty.setcbreak(sys.stdin.fileno())
|
||||
|
||||
while self.keyboard_running:
|
||||
if select.select([sys.stdin], [], [], 0.1)[0]:
|
||||
key = sys.stdin.read(1).lower()
|
||||
|
||||
# Velocity commands
|
||||
if key == "w":
|
||||
self.locomotion_cmd[0] += 0.4 # Forward
|
||||
elif key == "s":
|
||||
self.locomotion_cmd[0] -= 0.4 # Backward
|
||||
elif key == "a":
|
||||
self.locomotion_cmd[1] += 0.25 # Left
|
||||
elif key == "d":
|
||||
self.locomotion_cmd[1] -= 0.25 # Right
|
||||
elif key == "q":
|
||||
self.locomotion_cmd[2] += 0.5 # Rotate left
|
||||
elif key == "e":
|
||||
self.locomotion_cmd[2] -= 0.5 # Rotate right
|
||||
elif key == "z":
|
||||
self.locomotion_cmd[:] = 0.0 # Stop
|
||||
|
||||
# Height commands (only for GR00T ONNX policies)
|
||||
elif key == "r":
|
||||
self.groot_height_cmd += 0.05 # Raise 5cm
|
||||
elif key == "f":
|
||||
self.groot_height_cmd -= 0.05 # Lower 5cm
|
||||
|
||||
# Clamp commands to reasonable limits
|
||||
self.locomotion_cmd[0] = np.clip(self.locomotion_cmd[0], -0.8, 0.8) # vx
|
||||
self.locomotion_cmd[1] = np.clip(self.locomotion_cmd[1], -0.5, 0.5) # vy
|
||||
self.locomotion_cmd[2] = np.clip(self.locomotion_cmd[2], -1.0, 1.0) # yaw_rate
|
||||
|
||||
# Clamp height (reasonable range: 0.5m to 1.0m)
|
||||
if hasattr(self, "groot_height_cmd"):
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
|
||||
# Print current commands
|
||||
print(
|
||||
f"[VEL CMD] vx={self.locomotion_cmd[0]:.2f}, vy={self.locomotion_cmd[1]:.2f}, yaw={self.locomotion_cmd[2]:.2f}",
|
||||
end="",
|
||||
)
|
||||
if hasattr(self, "groot_height_cmd"):
|
||||
print(f" | [HEIGHT] {self.groot_height_cmd:.3f}m", end="")
|
||||
print() # Newline
|
||||
|
||||
finally:
|
||||
# Restore terminal settings
|
||||
if old_settings is not None:
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
|
||||
print("\nKeyboard controls stopped")
|
||||
|
||||
def start_keyboard_controls(self):
|
||||
"""Start the keyboard control thread (sim mode only)."""
|
||||
if not self.simulation_mode:
|
||||
logger.warning("Keyboard controls only available in simulation mode")
|
||||
return
|
||||
|
||||
if self.keyboard_running:
|
||||
logger.warning("Keyboard controls already running")
|
||||
return
|
||||
|
||||
self.keyboard_running = True
|
||||
self.keyboard_thread = threading.Thread(target=self._keyboard_listener_thread, daemon=True)
|
||||
self.keyboard_thread.start()
|
||||
logger.info("Keyboard controls started!")
|
||||
|
||||
def stop_keyboard_controls(self):
|
||||
"""Stop the keyboard control thread."""
|
||||
if not self.keyboard_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping keyboard controls...")
|
||||
self.keyboard_running = False
|
||||
if self.keyboard_thread:
|
||||
self.keyboard_thread.join(timeout=2.0)
|
||||
logger.info("Keyboard controls stopped")
|
||||
|
||||
def init_locomotion(self):
|
||||
"""Test locomotion control sequence: home arms -> move legs to default -> start policy thread."""
|
||||
if not self.config.locomotion_control:
|
||||
logger.warning("locomotion_control is False, cannot run test sequence")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion test sequence...")
|
||||
|
||||
# 2. Move legs to default position
|
||||
self.locomotion_move_to_default_pos()
|
||||
|
||||
# 3. Wait 3 seconds
|
||||
time.sleep(3.0)
|
||||
|
||||
# 4. Hold default leg position for 2 seconds
|
||||
self.locomotion_default_pos_state()
|
||||
|
||||
# 5. Start locomotion policy thread (runs in background)
|
||||
logger.info("Starting locomotion policy control...")
|
||||
self.start_locomotion_thread()
|
||||
|
||||
logger.info("Locomotion test sequence complete! Policy is now running in background.")
|
||||
logger.info("Use robot.stop_locomotion_thread() to stop the policy.")
|
||||
|
||||
def init_groot_locomotion(self):
|
||||
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
|
||||
if not self.config.locomotion_control:
|
||||
logger.warning("locomotion_control is False, cannot run GR00T init")
|
||||
return
|
||||
|
||||
logger.info("Starting GR00T locomotion initialization...")
|
||||
|
||||
# Move legs to default position (same as regular locomotion)
|
||||
self.locomotion_move_to_default_pos()
|
||||
|
||||
# Wait 3 seconds
|
||||
time.sleep(3.0)
|
||||
|
||||
# Hold default leg position for 2 seconds
|
||||
self.locomotion_default_pos_state()
|
||||
|
||||
# Start locomotion policy thread (will use groot_locomotion_run)
|
||||
logger.info("Starting GR00T locomotion policy control...")
|
||||
self.start_locomotion_thread()
|
||||
|
||||
logger.info("GR00T locomotion initialization complete! Policy is now running.")
|
||||
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
|
||||
|
||||
Reference in New Issue
Block a user