mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
simplified robot class
This commit is contained in:
@@ -6,15 +6,9 @@ from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
from lerobot.motors.calibration_gui import RangeFinderGUI
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
import json
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
import numpy as np
|
||||
@@ -36,7 +30,6 @@ import onnxruntime as ort
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState # idl for g1, h1_2
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
from unitree_sdk2py.g1.audio.g1_audio_client import AudioClient
|
||||
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import (
|
||||
MotionSwitcherClient,
|
||||
)
|
||||
@@ -45,11 +38,6 @@ from lerobot.envs.factory import make_env
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
import struct
|
||||
|
||||
import yaml
|
||||
|
||||
from typing import Union
|
||||
|
||||
from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK
|
||||
|
||||
import torch
|
||||
@@ -57,7 +45,6 @@ import torch
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
kTopicLowCommand_Motion = "rt/arm_sdk"
|
||||
kTopicLowState = "rt/lowstate"
|
||||
|
||||
G1_29_Num_Motors = 35
|
||||
@@ -102,8 +89,6 @@ class DataBuffer:
|
||||
with self.lock:
|
||||
self.data = data
|
||||
|
||||
#eventually observations should be everything: motor torques etc etc
|
||||
#motor class for unitree?
|
||||
class UnitreeG1(Robot):
|
||||
|
||||
config_class = UnitreeG1Config
|
||||
@@ -118,7 +103,6 @@ class UnitreeG1(Robot):
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.q_target = np.zeros(14)
|
||||
self.tauff_target = np.zeros(14)
|
||||
self.motion_mode = config.motion_mode
|
||||
self.simulation_mode = config.simulation_mode
|
||||
self.kp_high = config.kp_high
|
||||
self.kd_high = config.kd_high
|
||||
@@ -131,7 +115,6 @@ class UnitreeG1(Robot):
|
||||
self.arm_velocity_limit = config.arm_velocity_limit
|
||||
self.control_dt = config.control_dt
|
||||
|
||||
self._speed_gradual_max = config.speed_gradual_max
|
||||
self._gradual_start_time = config.gradual_start_time
|
||||
self._gradual_time = config.gradual_time
|
||||
|
||||
@@ -143,7 +126,6 @@ class UnitreeG1(Robot):
|
||||
self.freeze_body = config.freeze_body
|
||||
self.gravity_compensation = config.gravity_compensation
|
||||
|
||||
|
||||
self.calibrated = False
|
||||
|
||||
self.calibrate()
|
||||
@@ -155,39 +137,33 @@ class UnitreeG1(Robot):
|
||||
else:
|
||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||
|
||||
if not self.config.simulation_mode:
|
||||
self.msc = MotionSwitcherClient()
|
||||
self.msc.SetTimeout(5.0)
|
||||
self.msc.Init()
|
||||
|
||||
status, result = self.msc.CheckMode()
|
||||
print(status, result)
|
||||
#check if result name first
|
||||
if result is not None and "name" in result:
|
||||
while result["name"]:
|
||||
self.msc.ReleaseMode()
|
||||
status, result = self.msc.CheckMode()
|
||||
print(status, result)
|
||||
time.sleep(1)
|
||||
|
||||
# initialize lowcmd nd lowstate subscriber
|
||||
if self.simulation_mode:
|
||||
ChannelFactoryInitialize(0, "lo")
|
||||
|
||||
# Launch MuJoCo simulation environment
|
||||
|
||||
logger.info("Launching MuJoCo simulation environment...")
|
||||
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
||||
logger.info("MuJoCo environment launched successfully!")
|
||||
else:
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
|
||||
|
||||
if not self.config.simulation_mode:
|
||||
pass
|
||||
# self.msc = MotionSwitcherClient()
|
||||
# self.msc.SetTimeout(5.0)
|
||||
# self.msc.Init()
|
||||
|
||||
# status, result = self.msc.CheckMode()
|
||||
# print(status, result)
|
||||
# #check if result name first
|
||||
# if result is not None and "name" in result:
|
||||
# while result["name"]:
|
||||
# self.msc.ReleaseMode()
|
||||
# status, result = self.msc.CheckMode()
|
||||
# print(status, result)
|
||||
# time.sleep(1)
|
||||
|
||||
if self.motion_mode:
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Motion, hg_LowCmd)
|
||||
else:
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
# Always use debug mode (direct motor control)
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
self.lowcmd_publisher.Init()
|
||||
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
self.lowstate_subscriber.Init()
|
||||
@@ -203,9 +179,6 @@ class UnitreeG1(Robot):
|
||||
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
|
||||
logger.info("[UnitreeG1] Subscribe dds ok.")
|
||||
|
||||
# initialize audio client for LED, TTS, and audio playback
|
||||
|
||||
|
||||
# initialize hg's lowcmd msg
|
||||
self.crc = CRC()
|
||||
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||
@@ -238,15 +211,6 @@ class UnitreeG1(Robot):
|
||||
self.msg.motor_cmd[id].q = self.all_motor_q[id]
|
||||
#print current motor q, kp, kd
|
||||
|
||||
|
||||
|
||||
if config.audio_client:
|
||||
pass
|
||||
# self.audio_client = AudioClient()
|
||||
# self.audio_client.SetTimeout(10.0)
|
||||
# self.audio_client.Init()
|
||||
# logger.info("[UnitreeG1] Audio client initialized!")
|
||||
|
||||
logger.info("Lock OK!\n") #motors are not locked x
|
||||
# for i in range(10000):
|
||||
# print(self.get_current_motor_q())
|
||||
@@ -257,62 +221,24 @@ class UnitreeG1(Robot):
|
||||
self.keyboard_running = False
|
||||
self.locomotion_thread = None
|
||||
self.locomotion_running = False
|
||||
self.motion_imitation_thread = None
|
||||
self.motion_imitation_running = False
|
||||
|
||||
# Initialize publish thread for arm control
|
||||
# Note: This thread runs alongside locomotion/motion_imitation threads
|
||||
# Note: This thread runs alongside locomotion thread
|
||||
# - Arm thread: controls arms (indices 15-28)
|
||||
# - Locomotion thread: controls legs (0-11), waist (12-14)
|
||||
# Both update different parts of self.msg, both call Write()
|
||||
self.publish_thread = None
|
||||
self.ctrl_lock = threading.Lock()
|
||||
if not config.motion_imitation_control: # Allow with locomotion, disable only for motion imitation
|
||||
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
||||
self.publish_thread.daemon = True
|
||||
self.publish_thread.start()
|
||||
logger.info("Arm control publish thread started")
|
||||
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
||||
self.publish_thread.daemon = True
|
||||
self.publish_thread.start()
|
||||
logger.info("Arm control publish thread started")
|
||||
|
||||
# Load locomotion policy if enabled
|
||||
self.policy = None
|
||||
self.policy_type = None # 'torchscript', 'onnx', or 'motion_imitation'
|
||||
self.motion_loader = None
|
||||
self.policy_type = None # 'torchscript' or 'onnx'
|
||||
|
||||
if config.motion_imitation_control:
|
||||
# Motion imitation mode (dance, etc.)
|
||||
if config.motion_file_path is None:
|
||||
raise ValueError("motion_imitation_control is True but motion_file_path is not set")
|
||||
|
||||
logger.info(f"Loading motion reference from {config.motion_file_path}")
|
||||
|
||||
# Load motion file
|
||||
self.motion_loader = self.MotionLoader(config.motion_file_path, config.motion_fps)
|
||||
|
||||
# Load ONNX policy (optional for now - can run in direct playback mode)
|
||||
if config.motion_policy_path and Path(config.motion_policy_path).exists():
|
||||
logger.info(f"Loading motion imitation policy from {config.motion_policy_path}")
|
||||
self.policy = ort.InferenceSession(config.motion_policy_path)
|
||||
self.policy_type = 'motion_imitation'
|
||||
logger.info("Motion imitation ONNX policy loaded successfully")
|
||||
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||
else:
|
||||
logger.info("Running in DIRECT PLAYBACK mode (no policy - just reference motion)")
|
||||
self.policy = None
|
||||
self.policy_type = 'motion_playback'
|
||||
|
||||
# Initialize motion imitation variables
|
||||
self.motion_counter = 0
|
||||
self.motion_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints from robot
|
||||
self.motion_dqj_all = np.zeros(29, dtype=np.float32)
|
||||
self.motion_action = np.zeros(29, dtype=np.float32) # 29D action output
|
||||
self.motion_obs = np.zeros(154, dtype=np.float32) # 154D observation
|
||||
self.motion_elapsed_time = 0.0
|
||||
|
||||
# Initialize motion and start
|
||||
self.init_motion_imitation()
|
||||
|
||||
elif config.locomotion_control:
|
||||
if config.locomotion_control:
|
||||
if config.policy_path is None:
|
||||
raise ValueError("locomotion_control is True but policy_path is not set")
|
||||
|
||||
@@ -326,11 +252,33 @@ class UnitreeG1(Robot):
|
||||
logger.info("TorchScript policy loaded successfully")
|
||||
elif config.policy_path.endswith('.onnx'):
|
||||
logger.info("Detected ONNX (.onnx) policy")
|
||||
self.policy = ort.InferenceSession(config.policy_path)
|
||||
|
||||
# For GR00T-style policies, load both Balance and Walk policies
|
||||
# Balance policy for standing (low velocity commands)
|
||||
# Walk policy for locomotion (high velocity commands)
|
||||
balance_policy_path = config.policy_path.replace('Walk.onnx', 'Balance.onnx')
|
||||
walk_policy_path = config.policy_path
|
||||
|
||||
if Path(balance_policy_path).exists() and Path(walk_policy_path).exists():
|
||||
logger.info("Loading dual-policy system (Balance + Walk)")
|
||||
self.policy_balance = ort.InferenceSession(balance_policy_path)
|
||||
self.policy_walk = ort.InferenceSession(walk_policy_path)
|
||||
self.policy = None # Not used when dual policies are loaded
|
||||
logger.info(f"Balance policy loaded from: {balance_policy_path}")
|
||||
logger.info(f"Walk policy loaded from: {walk_policy_path}")
|
||||
logger.info(f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}")
|
||||
logger.info(f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}")
|
||||
else:
|
||||
# Fallback to single policy
|
||||
logger.info("Loading single ONNX policy")
|
||||
self.policy = ort.InferenceSession(config.policy_path)
|
||||
self.policy_balance = None
|
||||
self.policy_walk = None
|
||||
logger.info("ONNX policy loaded successfully")
|
||||
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||
|
||||
self.policy_type = 'onnx'
|
||||
logger.info("ONNX policy loaded successfully")
|
||||
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported policy format: {config.policy_path}. Only .pt (TorchScript) and .onnx (ONNX) are supported.")
|
||||
|
||||
@@ -415,12 +363,7 @@ class UnitreeG1(Robot):
|
||||
return cliped_arm_q_target
|
||||
|
||||
def _ctrl_motor_state(self):
|
||||
"""Arm control thread - publishes commands for arms only.
|
||||
NOTE: This thread is NOT started when motion_imitation_control or locomotion_control is True.
|
||||
Those modes handle their own publishing."""
|
||||
if self.motion_mode:
|
||||
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = 1.0
|
||||
|
||||
"""Arm control thread - publishes commands for arms only."""
|
||||
while True:
|
||||
start_time = time.time()
|
||||
|
||||
@@ -452,10 +395,6 @@ class UnitreeG1(Robot):
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
|
||||
if self._speed_gradual_max is True:
|
||||
t_elapsed = start_time - self._gradual_start_time
|
||||
self.arm_velocity_limit = 20.0 + (10.0 * min(1.0, t_elapsed / 5.0))
|
||||
|
||||
current_time = time.time()
|
||||
all_t_elapsed = current_time - start_time
|
||||
sleep_time = max(0, (self.control_dt - all_t_elapsed))
|
||||
@@ -485,37 +424,6 @@ class UnitreeG1(Robot):
|
||||
"""Return current state dq of the left and right arm motors."""
|
||||
return np.array([self.lowstate_buffer.GetData().motor_state[id].dq for id in G1_29_JointArmIndex])
|
||||
|
||||
def ctrl_dual_arm_go_home(self):
|
||||
"""Move both the left and right arms of the robot to their home position by setting the target joint angles (q) and torques (tau) to zero."""
|
||||
logger.info("[G1_29_ArmController] ctrl_dual_arm_go_home start...")
|
||||
max_attempts = 100
|
||||
current_attempts = 0
|
||||
with self.ctrl_lock:
|
||||
self.q_target = np.zeros(14)
|
||||
#self.q_target[G1_29_JointIndex.kLeftElbow] = 0.5
|
||||
# self.tauff_target = np.zeros(14)
|
||||
tolerance = 0.05 # Tolerance threshold for joint angles to determine "close to zero", can be adjusted based on your motor's precision requirements
|
||||
while current_attempts < max_attempts:
|
||||
current_q = self.get_current_dual_arm_q()
|
||||
if np.all(np.abs(current_q) < tolerance):
|
||||
if self.motion_mode:
|
||||
for weight in np.linspace(1, 0, num=101):
|
||||
self.msg.motor_cmd[G1_29_JointIndex.kNotUsedJoint0].q = weight
|
||||
time.sleep(0.02)
|
||||
logger.info("[G1_29_ArmController] both arms have reached the home position.")
|
||||
break
|
||||
current_attempts += 1
|
||||
time.sleep(0.05)
|
||||
|
||||
def speed_gradual_max(self, t=5.0):
|
||||
"""Parameter t is the total time required for arms velocity to gradually increase to its maximum value, in seconds. The default is 5.0."""
|
||||
self._gradual_start_time = time.time()
|
||||
self._gradual_time = t
|
||||
self._speed_gradual_max = True
|
||||
|
||||
def speed_instant_max(self):
|
||||
"""set arms velocity to the maximum value immediately, instead of gradually increasing."""
|
||||
self.arm_velocity_limit = 30.0
|
||||
|
||||
def _Is_weak_motor(self, motor_index):
|
||||
weak_motors = [
|
||||
@@ -614,159 +522,6 @@ class UnitreeG1(Robot):
|
||||
'motors': motors_data,
|
||||
}
|
||||
|
||||
def audio_control(self, command, volume: int = 80):
|
||||
"""
|
||||
Unified audio/LED control function for the G1 robot.
|
||||
|
||||
Args:
|
||||
command: Can be one of:
|
||||
- str: Text to speak via TTS
|
||||
- tuple[int, int, int]: RGB values (0-255) for LED control
|
||||
- str (path): Path to WAV file to play
|
||||
volume: Volume level 0-100 (default: 80)
|
||||
|
||||
Examples:
|
||||
robot.audio_control("Hello world") # TTS
|
||||
robot.audio_control((255, 0, 0)) # Red LED
|
||||
robot.audio_control("audio.wav") # Play WAV file
|
||||
"""
|
||||
# Set volume
|
||||
self.audio_client.SetVolume(volume)
|
||||
|
||||
# Detect command type and execute
|
||||
if isinstance(command, tuple) and len(command) == 3:
|
||||
# LED control - RGB tuple
|
||||
r, g, b = command
|
||||
logger.info(f"Setting LED to RGB({r}, {g}, {b})")
|
||||
self.audio_client.LedControl(r, g, b)
|
||||
|
||||
elif isinstance(command, str):
|
||||
# Check if it's a file path
|
||||
if Path(command).exists():
|
||||
# Play WAV file
|
||||
logger.info(f"Playing audio file: {command}")
|
||||
self._play_wav_file(command)
|
||||
else:
|
||||
# Text-to-speech
|
||||
logger.info(f"Speaking: {command}")
|
||||
self.audio_client.TtsMaker(command, 0) # 0 for English
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid command type: {type(command)}. "
|
||||
"Expected str (text/path) or tuple[int, int, int] (RGB)"
|
||||
)
|
||||
|
||||
def _read_wav_file(self, filename: str):
|
||||
"""Read WAV file and return PCM data as bytes."""
|
||||
with open(filename, 'rb') as f:
|
||||
def read(fmt):
|
||||
return struct.unpack(fmt, f.read(struct.calcsize(fmt)))
|
||||
|
||||
# Read RIFF header
|
||||
chunk_id, = read('<I')
|
||||
if chunk_id != 0x46464952: # "RIFF"
|
||||
raise ValueError("Not a valid WAV file (invalid RIFF header)")
|
||||
|
||||
_chunk_size, = read('<I')
|
||||
format_tag, = read('<I')
|
||||
if format_tag != 0x45564157: # "WAVE"
|
||||
raise ValueError("Not a valid WAV file (invalid WAVE format)")
|
||||
|
||||
# Read fmt chunk
|
||||
subchunk1_id, = read('<I')
|
||||
subchunk1_size, = read('<I')
|
||||
|
||||
# Skip JUNK chunk if present
|
||||
if subchunk1_id == 0x4B4E554A: # "JUNK"
|
||||
f.seek(subchunk1_size, 1)
|
||||
subchunk1_id, = read('<I')
|
||||
subchunk1_size, = read('<I')
|
||||
|
||||
if subchunk1_id != 0x20746D66: # "fmt "
|
||||
raise ValueError("Invalid fmt chunk")
|
||||
|
||||
if subchunk1_size not in [16, 18]:
|
||||
raise ValueError(f"Unsupported fmt chunk size: {subchunk1_size}")
|
||||
|
||||
audio_format, = read('<H')
|
||||
if audio_format != 1:
|
||||
raise ValueError(f"Only PCM format supported, got format {audio_format}")
|
||||
|
||||
num_channels, = read('<H')
|
||||
sample_rate, = read('<I')
|
||||
_byte_rate, = read('<I')
|
||||
_block_align, = read('<H')
|
||||
bits_per_sample, = read('<H')
|
||||
|
||||
if bits_per_sample != 16:
|
||||
raise ValueError(f"Only 16-bit samples supported, got {bits_per_sample}-bit")
|
||||
|
||||
if sample_rate != 16000:
|
||||
raise ValueError(f"Sample rate must be 16000 Hz, got {sample_rate} Hz")
|
||||
|
||||
if num_channels != 1:
|
||||
raise ValueError(f"Must be mono (1 channel), got {num_channels} channels")
|
||||
|
||||
if subchunk1_size == 18:
|
||||
extra_size, = read('<H')
|
||||
if extra_size != 0:
|
||||
f.seek(extra_size, 1)
|
||||
|
||||
# Find data chunk
|
||||
while True:
|
||||
subchunk2_id, subchunk2_size = read('<II')
|
||||
if subchunk2_id == 0x61746164: # "data"
|
||||
break
|
||||
f.seek(subchunk2_size, 1)
|
||||
|
||||
# Read PCM data
|
||||
raw_pcm = f.read(subchunk2_size)
|
||||
if len(raw_pcm) != subchunk2_size:
|
||||
raise ValueError("Failed to read full PCM data")
|
||||
|
||||
return raw_pcm
|
||||
|
||||
def _play_wav_file(self, filename: str, chunk_size: int = 96000):
|
||||
"""
|
||||
Play a WAV file through the robot's speaker.
|
||||
|
||||
Args:
|
||||
filename: Path to WAV file (must be 16kHz, mono, 16-bit PCM)
|
||||
chunk_size: Bytes per chunk (default: 96000 = ~3 seconds at 16kHz)
|
||||
"""
|
||||
# Read WAV file
|
||||
pcm_data = self._read_wav_file(filename)
|
||||
|
||||
stream_id = str(int(time.time() * 1000))
|
||||
app_name = "lerobot"
|
||||
offset = 0
|
||||
chunk_index = 0
|
||||
total_size = len(pcm_data)
|
||||
|
||||
logger.info(f"Playing audio: {total_size} bytes in {(total_size // chunk_size) + 1} chunks")
|
||||
|
||||
# Send audio in chunks
|
||||
while offset < total_size:
|
||||
remaining = total_size - offset
|
||||
current_chunk_size = min(chunk_size, remaining)
|
||||
chunk = pcm_data[offset:offset + current_chunk_size]
|
||||
|
||||
# Send chunk
|
||||
ret_code, _ = self.audio_client.PlayStream(app_name, stream_id, list(chunk))
|
||||
if ret_code != 0:
|
||||
logger.error(f"Failed to send chunk {chunk_index}, return code: {ret_code}")
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Sent chunk {chunk_index}/{(total_size // chunk_size)}")
|
||||
|
||||
offset += current_chunk_size
|
||||
chunk_index += 1
|
||||
time.sleep(1.0) # Wait between chunks
|
||||
|
||||
# Calculate playback duration
|
||||
duration_seconds = len(pcm_data) / (16000 * 2) # 16kHz, 16-bit (2 bytes)
|
||||
logger.info(f"Audio playback will take ~{duration_seconds:.1f} seconds")
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
obs_array = self.get_current_dual_arm_q()
|
||||
obs_dict = {f"{G1_29_JointArmIndex(motor).name}.pos": val for motor, val in zip(G1_29_JointArmIndex, obs_array, strict=True)}
|
||||
@@ -991,78 +746,6 @@ class UnitreeG1(Robot):
|
||||
self.ry = struct.unpack("f", data[12:16])[0]
|
||||
self.ly = struct.unpack("f", data[20:24])[0]
|
||||
|
||||
class MotionLoader:
|
||||
"""Load and interpolate motion from CSV file for motion imitation."""
|
||||
def __init__(self, motion_file: str, fps: float = 60.0):
|
||||
"""Load motion from CSV file.
|
||||
|
||||
CSV format: [root_pos(3), root_quat_xyzw(4), joint_dof(29)] per row
|
||||
"""
|
||||
self.dt = 1.0 / fps
|
||||
|
||||
# Load CSV
|
||||
data = np.loadtxt(motion_file, delimiter=',')
|
||||
self.num_frames = data.shape[0]
|
||||
self.duration = self.num_frames * self.dt
|
||||
|
||||
# Split data
|
||||
self.root_positions = data[:, 0:3] # (N, 3)
|
||||
self.root_quaternions_xyzw = data[:, 3:7] # (N, 4) [x, y, z, w]
|
||||
self.dof_positions = data[:, 7:] # (N, 29)
|
||||
|
||||
# Compute velocities (finite differences)
|
||||
self.dof_velocities = np.diff(self.dof_positions, axis=0, prepend=self.dof_positions[0:1]) / self.dt
|
||||
|
||||
# Current playback state
|
||||
self.current_time = 0.0
|
||||
self.index_0 = 0
|
||||
self.index_1 = 0
|
||||
self.blend = 0.0
|
||||
|
||||
logger.info(f"MotionLoader: Loaded {self.num_frames} frames, duration={self.duration:.2f}s")
|
||||
|
||||
def update(self, time: float):
|
||||
"""Update motion to specific time (loops at duration)."""
|
||||
self.current_time = time % self.duration # Loop
|
||||
phase = self.current_time / self.duration
|
||||
|
||||
self.index_0 = int(phase * (self.num_frames - 1))
|
||||
self.index_1 = min(self.index_0 + 1, self.num_frames - 1)
|
||||
self.blend = (self.current_time - self.index_0 * self.dt) / self.dt
|
||||
|
||||
def get_joint_pos(self) -> np.ndarray:
|
||||
"""Get interpolated joint positions (29D)."""
|
||||
return self.dof_positions[self.index_0] * (1 - self.blend) + \
|
||||
self.dof_positions[self.index_1] * self.blend
|
||||
|
||||
def get_joint_vel(self) -> np.ndarray:
|
||||
"""Get interpolated joint velocities (29D)."""
|
||||
return self.dof_velocities[self.index_0] * (1 - self.blend) + \
|
||||
self.dof_velocities[self.index_1] * self.blend
|
||||
|
||||
def get_root_quat_wxyz(self) -> np.ndarray:
|
||||
"""Get interpolated root quaternion [w, x, y, z]."""
|
||||
# Spherical linear interpolation (SLERP)
|
||||
q0 = self.root_quaternions_xyzw[self.index_0] # [x, y, z, w]
|
||||
q1 = self.root_quaternions_xyzw[self.index_1]
|
||||
|
||||
# Convert to scipy format [x, y, z, w]
|
||||
r0 = R.from_quat(q0)
|
||||
r1 = R.from_quat(q1)
|
||||
|
||||
# SLERP
|
||||
key_times = [0, 1]
|
||||
key_rots = R.from_quat([q0, q1])
|
||||
slerp = R.from_quat(key_rots.as_quat()) # Simplified - just use linear for now
|
||||
|
||||
# Linear interpolation for simplicity
|
||||
quat_xyzw = q0 * (1 - self.blend) + q1 * self.blend
|
||||
# Normalize
|
||||
quat_xyzw = quat_xyzw / np.linalg.norm(quat_xyzw)
|
||||
|
||||
# Convert to [w, x, y, z]
|
||||
return np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]], dtype=np.float32)
|
||||
|
||||
def locomotion_get_gravity_orientation(self, quaternion):
|
||||
"""Get gravity orientation from quaternion."""
|
||||
qw = quaternion[0]
|
||||
@@ -1287,8 +970,23 @@ class UnitreeG1(Robot):
|
||||
|
||||
# Run policy inference (ONNX) with 516D stacked observation
|
||||
obs_tensor = torch.from_numpy(self.groot_obs_stacked).unsqueeze(0)
|
||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
||||
ort_outs = self.policy.run(None, ort_inputs)
|
||||
|
||||
# Select appropriate policy based on command magnitude (dual-policy system)
|
||||
if self.policy_balance is not None and self.policy_walk is not None:
|
||||
# Dual-policy mode: switch between Balance and Walk
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
if cmd_magnitude < 0.05:
|
||||
# Use balance/standing policy for small commands
|
||||
selected_policy = self.policy_balance
|
||||
else:
|
||||
# Use walking policy for movement commands
|
||||
selected_policy = self.policy_walk
|
||||
else:
|
||||
# Single policy mode (fallback)
|
||||
selected_policy = self.policy
|
||||
|
||||
ort_inputs = {selected_policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
||||
ort_outs = selected_policy.run(None, ort_inputs)
|
||||
self.groot_action = ort_outs[0].squeeze()
|
||||
|
||||
# Zero out waist actions (yaw=12, roll=13, pitch=14) - only use leg actions (0-11)
|
||||
@@ -1487,10 +1185,7 @@ class UnitreeG1(Robot):
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion test sequence...")
|
||||
|
||||
# 1. Home the arms first
|
||||
logger.info("Homing arms to zero position...")
|
||||
#self.ctrl_dual_arm_go_home()
|
||||
|
||||
|
||||
# 2. Move legs to default position
|
||||
self.locomotion_move_to_default_pos()
|
||||
@@ -1532,354 +1227,6 @@ class UnitreeG1(Robot):
|
||||
logger.info("GR00T locomotion initialization complete! Policy is now running.")
|
||||
logger.info("516D observations (86D × 6 frames), 15D actions (legs + waist)")
|
||||
|
||||
def motion_imitation_run(self):
|
||||
"""Motion imitation policy loop - tracks reference motion (dance_102, etc)."""
|
||||
self.motion_counter += 1
|
||||
self.motion_elapsed_time = self.motion_counter * self.config.motion_control_dt
|
||||
|
||||
# Update motion loader to current time
|
||||
self.motion_loader.update(self.motion_elapsed_time)
|
||||
|
||||
# Get current lowstate
|
||||
lowstate = self.lowstate_buffer.GetData()
|
||||
if lowstate is None:
|
||||
return
|
||||
|
||||
# Get ALL 29 joint positions and velocities from robot
|
||||
# IMPORTANT: Convert from motor order to BFS order to match reference motion
|
||||
# The C++ code does: robot_bfs[i] = motor[joint_ids_map[i]]
|
||||
for i in range(29):
|
||||
motor_idx = self.config.motion_joint_ids_map[i]
|
||||
self.motion_qj_all[i] = lowstate.motor_state[motor_idx].q
|
||||
self.motion_dqj_all[i] = lowstate.motor_state[motor_idx].dq
|
||||
|
||||
# ======== 23 DOF MODE CONFIGURATION ========
|
||||
# For real robot - zeros out joints not present in 23 DOF hardware
|
||||
# Waist: yaw(12), pitch(14) | Wrist: L_pitch/yaw(20,21), R_pitch/yaw(27,28)
|
||||
USE_23DOF = True # Set to True for real robot without these joints
|
||||
JOINTS_TO_ZERO_23DOF = [12,14,20, 21, 27, 28]#12, 14, 20, 21, 27, 28]#
|
||||
|
||||
# Apply 23 DOF zeroing to robot observations if enabled
|
||||
if USE_23DOF:
|
||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
||||
self.motion_qj_all[joint_idx] = 0.0
|
||||
self.motion_dqj_all[joint_idx] = 0.0
|
||||
if self.motion_counter == 1:
|
||||
logger.info("="*60)
|
||||
logger.info("🤖 23 DOF MODE ENABLED")
|
||||
logger.info(f" Zeroing joints: {JOINTS_TO_ZERO_23DOF}")
|
||||
logger.info(" Waist: yaw(12), pitch(14)")
|
||||
logger.info(" Wrist L: pitch(20), yaw(21) | Wrist R: pitch(27), yaw(28)")
|
||||
logger.info(" Applied to: robot obs, reference motion, policy actions")
|
||||
logger.info("="*60)
|
||||
|
||||
# Get IMU data
|
||||
robot_quat = lowstate.imu_state.quaternion # [w, x, y, z]
|
||||
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32) # 3D
|
||||
|
||||
if self.policy is None:
|
||||
# DIRECT PLAYBACK MODE (no policy)
|
||||
motion_joint_pos_dfs = self.motion_loader.get_joint_pos()
|
||||
|
||||
# Zero out missing joints for 23 DOF mode
|
||||
if USE_23DOF:
|
||||
# Convert to BFS to zero out, then convert back
|
||||
motion_joint_pos_bfs_temp = np.zeros(29, dtype=np.float32)
|
||||
for i in range(29):
|
||||
motion_joint_pos_bfs_temp[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
|
||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
||||
motion_joint_pos_bfs_temp[joint_idx] = 0.0
|
||||
# Convert back to DFS for sending
|
||||
for i in range(29):
|
||||
motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]] = motion_joint_pos_bfs_temp[i]
|
||||
|
||||
for i in range(29):
|
||||
motor_idx = self.config.motion_joint_ids_map[i]
|
||||
csv_idx = self.config.motion_joint_ids_map[i]
|
||||
self.msg.motor_cmd[motor_idx].q = motion_joint_pos_dfs[csv_idx]
|
||||
self.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
else:
|
||||
# POLICY MODE - Full observation construction and inference
|
||||
|
||||
# ======== DEBUG TEST MODES ========
|
||||
# Mode 1: Direct playback (no policy) - set motion_policy_path = None in config instead
|
||||
# Mode 2: Send default pos (stand still) - TEST_SEND_DEFAULT_POS = True
|
||||
# Mode 3: Policy with zero reference - TEST_WITH_ZEROS = True, TEST_SEND_DEFAULT_POS = False
|
||||
# Mode 4: Policy with real reference - TEST_WITH_ZEROS = False, TEST_SEND_DEFAULT_POS = False
|
||||
TEST_WITH_ZEROS = False # True = use zero reference motion in observation
|
||||
TEST_SEND_DEFAULT_POS = False # True = bypass policy and send default pos (stand still)
|
||||
TEST_DIRECT_PLAYBACK = False # True = bypass policy and send reference motion directly
|
||||
|
||||
if TEST_DIRECT_PLAYBACK:
|
||||
# DEBUG: Play back reference motion without policy
|
||||
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D in DFS order
|
||||
|
||||
# Zero out missing joints for 23 DOF mode
|
||||
if USE_23DOF:
|
||||
# Convert to BFS to zero out, then convert back
|
||||
motion_joint_pos_bfs_temp = np.zeros(29, dtype=np.float32)
|
||||
for i in range(29):
|
||||
motion_joint_pos_bfs_temp[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
|
||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
||||
motion_joint_pos_bfs_temp[joint_idx] = 0.0
|
||||
# Convert back to DFS for sending
|
||||
for i in range(29):
|
||||
motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]] = motion_joint_pos_bfs_temp[i]
|
||||
|
||||
# Send directly to motors using joint_ids_map (same as direct playback mode)
|
||||
for i in range(29):
|
||||
motor_idx = self.config.motion_joint_ids_map[i]
|
||||
csv_idx = self.config.motion_joint_ids_map[i]
|
||||
self.msg.motor_cmd[motor_idx].q = motion_joint_pos_dfs[csv_idx]
|
||||
self.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
if self.motion_counter == 1:
|
||||
logger.info("="*60)
|
||||
logger.info("⚠️ DEBUG MODE: DIRECT PLAYBACK (reference motion, no policy)")
|
||||
logger.info("="*60)
|
||||
|
||||
target_joint_pos_bfs = None # Not used in this mode
|
||||
|
||||
else:
|
||||
# Run observation construction and policy
|
||||
if TEST_WITH_ZEROS:
|
||||
# Send zeros for reference motion
|
||||
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
|
||||
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
|
||||
if self.motion_counter == 1:
|
||||
logger.info("="*60)
|
||||
logger.info("⚠️ DEBUG MODE: Using ZERO reference motion + RUNNING POLICY")
|
||||
logger.info("="*60)
|
||||
else:
|
||||
# Get reference motion (DFS order from CSV)
|
||||
motion_joint_pos_dfs = self.motion_loader.get_joint_pos() # 29D
|
||||
motion_joint_vel_dfs = self.motion_loader.get_joint_vel() # 29D
|
||||
|
||||
# Convert from DFS to BFS order: bfs[i] = dfs[joint_ids_map[i]]
|
||||
motion_joint_pos_bfs = np.zeros(29, dtype=np.float32)
|
||||
motion_joint_vel_bfs = np.zeros(29, dtype=np.float32)
|
||||
for i in range(29):
|
||||
motion_joint_pos_bfs[i] = motion_joint_pos_dfs[self.config.motion_joint_ids_map[i]]
|
||||
motion_joint_vel_bfs[i] = motion_joint_vel_dfs[self.config.motion_joint_ids_map[i]]
|
||||
|
||||
# Zero out missing joints in reference motion for 23 DOF mode
|
||||
if USE_23DOF:
|
||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
||||
motion_joint_pos_bfs[joint_idx] = 0.0
|
||||
motion_joint_vel_bfs[joint_idx] = 0.0
|
||||
|
||||
# Compute motion_anchor_ori_b (6D rotation matrix representation)
|
||||
motion_quat_wxyz = self.motion_loader.get_root_quat_wxyz()
|
||||
robot_rot = R.from_quat([robot_quat[1], robot_quat[2], robot_quat[3], robot_quat[0]]).as_matrix()
|
||||
motion_rot = R.from_quat([motion_quat_wxyz[1], motion_quat_wxyz[2], motion_quat_wxyz[3], motion_quat_wxyz[0]]).as_matrix()
|
||||
relative_rot = robot_rot.T @ motion_rot
|
||||
motion_anchor_ori_b = np.array([relative_rot[0, 0], relative_rot[0, 1],
|
||||
relative_rot[1, 0], relative_rot[1, 1],
|
||||
relative_rot[2, 0], relative_rot[2, 1]], dtype=np.float32)
|
||||
|
||||
# Compute joint positions and velocities relative to default
|
||||
default_joint_pos = np.array(self.config.motion_default_joint_pos, dtype=np.float32)
|
||||
joint_pos_rel = self.motion_qj_all - default_joint_pos
|
||||
joint_vel_rel = self.motion_dqj_all.copy()
|
||||
|
||||
# Build 154D observation:
|
||||
# motion_command (58D) = joint_pos (29D) + joint_vel (29D) from reference
|
||||
# motion_anchor_ori_b (6D)
|
||||
# base_ang_vel (3D)
|
||||
# joint_pos_rel (29D)
|
||||
# joint_vel_rel (29D)
|
||||
# last_action (29D)
|
||||
self.motion_obs[0:29] = motion_joint_pos_bfs
|
||||
self.motion_obs[29:58] = motion_joint_vel_bfs
|
||||
self.motion_obs[58:64] = motion_anchor_ori_b
|
||||
self.motion_obs[64:67] = ang_vel
|
||||
self.motion_obs[67:96] = joint_pos_rel
|
||||
self.motion_obs[96:125] = joint_vel_rel
|
||||
self.motion_obs[125:154] = self.motion_action
|
||||
|
||||
if TEST_SEND_DEFAULT_POS:
|
||||
# DEBUG: Just send default positions (should make robot stand still)
|
||||
target_joint_pos_bfs = default_joint_pos.copy()
|
||||
if self.motion_counter == 1:
|
||||
logger.info("="*60)
|
||||
logger.info("⚠️ DEBUG MODE: Sending DEFAULT positions (NO POLICY)")
|
||||
logger.info("="*60)
|
||||
logger.info(f" Default pos BFS[0:5]: {target_joint_pos_bfs[0:5]}")
|
||||
if self.motion_counter % 50 == 0:
|
||||
logger.info(f" [DEFAULT MODE] Sending: [{target_joint_pos_bfs[0]:.4f}, {target_joint_pos_bfs[6]:.4f}, {target_joint_pos_bfs[12]:.4f}]")
|
||||
logger.info(f" [DEFAULT MODE] Robot at: [{self.motion_qj_all[0]:.4f}, {self.motion_qj_all[6]:.4f}, {self.motion_qj_all[12]:.4f}]")
|
||||
else:
|
||||
# Run ONNX policy inference
|
||||
obs_tensor = torch.from_numpy(self.motion_obs).unsqueeze(0)
|
||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_tensor.cpu().numpy()}
|
||||
ort_outs = self.policy.run(None, ort_inputs)
|
||||
self.motion_action = ort_outs[0].squeeze() # 29D action in BFS order
|
||||
|
||||
# Zero out missing joints in policy actions for 23 DOF mode
|
||||
if USE_23DOF:
|
||||
for joint_idx in JOINTS_TO_ZERO_23DOF:
|
||||
self.motion_action[joint_idx] = 0.0
|
||||
|
||||
# Process actions: scale and add offset
|
||||
action_scale = np.array(self.config.motion_action_scale, dtype=np.float32)
|
||||
target_joint_pos_bfs = default_joint_pos + self.motion_action * action_scale
|
||||
|
||||
# Send commands to motors: motor[joint_ids_map[i]] = action[i]
|
||||
for i in range(29):
|
||||
motor_idx = self.config.motion_joint_ids_map[i]
|
||||
self.msg.motor_cmd[motor_idx].q = target_joint_pos_bfs[i]
|
||||
self.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Debug print (only when running policy, not in TEST_SEND_DEFAULT_POS or TEST_DIRECT_PLAYBACK mode)
|
||||
if self.motion_counter == 1 and self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
|
||||
logger.info("="*60)
|
||||
logger.info("POLICY MODE OBSERVATION CHECK (First iteration)")
|
||||
logger.info("="*60)
|
||||
logger.info(f"Reference motion (BFS) samples: [{motion_joint_pos_bfs[0]:.3f}, {motion_joint_pos_bfs[6]:.3f}, {motion_joint_pos_bfs[12]:.3f}]")
|
||||
logger.info(f"Robot joints (BFS) samples: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
||||
logger.info(f"Default positions samples: [{default_joint_pos[0]:.3f}, {default_joint_pos[6]:.3f}, {default_joint_pos[12]:.3f}]")
|
||||
logger.info(f"Joint pos rel samples: [{joint_pos_rel[0]:.3f}, {joint_pos_rel[6]:.3f}, {joint_pos_rel[12]:.3f}]")
|
||||
logger.info(f"Joint vel rel samples: [{joint_vel_rel[0]:.3f}, {joint_vel_rel[6]:.3f}, {joint_vel_rel[12]:.3f}]")
|
||||
logger.info(f"Angular velocity: [{ang_vel[0]:.3f}, {ang_vel[1]:.3f}, {ang_vel[2]:.3f}]")
|
||||
logger.info(f"Motion anchor ori: [{motion_anchor_ori_b[0]:.3f}, ..., {motion_anchor_ori_b[5]:.3f}]")
|
||||
logger.info(f"Observation breakdown:")
|
||||
logger.info(f" [0:29] motion_cmd_pos: range [{self.motion_obs[0:29].min():.3f}, {self.motion_obs[0:29].max():.3f}]")
|
||||
logger.info(f" [29:58] motion_cmd_vel: range [{self.motion_obs[29:58].min():.3f}, {self.motion_obs[29:58].max():.3f}]")
|
||||
logger.info(f" [58:64] anchor_ori: range [{self.motion_obs[58:64].min():.3f}, {self.motion_obs[58:64].max():.3f}]")
|
||||
logger.info(f" [64:67] ang_vel: range [{self.motion_obs[64:67].min():.3f}, {self.motion_obs[64:67].max():.3f}]")
|
||||
logger.info(f" [67:96] joint_pos_rel: range [{self.motion_obs[67:96].min():.3f}, {self.motion_obs[67:96].max():.3f}]")
|
||||
logger.info(f" [96:125] joint_vel_rel: range [{self.motion_obs[96:125].min():.3f}, {self.motion_obs[96:125].max():.3f}]")
|
||||
logger.info(f" [125:154] last_action: range [{self.motion_obs[125:154].min():.3f}, {self.motion_obs[125:154].max():.3f}]")
|
||||
logger.info(f"Full obs range: [{self.motion_obs.min():.3f}, {self.motion_obs.max():.3f}]")
|
||||
logger.info(f"Action output (first): [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
||||
logger.info(f"Action scale samples: [{action_scale[0]:.3f}, {action_scale[6]:.3f}, {action_scale[12]:.3f}]")
|
||||
logger.info(f"Target positions samples: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
||||
logger.info("="*60)
|
||||
|
||||
if self.motion_counter % 50 == 0:
|
||||
if self.policy is None:
|
||||
mode = "DIRECT"
|
||||
elif TEST_DIRECT_PLAYBACK:
|
||||
mode = "DIRECT_DEBUG"
|
||||
elif TEST_SEND_DEFAULT_POS:
|
||||
mode = "DEFAULT_POS"
|
||||
elif TEST_WITH_ZEROS:
|
||||
mode = "POLICY_ZEROS"
|
||||
else:
|
||||
mode = "POLICY"
|
||||
logger.info(f"Motion {mode}: t={self.motion_elapsed_time:.2f}s, frame={self.motion_loader.index_0}/{self.motion_loader.num_frames}")
|
||||
if self.policy and not TEST_SEND_DEFAULT_POS and not TEST_DIRECT_PLAYBACK:
|
||||
logger.info(f" Policy action range: [{self.motion_action.min():.3f}, {self.motion_action.max():.3f}]")
|
||||
logger.info(f" Sample actions[0,6,12]: [{self.motion_action[0]:.3f}, {self.motion_action[6]:.3f}, {self.motion_action[12]:.3f}]")
|
||||
logger.info(f" Target pos (after scale)[0,6,12]: [{target_joint_pos_bfs[0]:.3f}, {target_joint_pos_bfs[6]:.3f}, {target_joint_pos_bfs[12]:.3f}]")
|
||||
logger.info(f" Robot pos (BFS)[0,6,12]: [{self.motion_qj_all[0]:.3f}, {self.motion_qj_all[6]:.3f}, {self.motion_qj_all[12]:.3f}]")
|
||||
|
||||
# Send command
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
|
||||
def _motion_imitation_thread_loop(self):
|
||||
"""Background thread that runs the motion imitation policy at specified rate."""
|
||||
logger.info("Motion imitation thread started")
|
||||
while self.motion_imitation_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.motion_imitation_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in motion imitation loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.config.motion_control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Motion imitation thread stopped")
|
||||
|
||||
def start_motion_imitation_thread(self):
|
||||
"""Start the background motion imitation control thread."""
|
||||
if not self.config.motion_imitation_control:
|
||||
logger.warning("motion_imitation_control is False, cannot start thread")
|
||||
return
|
||||
|
||||
if self.motion_imitation_running:
|
||||
logger.warning("Motion imitation thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting motion imitation control thread...")
|
||||
self.motion_imitation_running = True
|
||||
self.motion_imitation_thread = threading.Thread(target=self._motion_imitation_thread_loop, daemon=True)
|
||||
self.motion_imitation_thread.start()
|
||||
logger.info("Motion imitation control thread started!")
|
||||
|
||||
def stop_motion_imitation_thread(self):
|
||||
"""Stop the background motion imitation control thread."""
|
||||
if not self.motion_imitation_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping motion imitation control thread...")
|
||||
self.motion_imitation_running = False
|
||||
if self.motion_imitation_thread:
|
||||
self.motion_imitation_thread.join(timeout=2.0)
|
||||
logger.info("Motion imitation control thread stopped")
|
||||
|
||||
def init_motion_imitation(self):
|
||||
"""Initialize motion imitation - move to default standing pose and start policy."""
|
||||
if not self.config.motion_imitation_control:
|
||||
logger.warning("motion_imitation_control is False, cannot run initialization")
|
||||
return
|
||||
|
||||
logger.info("Starting motion imitation initialization...")
|
||||
|
||||
# Move to default standing position
|
||||
logger.info("Moving to default standing position...")
|
||||
total_time = 3.0
|
||||
num_steps = int(total_time / self.config.motion_control_dt)
|
||||
|
||||
# Get current positions (in motor order)
|
||||
current_q_motor = self.get_current_motor_q()
|
||||
|
||||
# target_q is in BFS order from config, need to convert to motor order
|
||||
target_q_bfs = np.array(self.config.motion_default_joint_pos, dtype=np.float32)
|
||||
target_q_motor = np.zeros(29, dtype=np.float32)
|
||||
for i in range(29):
|
||||
motor_idx = self.config.motion_joint_ids_map[i]
|
||||
target_q_motor[motor_idx] = target_q_bfs[i]
|
||||
|
||||
# Interpolate to target (both in motor order now)
|
||||
for i in range(num_steps):
|
||||
alpha = i / num_steps
|
||||
for motor_idx in range(29):
|
||||
self.msg.motor_cmd[motor_idx].q = current_q_motor[motor_idx] * (1 - alpha) + target_q_motor[motor_idx] * alpha
|
||||
self.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.msg.motor_cmd[motor_idx].kp = self.config.motion_stiffness[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].kd = self.config.motion_damping[motor_idx]
|
||||
self.msg.motor_cmd[motor_idx].tau = 0
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
time.sleep(self.config.motion_control_dt)
|
||||
|
||||
logger.info("Reached default position")
|
||||
|
||||
# Wait 2 seconds
|
||||
time.sleep(2.0)
|
||||
|
||||
# Start motion imitation policy thread
|
||||
logger.info("Starting motion imitation policy control...")
|
||||
self.start_motion_imitation_thread()
|
||||
|
||||
logger.info("Motion imitation initialization complete! Policy is now running.")
|
||||
logger.info(f"154D observations, 29D actions. Motion duration: {self.motion_loader.duration:.2f}s")
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
|
||||
Reference in New Issue
Block a user