mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
simplified robot class
This commit is contained in:
@@ -6,15 +6,9 @@ from typing import Any
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.motors import Motor, MotorNormMode
|
|
||||||
from lerobot.motors.calibration_gui import RangeFinderGUI
|
|
||||||
from lerobot.motors.feetech import (
|
|
||||||
FeetechMotorsBus,
|
|
||||||
)
|
|
||||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
|
||||||
import json
|
import json
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
from ..utils import ensure_safe_goal_position
|
|
||||||
from .config_unitree_g1 import UnitreeG1Config
|
from .config_unitree_g1 import UnitreeG1Config
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -36,7 +30,6 @@ import onnxruntime as ort
|
|||||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
|
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
|
||||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||||
from unitree_sdk2py.utils.crc import CRC
|
from unitree_sdk2py.utils.crc import CRC
|
||||||
from unitree_sdk2py.g1.audio.g1_audio_client import AudioClient
|
|
||||||
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import (
|
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import (
|
||||||
MotionSwitcherClient,
|
MotionSwitcherClient,
|
||||||
)
|
)
|
||||||
@@ -45,11 +38,6 @@ from lerobot.envs.factory import make_env
|
|||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -57,7 +45,6 @@ import torch
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||||
kTopicLowCommand_Motion = "rt/arm_sdk"
|
|
||||||
kTopicLowState = "rt/lowstate"
|
kTopicLowState = "rt/lowstate"
|
||||||
|
|
||||||
G1_29_Num_Motors = 35
|
G1_29_Num_Motors = 35
|
||||||
@@ -102,8 +89,6 @@ class DataBuffer:
|
|||||||
with self.lock:
|
with self.lock:
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
#eventually observations should be everything: motor torques etc etc
|
|
||||||
#motor class for unitree?
|
|
||||||
class UnitreeG1(Robot):
|
class UnitreeG1(Robot):
|
||||||
|
|
||||||
config_class = UnitreeG1Config
|
config_class = UnitreeG1Config
|
||||||
@@ -118,7 +103,6 @@ class UnitreeG1(Robot):
|
|||||||
self.cameras = make_cameras_from_configs(config.cameras)
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
self.q_target = np.zeros(14)
|
self.q_target = np.zeros(14)
|
||||||
self.tauff_target = np.zeros(14)
|
self.tauff_target = np.zeros(14)
|
||||||
self.motion_mode = config.motion_mode
|
|
||||||
self.simulation_mode = config.simulation_mode
|
self.simulation_mode = config.simulation_mode
|
||||||
self.kp_high = config.kp_high
|
self.kp_high = config.kp_high
|
||||||
self.kd_high = config.kd_high
|
self.kd_high = config.kd_high
|
||||||
@@ -131,7 +115,6 @@ class UnitreeG1(Robot):
|
|||||||
self.arm_velocity_limit = config.arm_velocity_limit
|
self.arm_velocity_limit = config.arm_velocity_limit
|
||||||
self.control_dt = config.control_dt
|
self.control_dt = config.control_dt
|
||||||
|
|
||||||
self._speed_gradual_max = config.speed_gradual_max
|
|
||||||
self._gradual_start_time = config.gradual_start_time
|
self._gradual_start_time = config.gradual_start_time
|
||||||
self._gradual_time = config.gradual_time
|
self._gradual_time = config.gradual_time
|
||||||
|
|
||||||
@@ -143,7 +126,6 @@ class UnitreeG1(Robot):
|
|||||||
self.freeze_body = config.freeze_body
|
self.freeze_body = config.freeze_body
|
||||||
self.gravity_compensation = config.gravity_compensation
|
self.gravity_compensation = config.gravity_compensation
|
||||||
|
|
||||||
|
|
||||||
self.calibrated = False
|
self.calibrated = False
|
||||||
|
|
||||||
self.calibrate()
|
self.calibrate()
|
||||||
@@ -155,39 +137,33 @@ class UnitreeG1(Robot):
|
|||||||
else:
|
else:
|
||||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||||
|
|
||||||
|
if not self.config.simulation_mode:
|
||||||
|
self.msc = MotionSwitcherClient()
|
||||||
|
self.msc.SetTimeout(5.0)
|
||||||
|
self.msc.Init()
|
||||||
|
|
||||||
|
status, result = self.msc.CheckMode()
|
||||||
|
print(status, result)
|
||||||
|
#check if result name first
|
||||||
|
if result is not None and "name" in result:
|
||||||
|
while result["name"]:
|
||||||
|
self.msc.ReleaseMode()
|
||||||
|
status, result = self.msc.CheckMode()
|
||||||
|
print(status, result)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
# initialize lowcmd nd lowstate subscriber
|
# initialize lowcmd nd lowstate subscriber
|
||||||
if self.simulation_mode:
|
if self.simulation_mode:
|
||||||
ChannelFactoryInitialize(0, "lo")
|
ChannelFactoryInitialize(0, "lo")
|
||||||
|
|
||||||
# Launch MuJoCo simulation environment
|
|
||||||
logger.info("Launching MuJoCo simulation environment...")
|
logger.info("Launching MuJoCo simulation environment...")
|
||||||
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
||||||
logger.info("MuJoCo environment launched successfully!")
|
logger.info("MuJoCo environment launched successfully!")
|
||||||
else:
|
else:
|
||||||
ChannelFactoryInitialize(0)
|
ChannelFactoryInitialize(0)
|
||||||
|
|
||||||
|
# Always use debug mode (direct motor control)
|
||||||
|
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||||
if not self.config.simulation_mode:
|
|
||||||
pass
|
|
||||||
# self.msc = MotionSwitcherClient()
|
|
||||||
# self.msc.SetTimeout(5.0)
|
|
||||||
# self.msc.Init()
|
|
||||||
|
|
||||||
# status, result = self.msc.CheckMode()
|
|
||||||
# print(status, result)
|
|
||||||
# #check if result name first
|
|
||||||
# if result is not None and "name" in result:
|
|
||||||
# while result["name"]:
|
|
||||||
# self.msc.ReleaseMode()
|
|
||||||
# status, result = self.msc.CheckMode()
|
|
||||||
# print(status, result)
|
|
||||||
# time.sleep(1)
|
|
||||||
|
|
||||||
if self.motion_mode:
|
|
||||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Motion, hg_LowCmd)
|
|
||||||
else:
|
|
||||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
|
||||||
self.lowcmd_publisher.Init()
|
self.lowcmd_publisher.Init()
|
||||||
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||||
self.lowstate_subscriber.Init()
|
self.lowstate_subscriber.Init()
|
||||||
@@ -203,9 +179,6 @@ class UnitreeG1(Robot):
|
|||||||
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
|
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
|
||||||
logger.info("[UnitreeG1] Subscribe dds ok.")
|
logger.info("[UnitreeG1] Subscribe dds ok.")
|
||||||
|
|
||||||
# initialize audio client for LED, TTS, and audio playback
|
|
||||||
|
|
||||||
|
|
||||||
# initialize hg's lowcmd msg
|
# initialize hg's lowcmd msg
|
||||||
self.crc = CRC()
|
self.crc = CRC()
|
||||||
self.msg = unitree_hg_msg_dds__LowCmd_()
|
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||||
@@ -238,15 +211,6 @@ class UnitreeG1(Robot):
|
|||||||
self.msg.motor_cmd[id].q = self.all_motor_q[id]
|
self.msg.motor_cmd[id].q = self.all_motor_q[id]
|
||||||
#print current motor q, kp, kd
|
#print current motor q, kp, kd
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if config.audio_client:
|
|
||||||
pass
|
|
||||||
# self.audio_client = AudioClient()
|
|
||||||
# self.audio_client.SetTimeout(10.0)
|
|
||||||
# self.audio_client.Init()
|
|
||||||
# logger.info("[UnitreeG1] Audio client initialized!")
|
|
||||||
|
|
||||||
logger.info("Lock OK!\n") #motors are not locked x
|
logger.info("Lock OK!\n") #motors are not locked x
|
||||||
# for i in range(10000):
|
# for i in range(10000):
|
||||||
# print(self.get_current_motor_q())
|
# print(self.get_current_motor_q())
|
||||||
@@ -257,62 +221,24 @@ class UnitreeG1(Robot):
|
|||||||
self.keyboard_running = False
|
self.keyboard_running = False
|
||||||
self.locomotion_thread = None
|
self.locomotion_thread = None
|
||||||
self.locomotion_running = False
|
self.locomotion_running = False
|
||||||
self.motion_imitation_thread = None
|
|
||||||
self.motion_imitation_running = False
|
|
||||||
|
|
||||||
# Initialize publish thread for arm control
|
# Initialize publish thread for arm control
|
||||||
# Note: This thread runs alongside locomotion/motion_imitation threads
|
# Note: This thread runs alongside locomotion thread
|
||||||
# - Arm thread: controls arms (indices 15-28)
|
# - Arm thread: controls arms (indices 15-28)
|
||||||
# - Locomotion thread: controls legs (0-11), waist (12-14)
|
# - Locomotion thread: controls legs (0-11), waist (12-14)
|
||||||
# Both update different parts of self.msg, both call Write()
|
# Both update different parts of self.msg, both call Write()
|
||||||
self.publish_thread = None
|
self.publish_thread = None
|
||||||
self.ctrl_lock = threading.Lock()
|
self.ctrl_lock = threading.Lock()
|
||||||
if not config.motion_imitation_control: # Allow with locomotion, disable only for motion imitation
|
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
||||||
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
self.publish_thread.daemon = True
|
||||||
self.publish_thread.daemon = True
|
self.publish_thread.start()
|
||||||
self.publish_thread.start()
|
logger.info("Arm control publish thread started")
|
||||||
logger.info("Arm control publish thread started")
|
|
||||||
|
|
||||||
# Load locomotion policy if enabled
|
# Load locomotion policy if enabled
|
||||||
self.policy = None
|
self.policy = None
|
||||||
self.policy_type = None # 'torchscript', 'onnx', or 'motion_imitation'
|
self.policy_type = None # 'torchscript' or 'onnx'
|
||||||
self.motion_loader = None
|
|
||||||
|
|
||||||
if config.motion_imitation_control:
|
if config.locomotion_control:
|
||||||
# Motion imitation mode (dance, etc.)
|
|
||||||
if config.motion_file_path is None:
|
|
||||||
raise ValueError("motion_imitation_control is True but motion_file_path is not set")
|
|
||||||
|
|
||||||
logger.info(f"Loading motion reference from {config.motion_file_path}")
|
|
||||||
|
|
||||||
# Load motion file
|
|
||||||
self.motion_loader = self.MotionLoader(config.motion_file_path, config.motion_fps)
|
|
||||||
|
|
||||||
# Load ONNX policy (optional for now - can run in direct playback mode)
|
|
||||||
if config.motion_policy_path and Path(config.motion_policy_path).exists():
|
|
||||||
logger.info(f"Loading motion imitation policy from {config.motion_policy_path}")
|
|
||||||
self.policy = ort.InferenceSession(config.motion_policy_path)
|
|
||||||
self.policy_type = 'motion_imitation'
|
|
||||||
logger.info("Motion imitation ONNX policy loaded successfully")
|
|
||||||
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
|
||||||
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
|
||||||
else:
|
|
||||||
logger.info("Running in DIRECT PLAYBACK mode (no policy - just reference motion)")
|
|
||||||
self.policy = None
|
|
||||||
self.policy_type = 'motion_playback'
|
|
||||||
|
|
||||||
# Initialize motion imitation variables
|
|
||||||
self.motion_counter = 0
|
|
||||||
self.motion_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints from robot
|
|
||||||
self.motion_dqj_all = np.zeros(29, dtype=np.float32)
|
|
||||||
self.motion_action = np.zeros(29, dtype=np.float32) # 29D action output
|
|
||||||
self.motion_obs = np.zeros(154, dtype=np.float32) # 154D observation
|
|
||||||
self.motion_elapsed_time = 0.0
|
|
||||||
|
|
||||||
# Initialize motion and start
|
|
||||||
self.init_motion_imitation()
|
|
||||||
|
|
||||||
elif config.locomotion_control:
|
|
||||||
if config.policy_path is None:
|
if config.policy_path is None:
|
||||||
raise ValueError("locomotion_control is True but policy_path is not set")
|
raise ValueError("locomotion_control is True but policy_path is not set")
|
||||||
|
|
||||||
@@ -326,11 +252,33 @@ class UnitreeG1(Robot):
|
|||||||
logger.info("TorchScript policy loaded successfully")
|
logger.info("TorchScript policy loaded successfully")
|
||||||
elif config.policy_path.endswith('.onnx'):
|
elif config.policy_path.endswith('.onnx'):
|
||||||
logger.info("Detected ONNX (.onnx) policy")
|
logger.info("Detected ONNX (.onnx) policy")
|
||||||
self.policy = ort.InferenceSession(config.policy_path)
|
|
||||||
|
# For GR00T-style policies, load both Balance and Walk policies
|
||||||
|
# Balance policy for standing (low velocity commands)
|
||||||
|
# Walk policy for locomotion (high velocity commands)
|
||||||
|
balance_policy_path = config.policy_path.replace('Walk.onnx', 'Balance.onnx')
|
||||||
|
walk_policy_path = config.policy_path
|
||||||
|
|
||||||
|
if Path(balance_policy_path).exists() and Path(walk_policy_path).exists():
|
||||||
|
logger.info("Loading dual-policy system (Balance + Walk)")
|
||||||
|
self.policy_balance = ort.InferenceSession(balance_policy_path)
|
||||||
|
self.policy_walk = ort.InferenceSession(walk_policy_path)
|
||||||
|
self.policy = None # Not used when dual policies are loaded
|
||||||
|
logger.info(f"Balance policy loaded from: {balance_policy_path}")
|
||||||
|
logger.info(f"Walk policy loaded from: {walk_policy_path}")
|
||||||
|
logger.info(f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}")
|
||||||
|
logger.info(f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}")
|
||||||
|
else:
|
||||||
|
# Fallback to single policy
|
||||||
|
logger.info("Loading single ONNX policy")
|
||||||
|
self.policy = ort.InferenceSession(config.policy_path)
|
||||||
|
self.policy_balance = None
|
||||||
|
self.policy_walk = None
|
||||||
|
logger.info("ONNX policy loaded successfully")
|
||||||
|
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||||
|
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||||
|
|
||||||
self.policy_type = 'onnx'
|
self.policy_type = 'onnx'
|
||||||
logger.info("ONNX policy loaded successfully")
|
|
||||||
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
|
||||||
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported.")
|
raise ValueError(f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported.")
|
||||||
|
|
||||||
@@ -415,12 +363,7 @@ class UnitreeG1(Robot):
|
|||||||
return cliped_arm_q_target
|
return cliped_arm_q_target
|
||||||
|
|
||||||
def _ctrl_motor_state(self):
|
def _ctrl_motor_state(self):
|
||||||
"""Arm control thread - publishes commands for arms only.
|
"""Arm control thread - publishes commands for arms only."""
|
||||||
NOTE: This thread is NOT started when motion_imitation_control or locomotion_control is True.
|
|
||||||
Those modes handle their own publishing."""
|
|
||||||
if self.motion_mode:
|
|
||||||
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = 1.0
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -452,10 +395,6 @@ class UnitreeG1(Robot):
|
|||||||
self.msg.crc = self.crc.Crc(self.msg)
|
self.msg.crc = self.crc.Crc(self.msg)
|
||||||
self.lowcmd_publisher.Write(self.msg)
|
self.lowcmd_publisher.Write(self.msg)
|
||||||
|
|
||||||
if self._speed_gradual_max is True:
|
|
||||||
t_elapsed = start_time - self._gradual_start_time
|
|
||||||
self.arm_velocity_limit = 20.0 + (10.0 * min(1.0, t_elapsed / 5.0))
|
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
all_t_elapsed = current_time - start_time
|
all_t_elapsed = current_time - start_time
|
||||||
sleep_time = max(0, (self.control_dt - all_t_elapsed))
|
sleep_time = max(0, (self.control_dt - all_t_elapsed))
|
||||||
@@ -485,37 +424,6 @@ class UnitreeG1(Robot):
|
|||||||
"""Return current state dq of the left and right arm motors."""
|
"""Return current state dq of the left and right arm motors."""
|
||||||
return np.array([self.lowstate_buffer.GetData().motor_state[id].dq for id in G1_29_JointArmIndex])
|
return np.array([self.lowstate_buffer.GetData().motor_state[id].dq for id in G1_29_JointArmIndex])
|
||||||
|
|
||||||
def ctrl_dual_arm_go_home(self):
|
|
||||||
"""Move both the left and right arms of the robot to their home position by setting the target joint angles (q) and torques (tau) to zero."""
|
|
||||||
logger.info("[G1_29_ArmController] ctrl_dual_arm_go_home start...")
|
|
||||||
max_attempts = 100
|
|
||||||
current_attempts = 0
|
|
||||||
with self.ctrl_lock:
|
|
||||||
self.q_target = np.zeros(14)
|
|
||||||
#self.q_target[G1_29_JointIndex.kLeftElbow] = 0.5
|
|
||||||
# self.tauff_target = np.zeros(14)
|
|
||||||
tolerance = 0.05 # Tolerance threshold for joint angles to determine "close to zero", can be adjusted based on your motor's precision requirements
|
|
||||||
while current_attempts < max_attempts:
|
|
||||||
current_q = self.get_current_dual_arm_q()
|
|
||||||
if np.all(np.abs(current_q) < tolerance):
|
|
||||||
if self.motion_mode:
|
|
||||||
for weight in np.linspace(1, 0, num=101):
|
|
||||||
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = weight
|
|
||||||
time.sleep(0.02)
|
|
||||||
logger.info("[G1_29_ArmController] both arms have reached the home position.")
|
|
||||||
break
|
|
||||||
current_attempts += 1
|
|
||||||
time.sleep(0.05)
|
|
||||||
|
|
||||||
def speed_gradual_max(self, t=5.0):
|
|
||||||
"""Parameter t is the total time required for arms velocity to gradually increase to its maximum value, in seconds. The default is 5.0."""
|
|
||||||
self._gradual_start_time = time.time()
|
|
||||||
self._gradual_time = t
|
|
||||||
self._speed_gradual_max = True
|
|
||||||
|
|
||||||
def speed_instant_max(self):
|
|
||||||
"""set arms velocity to the maximum value immediately, instead of gradually increasing."""
|
|
||||||
self.arm_velocity_limit = 30.0
|
|
||||||
|
|
||||||
def _Is_weak_motor(self, motor_index):
|
def _Is_weak_motor(self, motor_index):
|
||||||
weak_motors = [
|
weak_motors = [
|
||||||
@@ -614,159 +522,6 @@ class UnitreeG1(Robot):
|
|||||||
'motors': motors_data,
|
'motors': motors_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
def audio_control(self, command, volume: int = 80):
|
|
||||||
"""
|
|
||||||
Unified audio/LED control function for the G1 robot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Can be one of:
|
|
||||||
- str: Text to speak via TTS
|
|
||||||
- tuple[int, int, int]: RGB values (0-255) for LED control
|
|
||||||
- str (path): Path to WAV file to play
|
|
||||||
volume: Volume level 0-100 (default: 80)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
robot.audio_control("Hello world") # TTS
|
|
||||||
robot.audio_control((255, 0, 0)) # Red LED
|
|
||||||
robot.audio_control("audio.wav") # Play WAV file
|
|
||||||
"""
|
|
||||||
# Set volume
|
|
||||||
self.audio_client.SetVolume(volume)
|
|
||||||
|
|
||||||
# Detect command type and execute
|
|
||||||
if isinstance(command, tuple) and len(command) == 3:
|
|
||||||
# LED control - RGB tuple
|
|
||||||
r, g, b = command
|
|
||||||
logger.info(f"Setting LED to RGB({r}, {g}, {b})")
|
|
||||||
self.audio_client.LedControl(r, g, b)
|
|
||||||
|
|
||||||
elif isinstance(command, str):
|
|
||||||
# Check if it's a file path
|
|
||||||
if Path(command).exists():
|
|
||||||
# Play WAV file
|
|
||||||
logger.info(f"Playing audio file: {command}")
|
|
||||||
self._play_wav_file(command)
|
|
||||||
else:
|
|
||||||
# Text-to-speech
|
|
||||||
logger.info(f"Speaking: {command}")
|
|
||||||
self.audio_client.TtsMaker(command, 0) # 0 for English
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid command type: {type(command)}. "
|
|
||||||
"Expected str (text/path) or tuple[int, int, int] (RGB)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _read_wav_file(self, filename: str):
|
|
||||||
"""Read WAV file and return PCM data as bytes."""
|
|
||||||
with open(filename, 'rb') as f:
|
|
||||||
def read(fmt):
|
|
||||||
return struct.unpack(fmt, f.read(struct.calcsize(fmt)))
|
|
||||||
|
|
||||||
# Read RIFF header
|
|
||||||
chunk_id, = read('<I')
|
|
||||||
if chunk_id != 0x46464952: # "RIFF"
|
|
||||||
raise ValueError("Not a valid WAV file (invalid RIFF header)")
|
|
||||||
|
|
||||||
_chunk_size, = read('<I')
|
|
||||||
format_tag, = read('<I')
|
|
||||||
if format_tag != 0x45564157: # "WAVE"
|
|
||||||
raise ValueError("Not a valid WAV file (invalid WAVE format)")
|
|
||||||
|
|
||||||
# Read fmt chunk
|
|
||||||
subchunk1_id, = read('<I')
|
|
||||||
subchunk1_size, = read('<I')
|
|
||||||
|
|
||||||
# Skip JUNK chunk if present
|
|
||||||
if subchunk1_id == 0x4B4E554A: # "JUNK"
|
|
||||||
f.seek(subchunk1_size, 1)
|
|
||||||
subchunk1_id, = read('<I')
|
|
||||||
subchunk1_size, = read('<I')
|
|
||||||
|
|
||||||
if subchunk1_id != 0x20746D66: # "fmt "
|
|
||||||
raise ValueError("Invalid fmt chunk")
|
|
||||||
|
|
||||||
if subchunk1_size not in [16, 18]:
|
|
||||||
raise ValueError(f"Unsupported fmt chunk size: {subchunk1_size}")
|
|
||||||
|
|
||||||
audio_format, = read('<H')
|
|
||||||
if audio_format != 1:
|
|
||||||
raise ValueError(f"Only PCM format supported, got format {audio_format}")
|
|
||||||
|
|
||||||
num_channels, = read('<H')
|
|
||||||
sample_rate, = read('<I')
|
|
||||||
_byte_rate, = read('<I')
|
|
||||||
_block_align, = read('<H')
|
|
||||||
bits_per_sample, = read('<H')
|
|
||||||
|
|
||||||
if bits_per_sample != 16:
|
|
||||||
raise ValueError(f"Only 16-bit samples supported, got {bits_per_sample}-bit")
|
|
||||||
|
|
||||||
if sample_rate != 16000:
|
|
||||||
raise ValueError(f"Sample rate must be 16000 Hz, got {sample_rate} Hz")
|
|
||||||
|
|
||||||
if num_channels != 1:
|
|
||||||
raise ValueError(f"Must be mono (1 channel), got {num_channels} channels")
|
|
||||||
|
|
||||||
if subchunk1_size == 18:
|
|
||||||
extra_size, = read('<H')
|
|
||||||
if extra_size != 0:
|
|
||||||
f.seek(extra_size, 1)
|
|
||||||
|
|
||||||
# Find data chunk
|
|
||||||
while True:
|
|
||||||
subchunk2_id, subchunk2_size = read('<II')
|
|
||||||
if subchunk2_id == 0x61746164: # "data"
|
|
||||||
break
|
|
||||||
f.seek(subchunk2_size, 1)
|
|
||||||
|
|
||||||
# Read PCM data
|
|
||||||
raw_pcm = f.read(subchunk2_size)
|
|
||||||
if len(raw_pcm) != subchunk2_size:
|
|
||||||
raise ValueError("Failed to read full PCM data")
|
|
||||||
|
|
||||||
return raw_pcm
|
|
||||||
|
|
||||||
def _play_wav_file(self, filename: str, chunk_size: int = 96000):
|
|
||||||
"""
|
|
||||||
Play a WAV file through the robot's speaker.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: Path to WAV file (must be 16kHz, mono, 16-bit PCM)
|
|
||||||
chunk_size: Bytes per chunk (default: 96000 = ~3 seconds at 16kHz)
|
|
||||||
"""
|
|
||||||
# Read WAV file
|
|
||||||
pcm_data = self._read_wav_file(filename)
|
|
||||||
|
|
||||||
stream_id = str(int(time.time() * 1000))
|
|
||||||
app_name = "lerobot"
|
|
||||||
offset = 0
|
|
||||||
chunk_index = 0
|
|
||||||
total_size = len(pcm_data)
|
|
||||||
|
|
||||||
logger.info(f"Playing audio: {total_size} bytes in {(total_size // chunk_size) + 1} chunks")
|
|
||||||
|
|
||||||
# Send audio in chunks
|
|
||||||
while offset < total_size:
|
|
||||||
remaining = total_size - offset
|
|
||||||
current_chunk_size = min(chunk_size, remaining)
|
|
||||||
chunk = pcm_data[offset:offset + current_chunk_size]
|
|
||||||
|
|
||||||
# Send chunk
|
|
||||||
ret_code, _ = self.audio_client.PlayStream(app_name, stream_id, list(chunk))
|
|
||||||
if ret_code != 0:
|
|
||||||
logger.error(f"Failed to send chunk {chunk_index}, return code: {ret_code}")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.debug(f"Sent chunk {chunk_index}/{(total_size // chunk_size)}")
|
|
||||||
|
|
||||||
offset += current_chunk_size
|
|
||||||
chunk_index += 1
|
|
||||||
time.sleep(1.0) # Wait between chunks
|
|
||||||
|
|
||||||
# Calculate playback duration
|
|
||||||
duration_seconds = len(pcm_data) / (16000 * 2) # 16kHz, 16-bit (2 bytes)
|
|
||||||
logger.info(f"Audio playback will take ~{duration_seconds:.1f} seconds")
|
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Any]:
|
def get_observation(self) -> dict[str, Any]:
|
||||||
obs_array = self.get_current_dual_arm_q()
|
obs_array = self.get_current_dual_arm_q()
|
||||||
obs_dict = {f"{G1_29_JointArmIndex(motor).name}.pos": val for motor, val in zip(G1_29_JointArmIndex, obs_array, strict=True)}
|
obs_dict = {f"{G1_29_JointArmIndex(motor).name}.pos": val for motor, val in zip(G1_29_JointArmIndex, obs_array, strict=True)}
|
||||||
@@ -991,78 +746,6 @@ class UnitreeG1(Robot):
|
|||||||
self.ry = struct.unpack("f", data[12:16])[0]
|
self.ry = struct.unpack("f", data[12:16])[0]
|
||||||
self.ly = struct.unpack("f", data[20:24])[0]
|
self.ly = struct.unpack("f", data[20:24])[0]
|
||||||
|
|
||||||
class MotionLoader:
|
|
||||||
"""Load and interpolate motion from CSV file for motion imitation."""
|
|
||||||
def __init__(self, motion_file: str, fps: float = 60.0):
|
|
||||||
"""Load motion from CSV file.
|
|
||||||
|
|
||||||
CSV format: [root_pos(3), root_quat_xyzw(4), joint_dof(29)] per row
|
|
||||||
"""
|
|
||||||
self.dt = 1.0 / fps
|
|
||||||
|
|
||||||
# Load CSV
|
|
||||||
data = np.loadtxt(motion_file, delimiter=',')
|
|
||||||
self.num_frames = data.shape[0]
|
|
||||||
self.duration = self.num_frames * self.dt
|
|
||||||
|
|
||||||
# Split data
|
|
||||||
self.root_positions = data[:, 0:3] # (N, 3)
|
|
||||||
self.root_quaternions_xyzw = data[:, 3:7] # (N, 4) [x, y, z, w]
|
|
||||||
self.dof_positions = data[:, 7:] # (N, 29)
|
|
||||||
|
|
||||||
# Compute velocities (finite differences)
|
|
||||||
self.dof_velocities = np.diff(self.dof_positions, axis=0, prepend=self.dof_positions[0:1]) / self.dt
|
|
||||||
|
|
||||||
# Current playback state
|
|
||||||
self.current_time = 0.0
|
|
||||||
self.index_0 = 0
|
|
||||||
self.index_1 = 0
|
|
||||||
self.blend = 0.0
|
|
||||||
|
|
||||||
logger.info(f"MotionLoader: Loaded {self.num_frames} frames, duration={self.duration:.2f}s")
|
|
||||||
|
|
||||||
def update(self, time: float):
|
|
||||||
"""Update motion to specific time (loops at duration)."""
|
|
||||||
self.current_time = time % self.duration # Loop
|
|
||||||
phase = self.current_time / self.duration
|
|
||||||
|
|
||||||
self.index_0 = int(phase * (self.num_frames - 1))
|
|
||||||
self.index_1 = min(self.index_0 + 1, self.num_frames - 1)
|
|
||||||
self.blend = (self.current_time - self.index_0 * self.dt) / self.dt
|
|
||||||
|
|
||||||
def get_joint_pos(self) -> np.ndarray:
|
|
||||||
"""Get interpolated joint positions (29D)."""
|
|
||||||
return self.dof_positions[self.index_0] * (1 - self.blend) + \
|
|
||||||
self.dof_positions[self.index_1] * self.blend
|
|
||||||
|
|
||||||
def get_joint_vel(self) -> np.ndarray:
|
|
||||||
"""Get interpolated joint velocities (29D)."""
|
|
||||||
return self.dof_velocities[self.index_0] * (1 - self.blend) + \
|
|
||||||
self.dof_velocities[self.index_1] * self.blend
|
|
||||||
|
|
||||||
def get_root_quat_wxyz(self) -> np.ndarray:
|
|
||||||
"""Get interpolated root quaternion [w, x, y, z]."""
|
|
||||||
# Spherical linear interpolation (SLERP)
|
|
||||||
q0 = self.root_quaternions_xyzw[self.index_0] # [x, y, z, w]
|
|
||||||
q1 = self.root_quaternions_xyzw[self.index_1]
|
|
||||||
|
|
||||||
# Convert to scipy format [x, y, z, w]
|
|
||||||
r0 = R.from_quat(q0)
|
|
||||||
r1 = R.from_quat(q1)
|
|
||||||
|
|
||||||
# SLERP
|
|
||||||
key_times = [0, 1]
|
|
||||||
key_rots = R.from_quat([q0, q1])
|
|
||||||
slerp = R.from_quat(key_rots.as_quat()) # Simplified - just use linear for now
|
|
||||||
|
|
||||||
# Linear interpolation for simplicity
|
|
||||||
quat_xyzw = q0 * (1 - self.blend) + q1 * self.blend
|
|
||||||
# Normalize
|
|
||||||
quat_xyzw = quat_xyzw / np.linalg.norm(quat_xyzw)
|
|
||||||
|
|
||||||
# Convert to [w, x, y, z]
|
|
||||||
return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]], dtype=np.float32)
|
|
||||||
|
|
||||||
def locomotion_get_gravity_orientation(self, quaternion):
|
def locomotion_get_gravity_orientation(self, quaternion):
|
||||||
"""Get gravity orientation from quaternion."""
|
"""Get gravity orientation from quaternion."""
|
||||||
qw = quaternion[0]
|
qw = quaternion[0]
|
||||||
@@ -1287,8 +970,23 @@ class UnitreeG1(Robot):
|
|||||||
|
|
||||||
# Run policy inference (ONNX) with 516D stacked observation
|
# Run policy inference (ONNX) with 516D stacked observation
|
||||||
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
||||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
|
||||||
ort_outs = self.policy.run(None, ort_inputs)
|
# Select appropriate policy based on command magnitude (dual-policy system)
|
||||||
|
if self.policy_balance is not None and self.policy_walk is not None:
|
||||||
|
# Dual-policy mode: switch between Balance and Walk
|
||||||
|
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||||
|
if cmd_magnitude < 0.05:
|
||||||
|
# Use balance/standing policy for small commands
|
||||||
|
selected_policy = self.policy_balance
|
||||||
|
else:
|
||||||
|
# Use walking policy for movement commands
|
||||||
|
selected_policy = self.policy_walk
|
||||||
|
else:
|
||||||
|
# Single policy mode (fallback)
|
||||||
|
selected_policy = self.policy
|
||||||
|
|
||||||
|
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
||||||
|
ort_outs = selected_policy.run(None, ort_inputs)
|
||||||
self.groot_action = ort_outs[0].squeeze()
|
self.groot_action = ort_outs[0].squeeze()
|
||||||
|
|
||||||
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
|
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
|
||||||
@@ -1488,9 +1186,6 @@ class UnitreeG1(Robot):
|
|||||||
|
|
||||||
logger.info("Starting locomotion test sequence...")
|
logger.info("Starting locomotion test sequence...")
|
||||||
|
|
||||||
# 1. Home the arms first
|
|
||||||
logger.info("Homing arms to zero position...")
|
|
||||||
#self.ctrl_dual_arm_go_home()
|
|
||||||
|
|
||||||
# 2. Move legs to default position
|
# 2. Move legs to default position
|
||||||
self.locomotion_move_to_default_pos()
|
self.locomotion_move_to_default_pos()
|
||||||
@@ -1532,354 +1227,6 @@ class UnitreeG1(Robot):
|
|||||||
logger.info("GR00T locomotion initialization complete! Policy is now running.")
|
logger.info("GR00T locomotion initialization complete! Policy is now running.")
|
||||||
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
|
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
|
||||||
|
|
||||||
def motion_imitation_run(self):
|
|
||||||
"""Motion imitation policy loop - tracks reference motion (dance_102, etc)."""
|
|
||||||
self.motion_counter += 1
|
|
||||||
self.motion_elapsed_time = self.motion_counter * self.config.motion_control_dt
|
|
||||||
|
|
||||||
# Update motion loader to current time
|
|
||||||
self.motion_loader.update(self.motion_elapsed_time)
|
|
||||||
|
|
||||||
# Get current lowstate
|
|
||||||
lowstate = self.lowstate_buffer.GetData()
|
|
||||||
if lowstate is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get ALL 29 joint positions and velocities from robot
|
|
||||||
# IMPORTANT: Convert from motor order to BFS order to match reference motion
|
|
||||||
# The C++ code does: robot_bfs[i] = motor[joint_ids_map[i]]
|
|
||||||
for i in range(29):
|
|
||||||
motor_idx = self.config.motion_joint_ids_map[i]
|
|
||||||
self.motion_qj_all[i] = lowstate.motor_state[motor_idx].q
|
|
||||||
self.motion_dqj_all[i] = lowstate.motor_state[motor_idx].dq
|
|
||||||
|
|
||||||
# ======== 23 DOF MODE CONFIGURATION ========
|
|
||||||
# For real robot - zeros out joints not present in 23 DOF hardware
|
|
||||||
# Waist: yaw(12), pitch(14) | Wrist: L_pitch/yaw(20,21), R_pitch/yaw(27,28)
|
|
||||||
USE_23DOF = True # Set to True for real robot without these joints
|
|
||||||
JOINTS_TO_ZERO_23DOF = [12,14,20, 21, 27, 28]#12, 14, 20, 21, 27, 28]#
|
|
||||||
|
|
||||||
# Apply 23 DOF zeroing to robot observations if enabled
|
|
||||||
if USE_23DOF:
|
|
||||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
|
||||||
self.motion_qj_all[joint_idx] = 0.0
|
|
||||||
self.motion_dqj_all[joint_idx] = 0.0
|
|
||||||
if self.motion_counter == 1:
|
|
||||||
logger.info("="*60)
|
|
||||||
logger.info("🤖 23 DOF MODE ENABLED")
|
|
||||||
logger.info(f" Zeroing joints: {JOINTS_TO_ZERO_23DOF}")
|
|
||||||
logger.info(" Waist: yaw(12), pitch(14)")
|
|
||||||
logger.info(" Wrist L: pitch(20), yaw(21) | Wrist R: pitch(27), yaw(28)")
|
|
||||||
logger.info(" Applied to: robot obs, reference motion, policy actions")
|
|
||||||
logger.info("="*60)
|
|
||||||
|
|
||||||
# Get IMU data
|
|
||||||
robot_quat = lowstate.imu_state.quaternion # [w, x, y, z]
|
|
||||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32) # 3D
|
|
||||||
|
|
||||||
if self.policy is None:
|
|
||||||
# DIRECT PLAYBACK MODE (no policy)
|
|
||||||
motion_joint_pos_dfs = self.motion_loader.get_joint_pos()
|
|
||||||
|
|
||||||
# Zero out missing joints for 23 DOF mode
|
|
||||||
if USE_23DOF:
|
|
||||||
# Convert to BFS to zero out, then convert back
|
|
||||||
motion_joint_pos_bfs_temp = np.zeros(29, dtype=np.float32)
|
|
||||||
for i in range(29):
|
|
||||||
motion_joint_pos_bfs_temp[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
|
|
||||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
|
||||||
motion_joint_pos_bfs_temp[joint_idx] = 0.0
|
|
||||||
# Convert back to DFS for sending
|
|
||||||
for i in range(29):
|
|
||||||
motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]] = motion_joint_pos_bfs_temp[i]
|
|
||||||
|
|
||||||
for i in range(29):
|
|
||||||
motor_idx = self.config.motion_joint_ids_map[i]
|
|
||||||
csv_idx = self.config.motion_joint_ids_map[i]
|
|
||||||
self.msg.motor_cmd[motor_idx].q = motion_joint_pos_dfs[csv_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].qd = 0
|
|
||||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].tau = 0
|
|
||||||
else:
|
|
||||||
# POLICY MODE - Full observation construction and inference
|
|
||||||
|
|
||||||
# ======== DEBUG TEST MODES ========
|
|
||||||
# Mode 1: Direct playback (no policy) - set motion_policy_path = None in config instead
|
|
||||||
# Mode 2: Send default pos (stand still) - TEST_SEND_DEFAULT_POS = True
|
|
||||||
# Mode 3: Policy with zero reference - TEST_WITH_ZEROS = True, TEST_SEND_DEFAULT_POS = False
|
|
||||||
# Mode 4: Policy with real reference - TEST_WITH_ZEROS = False, TEST_SEND_DEFAULT_POS = False
|
|
||||||
TEST_WITH_ZEROS = False # True = use zero reference motion in observation
|
|
||||||
TEST_SEND_DEFAULT_POS = False # True = bypass policy and send default pos (stand still)
|
|
||||||
TEST_DIRECT_PLAYBACK = False # True = bypass policy and send reference motion directly
|
|
||||||
|
|
||||||
if TEST_DIRECT_PLAYBACK:
|
|
||||||
# DEBUG: Play back reference motion without policy
|
|
||||||
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D in DFS order
|
|
||||||
|
|
||||||
# Zero out missing joints for 23 DOF mode
|
|
||||||
if USE_23DOF:
|
|
||||||
# Convert to BFS to zero out, then convert back
|
|
||||||
motion_joint_pos_bfs_temp = np.zeros(29, dtype=np.float32)
|
|
||||||
for i in range(29):
|
|
||||||
motion_joint_pos_bfs_temp[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
|
|
||||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
|
||||||
motion_joint_pos_bfs_temp[joint_idx] = 0.0
|
|
||||||
# Convert back to DFS for sending
|
|
||||||
for i in range(29):
|
|
||||||
motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]] = motion_joint_pos_bfs_temp[i]
|
|
||||||
|
|
||||||
# Send directly to motors using joint_ids_map (same as direct playback mode)
|
|
||||||
for i in range(29):
|
|
||||||
motor_idx = self.config.motion_joint_ids_map[i]
|
|
||||||
csv_idx = self.config.motion_joint_ids_map[i]
|
|
||||||
self.msg.motor_cmd[motor_idx].q = motion_joint_pos_dfs[csv_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].qd = 0
|
|
||||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].tau = 0
|
|
||||||
|
|
||||||
if self.motion_counter == 1:
|
|
||||||
logger.info("="*60)
|
|
||||||
logger.info("⚠️ DEBUG MODE: DIRECT PLAYBACK (reference motion, no policy)")
|
|
||||||
logger.info("="*60)
|
|
||||||
|
|
||||||
target_joint_pos_bfs = None # Not used in this mode
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Run observation construction and policy
|
|
||||||
if TEST_WITH_ZEROS:
|
|
||||||
# Send zeros for reference motion
|
|
||||||
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
|
|
||||||
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
|
|
||||||
if self.motion_counter == 1:
|
|
||||||
logger.info("="*60)
|
|
||||||
logger.info("⚠️ DEBUG MODE: Using ZERO reference motion + RUNNING POLICY")
|
|
||||||
logger.info("="*60)
|
|
||||||
else:
|
|
||||||
# Get reference motion (DFS order from CSV)
|
|
||||||
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D
|
|
||||||
motion_joint_vel_dfs = self.motion_loader.get_joint_vel() # 29D
|
|
||||||
|
|
||||||
# Convert from DFS to BFS order: bfs[i] = dfs[joint_ids_map[i]]
|
|
||||||
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
|
|
||||||
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
|
|
||||||
for i in range(29):
|
|
||||||
motion_joint_pos_bfs[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
|
|
||||||
motion_joint_vel_bfs[i] = motion_joint_vel_dfs[self.config.motion_joint_ids_map[i]]
|
|
||||||
|
|
||||||
# Zero out missing joints in reference motion for 23 DOF mode
|
|
||||||
if USE_23DOF:
|
|
||||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
|
||||||
motion_joint_pos_bfs[joint_idx] = 0.0
|
|
||||||
motion_joint_vel_bfs[joint_idx] = 0.0
|
|
||||||
|
|
||||||
# Compute motion_anchor_ori_b (6D rotation matrix representation)
|
|
||||||
motion_quat_wxyz = self.motion_loader.get_root_quat_wxyz()
|
|
||||||
robot_rot = R.from_quat([robot_quat[1], robot_quat[2], robot_quat[3], robot_quat[0]]).as_matrix()
|
|
||||||
motion_rot = R.from_quat([motion_quat_wxyz[1], motion_quat_wxyz[2], motion_quat_wxyz[3], motion_quat_wxyz[0]]).as_matrix()
|
|
||||||
relative_rot = robot_rot.T @ motion_rot
|
|
||||||
motion_anchor_ori_b = np.array([relative_rot[0, 0], relative_rot[0, 1],
|
|
||||||
relative_rot[1, 0], relative_rot[1, 1],
|
|
||||||
relative_rot[2, 0], relative_rot[2, 1]], dtype=np.float32)
|
|
||||||
|
|
||||||
# Compute joint positions and velocities relative to default
|
|
||||||
default_joint_pos = np.array(self.config.motion_default_joint_pos, dtype=np.float32)
|
|
||||||
joint_pos_rel = self.motion_qj_all - default_joint_pos
|
|
||||||
joint_vel_rel = self.motion_dqj_all.copy()
|
|
||||||
|
|
||||||
# Build 154D observation:
|
|
||||||
# motion_command (58D) = joint_pos (29D) + joint_vel (29D) from reference
|
|
||||||
# motion_anchor_ori_b (6D)
|
|
||||||
# base_ang_vel (3D)
|
|
||||||
# joint_pos_rel (29D)
|
|
||||||
# joint_vel_rel (29D)
|
|
||||||
# last_action (29D)
|
|
||||||
self.motion_obs[0:29] = motion_joint_pos_bfs
|
|
||||||
self.motion_obs[29:58] = motion_joint_vel_bfs
|
|
||||||
self.motion_obs[58:64] = motion_anchor_ori_b
|
|
||||||
self.motion_obs[64:67] = ang_vel
|
|
||||||
self.motion_obs[67:96] = joint_pos_rel
|
|
||||||
self.motion_obs[96:125] = joint_vel_rel
|
|
||||||
self.motion_obs[125:154] = self.motion_action
|
|
||||||
|
|
||||||
if TEST_SEND_DEFAULT_POS:
|
|
||||||
# DEBUG: Just send default positions (should make robot stand still)
|
|
||||||
target_joint_pos_bfs = default_joint_pos.copy()
|
|
||||||
if self.motion_counter == 1:
|
|
||||||
logger.info("="*60)
|
|
||||||
logger.info("⚠️ DEBUG MODE: Sending DEFAULT positions (NO POLICY)")
|
|
||||||
logger.info("="*60)
|
|
||||||
logger.info(f" Default pos BFS[0:5]: {target_joint_pos_bfs[0:5]}")
|
|
||||||
if self.motion_counter % 50 == 0:
|
|
||||||
logger.info(f" [DEFAULT MODE] Sending: [{target_joint_pos_bfs[0]:.4f}, {target_joint_pos_bfs[6]:.4f}, {target_joint_pos_bfs[12]:.4f}]")
|
|
||||||
logger.info(f" [DEFAULT MODE] Robot at: [{self.motion_qj_all[0]:.4f}, {self.motion_qj_all[6]:.4f}, {self.motion_qj_all[12]:.4f}]")
|
|
||||||
else:
|
|
||||||
# Run ONNX policy inference
|
|
||||||
obs_tensor = torch.from_numpy(self.motion_obs).unsqueeze(0)
|
|
||||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
|
||||||
ort_outs = self.policy.run(None, ort_inputs)
|
|
||||||
self.motion_action = ort_outs[0].squeeze() # 29D action in BFS order
|
|
||||||
|
|
||||||
# Zero out missing joints in policy actions for 23 DOF mode
|
|
||||||
if USE_23DOF:
|
|
||||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
|
||||||
self.motion_action[joint_idx] = 0.0
|
|
||||||
|
|
||||||
# Process actions: scale and add offset
|
|
||||||
action_scale = np.array(self.config.motion_action_scale, dtype=np.float32)
|
|
||||||
target_joint_pos_bfs = default_joint_pos + self.motion_action * action_scale
|
|
||||||
|
|
||||||
# Send commands to motors: motor[joint_ids_map[i]] = action[i]
|
|
||||||
for i in range(29):
|
|
||||||
motor_idx = self.config.motion_joint_ids_map[i]
|
|
||||||
self.msg.motor_cmd[motor_idx].q = target_joint_pos_bfs[i]
|
|
||||||
self.msg.motor_cmd[motor_idx].qd = 0
|
|
||||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].tau = 0
|
|
||||||
|
|
||||||
# Debug print (only when running policy, not in TEST_SEND_DEFAULT_POS or TEST_DIRECT_PLAYBACK mode)
|
|
||||||
if self.motion_counter == 1 and self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
|
|
||||||
logger.info("="*60)
|
|
||||||
logger.info("POLICY MODE OBSERVATION CHECK (First iteration)")
|
|
||||||
logger.info("="*60)
|
|
||||||
logger.info(f"Reference motion (BFS) samples: [{motion_joint_pos_bfs[0]:.3f}, {motion_joint_pos_bfs[6]:.3f}, {motion_joint_pos_bfs[12]:.3f}]")
|
|
||||||
logger.info(f"Robot joints (BFS) samples: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
|
||||||
logger.info(f"Default positions samples: [{default_joint_pos[0]:.3f}, {default_joint_pos[6]:.3f}, {default_joint_pos[12]:.3f}]")
|
|
||||||
logger.info(f"Joint pos rel samples: [{joint_pos_rel[0]:.3f}, {joint_pos_rel[6]:.3f}, {joint_pos_rel[12]:.3f}]")
|
|
||||||
logger.info(f"Joint vel rel samples: [{joint_vel_rel[0]:.3f}, {joint_vel_rel[6]:.3f}, {joint_vel_rel[12]:.3f}]")
|
|
||||||
logger.info(f"Angular velocity: [{ang_vel[0]:.3f}, {ang_vel[1]:.3f}, {ang_vel[2]:.3f}]")
|
|
||||||
logger.info(f"Motion anchor ori: [{motion_anchor_ori_b[0]:.3f}, ..., {motion_anchor_ori_b[5]:.3f}]")
|
|
||||||
logger.info(f"Observation breakdown:")
|
|
||||||
logger.info(f" [0:29] motion_cmd_pos: range [{self.motion_obs[0:29].min():.3f}, {self.motion_obs[0:29].max():.3f}]")
|
|
||||||
logger.info(f" [29:58] motion_cmd_vel: range [{self.motion_obs[29:58].min():.3f}, {self.motion_obs[29:58].max():.3f}]")
|
|
||||||
logger.info(f" [58:64] anchor_ori: range [{self.motion_obs[58:64].min():.3f}, {self.motion_obs[58:64].max():.3f}]")
|
|
||||||
logger.info(f" [64:67] ang_vel: range [{self.motion_obs[64:67].min():.3f}, {self.motion_obs[64:67].max():.3f}]")
|
|
||||||
logger.info(f" [67:96] joint_pos_rel: range [{self.motion_obs[67:96].min():.3f}, {self.motion_obs[67:96].max():.3f}]")
|
|
||||||
logger.info(f" [96:125] joint_vel_rel: range [{self.motion_obs[96:125].min():.3f}, {self.motion_obs[96:125].max():.3f}]")
|
|
||||||
logger.info(f" [125:154] last_action: range [{self.motion_obs[125:154].min():.3f}, {self.motion_obs[125:154].max():.3f}]")
|
|
||||||
logger.info(f"Full obs range: [{self.motion_obs.min():.3f}, {self.motion_obs.max():.3f}]")
|
|
||||||
logger.info(f"Action output (first): [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
|
||||||
logger.info(f"Action scale samples: [{action_scale[0]:.3f}, {action_scale[6]:.3f}, {action_scale[12]:.3f}]")
|
|
||||||
logger.info(f"Target positions samples: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
|
||||||
logger.info("="*60)
|
|
||||||
|
|
||||||
if self.motion_counter % 50 == 0:
|
|
||||||
if self.policy is None:
|
|
||||||
mode = "DIRECT"
|
|
||||||
elif TEST_DIRECT_PLAYBACK:
|
|
||||||
mode = "DIRECT_DEBUG"
|
|
||||||
elif TEST_SEND_DEFAULT_POS:
|
|
||||||
mode = "DEFAULT_POS"
|
|
||||||
elif TEST_WITH_ZEROS:
|
|
||||||
mode = "POLICY_ZEROS"
|
|
||||||
else:
|
|
||||||
mode = "POLICY"
|
|
||||||
logger.info(f"Motion {mode}: t={self.motion_elapsed_time:.2f}s, frame={self.motion_loader.index_0}/{self.motion_loader.num_frames}")
|
|
||||||
if self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
|
|
||||||
logger.info(f" Policy action range: [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
|
||||||
logger.info(f" Sample actions[0,6,12]: [{self.motion_action[0]:.3f}, {self.motion_action[6]:.3f}, {self.motion_action[12]:.3f}]")
|
|
||||||
logger.info(f" Target pos (after scale)[0,6,12]: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
|
||||||
logger.info(f" Robot pos (BFS)[0,6,12]: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
|
||||||
|
|
||||||
# Send command
|
|
||||||
self.msg.crc = self.crc.Crc(self.msg)
|
|
||||||
self.lowcmd_publisher.Write(self.msg)
|
|
||||||
|
|
||||||
def _motion_imitation_thread_loop(self):
|
|
||||||
"""Background thread that runs the motion imitation policy at specified rate."""
|
|
||||||
logger.info("Motion imitation thread started")
|
|
||||||
while self.motion_imitation_running:
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
self.motion_imitation_run()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in motion imitation loop: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# Sleep to maintain control rate
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
sleep_time = max(0, self.config.motion_control_dt - elapsed)
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
logger.info("Motion imitation thread stopped")
|
|
||||||
|
|
||||||
def start_motion_imitation_thread(self):
|
|
||||||
"""Start the background motion imitation control thread."""
|
|
||||||
if not self.config.motion_imitation_control:
|
|
||||||
logger.warning("motion_imitation_control is False, cannot start thread")
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.motion_imitation_running:
|
|
||||||
logger.warning("Motion imitation thread already running")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Starting motion imitation control thread...")
|
|
||||||
self.motion_imitation_running = True
|
|
||||||
self.motion_imitation_thread = threading.Thread(target=self._motion_imitation_thread_loop, daemon=True)
|
|
||||||
self.motion_imitation_thread.start()
|
|
||||||
logger.info("Motion imitation control thread started!")
|
|
||||||
|
|
||||||
def stop_motion_imitation_thread(self):
|
|
||||||
"""Stop the background motion imitation control thread."""
|
|
||||||
if not self.motion_imitation_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Stopping motion imitation control thread...")
|
|
||||||
self.motion_imitation_running = False
|
|
||||||
if self.motion_imitation_thread:
|
|
||||||
self.motion_imitation_thread.join(timeout=2.0)
|
|
||||||
logger.info("Motion imitation control thread stopped")
|
|
||||||
|
|
||||||
def init_motion_imitation(self):
|
|
||||||
"""Initialize motion imitation - move to default standing pose and start policy."""
|
|
||||||
if not self.config.motion_imitation_control:
|
|
||||||
logger.warning("motion_imitation_control is False, cannot run initialization")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Starting motion imitation initialization...")
|
|
||||||
|
|
||||||
# Move to default standing position
|
|
||||||
logger.info("Moving to default standing position...")
|
|
||||||
total_time = 3.0
|
|
||||||
num_steps = int(total_time / self.config.motion_control_dt)
|
|
||||||
|
|
||||||
# Get current positions (in motor order)
|
|
||||||
current_q_motor = self.get_current_motor_q()
|
|
||||||
|
|
||||||
# target_q is in BFS order from config, need to convert to motor order
|
|
||||||
target_q_bfs = np.array(self.config.motion_default_joint_pos, dtype=np.float32)
|
|
||||||
target_q_motor = np.zeros(29, dtype=np.float32)
|
|
||||||
for i in range(29):
|
|
||||||
motor_idx = self.config.motion_joint_ids_map[i]
|
|
||||||
target_q_motor[motor_idx] = target_q_bfs[i]
|
|
||||||
|
|
||||||
# Interpolate to target (both in motor order now)
|
|
||||||
for i in range(num_steps):
|
|
||||||
alpha = i / num_steps
|
|
||||||
for motor_idx in range(29):
|
|
||||||
self.msg.motor_cmd[motor_idx].q = current_q_motor[motor_idx] * (1 - alpha) + target_q_motor[motor_idx] * alpha
|
|
||||||
self.msg.motor_cmd[motor_idx].qd = 0
|
|
||||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
|
||||||
self.msg.motor_cmd[motor_idx].tau = 0
|
|
||||||
self.msg.crc = self.crc.Crc(self.msg)
|
|
||||||
self.lowcmd_publisher.Write(self.msg)
|
|
||||||
time.sleep(self.config.motion_control_dt)
|
|
||||||
|
|
||||||
logger.info("Reached default position")
|
|
||||||
|
|
||||||
# Wait 2 seconds
|
|
||||||
time.sleep(2.0)
|
|
||||||
|
|
||||||
# Start motion imitation policy thread
|
|
||||||
logger.info("Starting motion imitation policy control...")
|
|
||||||
self.start_motion_imitation_thread()
|
|
||||||
|
|
||||||
logger.info("Motion imitation initialization complete! Policy is now running.")
|
|
||||||
logger.info(f"154D observations, 29D actions. Motion duration: {self.motion_loader.duration:.2f}s")
|
|
||||||
|
|
||||||
|
|
||||||
class G1_29_JointArmIndex(IntEnum):
|
class G1_29_JointArmIndex(IntEnum):
|
||||||
# Left arm
|
# Left arm
|
||||||
|
|||||||
Reference in New Issue
Block a user