simplified robot class

This commit is contained in:
Martino Russi
2025-11-26 15:11:45 +01:00
parent c65866ddd8
commit c7834c3db8
+71 -724
View File
@@ -6,15 +6,9 @@ from typing import Any
from pathlib import Path
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
FeetechMotorsBus,
)
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
import json
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_unitree_g1 import UnitreeG1Config
import numpy as np
@@ -36,7 +30,6 @@ import onnxruntime as ort
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.utils.crc import CRC
from unitree_sdk2py.g1.audio.g1_audio_client import AudioClient
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import (
MotionSwitcherClient,
)
@@ -45,11 +38,6 @@ from lerobot.envs.factory import make_env
from scipy.spatial.transform import Rotation as R
import struct
import yaml
from typing import Union
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
import torch
@@ -57,7 +45,6 @@ import torch
logger = logging.getLogger(__name__)
kTopicLowCommand_Debug = "rt/lowcmd"
kTopicLowCommand_Motion = "rt/arm_sdk"
kTopicLowState = "rt/lowstate"
G1_29_Num_Motors = 35
@@ -102,8 +89,6 @@ class DataBuffer:
with self.lock:
self.data = data
#eventually observations should be everything: motor torques etc etc
#motor class for unitree?
class UnitreeG1(Robot):
config_class = UnitreeG1Config
@@ -118,7 +103,6 @@ class UnitreeG1(Robot):
self.cameras = make_cameras_from_configs(config.cameras)
self.q_target = np.zeros(14)
self.tauff_target = np.zeros(14)
self.motion_mode = config.motion_mode
self.simulation_mode = config.simulation_mode
self.kp_high = config.kp_high
self.kd_high = config.kd_high
@@ -131,7 +115,6 @@ class UnitreeG1(Robot):
self.arm_velocity_limit = config.arm_velocity_limit
self.control_dt = config.control_dt
self._speed_gradual_max = config.speed_gradual_max
self._gradual_start_time = config.gradual_start_time
self._gradual_time = config.gradual_time
@@ -143,7 +126,6 @@ class UnitreeG1(Robot):
self.freeze_body = config.freeze_body
self.gravity_compensation = config.gravity_compensation
self.calibrated = False
self.calibrate()
@@ -155,39 +137,33 @@ class UnitreeG1(Robot):
else:
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
if not self.config.simulation_mode:
self.msc = MotionSwitcherClient()
self.msc.SetTimeout(5.0)
self.msc.Init()
status, result = self.msc.CheckMode()
print(status, result)
#check if result name first
if result is not None and "name" in result:
while result["name"]:
self.msc.ReleaseMode()
status, result = self.msc.CheckMode()
print(status, result)
time.sleep(1)
# initialize lowcmd nd lowstate subscriber
if self.simulation_mode:
ChannelFactoryInitialize(0, "lo")
# Launch MuJoCo simulation environment
logger.info("Launching MuJoCo simulation environment...")
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
logger.info("MuJoCo environment launched successfully!")
else:
ChannelFactoryInitialize(0)
if not self.config.simulation_mode:
pass
# self.msc = MotionSwitcherClient()
# self.msc.SetTimeout(5.0)
# self.msc.Init()
# status, result = self.msc.CheckMode()
# print(status, result)
# #check if result name first
# if result is not None and "name" in result:
# while result["name"]:
# self.msc.ReleaseMode()
# status, result = self.msc.CheckMode()
# print(status, result)
# time.sleep(1)
if self.motion_mode:
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Motion, hg_LowCmd)
else:
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
# Always use debug mode (direct motor control)
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
self.lowcmd_publisher.Init()
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
self.lowstate_subscriber.Init()
@@ -203,9 +179,6 @@ class UnitreeG1(Robot):
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
logger.info("[UnitreeG1] Subscribe dds ok.")
# initialize audio client for LED, TTS, and audio playback
# initialize hg's lowcmd msg
self.crc = CRC()
self.msg = unitree_hg_msg_dds__LowCmd_()
@@ -238,15 +211,6 @@ class UnitreeG1(Robot):
self.msg.motor_cmd[id].q = self.all_motor_q[id]
#print current motor q, kp, kd
if config.audio_client:
pass
# self.audio_client = AudioClient()
# self.audio_client.SetTimeout(10.0)
# self.audio_client.Init()
# logger.info("[UnitreeG1] Audio client initialized!")
logger.info("Lock OK!\n") #motors are not locked x
# for i in range(10000):
# print(self.get_current_motor_q())
@@ -257,62 +221,24 @@ class UnitreeG1(Robot):
self.keyboard_running = False
self.locomotion_thread = None
self.locomotion_running = False
self.motion_imitation_thread = None
self.motion_imitation_running = False
# Initialize publish thread for arm control
# Note: This thread runs alongside locomotion/motion_imitation threads
# Note: This thread runs alongside locomotion thread
# - Arm thread: controls arms (indices 15-28)
# - Locomotion thread: controls legs (0-11), waist (12-14)
# Both update different parts of self.msg, both call Write()
self.publish_thread = None
self.ctrl_lock = threading.Lock()
if not config.motion_imitation_control: # Allow with locomotion, disable only for motion imitation
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
self.publish_thread.daemon = True
self.publish_thread.start()
logger.info("Arm control publish thread started")
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
self.publish_thread.daemon = True
self.publish_thread.start()
logger.info("Arm control publish thread started")
# Load locomotion policy if enabled
self.policy = None
self.policy_type = None # 'torchscript', 'onnx', or 'motion_imitation'
self.motion_loader = None
self.policy_type = None # 'torchscript' or 'onnx'
if config.motion_imitation_control:
# Motion imitation mode (dance, etc.)
if config.motion_file_path is None:
raise ValueError("motion_imitation_control is True but motion_file_path is not set")
logger.info(f"Loading motion reference from {config.motion_file_path}")
# Load motion file
self.motion_loader = self.MotionLoader(config.motion_file_path, config.motion_fps)
# Load ONNX policy (optional for now - can run in direct playback mode)
if config.motion_policy_path and Path(config.motion_policy_path).exists():
logger.info(f"Loading motion imitation policy from {config.motion_policy_path}")
self.policy = ort.InferenceSession(config.motion_policy_path)
self.policy_type = 'motion_imitation'
logger.info("Motion imitation ONNX policy loaded successfully")
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
else:
logger.info("Running in DIRECT PLAYBACK mode (no policy - just reference motion)")
self.policy = None
self.policy_type = 'motion_playback'
# Initialize motion imitation variables
self.motion_counter = 0
self.motion_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints from robot
self.motion_dqj_all = np.zeros(29, dtype=np.float32)
self.motion_action = np.zeros(29, dtype=np.float32) # 29D action output
self.motion_obs = np.zeros(154, dtype=np.float32) # 154D observation
self.motion_elapsed_time = 0.0
# Initialize motion and start
self.init_motion_imitation()
elif config.locomotion_control:
if config.locomotion_control:
if config.policy_path is None:
raise ValueError("locomotion_control is True but policy_path is not set")
@@ -326,11 +252,33 @@ class UnitreeG1(Robot):
logger.info("TorchScript policy loaded successfully")
elif config.policy_path.endswith('.onnx'):
logger.info("Detected ONNX (.onnx) policy")
self.policy = ort.InferenceSession(config.policy_path)
# For GR00T-style policies, load both Balance and Walk policies
# Balance policy for standing (low velocity commands)
# Walk policy for locomotion (high velocity commands)
balance_policy_path = config.policy_path.replace('Walk.onnx', 'Balance.onnx')
walk_policy_path = config.policy_path
if Path(balance_policy_path).exists() and Path(walk_policy_path).exists():
logger.info("Loading dual-policy system (Balance + Walk)")
self.policy_balance = ort.InferenceSession(balance_policy_path)
self.policy_walk = ort.InferenceSession(walk_policy_path)
self.policy = None # Not used when dual policies are loaded
logger.info(f"Balance policy loaded from: {balance_policy_path}")
logger.info(f"Walk policy loaded from: {walk_policy_path}")
logger.info(f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}")
logger.info(f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}")
else:
# Fallback to single policy
logger.info("Loading single ONNX policy")
self.policy = ort.InferenceSession(config.policy_path)
self.policy_balance = None
self.policy_walk = None
logger.info("ONNX policy loaded successfully")
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
self.policy_type = 'onnx'
logger.info("ONNX policy loaded successfully")
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
else:
raise ValueError(f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported.")
@@ -415,12 +363,7 @@ class UnitreeG1(Robot):
return cliped_arm_q_target
def _ctrl_motor_state(self):
"""Arm control thread - publishes commands for arms only.
NOTE: This thread is NOT started when motion_imitation_control or locomotion_control is True.
Those modes handle their own publishing."""
if self.motion_mode:
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = 1.0
"""Arm control thread - publishes commands for arms only."""
while True:
start_time = time.time()
@@ -452,10 +395,6 @@ class UnitreeG1(Robot):
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
if self._speed_gradual_max is True:
t_elapsed = start_time - self._gradual_start_time
self.arm_velocity_limit = 20.0 + (10.0 * min(1.0, t_elapsed / 5.0))
current_time = time.time()
all_t_elapsed = current_time - start_time
sleep_time = max(0, (self.control_dt - all_t_elapsed))
@@ -485,37 +424,6 @@ class UnitreeG1(Robot):
"""Return current state dq of the left and right arm motors."""
return np.array([self.lowstate_buffer.GetData().motor_state[id].dq for id in G1_29_JointArmIndex])
def ctrl_dual_arm_go_home(self):
"""Move both the left and right arms of the robot to their home position by setting the target joint angles (q) and torques (tau) to zero."""
logger.info("[G1_29_ArmController] ctrl_dual_arm_go_home start...")
max_attempts = 100
current_attempts = 0
with self.ctrl_lock:
self.q_target = np.zeros(14)
#self.q_target[G1_29_JointIndex.kLeftElbow] = 0.5
# self.tauff_target = np.zeros(14)
tolerance = 0.05 # Tolerance threshold for joint angles to determine "close to zero", can be adjusted based on your motor's precision requirements
while current_attempts < max_attempts:
current_q = self.get_current_dual_arm_q()
if np.all(np.abs(current_q) < tolerance):
if self.motion_mode:
for weight in np.linspace(1, 0, num=101):
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = weight
time.sleep(0.02)
logger.info("[G1_29_ArmController] both arms have reached the home position.")
break
current_attempts += 1
time.sleep(0.05)
def speed_gradual_max(self, t=5.0):
"""Parameter t is the total time required for arms velocity to gradually increase to its maximum value, in seconds. The default is 5.0."""
self._gradual_start_time = time.time()
self._gradual_time = t
self._speed_gradual_max = True
def speed_instant_max(self):
"""set arms velocity to the maximum value immediately, instead of gradually increasing."""
self.arm_velocity_limit = 30.0
def _Is_weak_motor(self, motor_index):
weak_motors = [
@@ -614,159 +522,6 @@ class UnitreeG1(Robot):
'motors': motors_data,
}
def audio_control(self, command, volume: int = 80):
"""
Unified audio/LED control function for the G1 robot.
Args:
command: Can be one of:
- str: Text to speak via TTS
- tuple[int, int, int]: RGB values (0-255) for LED control
- str (path): Path to WAV file to play
volume: Volume level 0-100 (default: 80)
Examples:
robot.audio_control("Hello world") # TTS
robot.audio_control((255, 0, 0)) # Red LED
robot.audio_control("audio.wav") # Play WAV file
"""
# Set volume
self.audio_client.SetVolume(volume)
# Detect command type and execute
if isinstance(command, tuple) and len(command) == 3:
# LED control - RGB tuple
r, g, b = command
logger.info(f"Setting LED to RGB({r}, {g}, {b})")
self.audio_client.LedControl(r, g, b)
elif isinstance(command, str):
# Check if it's a file path
if Path(command).exists():
# Play WAV file
logger.info(f"Playing audio file: {command}")
self._play_wav_file(command)
else:
# Text-to-speech
logger.info(f"Speaking: {command}")
self.audio_client.TtsMaker(command, 0) # 0 for English
else:
raise ValueError(
f"Invalid command type: {type(command)}. "
"Expected str (text/path) or tuple[int, int, int] (RGB)"
)
def _read_wav_file(self, filename: str):
"""Read WAV file and return PCM data as bytes."""
with open(filename, 'rb') as f:
def read(fmt):
return struct.unpack(fmt, f.read(struct.calcsize(fmt)))
# Read RIFF header
chunk_id, = read('<I')
if chunk_id != 0x46464952: # "RIFF"
raise ValueError("Not a valid WAV file (invalid RIFF header)")
_chunk_size, = read('<I')
format_tag, = read('<I')
if format_tag != 0x45564157: # "WAVE"
raise ValueError("Not a valid WAV file (invalid WAVE format)")
# Read fmt chunk
subchunk1_id, = read('<I')
subchunk1_size, = read('<I')
# Skip JUNK chunk if present
if subchunk1_id == 0x4B4E554A: # "JUNK"
f.seek(subchunk1_size, 1)
subchunk1_id, = read('<I')
subchunk1_size, = read('<I')
if subchunk1_id != 0x20746D66: # "fmt "
raise ValueError("Invalid fmt chunk")
if subchunk1_size not in [16, 18]:
raise ValueError(f"Unsupported fmt chunk size: {subchunk1_size}")
audio_format, = read('<H')
if audio_format != 1:
raise ValueError(f"Only PCM format supported, got format {audio_format}")
num_channels, = read('<H')
sample_rate, = read('<I')
_byte_rate, = read('<I')
_block_align, = read('<H')
bits_per_sample, = read('<H')
if bits_per_sample != 16:
raise ValueError(f"Only 16-bit samples supported, got {bits_per_sample}-bit")
if sample_rate != 16000:
raise ValueError(f"Sample rate must be 16000 Hz, got {sample_rate} Hz")
if num_channels != 1:
raise ValueError(f"Must be mono (1 channel), got {num_channels} channels")
if subchunk1_size == 18:
extra_size, = read('<H')
if extra_size != 0:
f.seek(extra_size, 1)
# Find data chunk
while True:
subchunk2_id, subchunk2_size = read('<II')
if subchunk2_id == 0x61746164: # "data"
break
f.seek(subchunk2_size, 1)
# Read PCM data
raw_pcm = f.read(subchunk2_size)
if len(raw_pcm) != subchunk2_size:
raise ValueError("Failed to read full PCM data")
return raw_pcm
def _play_wav_file(self, filename: str, chunk_size: int = 96000):
"""
Play a WAV file through the robot's speaker.
Args:
filename: Path to WAV file (must be 16kHz, mono, 16-bit PCM)
chunk_size: Bytes per chunk (default: 96000 = ~3 seconds at 16kHz)
"""
# Read WAV file
pcm_data = self._read_wav_file(filename)
stream_id = str(int(time.time() * 1000))
app_name = "lerobot"
offset = 0
chunk_index = 0
total_size = len(pcm_data)
logger.info(f"Playing audio: {total_size} bytes in {(total_size // chunk_size) + 1} chunks")
# Send audio in chunks
while offset < total_size:
remaining = total_size - offset
current_chunk_size = min(chunk_size, remaining)
chunk = pcm_data[offset:offset + current_chunk_size]
# Send chunk
ret_code, _ = self.audio_client.PlayStream(app_name, stream_id, list(chunk))
if ret_code != 0:
logger.error(f"Failed to send chunk {chunk_index}, return code: {ret_code}")
break
else:
logger.debug(f"Sent chunk {chunk_index}/{(total_size // chunk_size)}")
offset += current_chunk_size
chunk_index += 1
time.sleep(1.0) # Wait between chunks
# Calculate playback duration
duration_seconds = len(pcm_data) / (16000 * 2) # 16kHz, 16-bit (2 bytes)
logger.info(f"Audio playback will take ~{duration_seconds:.1f} seconds")
def get_observation(self) -> dict[str, Any]:
obs_array = self.get_current_dual_arm_q()
obs_dict = {f"{G1_29_JointArmIndex(motor).name}.pos": val for motor, val in zip(G1_29_JointArmIndex, obs_array, strict=True)}
@@ -991,78 +746,6 @@ class UnitreeG1(Robot):
self.ry = struct.unpack("f", data[12:16])[0]
self.ly = struct.unpack("f", data[20:24])[0]
class MotionLoader:
"""Load and interpolate motion from CSV file for motion imitation."""
def __init__(self, motion_file: str, fps: float = 60.0):
"""Load motion from CSV file.
CSV format: [root_pos(3), root_quat_xyzw(4), joint_dof(29)] per row
"""
self.dt = 1.0 / fps
# Load CSV
data = np.loadtxt(motion_file, delimiter=',')
self.num_frames = data.shape[0]
self.duration = self.num_frames * self.dt
# Split data
self.root_positions = data[:, 0:3] # (N, 3)
self.root_quaternions_xyzw = data[:, 3:7] # (N, 4) [x, y, z, w]
self.dof_positions = data[:, 7:] # (N, 29)
# Compute velocities (finite differences)
self.dof_velocities = np.diff(self.dof_positions, axis=0, prepend=self.dof_positions[0:1]) / self.dt
# Current playback state
self.current_time = 0.0
self.index_0 = 0
self.index_1 = 0
self.blend = 0.0
logger.info(f"MotionLoader: Loaded {self.num_frames} frames, duration={self.duration:.2f}s")
def update(self, time: float):
"""Update motion to specific time (loops at duration)."""
self.current_time = time % self.duration # Loop
phase = self.current_time / self.duration
self.index_0 = int(phase * (self.num_frames - 1))
self.index_1 = min(self.index_0 + 1, self.num_frames - 1)
self.blend = (self.current_time - self.index_0 * self.dt) / self.dt
def get_joint_pos(self) -> np.ndarray:
"""Get interpolated joint positions (29D)."""
return self.dof_positions[self.index_0] * (1 - self.blend) + \
self.dof_positions[self.index_1] * self.blend
def get_joint_vel(self) -> np.ndarray:
"""Get interpolated joint velocities (29D)."""
return self.dof_velocities[self.index_0] * (1 - self.blend) + \
self.dof_velocities[self.index_1] * self.blend
def get_root_quat_wxyz(self) -> np.ndarray:
"""Get interpolated root quaternion [w, x, y, z]."""
# Spherical linear interpolation (SLERP)
q0 = self.root_quaternions_xyzw[self.index_0] # [x, y, z, w]
q1 = self.root_quaternions_xyzw[self.index_1]
# Convert to scipy format [x, y, z, w]
r0 = R.from_quat(q0)
r1 = R.from_quat(q1)
# SLERP
key_times = [0, 1]
key_rots = R.from_quat([q0, q1])
slerp = R.from_quat(key_rots.as_quat()) # Simplified - just use linear for now
# Linear interpolation for simplicity
quat_xyzw = q0 * (1 - self.blend) + q1 * self.blend
# Normalize
quat_xyzw = quat_xyzw / np.linalg.norm(quat_xyzw)
# Convert to [w, x, y, z]
return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]], dtype=np.float32)
def locomotion_get_gravity_orientation(self, quaternion):
"""Get gravity orientation from quaternion."""
qw = quaternion[0]
@@ -1287,8 +970,23 @@ class UnitreeG1(Robot):
# Run policy inference (ONNX) with 516D stacked observation
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
ort_outs = self.policy.run(None, ort_inputs)
# Select appropriate policy based on command magnitude (dual-policy system)
if self.policy_balance is not None and self.policy_walk is not None:
# Dual-policy mode: switch between Balance and Walk
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
if cmd_magnitude < 0.05:
# Use balance/standing policy for small commands
selected_policy = self.policy_balance
else:
# Use walking policy for movement commands
selected_policy = self.policy_walk
else:
# Single policy mode (fallback)
selected_policy = self.policy
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
ort_outs = selected_policy.run(None, ort_inputs)
self.groot_action = ort_outs[0].squeeze()
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
@@ -1487,10 +1185,7 @@ class UnitreeG1(Robot):
return
logger.info("Starting locomotion test sequence...")
# 1. Home the arms first
logger.info("Homing arms to zero position...")
#self.ctrl_dual_arm_go_home()
# 2. Move legs to default position
self.locomotion_move_to_default_pos()
@@ -1532,354 +1227,6 @@ class UnitreeG1(Robot):
logger.info("GR00T locomotion initialization complete! Policy is now running.")
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
def motion_imitation_run(self):
"""Motion imitation policy loop - tracks reference motion (dance_102, etc)."""
self.motion_counter += 1
self.motion_elapsed_time = self.motion_counter * self.config.motion_control_dt
# Update motion loader to current time
self.motion_loader.update(self.motion_elapsed_time)
# Get current lowstate
lowstate = self.lowstate_buffer.GetData()
if lowstate is None:
return
# Get ALL 29 joint positions and velocities from robot
# IMPORTANT: Convert from motor order to BFS order to match reference motion
# The C++ code does: robot_bfs[i] = motor[joint_ids_map[i]]
for i in range(29):
motor_idx = self.config.motion_joint_ids_map[i]
self.motion_qj_all[i] = lowstate.motor_state[motor_idx].q
self.motion_dqj_all[i] = lowstate.motor_state[motor_idx].dq
# ======== 23 DOF MODE CONFIGURATION ========
# For real robot - zeros out joints not present in 23 DOF hardware
# Waist: yaw(12), pitch(14) | Wrist: L_pitch/yaw(20,21), R_pitch/yaw(27,28)
USE_23DOF = True # Set to True for real robot without these joints
JOINTS_TO_ZERO_23DOF = [12,14,20, 21, 27, 28]#12, 14, 20, 21, 27, 28]#
# Apply 23 DOF zeroing to robot observations if enabled
if USE_23DOF:
for joint_idx in JOINTS_TO_ZERO_23DOF:
self.motion_qj_all[joint_idx] = 0.0
self.motion_dqj_all[joint_idx] = 0.0
if self.motion_counter == 1:
logger.info("="*60)
logger.info("🤖 23 DOF MODE ENABLED")
logger.info(f" Zeroing joints: {JOINTS_TO_ZERO_23DOF}")
logger.info(" Waist: yaw(12), pitch(14)")
logger.info(" Wrist L: pitch(20), yaw(21) | Wrist R: pitch(27), yaw(28)")
logger.info(" Applied to: robot obs, reference motion, policy actions")
logger.info("="*60)
# Get IMU data
robot_quat = lowstate.imu_state.quaternion # [w, x, y, z]
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32) # 3D
if self.policy is None:
# DIRECT PLAYBACK MODE (no policy)
motion_joint_pos_dfs = self.motion_loader.get_joint_pos()
# Zero out missing joints for 23 DOF mode
if USE_23DOF:
# Convert to BFS to zero out, then convert back
motion_joint_pos_bfs_temp = np.zeros(29, dtype=np.float32)
for i in range(29):
motion_joint_pos_bfs_temp[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
for joint_idx in JOINTS_TO_ZERO_23DOF:
motion_joint_pos_bfs_temp[joint_idx] = 0.0
# Convert back to DFS for sending
for i in range(29):
motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]] = motion_joint_pos_bfs_temp[i]
for i in range(29):
motor_idx = self.config.motion_joint_ids_map[i]
csv_idx = self.config.motion_joint_ids_map[i]
self.msg.motor_cmd[motor_idx].q = motion_joint_pos_dfs[csv_idx]
self.msg.motor_cmd[motor_idx].qd = 0
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
self.msg.motor_cmd[motor_idx].tau = 0
else:
# POLICY MODE - Full observation construction and inference
# ======== DEBUG TEST MODES ========
# Mode 1: Direct playback (no policy) - set motion_policy_path = None in config instead
# Mode 2: Send default pos (stand still) - TEST_SEND_DEFAULT_POS = True
# Mode 3: Policy with zero reference - TEST_WITH_ZEROS = True, TEST_SEND_DEFAULT_POS = False
# Mode 4: Policy with real reference - TEST_WITH_ZEROS = False, TEST_SEND_DEFAULT_POS = False
TEST_WITH_ZEROS = False # True = use zero reference motion in observation
TEST_SEND_DEFAULT_POS = False # True = bypass policy and send default pos (stand still)
TEST_DIRECT_PLAYBACK = False # True = bypass policy and send reference motion directly
if TEST_DIRECT_PLAYBACK:
# DEBUG: Play back reference motion without policy
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D in DFS order
# Zero out missing joints for 23 DOF mode
if USE_23DOF:
# Convert to BFS to zero out, then convert back
motion_joint_pos_bfs_temp = np.zeros(29, dtype=np.float32)
for i in range(29):
motion_joint_pos_bfs_temp[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
for joint_idx in JOINTS_TO_ZERO_23DOF:
motion_joint_pos_bfs_temp[joint_idx] = 0.0
# Convert back to DFS for sending
for i in range(29):
motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]] = motion_joint_pos_bfs_temp[i]
# Send directly to motors using joint_ids_map (same as direct playback mode)
for i in range(29):
motor_idx = self.config.motion_joint_ids_map[i]
csv_idx = self.config.motion_joint_ids_map[i]
self.msg.motor_cmd[motor_idx].q = motion_joint_pos_dfs[csv_idx]
self.msg.motor_cmd[motor_idx].qd = 0
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
self.msg.motor_cmd[motor_idx].tau = 0
if self.motion_counter == 1:
logger.info("="*60)
logger.info("⚠️ DEBUG MODE: DIRECT PLAYBACK (reference motion, no policy)")
logger.info("="*60)
target_joint_pos_bfs = None # Not used in this mode
else:
# Run observation construction and policy
if TEST_WITH_ZEROS:
# Send zeros for reference motion
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
if self.motion_counter == 1:
logger.info("="*60)
logger.info("⚠️ DEBUG MODE: Using ZERO reference motion + RUNNING POLICY")
logger.info("="*60)
else:
# Get reference motion (DFS order from CSV)
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D
motion_joint_vel_dfs = self.motion_loader.get_joint_vel() # 29D
# Convert from DFS to BFS order: bfs[i] = dfs[joint_ids_map[i]]
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
for i in range(29):
motion_joint_pos_bfs[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
motion_joint_vel_bfs[i] = motion_joint_vel_dfs[self.config.motion_joint_ids_map[i]]
# Zero out missing joints in reference motion for 23 DOF mode
if USE_23DOF:
for joint_idx in JOINTS_TO_ZERO_23DOF:
motion_joint_pos_bfs[joint_idx] = 0.0
motion_joint_vel_bfs[joint_idx] = 0.0
# Compute motion_anchor_ori_b (6D rotation matrix representation)
motion_quat_wxyz = self.motion_loader.get_root_quat_wxyz()
robot_rot = R.from_quat([robot_quat[1], robot_quat[2], robot_quat[3], robot_quat[0]]).as_matrix()
motion_rot = R.from_quat([motion_quat_wxyz[1], motion_quat_wxyz[2], motion_quat_wxyz[3], motion_quat_wxyz[0]]).as_matrix()
relative_rot = robot_rot.T @ motion_rot
motion_anchor_ori_b = np.array([relative_rot[0, 0], relative_rot[0, 1],
relative_rot[1, 0], relative_rot[1, 1],
relative_rot[2, 0], relative_rot[2, 1]], dtype=np.float32)
# Compute joint positions and velocities relative to default
default_joint_pos = np.array(self.config.motion_default_joint_pos, dtype=np.float32)
joint_pos_rel = self.motion_qj_all - default_joint_pos
joint_vel_rel = self.motion_dqj_all.copy()
# Build 154D observation:
# motion_command (58D) = joint_pos (29D) + joint_vel (29D) from reference
# motion_anchor_ori_b (6D)
# base_ang_vel (3D)
# joint_pos_rel (29D)
# joint_vel_rel (29D)
# last_action (29D)
self.motion_obs[0:29] = motion_joint_pos_bfs
self.motion_obs[29:58] = motion_joint_vel_bfs
self.motion_obs[58:64] = motion_anchor_ori_b
self.motion_obs[64:67] = ang_vel
self.motion_obs[67:96] = joint_pos_rel
self.motion_obs[96:125] = joint_vel_rel
self.motion_obs[125:154] = self.motion_action
if TEST_SEND_DEFAULT_POS:
# DEBUG: Just send default positions (should make robot stand still)
target_joint_pos_bfs = default_joint_pos.copy()
if self.motion_counter == 1:
logger.info("="*60)
logger.info("⚠️ DEBUG MODE: Sending DEFAULT positions (NO POLICY)")
logger.info("="*60)
logger.info(f" Default pos BFS[0:5]: {target_joint_pos_bfs[0:5]}")
if self.motion_counter % 50 == 0:
logger.info(f" [DEFAULT MODE] Sending: [{target_joint_pos_bfs[0]:.4f}, {target_joint_pos_bfs[6]:.4f}, {target_joint_pos_bfs[12]:.4f}]")
logger.info(f" [DEFAULT MODE] Robot at: [{self.motion_qj_all[0]:.4f}, {self.motion_qj_all[6]:.4f}, {self.motion_qj_all[12]:.4f}]")
else:
# Run ONNX policy inference
obs_tensor = torch.from_numpy(self.motion_obs).unsqueeze(0)
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
ort_outs = self.policy.run(None, ort_inputs)
self.motion_action = ort_outs[0].squeeze() # 29D action in BFS order
# Zero out missing joints in policy actions for 23 DOF mode
if USE_23DOF:
for joint_idx in JOINTS_TO_ZERO_23DOF:
self.motion_action[joint_idx] = 0.0
# Process actions: scale and add offset
action_scale = np.array(self.config.motion_action_scale, dtype=np.float32)
target_joint_pos_bfs = default_joint_pos + self.motion_action * action_scale
# Send commands to motors: motor[joint_ids_map[i]] = action[i]
for i in range(29):
motor_idx = self.config.motion_joint_ids_map[i]
self.msg.motor_cmd[motor_idx].q = target_joint_pos_bfs[i]
self.msg.motor_cmd[motor_idx].qd = 0
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
self.msg.motor_cmd[motor_idx].tau = 0
# Debug print (only when running policy, not in TEST_SEND_DEFAULT_POS or TEST_DIRECT_PLAYBACK mode)
if self.motion_counter == 1 and self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
logger.info("="*60)
logger.info("POLICY MODE OBSERVATION CHECK (First iteration)")
logger.info("="*60)
logger.info(f"Reference motion (BFS) samples: [{motion_joint_pos_bfs[0]:.3f}, {motion_joint_pos_bfs[6]:.3f}, {motion_joint_pos_bfs[12]:.3f}]")
logger.info(f"Robot joints (BFS) samples: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
logger.info(f"Default positions samples: [{default_joint_pos[0]:.3f}, {default_joint_pos[6]:.3f}, {default_joint_pos[12]:.3f}]")
logger.info(f"Joint pos rel samples: [{joint_pos_rel[0]:.3f}, {joint_pos_rel[6]:.3f}, {joint_pos_rel[12]:.3f}]")
logger.info(f"Joint vel rel samples: [{joint_vel_rel[0]:.3f}, {joint_vel_rel[6]:.3f}, {joint_vel_rel[12]:.3f}]")
logger.info(f"Angular velocity: [{ang_vel[0]:.3f}, {ang_vel[1]:.3f}, {ang_vel[2]:.3f}]")
logger.info(f"Motion anchor ori: [{motion_anchor_ori_b[0]:.3f}, ..., {motion_anchor_ori_b[5]:.3f}]")
logger.info(f"Observation breakdown:")
logger.info(f" [0:29] motion_cmd_pos: range [{self.motion_obs[0:29].min():.3f}, {self.motion_obs[0:29].max():.3f}]")
logger.info(f" [29:58] motion_cmd_vel: range [{self.motion_obs[29:58].min():.3f}, {self.motion_obs[29:58].max():.3f}]")
logger.info(f" [58:64] anchor_ori: range [{self.motion_obs[58:64].min():.3f}, {self.motion_obs[58:64].max():.3f}]")
logger.info(f" [64:67] ang_vel: range [{self.motion_obs[64:67].min():.3f}, {self.motion_obs[64:67].max():.3f}]")
logger.info(f" [67:96] joint_pos_rel: range [{self.motion_obs[67:96].min():.3f}, {self.motion_obs[67:96].max():.3f}]")
logger.info(f" [96:125] joint_vel_rel: range [{self.motion_obs[96:125].min():.3f}, {self.motion_obs[96:125].max():.3f}]")
logger.info(f" [125:154] last_action: range [{self.motion_obs[125:154].min():.3f}, {self.motion_obs[125:154].max():.3f}]")
logger.info(f"Full obs range: [{self.motion_obs.min():.3f}, {self.motion_obs.max():.3f}]")
logger.info(f"Action output (first): [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
logger.info(f"Action scale samples: [{action_scale[0]:.3f}, {action_scale[6]:.3f}, {action_scale[12]:.3f}]")
logger.info(f"Target positions samples: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
logger.info("="*60)
if self.motion_counter % 50 == 0:
if self.policy is None:
mode = "DIRECT"
elif TEST_DIRECT_PLAYBACK:
mode = "DIRECT_DEBUG"
elif TEST_SEND_DEFAULT_POS:
mode = "DEFAULT_POS"
elif TEST_WITH_ZEROS:
mode = "POLICY_ZEROS"
else:
mode = "POLICY"
logger.info(f"Motion {mode}: t={self.motion_elapsed_time:.2f}s, frame={self.motion_loader.index_0}/{self.motion_loader.num_frames}")
if self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
logger.info(f" Policy action range: [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
logger.info(f" Sample actions[0,6,12]: [{self.motion_action[0]:.3f}, {self.motion_action[6]:.3f}, {self.motion_action[12]:.3f}]")
logger.info(f" Target pos (after scale)[0,6,12]: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
logger.info(f" Robot pos (BFS)[0,6,12]: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
# Send command
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
def _motion_imitation_thread_loop(self):
"""Background thread that runs the motion imitation policy at specified rate."""
logger.info("Motion imitation thread started")
while self.motion_imitation_running:
start_time = time.time()
try:
self.motion_imitation_run()
except Exception as e:
logger.error(f"Error in motion imitation loop: {e}")
import traceback
traceback.print_exc()
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, self.config.motion_control_dt - elapsed)
time.sleep(sleep_time)
logger.info("Motion imitation thread stopped")
def start_motion_imitation_thread(self):
"""Start the background motion imitation control thread."""
if not self.config.motion_imitation_control:
logger.warning("motion_imitation_control is False, cannot start thread")
return
if self.motion_imitation_running:
logger.warning("Motion imitation thread already running")
return
logger.info("Starting motion imitation control thread...")
self.motion_imitation_running = True
self.motion_imitation_thread = threading.Thread(target=self._motion_imitation_thread_loop, daemon=True)
self.motion_imitation_thread.start()
logger.info("Motion imitation control thread started!")
def stop_motion_imitation_thread(self):
"""Stop the background motion imitation control thread."""
if not self.motion_imitation_running:
return
logger.info("Stopping motion imitation control thread...")
self.motion_imitation_running = False
if self.motion_imitation_thread:
self.motion_imitation_thread.join(timeout=2.0)
logger.info("Motion imitation control thread stopped")
def init_motion_imitation(self):
"""Initialize motion imitation - move to default standing pose and start policy."""
if not self.config.motion_imitation_control:
logger.warning("motion_imitation_control is False, cannot run initialization")
return
logger.info("Starting motion imitation initialization...")
# Move to default standing position
logger.info("Moving to default standing position...")
total_time = 3.0
num_steps = int(total_time / self.config.motion_control_dt)
# Get current positions (in motor order)
current_q_motor = self.get_current_motor_q()
# target_q is in BFS order from config, need to convert to motor order
target_q_bfs = np.array(self.config.motion_default_joint_pos, dtype=np.float32)
target_q_motor = np.zeros(29, dtype=np.float32)
for i in range(29):
motor_idx = self.config.motion_joint_ids_map[i]
target_q_motor[motor_idx] = target_q_bfs[i]
# Interpolate to target (both in motor order now)
for i in range(num_steps):
alpha = i / num_steps
for motor_idx in range(29):
self.msg.motor_cmd[motor_idx].q = current_q_motor[motor_idx] * (1 - alpha) + target_q_motor[motor_idx] * alpha
self.msg.motor_cmd[motor_idx].qd = 0
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
self.msg.motor_cmd[motor_idx].tau = 0
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
time.sleep(self.config.motion_control_dt)
logger.info("Reached default position")
# Wait 2 seconds
time.sleep(2.0)
# Start motion imitation policy thread
logger.info("Starting motion imitation policy control...")
self.start_motion_imitation_thread()
logger.info("Motion imitation initialization complete! Policy is now running.")
logger.info(f"154D observations, 29D actions. Motion duration: {self.motion_loader.duration:.2f}s")
class G1_29_JointArmIndex(IntEnum):
# Left arm