separate groot locomotion logic

This commit is contained in:
Martino Russi
2025-11-26 17:17:02 +01:00
parent d7481f653e
commit 3385350f2d
6 changed files with 367 additions and 638 deletions
+345
View File
@@ -0,0 +1,345 @@
#!/usr/bin/env python
"""
Example: GR00T Locomotion with Pre-loaded Policies
This example demonstrates the NEW pattern for loading GR00T policies externally
and passing them to the robot class.
"""
import logging
import threading
import time
from collections import deque
import numpy as np
import onnxruntime as ort
import torch
from scipy.spatial.transform import Rotation as R
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logger = logging.getLogger(__name__)
def load_groot_policies() -> tuple:
"""Load GR00T dual-policy system (Balance + Walk) from ONNX files."""
logger.info("Loading GR00T dual-policy system...")
# Load ONNX policies
policy_balance = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Balance.onnx")
policy_walk = ort.InferenceSession("examples/unitree_g1/locomotion/GR00T-WholeBodyControl-Walk.onnx")
logger.info("GR00T policies loaded successfully")
logger.info(f" Input shape: {policy_balance.get_inputs()[0].shape}")
logger.info(f" Output shape: {policy_balance.get_outputs()[0].shape}")
return policy_balance, policy_walk
class GrootLocomotionController:
"""
Handles GR00T-style locomotion control for the Unitree G1 robot.
This controller manages:
- Dual-policy system (Balance + Walk)
- 29-joint observation processing
- 15D action output (legs + waist)
- Policy inference and motor command generation
"""
# GR00T default angles for all 29 joints
GROOT_DEFAULT_ANGLES = np.array([
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
0.0, 0.0, 0.0, # waist
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # left arm (zeroed)
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # right arm (zeroed)
], dtype=np.float32)
# Joints to zero out in observations and commands
JOINTS_TO_ZERO = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
PROBLEMATIC_JOINTS = [12, 14, 20, 21, 27, 28]
def __init__(self, policy_balance, policy_walk, robot, config):
"""
Initialize the GR00T locomotion controller.
Args:
policy_balance: ONNX InferenceSession for balance/standing policy
policy_walk: ONNX InferenceSession for walking policy
robot: Reference to the UnitreeG1 robot instance
config: UnitreeG1Config object with locomotion parameters
"""
self.policy_balance = policy_balance
self.policy_walk = policy_walk
self.robot = robot
self.config = config
# Locomotion state
self.locomotion_counter = 0
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, yaw_rate
# GR00T-specific state
self.groot_qj_all = np.zeros(29, dtype=np.float32)
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
self.groot_action = np.zeros(15, dtype=np.float32)
self.groot_obs_single = np.zeros(86, dtype=np.float32)
self.groot_obs_history = deque(maxlen=6)
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
self.groot_height_cmd = 0.74 # Default base height
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# Initialize history with zeros
for _ in range(6):
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
# Thread management
self.locomotion_running = False
self.locomotion_thread = None
logger.info("GrootLocomotionController initialized")
def groot_locomotion_run(self):
"""GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action."""
self.locomotion_counter += 1
# Get current lowstate
lowstate = self.robot.lowstate_buffer.GetData()
if lowstate is None:
return
# Update remote controller from lowstate
if lowstate.wireless_remote is not None:
self.robot.remote_controller.set(lowstate.wireless_remote)
# R1/R2 buttons for height control on real robot (button indices 0 and 4)
if self.robot.remote_controller.button[0]: # R1 - raise height
self.groot_height_cmd += 0.001 # Small increment per timestep (~0.05m per second at 50Hz)
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
if self.robot.remote_controller.button[4]: # R2 - lower height
self.groot_height_cmd -= 0.001 # Small decrement per timestep
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
else:
# Default to zero commands if no remote data
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
# Get ALL 29 joint positions and velocities
for i in range(29):
self.groot_qj_all[i] = lowstate.motor_state[i].q
self.groot_dqj_all[i] = lowstate.motor_state[i].dq
# Get IMU data
quat = lowstate.imu_state.quaternion
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
# Transform IMU if using torso IMU
if self.config.locomotion_imu_type == "torso":
waist_yaw = lowstate.motor_state[12].q # Waist yaw index
waist_yaw_omega = lowstate.motor_state[12].dq
quat, ang_vel_3d = self.robot.locomotion_transform_imu_data(
waist_yaw, waist_yaw_omega, quat, np.array([ang_vel])
)
ang_vel = ang_vel_3d.flatten()
# Create observation
gravity_orientation = self.robot.locomotion_get_gravity_orientation(quat)
# Zero out specific joints in observation
for idx in self.JOINTS_TO_ZERO:
self.groot_qj_all[idx] = 0.0
self.groot_dqj_all[idx] = 0.0
# Scale joint positions and velocities
qj_obs = self.groot_qj_all.copy()
dqj_obs = self.groot_dqj_all.copy()
qj_obs = (qj_obs - self.GROOT_DEFAULT_ANGLES) * self.config.dof_pos_scale
dqj_obs = dqj_obs * self.config.dof_vel_scale
ang_vel_scaled = ang_vel * self.config.groot_ang_vel_scale
# Get velocity commands (keyboard or remote)
if not self.robot.simulation_mode:
self.locomotion_cmd[0] = self.robot.remote_controller.ly
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1
# Build 86D single frame observation (GR00T format)
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(self.config.groot_cmd_scale)
self.groot_obs_single[3] = self.groot_height_cmd
self.groot_obs_single[4:7] = self.groot_orientation_cmd
self.groot_obs_single[7:10] = ang_vel_scaled
self.groot_obs_single[10:13] = gravity_orientation
self.groot_obs_single[13:42] = qj_obs # 29D joint positions
self.groot_obs_single[42:71] = dqj_obs # 29D joint velocities
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
# Add to history and stack observations (6 frames × 86D = 516D)
self.groot_obs_history.append(self.groot_obs_single.copy())
# Stack all 6 frames into 516D vector
for i, obs_frame in enumerate(self.groot_obs_history):
start_idx = i * 86
end_idx = start_idx + 86
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
# Run policy inference (ONNX) with 516D stacked observation
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
# Select appropriate policy based on command magnitude (dual-policy system)
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
if cmd_magnitude < 0.05:
# Use balance/standing policy for small commands
selected_policy = self.policy_balance
else:
# Use walking policy for movement commands
selected_policy = self.policy_walk
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
ort_outs = selected_policy.run(None, ort_inputs)
self.groot_action = ort_outs[0].squeeze()
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
self.groot_action[12] = 0.0 # Waist yaw
self.groot_action[13] = 0.0 # Waist roll
self.groot_action[14] = 0.0 # Waist pitch
# Transform action to target joint positions (15D: legs + waist)
target_dof_pos_15 = (
self.GROOT_DEFAULT_ANGLES[:15] + self.groot_action * self.config.locomotion_action_scale
)
# Send commands to LEG motors (0-11)
for i in range(12):
motor_idx = i
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i]
self.robot.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# Send WAIST commands - but SKIP waist yaw (12) and waist pitch (14)
# Only send waist roll (13)
waist_roll_idx = 13
waist_roll_action_idx = 13
self.robot.msg.motor_cmd[waist_roll_idx].q = target_dof_pos_15[waist_roll_action_idx]
self.robot.msg.motor_cmd[waist_roll_idx].qd = 0
self.robot.msg.motor_cmd[waist_roll_idx].kp = self.config.locomotion_arm_waist_kps[1]
self.robot.msg.motor_cmd[waist_roll_idx].kd = self.config.locomotion_arm_waist_kds[1]
self.robot.msg.motor_cmd[waist_roll_idx].tau = 0
# Zero out the problematic joints (waist yaw, waist pitch, wrist pitch/yaw)
for joint_idx in self.PROBLEMATIC_JOINTS:
self.robot.msg.motor_cmd[joint_idx].q = 0.0
self.robot.msg.motor_cmd[joint_idx].qd = 0
if joint_idx in [12, 14]: # waist
kp_idx = 0 if joint_idx == 12 else 2 # yaw or pitch
self.robot.msg.motor_cmd[joint_idx].kp = self.config.locomotion_arm_waist_kps[kp_idx]
self.robot.msg.motor_cmd[joint_idx].kd = self.config.locomotion_arm_waist_kds[kp_idx]
else: # wrists (20, 21, 27, 28)
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp_wrist
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd_wrist
self.robot.msg.motor_cmd[joint_idx].tau = 0
# Send command
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.groot_locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, self.config.locomotion_control_dt - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
"""Start the background locomotion control thread."""
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
"""Stop the background locomotion control thread."""
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def init_groot_locomotion(self):
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
logger.info("Starting GR00T locomotion initialization...")
# Move legs to default position
self.robot.locomotion_move_to_default_pos()
# Wait 3 seconds
time.sleep(3.0)
# Hold default leg position for 2 seconds
self.robot.locomotion_default_pos_state()
# Start locomotion policy thread
logger.info("Starting GR00T locomotion policy control...")
self.start_locomotion_thread()
logger.info("GR00T locomotion initialization complete! Policy is now running.")
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
if __name__ == "__main__":
# 1. Load policies externally (separate from robot initialization)
policy_balance, policy_walk = load_groot_policies()
# 2. Create config (no locomotion_control=True since we're using external controller)
config = UnitreeG1Config()
# 3. Initialize robot
robot = UnitreeG1(config)
# 4. Create GR00T locomotion controller with loaded policies
groot_controller = GrootLocomotionController(
policy_balance=policy_balance,
policy_walk=policy_walk,
robot=robot,
config=config,
)
# 5. Initialize and start locomotion
groot_controller.init_groot_locomotion()
# Robot is now ready with locomotion control!
print("Robot initialized with GR00T locomotion policies")
print("Locomotion controller running in background thread")
print("Press Ctrl+C to stop")
try:
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
groot_controller.stop_locomotion_thread()
print("Done!")
@@ -15,6 +15,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from lerobot.cameras import CameraConfig
@@ -24,7 +25,7 @@ from ..config import RobotConfig
@RobotConfig.register_subclass("unitree_g1")
@dataclass
class UnitreeG1Config(RobotConfig):
# id: str = "unitree_g1"
# id: str = "unitree_g1"
simulation_mode: bool = False
kp_high = 40.0
kd_high = 3.0
@@ -52,49 +53,38 @@ class UnitreeG1Config(RobotConfig):
# This robot class ONLY uses sockets to communicate with a bridge on the Orin
# Run 'python dds_to_socket.py' on the Orin first, then set this to the Orin's IP
# Example: socket_host="192.168.123.164" (Orin's wlan0 IP)
socket_host: str | None = None # = "172.18.129.215"
socket_host: str | None = None# = "172.18.129.215"
socket_port: int | None = None
# Locomotion control
locomotion_control: bool = True
# policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
policy_path: str | None = None
# Pre-loaded policies (preferred method for GR00T locomotion)
policy_walk: Any = None # Pre-loaded walk policy (ONNX InferenceSession)
policy_balance: Any = None # Pre-loaded balance policy (ONNX InferenceSession)
# Locomotion parameters (from g1.yaml)
locomotion_control_dt: float = 0.02
leg_joint2motor_idx: list = field(default_factory=lambda: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
locomotion_kps: list = field(
default_factory=lambda: [150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40]
)
locomotion_kps: list = field(default_factory=lambda: [150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40])
locomotion_kds: list = field(default_factory=lambda: [2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2])
default_leg_angles: list = field(
default_factory=lambda: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0]
)
arm_waist_joint2motor_idx: list = field(
default_factory=lambda: [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]
)
locomotion_arm_waist_kps: list = field(
default_factory=lambda: [250, 250, 250, 100, 100, 50, 50, 20, 20, 20, 100, 100, 50, 50, 20, 20, 20]
)
locomotion_arm_waist_kds: list = field(
default_factory=lambda: [5, 5, 5, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1]
)
locomotion_arm_waist_target: list = field(
default_factory=lambda: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
)
default_leg_angles: list = field(default_factory=lambda: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0])
arm_waist_joint2motor_idx: list = field(default_factory=lambda: [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28])
locomotion_arm_waist_kps: list = field(default_factory=lambda: [250, 250, 250, 100, 100, 50, 50, 20, 20, 20, 100, 100, 50, 50, 20, 20, 20])
locomotion_arm_waist_kds: list = field(default_factory=lambda: [5, 5, 5, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1])
locomotion_arm_waist_target: list = field(default_factory=lambda: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
ang_vel_scale: float = 0.25
dof_pos_scale: float = 1.0
dof_vel_scale: float = 0.05
locomotion_action_scale: float = 0.25
cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25])
# GR00T-specific scaling (different from regular locomotion!)
groot_ang_vel_scale: float = 0.25 # GR00T uses 0.5, not 0.25
groot_cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25]) # yaw is 0.5 for GR00T
num_locomotion_actions: int = 12
num_locomotion_obs: int = 47
max_cmd: list = field(default_factory=lambda: [0.8, 0.5, 1.57])
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
+4 -2
View File
@@ -1,5 +1,6 @@
from enum import IntEnum
class G1_29_JointArmIndex(IntEnum):
# Left arm
kLeftShoulderPitch = 15
@@ -19,6 +20,7 @@ class G1_29_JointArmIndex(IntEnum):
kRightWristPitch = 27
kRightWristYaw = 28
class G1_29_JointIndex(IntEnum):
# Left leg
kLeftHipPitch = 0
@@ -36,7 +38,7 @@ class G1_29_JointIndex(IntEnum):
kRightAnklePitch = 10
kRightAnkleRoll = 11
kWaistYaw = 12 #we're c
kWaistYaw = 12 # we're c
kWaistRoll = 13
kWaistPitch = 14
@@ -64,4 +66,4 @@ class G1_29_JointIndex(IntEnum):
kNotUsedJoint2 = 31
kNotUsedJoint3 = 32
kNotUsedJoint4 = 33
kNotUsedJoint5 = 34
kNotUsedJoint5 = 34
+1 -609
View File
@@ -9,7 +9,6 @@ import time
import tty
from collections import deque
from functools import cached_property
from pathlib import Path
from typing import Any
import numpy as np
@@ -172,12 +171,6 @@ class UnitreeG1(Robot):
self.msg.motor_cmd[id].kp = self.kp_high
self.msg.motor_cmd[id].kd = self.kd_high
self.msg.motor_cmd[id].q = self.all_motor_q[id]
# print current motor q, kp, kd
logger.warning("Lock OK!\n") # motors are not locked x
# for i in range(10000):
# print(self.get_current_motor_q())
# time.sleep(0.05)
# Initialize control flags BEFORE starting threads
self.keyboard_thread = None
@@ -185,10 +178,6 @@ class UnitreeG1(Robot):
self.locomotion_thread = None
self.locomotion_running = False
# Initialize publish thread for arm control
# Note: This thread runs alongside locomotion thread
# - Arm thread: controls arms (indices 15-28)
# - Locomotion thread: controls legs (0-11), waist (12-14)
# Both update different parts of self.msg, both call Write()
self.publish_thread = None
self.ctrl_lock = threading.Lock()
@@ -196,102 +185,8 @@ class UnitreeG1(Robot):
self.publish_thread.daemon = True
self.publish_thread.start()
logger.warning("Arm control publish thread started")
self.remote_controller = self.RemoteController()
# Load locomotion policy if enabled
self.policy = None
self.policy_type = None # 'torchscript' or 'onnx'
print(config)
if config.locomotion_control:
if config.policy_path is None:
raise ValueError("locomotion_control is True but policy_path is not set")
logger.warning(f"Loading locomotion policy from {config.policy_path}")
# Check file extension and load accordingly
if config.policy_path.endswith(".pt"):
logger.warning("Detected TorchScript (.pt) policy")
self.policy = torch.jit.load(config.policy_path)
self.policy_type = "torchscript"
logger.info("TorchScript policy loaded successfully")
elif config.policy_path.endswith(".onnx"):
logger.warning("Detected ONNX (.onnx) policy")
# For GR00T-style policies, load both Balance and Walk policies
# Balance policy for standing (low velocity commands)
# Walk policy for locomotion (high velocity commands)
balance_policy_path = config.policy_path.replace("Walk.onnx", "Balance.onnx")
walk_policy_path = config.policy_path
if Path(balance_policy_path).exists() and Path(walk_policy_path).exists():
logger.info("Loading dual-policy system (Balance + Walk)")
self.policy_balance = ort.InferenceSession(balance_policy_path)
self.policy_walk = ort.InferenceSession(walk_policy_path)
self.policy = None # Not used when dual policies are loaded
logger.info(f"Balance policy loaded from: {balance_policy_path}")
logger.info(f"Walk policy loaded from: {walk_policy_path}")
logger.info(
f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}"
)
logger.info(
f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}"
)
else:
# Fallback to single policy
logger.info("Loading single ONNX policy")
self.policy = ort.InferenceSession(config.policy_path)
self.policy_balance = None
self.policy_walk = None
logger.info("ONNX policy loaded successfully")
logger.info(
f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}"
)
logger.info(
f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}"
)
self.policy_type = "onnx"
else:
raise ValueError(
f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported."
)
# Initialize locomotion variables
self.remote_controller = self.RemoteController()
self.locomotion_counter = 0
self.qj = np.zeros(config.num_locomotion_actions, dtype=np.float32)
self.dqj = np.zeros(config.num_locomotion_actions, dtype=np.float32)
self.locomotion_action = np.zeros(config.num_locomotion_actions, dtype=np.float32)
self.locomotion_obs = np.zeros(config.num_locomotion_obs, dtype=np.float32)
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# GR00T-specific variables (for ONNX policies with 29 joints)
if self.policy_type == "onnx":
self.groot_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
self.groot_action = np.zeros(15, dtype=np.float32) # 15D action (legs + waist)
self.groot_obs_single = np.zeros(86, dtype=np.float32) # 86D single frame observation
self.groot_obs_history = deque(maxlen=6) # 6-frame history buffer
self.groot_obs_stacked = np.zeros(516, dtype=np.float32) # 86D × 6 = 516D stacked observation
self.groot_height_cmd = 0.74 # Default base height
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # roll, pitch, yaw
# Initialize history with zeros
for _ in range(6):
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
# Start keyboard controls if in simulation mode
if self.simulation_mode:
logger.info("Starting keyboard controls for simulation...")
self.start_keyboard_controls()
# Use different init based on policy type
if self.policy_type == "onnx":
self.init_groot_locomotion()
else:
self.init_locomotion()
elif self.simulation_mode:
# Even without locomotion, provide keyboard feedback in sim
logger.info("Simulation mode active (locomotion disabled)")
logger.info("Initialize G1 OK!\n")
@@ -766,506 +661,3 @@ class UnitreeG1(Robot):
R_pelvis = np.dot(R_torso, RzWaist.T)
w = np.dot(RzWaist, imu_omega[0]) - np.array([0, 0, waist_yaw_omega])
return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w
def locomotion_run(self):
"""Main locomotion policy loop - runs policy and sends leg commands."""
self.locomotion_counter += 1
# Get current lowstate
lowstate = self.lowstate_buffer.GetData()
if lowstate is None:
return
# Update remote controller from lowstate
if lowstate.wireless_remote is not None:
self.remote_controller.set(lowstate.wireless_remote)
else:
# Default to zero commands if no remote data
self.remote_controller.lx = 0.0
self.remote_controller.ly = 0.0
self.remote_controller.rx = 0.0
self.remote_controller.ry = 0.0
# Get the current joint position and velocity (LEGS ONLY)
for i in range(len(self.config.leg_joint2motor_idx)):
self.qj[i] = lowstate.motor_state[self.config.leg_joint2motor_idx[i]].q
self.dqj[i] = lowstate.motor_state[self.config.leg_joint2motor_idx[i]].dq
# Get IMU data
quat = lowstate.imu_state.quaternion
ang_vel = np.array([lowstate.imu_state.gyroscope], dtype=np.float32)
if self.config.locomotion_imu_type == "torso":
# Transform IMU data from torso to pelvis frame
waist_yaw = lowstate.motor_state[self.config.arm_waist_joint2motor_idx[0]].q
waist_yaw_omega = lowstate.motor_state[self.config.arm_waist_joint2motor_idx[0]].dq
quat, ang_vel = self.locomotion_transform_imu_data(waist_yaw, waist_yaw_omega, quat, ang_vel)
# Create observation
gravity_orientation = self.locomotion_get_gravity_orientation(quat)
qj_obs = self.qj.copy()
dqj_obs = self.dqj.copy()
qj_obs = (qj_obs - np.array(self.config.default_leg_angles)) * self.config.dof_pos_scale
dqj_obs = dqj_obs * self.config.dof_vel_scale
ang_vel = ang_vel * self.config.ang_vel_scale
# Calculate phase
period = 0.8
count = self.locomotion_counter * self.config.locomotion_control_dt
phase = count % period / period
sin_phase = np.sin(2 * np.pi * phase)
cos_phase = np.cos(2 * np.pi * phase)
# Get velocity commands from remote controller (only if NOT in simulation mode)
# In simulation mode, keyboard controls set self.locomotion_cmd directly
if not self.simulation_mode:
self.locomotion_cmd[0] = self.remote_controller.ly
self.locomotion_cmd[1] = self.remote_controller.lx * -1
self.locomotion_cmd[2] = self.remote_controller.rx * -1
# Debug: print remote controller values every 50 iterations (~1 second at 50Hz)
if self.locomotion_counter % 50 == 0:
logger.debug(
f"Remote controller - lx:{self.remote_controller.lx:.2f}, ly:{self.remote_controller.ly:.2f}, rx:{self.remote_controller.rx:.2f}"
)
# Build observation vector
num_actions = self.config.num_locomotion_actions
self.locomotion_obs[:3] = ang_vel
self.locomotion_obs[3:6] = gravity_orientation
self.locomotion_obs[6:9] = (
self.locomotion_cmd * np.array(self.config.cmd_scale) * np.array(self.config.max_cmd)
)
self.locomotion_obs[9 : 9 + num_actions] = qj_obs
self.locomotion_obs[9 + num_actions : 9 + num_actions * 2] = dqj_obs
self.locomotion_obs[9 + num_actions * 2 : 9 + num_actions * 3] = self.locomotion_action
self.locomotion_obs[9 + num_actions * 3] = sin_phase
self.locomotion_obs[9 + num_actions * 3 + 1] = cos_phase
# Get action from policy network
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0)
if self.policy_type == "torchscript":
# TorchScript inference
self.locomotion_action = self.policy(obs_tensor).detach().numpy().squeeze()
elif self.policy_type == "onnx":
# ONNX inference
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
ort_outs = self.policy.run(None, ort_inputs)
self.locomotion_action = ort_outs[0].squeeze()
else:
raise ValueError(f"Unknown policy type: {self.policy_type}")
# Transform action to target joint positions
target_dof_pos = (
np.array(self.config.default_leg_angles)
+ self.locomotion_action * self.config.locomotion_action_scale
)
# Send commands to LEG motors only
for i in range(len(self.config.leg_joint2motor_idx)):
motor_idx = self.config.leg_joint2motor_idx[i]
self.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
self.msg.motor_cmd[motor_idx].qd = 0
self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i]
self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i]
self.msg.motor_cmd[motor_idx].tau = 0
# Hold WAIST motors at 0 (indices 12, 13, 14 = WaistYaw, WaistRoll, WaistPitch)
waist_indices = self.config.arm_waist_joint2motor_idx[:3] # First 3 are waist
for i, motor_idx in enumerate(waist_indices):
self.msg.motor_cmd[motor_idx].q = 0.0
self.msg.motor_cmd[motor_idx].qd = 0
self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_arm_waist_kps[i]
self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_arm_waist_kds[i]
self.msg.motor_cmd[motor_idx].tau = 0
# Send command
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
def groot_locomotion_run(self):
"""GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action."""
self.locomotion_counter += 1
# Get current lowstate
lowstate = self.lowstate_buffer.GetData()
if lowstate is None:
return
# Update remote controller from lowstate
if lowstate.wireless_remote is not None:
self.remote_controller.set(lowstate.wireless_remote)
# R1/R2 buttons for height control on real robot (button indices 4 and 5)
if self.remote_controller.button[0]: # R1 - raise height
self.groot_height_cmd += 0.001 # Small increment per timestep (~0.05m per second at 50Hz)
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
if self.remote_controller.button[4]: # R2 - lower height
self.groot_height_cmd -= 0.001 # Small decrement per timestep
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
else:
# Default to zero commands if no remote data
self.remote_controller.lx = 0.0
self.remote_controller.ly = 0.0
self.remote_controller.rx = 0.0
self.remote_controller.ry = 0.0
# Get ALL 29 joint positions and velocities
for i in range(29):
self.groot_qj_all[i] = lowstate.motor_state[i].q
self.groot_dqj_all[i] = lowstate.motor_state[i].dq
# Get IMU data
quat = lowstate.imu_state.quaternion
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
# Transform IMU if using torso IMU
if self.config.locomotion_imu_type == "torso":
waist_yaw = lowstate.motor_state[12].q # Waist yaw index
waist_yaw_omega = lowstate.motor_state[12].dq
quat, ang_vel_3d = self.locomotion_transform_imu_data(
waist_yaw, waist_yaw_omega, quat, np.array([ang_vel])
)
ang_vel = ang_vel_3d.flatten()
# Create observation
gravity_orientation = self.locomotion_get_gravity_orientation(quat)
joints_to_zero_obs = [12, 14, 20, 21, 27, 28] # Note: NOT 13 (waist roll exists)
for idx in joints_to_zero_obs:
self.groot_qj_all[idx] = 0.0
self.groot_dqj_all[idx] = 0.0
# Scale joint positions and velocities
qj_obs = self.groot_qj_all.copy()
dqj_obs = self.groot_dqj_all.copy()
# Subtract default angles for legs + waist (15 joints)
# GR00T default_angles: [-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, -0.1, 0.0, 0.0, 0.3, -0.2, 0.0, 0.0, 0.0, 0.0]
groot_default_angles = np.array(
[
-0.1,
0.0,
0.0,
0.3,
-0.2,
0.0, # left leg
-0.1,
0.0,
0.0,
0.3,
-0.2,
0.0, # right leg
0.0,
0.0,
0.0, # waist
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0, # left arm (zeroed)
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
dtype=np.float32,
) # right arm (zeroed)
qj_obs = (qj_obs - groot_default_angles) * self.config.dof_pos_scale
dqj_obs = dqj_obs * self.config.dof_vel_scale
ang_vel_scaled = ang_vel * self.config.groot_ang_vel_scale # Use GR00T-specific scaling!
# Get velocity commands (keyboard or remote)
if not self.simulation_mode:
self.locomotion_cmd[0] = self.remote_controller.ly
self.locomotion_cmd[1] = self.remote_controller.lx * -1
self.locomotion_cmd[2] = self.remote_controller.rx * -1
# Build 86D single frame observation (GR00T format)
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(
self.config.groot_cmd_scale
) # cmd - use GR00T scaling!
self.groot_obs_single[3] = self.groot_height_cmd # height_cmd
self.groot_obs_single[4:7] = self.groot_orientation_cmd # roll, pitch, yaw cmd
self.groot_obs_single[7:10] = ang_vel_scaled # angular velocity
self.groot_obs_single[10:13] = gravity_orientation # gravity
self.groot_obs_single[13:42] = qj_obs # joint positions (29D)
self.groot_obs_single[42:71] = dqj_obs # joint velocities (29D)
self.groot_obs_single[71:86] = self.groot_action # previous actions (15D)
# Add to history and stack observations (6 frames × 86D = 516D)
self.groot_obs_history.append(self.groot_obs_single.copy())
# Stack all 6 frames into 516D vector
for i, obs_frame in enumerate(self.groot_obs_history):
start_idx = i * 86
end_idx = start_idx + 86
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
# Run policy inference (ONNX) with 516D stacked observation
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
# Select appropriate policy based on command magnitude (dual-policy system)
if self.policy_balance is not None and self.policy_walk is not None:
# Dual-policy mode: switch between Balance and Walk
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
if cmd_magnitude < 0.05:
# Use balance/standing policy for small commands
selected_policy = self.policy_balance
else:
# Use walking policy for movement commands
selected_policy = self.policy_walk
else:
# Single policy mode (fallback)
selected_policy = self.policy
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
ort_outs = selected_policy.run(None, ort_inputs)
self.groot_action = ort_outs[0].squeeze()
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
# This ensures action history in observations matches what's actually executed
self.groot_action[12] = 0.0 # Waist yaw
self.groot_action[13] = 0.0 # Waist roll
self.groot_action[14] = 0.0 # Waist pitch
# Transform action to target joint positions (15D: legs + waist, but waist actions are zeroed)
target_dof_pos_15 = (
groot_default_angles[:15] + self.groot_action * self.config.locomotion_action_scale
)
# Send commands to LEG motors (0-11)
for i in range(12):
motor_idx = i
self.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
self.msg.motor_cmd[motor_idx].qd = 0
self.msg.motor_cmd[motor_idx].kp = self.config.locomotion_kps[i]
self.msg.motor_cmd[motor_idx].kd = self.config.locomotion_kds[i]
self.msg.motor_cmd[motor_idx].tau = 0
# Send WAIST commands - but SKIP waist yaw (12) and waist pitch (14)
# Only send waist roll (13)
waist_roll_idx = 13
waist_roll_action_idx = 13 # In the 15D action
self.msg.motor_cmd[waist_roll_idx].q = target_dof_pos_15[waist_roll_action_idx]
self.msg.motor_cmd[waist_roll_idx].qd = 0
self.msg.motor_cmd[waist_roll_idx].kp = self.config.locomotion_arm_waist_kps[
1
] # index 1 is waist roll
self.msg.motor_cmd[waist_roll_idx].kd = self.config.locomotion_arm_waist_kds[1]
self.msg.motor_cmd[waist_roll_idx].tau = 0
# Zero out the problematic joints (waist yaw, waist pitch, wrist pitch/yaw)
problematic_joints = [12, 14, 20, 21, 27, 28]
for joint_idx in problematic_joints:
self.msg.motor_cmd[joint_idx].q = 0.0
self.msg.motor_cmd[joint_idx].qd = 0
if joint_idx in [12, 14]: # waist
kp_idx = 0 if joint_idx == 12 else 2 # yaw or pitch
self.msg.motor_cmd[joint_idx].kp = self.config.locomotion_arm_waist_kps[kp_idx]
self.msg.motor_cmd[joint_idx].kd = self.config.locomotion_arm_waist_kds[kp_idx]
else: # wrists (20, 21, 27, 28)
self.msg.motor_cmd[joint_idx].kp = self.kp_wrist
self.msg.motor_cmd[joint_idx].kd = self.kd_wrist
self.msg.motor_cmd[joint_idx].tau = 0
# Send command
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
# Use different run function based on policy type
if self.policy_type == "onnx":
self.groot_locomotion_run()
else:
self.locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, self.config.locomotion_control_dt - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
"""Start the background locomotion control thread."""
if not self.config.locomotion_control:
logger.warning("locomotion_control is False, cannot start thread")
return
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
"""Stop the background locomotion control thread."""
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
# Also stop keyboard thread if running
if self.keyboard_running:
self.stop_keyboard_controls()
def _keyboard_listener_thread(self):
"""Background thread that listens for keyboard input (sim mode only)."""
print("\n" + "=" * 60)
print("KEYBOARD CONTROLS ACTIVE!")
print(" W/S: Forward/Backward")
print(" A/D: Left/Right")
print(" Q/E: Rotate Left/Right")
print(" R/F: Raise/Lower Height (±5cm)")
print(" Z: Stop (zero velocity commands)")
print("=" * 60 + "\n")
# Save terminal settings
old_settings = None
try:
old_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno())
while self.keyboard_running:
if select.select([sys.stdin], [], [], 0.1)[0]:
key = sys.stdin.read(1).lower()
# Velocity commands
if key == "w":
self.locomotion_cmd[0] += 0.4 # Forward
elif key == "s":
self.locomotion_cmd[0] -= 0.4 # Backward
elif key == "a":
self.locomotion_cmd[1] += 0.25 # Left
elif key == "d":
self.locomotion_cmd[1] -= 0.25 # Right
elif key == "q":
self.locomotion_cmd[2] += 0.5 # Rotate left
elif key == "e":
self.locomotion_cmd[2] -= 0.5 # Rotate right
elif key == "z":
self.locomotion_cmd[:] = 0.0 # Stop
# Height commands (only for GR00T ONNX policies)
elif key == "r":
self.groot_height_cmd += 0.05 # Raise 5cm
elif key == "f":
self.groot_height_cmd -= 0.05 # Lower 5cm
# Clamp commands to reasonable limits
self.locomotion_cmd[0] = np.clip(self.locomotion_cmd[0], -0.8, 0.8) # vx
self.locomotion_cmd[1] = np.clip(self.locomotion_cmd[1], -0.5, 0.5) # vy
self.locomotion_cmd[2] = np.clip(self.locomotion_cmd[2], -1.0, 1.0) # yaw_rate
# Clamp height (reasonable range: 0.5m to 1.0m)
if hasattr(self, "groot_height_cmd"):
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
# Print current commands
print(
f"[VEL CMD] vx={self.locomotion_cmd[0]:.2f}, vy={self.locomotion_cmd[1]:.2f}, yaw={self.locomotion_cmd[2]:.2f}",
end="",
)
if hasattr(self, "groot_height_cmd"):
print(f" | [HEIGHT] {self.groot_height_cmd:.3f}m", end="")
print() # Newline
finally:
# Restore terminal settings
if old_settings is not None:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
print("\nKeyboard controls stopped")
def start_keyboard_controls(self):
"""Start the keyboard control thread (sim mode only)."""
if not self.simulation_mode:
logger.warning("Keyboard controls only available in simulation mode")
return
if self.keyboard_running:
logger.warning("Keyboard controls already running")
return
self.keyboard_running = True
self.keyboard_thread = threading.Thread(target=self._keyboard_listener_thread, daemon=True)
self.keyboard_thread.start()
logger.info("Keyboard controls started!")
def stop_keyboard_controls(self):
"""Stop the keyboard control thread."""
if not self.keyboard_running:
return
logger.info("Stopping keyboard controls...")
self.keyboard_running = False
if self.keyboard_thread:
self.keyboard_thread.join(timeout=2.0)
logger.info("Keyboard controls stopped")
def init_locomotion(self):
"""Test locomotion control sequence: home arms -> move legs to default -> start policy thread."""
if not self.config.locomotion_control:
logger.warning("locomotion_control is False, cannot run test sequence")
return
logger.info("Starting locomotion test sequence...")
# 2. Move legs to default position
self.locomotion_move_to_default_pos()
# 3. Wait 3 seconds
time.sleep(3.0)
# 4. Hold default leg position for 2 seconds
self.locomotion_default_pos_state()
# 5. Start locomotion policy thread (runs in background)
logger.info("Starting locomotion policy control...")
self.start_locomotion_thread()
logger.info("Locomotion test sequence complete! Policy is now running in background.")
logger.info("Use robot.stop_locomotion_thread() to stop the policy.")
def init_groot_locomotion(self):
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
if not self.config.locomotion_control:
logger.warning("locomotion_control is False, cannot run GR00T init")
return
logger.info("Starting GR00T locomotion initialization...")
# Move legs to default position (same as regular locomotion)
self.locomotion_move_to_default_pos()
# Wait 3 seconds
time.sleep(3.0)
# Hold default leg position for 2 seconds
self.locomotion_default_pos_state()
# Start locomotion policy thread (will use groot_locomotion_run)
logger.info("Starting GR00T locomotion policy control...")
self.start_locomotion_thread()
logger.info("GR00T locomotion initialization complete! Policy is now running.")
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")