mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
add amazon policies
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -26,7 +26,7 @@ from ..config import RobotConfig
|
|||||||
class UnitreeG1Config(RobotConfig):
|
class UnitreeG1Config(RobotConfig):
|
||||||
# id: str = "unitree_g1"
|
# id: str = "unitree_g1"
|
||||||
motion_mode: bool = False
|
motion_mode: bool = False
|
||||||
simulation_mode: bool = True
|
simulation_mode: bool = False
|
||||||
kp_high = 40.0
|
kp_high = 40.0
|
||||||
kd_high = 3.0
|
kd_high = 3.0
|
||||||
kp_low = 80.0
|
kp_low = 80.0
|
||||||
@@ -56,13 +56,15 @@ class UnitreeG1Config(RobotConfig):
|
|||||||
# This robot class ONLY uses sockets to communicate with a bridge on the Orin
|
# This robot class ONLY uses sockets to communicate with a bridge on the Orin
|
||||||
# Run 'python dds_to_socket.py' on the Orin first, then set this to the Orin's IP
|
# Run 'python dds_to_socket.py' on the Orin first, then set this to the Orin's IP
|
||||||
# Example: socket_host="192.168.123.164" (Orin's wlan0 IP)
|
# Example: socket_host="192.168.123.164" (Orin's wlan0 IP)
|
||||||
socket_host: str | None = None # = "172.18.129.215"#
|
socket_host = "172.18.129.215"#
|
||||||
socket_port: int | None = None
|
socket_port: int | None = None
|
||||||
|
|
||||||
# Locomotion control
|
# Locomotion control
|
||||||
locomotion_control: bool = False
|
locomotion_control: bool = True
|
||||||
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
|
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
|
||||||
policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx"
|
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx"
|
||||||
|
policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/amazon_fastsac_g1_29dof.onnx"
|
||||||
|
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/amazon_ppo_g1_29dof.onnx"
|
||||||
|
|
||||||
# Motion imitation (dance_102, gangnam_style, etc)
|
# Motion imitation (dance_102, gangnam_style, etc)
|
||||||
motion_imitation_control: bool = False
|
motion_imitation_control: bool = False
|
||||||
@@ -100,7 +102,81 @@ class UnitreeG1Config(RobotConfig):
|
|||||||
# GR00T-specific scaling (different from regular locomotion!)
|
# GR00T-specific scaling (different from regular locomotion!)
|
||||||
groot_ang_vel_scale: float = 0.25 # GR00T uses 0.5, not 0.25
|
groot_ang_vel_scale: float = 0.25 # GR00T uses 0.5, not 0.25
|
||||||
groot_cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25]) # yaw is 0.5 for GR00T
|
groot_cmd_scale: list = field(default_factory=lambda: [2.0, 2.0, 0.25]) # yaw is 0.5 for GR00T
|
||||||
num_locomotion_actions: int = 12
|
|
||||||
num_locomotion_obs: int = 47
|
# Locomotion dimensions (12-DOF legs-only vs 29-DOF whole-body)
|
||||||
|
num_locomotion_actions: int = 29 # 12 for legs-only, 29 for whole-body
|
||||||
|
num_locomotion_obs: int = 100 # 47 for legs-only (12-DOF), 100 for whole-body (29-DOF)
|
||||||
max_cmd: list = field(default_factory=lambda: [0.8, 0.5, 1.57])
|
max_cmd: list = field(default_factory=lambda: [0.8, 0.5, 1.57])
|
||||||
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
|
locomotion_imu_type: str = "pelvis" # "torso" or "pelvis"
|
||||||
|
|
||||||
|
# 29-DOF whole-body locomotion parameters
|
||||||
|
default_all_joint_angles: list = field(default_factory=lambda: [
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg (from holosoma)
|
||||||
|
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg (from holosoma)
|
||||||
|
0.0, 0.0, 0.0, # waist (yaw, roll, pitch)
|
||||||
|
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm (from holosoma)
|
||||||
|
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm (from holosoma)
|
||||||
|
])
|
||||||
|
# KP/KD values from holosoma (tuned for G1 hardware)
|
||||||
|
all_joint_kps: list = field(default_factory=lambda: [
|
||||||
|
40.179238471, # left_hip_pitch
|
||||||
|
99.098427777, # left_hip_roll
|
||||||
|
40.179238471, # left_hip_yaw
|
||||||
|
99.098427777, # left_knee
|
||||||
|
28.501246196, # left_ankle_pitch
|
||||||
|
28.501246196, # left_ankle_roll
|
||||||
|
40.179238471, # right_hip_pitch
|
||||||
|
99.098427777, # right_hip_roll
|
||||||
|
40.179238471, # right_hip_yaw
|
||||||
|
99.098427777, # right_knee
|
||||||
|
28.501246196, # right_ankle_pitch
|
||||||
|
28.501246196, # right_ankle_roll
|
||||||
|
40.179238471, # waist_yaw
|
||||||
|
28.501246196, # waist_roll
|
||||||
|
28.501246196, # waist_pitch
|
||||||
|
14.250623098, # left_shoulder_pitch
|
||||||
|
14.250623098, # left_shoulder_roll
|
||||||
|
14.250623098, # left_shoulder_yaw
|
||||||
|
14.250623098, # left_elbow
|
||||||
|
14.250623098, # left_wrist_roll
|
||||||
|
16.778327481, # left_wrist_pitch
|
||||||
|
16.778327481, # left_wrist_yaw
|
||||||
|
14.250623098, # right_shoulder_pitch
|
||||||
|
14.250623098, # right_shoulder_roll
|
||||||
|
14.250623098, # right_shoulder_yaw
|
||||||
|
14.250623098, # right_elbow
|
||||||
|
14.250623098, # right_wrist_roll
|
||||||
|
16.778327481, # right_wrist_pitch
|
||||||
|
16.778327481, # right_wrist_yaw
|
||||||
|
])
|
||||||
|
all_joint_kds: list = field(default_factory=lambda: [
|
||||||
|
2.557889765, # left_hip_pitch
|
||||||
|
6.308801854, # left_hip_roll
|
||||||
|
2.557889765, # left_hip_yaw
|
||||||
|
6.308801854, # left_knee
|
||||||
|
1.814445687, # left_ankle_pitch
|
||||||
|
1.814445687, # left_ankle_roll
|
||||||
|
2.557889765, # right_hip_pitch
|
||||||
|
6.308801854, # right_hip_roll
|
||||||
|
2.557889765, # right_hip_yaw
|
||||||
|
6.308801854, # right_knee
|
||||||
|
1.814445687, # right_ankle_pitch
|
||||||
|
1.814445687, # right_ankle_roll
|
||||||
|
2.557889765, # waist_yaw
|
||||||
|
1.814445687, # waist_roll
|
||||||
|
1.814445687, # waist_pitch
|
||||||
|
0.907222843, # left_shoulder_pitch
|
||||||
|
0.907222843, # left_shoulder_roll
|
||||||
|
0.907222843, # left_shoulder_yaw
|
||||||
|
0.907222843, # left_elbow
|
||||||
|
0.907222843, # left_wrist_roll
|
||||||
|
1.068141502, # left_wrist_pitch
|
||||||
|
1.068141502, # left_wrist_yaw
|
||||||
|
0.907222843, # right_shoulder_pitch
|
||||||
|
0.907222843, # right_shoulder_roll
|
||||||
|
0.907222843, # right_shoulder_yaw
|
||||||
|
0.907222843, # right_elbow
|
||||||
|
0.907222843, # right_wrist_roll
|
||||||
|
1.068141502, # right_wrist_pitch
|
||||||
|
1.068141502, # right_wrist_yaw
|
||||||
|
])
|
||||||
@@ -0,0 +1,347 @@
|
|||||||
|
import cv2
|
||||||
|
import zmq
|
||||||
|
import time
|
||||||
|
import struct
|
||||||
|
from collections import deque
|
||||||
|
import numpy as np
|
||||||
|
import pyrealsense2 as rs
|
||||||
|
import logging_mp
|
||||||
|
|
||||||
|
logger_mp = logging_mp.get_logger(__name__, level=logging_mp.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
class RealSenseCamera(object):
|
||||||
|
def __init__(self, img_shape, fps, serial_number=None, enable_depth=False) -> None:
|
||||||
|
"""
|
||||||
|
img_shape: [height, width]
|
||||||
|
serial_number: serial number
|
||||||
|
"""
|
||||||
|
self.img_shape = img_shape
|
||||||
|
self.fps = fps
|
||||||
|
self.serial_number = serial_number
|
||||||
|
self.enable_depth = enable_depth
|
||||||
|
|
||||||
|
align_to = rs.stream.color
|
||||||
|
self.align = rs.align(align_to)
|
||||||
|
self.init_realsense()
|
||||||
|
|
||||||
|
def init_realsense(self):
|
||||||
|
self.pipeline = rs.pipeline()
|
||||||
|
config = rs.config()
|
||||||
|
if self.serial_number is not None:
|
||||||
|
config.enable_device(self.serial_number)
|
||||||
|
|
||||||
|
config.enable_stream(rs.stream.color, self.img_shape[1], self.img_shape[0], rs.format.bgr8, self.fps)
|
||||||
|
|
||||||
|
if self.enable_depth:
|
||||||
|
config.enable_stream(rs.stream.depth, self.img_shape[1], self.img_shape[0], rs.format.z16, self.fps)
|
||||||
|
|
||||||
|
profile = self.pipeline.start(config)
|
||||||
|
self._device = profile.get_device()
|
||||||
|
if self._device is None:
|
||||||
|
logger_mp.error("[Image Server] pipe_profile.get_device() is None .")
|
||||||
|
if self.enable_depth:
|
||||||
|
assert self._device is not None
|
||||||
|
depth_sensor = self._device.first_depth_sensor()
|
||||||
|
self.g_depth_scale = depth_sensor.get_depth_scale()
|
||||||
|
|
||||||
|
self.intrinsics = profile.get_stream(rs.stream.color).as_video_stream_profile().get_intrinsics()
|
||||||
|
|
||||||
|
def get_frame(self):
|
||||||
|
frames = self.pipeline.wait_for_frames()
|
||||||
|
aligned_frames = self.align.process(frames)
|
||||||
|
color_frame = aligned_frames.get_color_frame()
|
||||||
|
|
||||||
|
if self.enable_depth:
|
||||||
|
depth_frame = aligned_frames.get_depth_frame()
|
||||||
|
|
||||||
|
if not color_frame:
|
||||||
|
return None
|
||||||
|
|
||||||
|
color_image = np.asanyarray(color_frame.get_data())
|
||||||
|
# color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||||
|
depth_image = np.asanyarray(depth_frame.get_data()) if self.enable_depth else None
|
||||||
|
return color_image, depth_image
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.pipeline.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class OpenCVCamera:
|
||||||
|
def __init__(self, device_id, img_shape, fps):
|
||||||
|
"""
|
||||||
|
decive_id: /dev/video* or *
|
||||||
|
img_shape: [height, width]
|
||||||
|
"""
|
||||||
|
self.id = device_id
|
||||||
|
self.fps = fps
|
||||||
|
self.img_shape = img_shape
|
||||||
|
self.cap = cv2.VideoCapture(self.id, cv2.CAP_V4L2)
|
||||||
|
self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter.fourcc("M", "J", "P", "G"))
|
||||||
|
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.img_shape[0])
|
||||||
|
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.img_shape[1])
|
||||||
|
self.cap.set(cv2.CAP_PROP_FPS, self.fps)
|
||||||
|
|
||||||
|
# Test if the camera can read frames
|
||||||
|
if not self._can_read_frame():
|
||||||
|
logger_mp.error(
|
||||||
|
f"[Image Server] Camera {self.id} Error: Failed to initialize the camera or read frames. Exiting..."
|
||||||
|
)
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
def _can_read_frame(self):
|
||||||
|
success, _ = self.cap.read()
|
||||||
|
return success
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.cap.release()
|
||||||
|
|
||||||
|
def get_frame(self):
|
||||||
|
ret, color_image = self.cap.read()
|
||||||
|
if not ret:
|
||||||
|
return None
|
||||||
|
return color_image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServer:
|
||||||
|
def __init__(self, config, port=5554, Unit_Test=False):
|
||||||
|
"""
|
||||||
|
config example1:
|
||||||
|
{
|
||||||
|
'fps':30 # frame per second
|
||||||
|
'head_camera_type': 'opencv', # opencv or realsense
|
||||||
|
'head_camera_image_shape': [480, 1280], # Head camera resolution [height, width]
|
||||||
|
'head_camera_id_numbers': [0], # '/dev/video0' (opencv)
|
||||||
|
'wrist_camera_type': 'realsense',
|
||||||
|
'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
|
||||||
|
'wrist_camera_id_numbers': ["218622271789", "241222076627"], # realsense camera's serial number
|
||||||
|
}
|
||||||
|
|
||||||
|
config example2:
|
||||||
|
{
|
||||||
|
'fps':30 # frame per second
|
||||||
|
'head_camera_type': 'realsense', # opencv or realsense
|
||||||
|
'head_camera_image_shape': [480, 640], # Head camera resolution [height, width]
|
||||||
|
'head_camera_id_numbers': ["218622271739"], # realsense camera's serial number
|
||||||
|
'wrist_camera_type': 'opencv',
|
||||||
|
'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
|
||||||
|
'wrist_camera_id_numbers': [0,1], # '/dev/video0' and '/dev/video1' (opencv)
|
||||||
|
}
|
||||||
|
|
||||||
|
If you are not using the wrist camera, you can comment out its configuration, like this below:
|
||||||
|
config:
|
||||||
|
{
|
||||||
|
'fps':30 # frame per second
|
||||||
|
'head_camera_type': 'opencv', # opencv or realsense
|
||||||
|
'head_camera_image_shape': [480, 1280], # Head camera resolution [height, width]
|
||||||
|
'head_camera_id_numbers': [0], # '/dev/video0' (opencv)
|
||||||
|
#'wrist_camera_type': 'realsense',
|
||||||
|
#'wrist_camera_image_shape': [480, 640], # Wrist camera resolution [height, width]
|
||||||
|
#'wrist_camera_id_numbers': ["218622271789", "241222076627"], # serial number (realsense)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
logger_mp.info(config)
|
||||||
|
self.fps = config.get("fps", 30)
|
||||||
|
self.head_camera_type = config.get("head_camera_type", "opencv")
|
||||||
|
self.head_image_shape = config.get("head_camera_image_shape", [480, 640]) # (height, width)
|
||||||
|
self.head_camera_id_numbers = config.get("head_camera_id_numbers", [0])
|
||||||
|
|
||||||
|
self.wrist_camera_type = config.get("wrist_camera_type", None)
|
||||||
|
self.wrist_image_shape = config.get("wrist_camera_image_shape", [480, 640]) # (height, width)
|
||||||
|
self.wrist_camera_id_numbers = config.get("wrist_camera_id_numbers", None)
|
||||||
|
|
||||||
|
self.port = port
|
||||||
|
self.Unit_Test = Unit_Test
|
||||||
|
|
||||||
|
# Initialize head cameras
|
||||||
|
self.head_cameras = []
|
||||||
|
if self.head_camera_type == "opencv":
|
||||||
|
for device_id in self.head_camera_id_numbers:
|
||||||
|
camera = OpenCVCamera(device_id=device_id, img_shape=self.head_image_shape, fps=self.fps)
|
||||||
|
self.head_cameras.append(camera)
|
||||||
|
elif self.head_camera_type == "realsense":
|
||||||
|
for serial_number in self.head_camera_id_numbers:
|
||||||
|
camera = RealSenseCamera(img_shape=self.head_image_shape, fps=self.fps, serial_number=serial_number)
|
||||||
|
self.head_cameras.append(camera)
|
||||||
|
else:
|
||||||
|
logger_mp.warning(f"[Image Server] Unsupported head_camera_type: {self.head_camera_type}")
|
||||||
|
|
||||||
|
# Initialize wrist cameras if provided
|
||||||
|
self.wrist_cameras = []
|
||||||
|
if self.wrist_camera_type and self.wrist_camera_id_numbers:
|
||||||
|
if self.wrist_camera_type == "opencv":
|
||||||
|
for device_id in self.wrist_camera_id_numbers:
|
||||||
|
camera = OpenCVCamera(device_id=device_id, img_shape=self.wrist_image_shape, fps=self.fps)
|
||||||
|
self.wrist_cameras.append(camera)
|
||||||
|
elif self.wrist_camera_type == "realsense":
|
||||||
|
for serial_number in self.wrist_camera_id_numbers:
|
||||||
|
camera = RealSenseCamera(
|
||||||
|
img_shape=self.wrist_image_shape, fps=self.fps, serial_number=serial_number
|
||||||
|
)
|
||||||
|
self.wrist_cameras.append(camera)
|
||||||
|
else:
|
||||||
|
logger_mp.warning(f"[Image Server] Unsupported wrist_camera_type: {self.wrist_camera_type}")
|
||||||
|
|
||||||
|
# Set ZeroMQ context and socket
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.PUB)
|
||||||
|
self.socket.bind(f"tcp://*:{self.port}")
|
||||||
|
|
||||||
|
if self.Unit_Test:
|
||||||
|
self._init_performance_metrics()
|
||||||
|
|
||||||
|
for cam in self.head_cameras:
|
||||||
|
if isinstance(cam, OpenCVCamera):
|
||||||
|
logger_mp.info(
|
||||||
|
f"[Image Server] Head camera {cam.id} resolution: {cam.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)} x {cam.cap.get(cv2.CAP_PROP_FRAME_WIDTH)}"
|
||||||
|
)
|
||||||
|
elif isinstance(cam, RealSenseCamera):
|
||||||
|
logger_mp.info(
|
||||||
|
f"[Image Server] Head camera {cam.serial_number} resolution: {cam.img_shape[0]} x {cam.img_shape[1]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger_mp.warning("[Image Server] Unknown camera type in head_cameras.")
|
||||||
|
|
||||||
|
for cam in self.wrist_cameras:
|
||||||
|
if isinstance(cam, OpenCVCamera):
|
||||||
|
logger_mp.info(
|
||||||
|
f"[Image Server] Wrist camera {cam.id} resolution: {cam.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)} x {cam.cap.get(cv2.CAP_PROP_FRAME_WIDTH)}"
|
||||||
|
)
|
||||||
|
elif isinstance(cam, RealSenseCamera):
|
||||||
|
logger_mp.info(
|
||||||
|
f"[Image Server] Wrist camera {cam.serial_number} resolution: {cam.img_shape[0]} x {cam.img_shape[1]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger_mp.warning("[Image Server] Unknown camera type in wrist_cameras.")
|
||||||
|
|
||||||
|
logger_mp.info("[Image Server] Image server has started, waiting for client connections...")
|
||||||
|
|
||||||
|
def _init_performance_metrics(self):
|
||||||
|
self.frame_count = 0 # Total frames sent
|
||||||
|
self.time_window = 1.0 # Time window for FPS calculation (in seconds)
|
||||||
|
self.frame_times = deque() # Timestamps of frames sent within the time window
|
||||||
|
self.start_time = time.time() # Start time of the streaming
|
||||||
|
|
||||||
|
def _update_performance_metrics(self, current_time):
|
||||||
|
# Add current time to frame times deque
|
||||||
|
self.frame_times.append(current_time)
|
||||||
|
# Remove timestamps outside the time window
|
||||||
|
while self.frame_times and self.frame_times[0] < current_time - self.time_window:
|
||||||
|
self.frame_times.popleft()
|
||||||
|
# Increment frame count
|
||||||
|
self.frame_count += 1
|
||||||
|
|
||||||
|
def _print_performance_metrics(self, current_time):
|
||||||
|
if self.frame_count % 30 == 0:
|
||||||
|
elapsed_time = current_time - self.start_time
|
||||||
|
real_time_fps = len(self.frame_times) / self.time_window
|
||||||
|
logger_mp.info(
|
||||||
|
f"[Image Server] Real-time FPS: {real_time_fps:.2f}, Total frames sent: {self.frame_count}, Elapsed time: {elapsed_time:.2f} sec"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
for cam in self.head_cameras:
|
||||||
|
cam.release()
|
||||||
|
for cam in self.wrist_cameras:
|
||||||
|
cam.release()
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
logger_mp.info("[Image Server] The server has been closed.")
|
||||||
|
|
||||||
|
def send_process(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
head_frames = []
|
||||||
|
for cam in self.head_cameras:
|
||||||
|
if self.head_camera_type == "opencv":
|
||||||
|
color_image = cam.get_frame()
|
||||||
|
if color_image is None:
|
||||||
|
logger_mp.error("[Image Server] Head camera frame read is error.")
|
||||||
|
break
|
||||||
|
elif self.head_camera_type == "realsense":
|
||||||
|
color_image, depth_iamge = cam.get_frame()
|
||||||
|
if color_image is None:
|
||||||
|
logger_mp.error("[Image Server] Head camera frame read is error.")
|
||||||
|
break
|
||||||
|
head_frames.append(color_image)
|
||||||
|
if len(head_frames) != len(self.head_cameras):
|
||||||
|
break
|
||||||
|
head_color = cv2.hconcat(head_frames)
|
||||||
|
|
||||||
|
if self.wrist_cameras:
|
||||||
|
wrist_frames = []
|
||||||
|
for cam in self.wrist_cameras:
|
||||||
|
if self.wrist_camera_type == "opencv":
|
||||||
|
color_image = cam.get_frame()
|
||||||
|
if color_image is None:
|
||||||
|
logger_mp.error("[Image Server] Wrist camera frame read is error.")
|
||||||
|
break
|
||||||
|
elif self.wrist_camera_type == "realsense":
|
||||||
|
color_image, depth_iamge = cam.get_frame()
|
||||||
|
if color_image is None:
|
||||||
|
logger_mp.error("[Image Server] Wrist camera frame read is error.")
|
||||||
|
break
|
||||||
|
wrist_frames.append(color_image)
|
||||||
|
wrist_color = cv2.hconcat(wrist_frames)
|
||||||
|
|
||||||
|
# Concatenate head and wrist frames
|
||||||
|
full_color = cv2.hconcat([head_color, wrist_color])
|
||||||
|
else:
|
||||||
|
full_color = head_color
|
||||||
|
|
||||||
|
ret, buffer = cv2.imencode(".jpg", full_color)
|
||||||
|
if not ret:
|
||||||
|
logger_mp.error("[Image Server] Frame imencode is failed.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
jpg_bytes = buffer.tobytes()
|
||||||
|
|
||||||
|
if self.Unit_Test:
|
||||||
|
timestamp = time.time()
|
||||||
|
frame_id = self.frame_count
|
||||||
|
header = struct.pack("dI", timestamp, frame_id) # 8-byte double, 4-byte unsigned int
|
||||||
|
message = header + jpg_bytes
|
||||||
|
else:
|
||||||
|
message = jpg_bytes
|
||||||
|
|
||||||
|
self.socket.send(message)
|
||||||
|
|
||||||
|
if self.Unit_Test:
|
||||||
|
current_time = time.time()
|
||||||
|
self._update_performance_metrics(current_time)
|
||||||
|
self._print_performance_metrics(current_time)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger_mp.warning("[Image Server] Interrupted by user.")
|
||||||
|
finally:
|
||||||
|
self._close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# config = {
|
||||||
|
# "fps": 30,
|
||||||
|
# "head_camera_type": "opencv",
|
||||||
|
# "head_camera_image_shape": [480, 1280], # Head camera resolution
|
||||||
|
# "head_camera_id_numbers": [0],
|
||||||
|
# "wrist_camera_type": "opencv",
|
||||||
|
# "wrist_camera_image_shape": [480, 640], # Wrist camera resolution
|
||||||
|
# "wrist_camera_id_numbers": [2, 4],
|
||||||
|
#
|
||||||
|
#infrared
|
||||||
|
# config = {
|
||||||
|
# "fps": 30,
|
||||||
|
# "head_camera_type": "opencv",
|
||||||
|
# "head_camera_image_shape": [480, 640],
|
||||||
|
# "head_camera_id_numbers": [2], # <-- wrist cam that reported 480x640
|
||||||
|
# # no wrist_* keys
|
||||||
|
# }
|
||||||
|
#rgb
|
||||||
|
config = {
|
||||||
|
"fps": 30,
|
||||||
|
"head_camera_type": "opencv",
|
||||||
|
"head_camera_image_shape": [480,640], # match the device
|
||||||
|
"head_camera_id_numbers": [4], # /dev/video4 is RGB
|
||||||
|
}
|
||||||
|
|
||||||
|
server = ImageServer(config, Unit_Test=False)
|
||||||
|
server.send_process()
|
||||||
@@ -243,13 +243,17 @@ class UnitreeG1(Robot):
|
|||||||
# - Arm thread: controls arms (indices 15-28)
|
# - Arm thread: controls arms (indices 15-28)
|
||||||
# - Locomotion thread: controls legs (0-11), waist (12-14)
|
# - Locomotion thread: controls legs (0-11), waist (12-14)
|
||||||
# Both update different parts of self.msg, both call Write()
|
# Both update different parts of self.msg, both call Write()
|
||||||
|
# DISABLE for 29-DOF policies since they control ALL joints including arms
|
||||||
self.publish_thread = None
|
self.publish_thread = None
|
||||||
self.ctrl_lock = threading.Lock()
|
self.ctrl_lock = threading.Lock()
|
||||||
if not config.motion_imitation_control: # Allow with locomotion, disable only for motion imitation
|
is_29dof = config.policy_path and '29dof' in config.policy_path.lower()
|
||||||
|
if not config.motion_imitation_control and not is_29dof:
|
||||||
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
||||||
self.publish_thread.daemon = True
|
self.publish_thread.daemon = True
|
||||||
self.publish_thread.start()
|
self.publish_thread.start()
|
||||||
logger.info("Arm control publish thread started")
|
logger.info("Arm control publish thread started")
|
||||||
|
elif is_29dof:
|
||||||
|
logger.info("Arm control thread DISABLED (29-DOF policy controls all joints)")
|
||||||
|
|
||||||
# Load locomotion policy if enabled
|
# Load locomotion policy if enabled
|
||||||
self.policy = None
|
self.policy = None
|
||||||
@@ -305,23 +309,32 @@ class UnitreeG1(Robot):
|
|||||||
elif config.policy_path.endswith('.onnx'):
|
elif config.policy_path.endswith('.onnx'):
|
||||||
logger.info("Detected ONNX (.onnx) policy")
|
logger.info("Detected ONNX (.onnx) policy")
|
||||||
|
|
||||||
# For GR00T-style policies, load both Balance and Walk policies
|
# Check if this is a GR00T dual-policy system (Walk.onnx)
|
||||||
# Balance policy for standing (low velocity commands)
|
# Only try loading dual policies if the filename explicitly contains "Walk"
|
||||||
# Walk policy for locomotion (high velocity commands)
|
if 'Walk.onnx' in config.policy_path:
|
||||||
balance_policy_path = config.policy_path.replace('Walk.onnx', 'Balance.onnx')
|
balance_policy_path = config.policy_path.replace('Walk.onnx', 'Balance.onnx')
|
||||||
walk_policy_path = config.policy_path
|
walk_policy_path = config.policy_path
|
||||||
|
|
||||||
if Path(balance_policy_path).exists() and Path(walk_policy_path).exists():
|
if Path(balance_policy_path).exists() and Path(walk_policy_path).exists():
|
||||||
logger.info("Loading dual-policy system (Balance + Walk)")
|
logger.info("Loading GR00T dual-policy system (Balance + Walk)")
|
||||||
self.policy_balance = ort.InferenceSession(balance_policy_path)
|
self.policy_balance = ort.InferenceSession(balance_policy_path)
|
||||||
self.policy_walk = ort.InferenceSession(walk_policy_path)
|
self.policy_walk = ort.InferenceSession(walk_policy_path)
|
||||||
self.policy = None # Not used when dual policies are loaded
|
self.policy = None # Not used when dual policies are loaded
|
||||||
logger.info(f"Balance policy loaded from: {balance_policy_path}")
|
logger.info(f"Balance policy loaded from: {balance_policy_path}")
|
||||||
logger.info(f"Walk policy loaded from: {walk_policy_path}")
|
logger.info(f"Walk policy loaded from: {walk_policy_path}")
|
||||||
logger.info(f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}")
|
logger.info(f"ONNX input: {self.policy_balance.get_inputs()[0].name}, shape: {self.policy_balance.get_inputs()[0].shape}")
|
||||||
logger.info(f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}")
|
logger.info(f"ONNX output: {self.policy_balance.get_outputs()[0].name}, shape: {self.policy_balance.get_outputs()[0].shape}")
|
||||||
|
else:
|
||||||
|
# Single policy
|
||||||
|
logger.info("Loading single ONNX policy")
|
||||||
|
self.policy = ort.InferenceSession(config.policy_path)
|
||||||
|
self.policy_balance = None
|
||||||
|
self.policy_walk = None
|
||||||
|
logger.info("ONNX policy loaded successfully")
|
||||||
|
logger.info(f"ONNX input: {self.policy.get_inputs()[0].name}, shape: {self.policy.get_inputs()[0].shape}")
|
||||||
|
logger.info(f"ONNX output: {self.policy.get_outputs()[0].name}, shape: {self.policy.get_outputs()[0].shape}")
|
||||||
else:
|
else:
|
||||||
# Fallback to single policy
|
# Single ONNX policy (not GR00T)
|
||||||
logger.info("Loading single ONNX policy")
|
logger.info("Loading single ONNX policy")
|
||||||
self.policy = ort.InferenceSession(config.policy_path)
|
self.policy = ort.InferenceSession(config.policy_path)
|
||||||
self.policy_balance = None
|
self.policy_balance = None
|
||||||
@@ -343,8 +356,27 @@ class UnitreeG1(Robot):
|
|||||||
self.locomotion_obs = np.zeros(config.num_locomotion_obs, dtype=np.float32)
|
self.locomotion_obs = np.zeros(config.num_locomotion_obs, dtype=np.float32)
|
||||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
# GR00T-specific variables (for ONNX policies with 29 joints)
|
# Detect 29-DOF policy from filename
|
||||||
if self.policy_type == 'onnx':
|
self.is_29dof_policy = '29dof' in config.policy_path.lower()
|
||||||
|
|
||||||
|
# Joints that G1 23-DOF doesn't have (freeze these)
|
||||||
|
# 12: waist_yaw, 14: waist_pitch
|
||||||
|
# 20: left_wrist_pitch, 21: left_wrist_yaw
|
||||||
|
# 27: right_wrist_pitch, 28: right_wrist_yaw
|
||||||
|
self.joints_to_freeze_23dof = [12, 14, 20, 21, 27, 28]
|
||||||
|
|
||||||
|
# Phase state for 29-DOF locomotion (2D: left foot, right foot)
|
||||||
|
if self.is_29dof_policy:
|
||||||
|
self.phase_29dof = np.zeros((1, 2), dtype=np.float32)
|
||||||
|
self.phase_29dof[0, 0] = 0.0 # left foot starts at 0
|
||||||
|
self.phase_29dof[0, 1] = np.pi # right foot starts at π
|
||||||
|
gait_period = 1.0 # seconds
|
||||||
|
self.phase_dt_29dof = 2 * np.pi / (50.0 * gait_period) # 50Hz control rate
|
||||||
|
self.last_unscaled_action = np.zeros(29, dtype=np.float32)
|
||||||
|
self.is_standing_29dof = False # Track standing state for phase reset
|
||||||
|
|
||||||
|
# GR00T-specific variables (ONLY for GR00T dual-policy system)
|
||||||
|
if hasattr(self, 'policy_balance') and self.policy_balance is not None:
|
||||||
self.groot_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints
|
self.groot_qj_all = np.zeros(29, dtype=np.float32) # All 29 joints
|
||||||
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||||
self.groot_action = np.zeros(15, dtype=np.float32) # 15D action (legs + waist)
|
self.groot_action = np.zeros(15, dtype=np.float32) # 15D action (legs + waist)
|
||||||
@@ -364,9 +396,14 @@ class UnitreeG1(Robot):
|
|||||||
self.start_keyboard_controls()
|
self.start_keyboard_controls()
|
||||||
|
|
||||||
# Use different init based on policy type
|
# Use different init based on policy type
|
||||||
if self.policy_type == 'onnx':
|
if hasattr(self, 'is_29dof_policy') and self.is_29dof_policy:
|
||||||
|
# 29-DOF whole-body ONNX policy
|
||||||
|
self.init_29dof_locomotion()
|
||||||
|
elif hasattr(self, 'policy_balance') and self.policy_balance is not None:
|
||||||
|
# GR00T dual-policy system
|
||||||
self.init_groot_locomotion()
|
self.init_groot_locomotion()
|
||||||
else:
|
else:
|
||||||
|
# Regular 12-DOF policy
|
||||||
self.init_locomotion()
|
self.init_locomotion()
|
||||||
elif self.simulation_mode:
|
elif self.simulation_mode:
|
||||||
# Even without locomotion, provide keyboard feedback in sim
|
# Even without locomotion, provide keyboard feedback in sim
|
||||||
@@ -1195,6 +1232,150 @@ class UnitreeG1(Robot):
|
|||||||
self.msg.crc = self.crc.Crc(self.msg)
|
self.msg.crc = self.crc.Crc(self.msg)
|
||||||
self.lowcmd_publisher.Write(self.msg)
|
self.lowcmd_publisher.Write(self.msg)
|
||||||
|
|
||||||
|
def locomotion_29dof_run(self):
|
||||||
|
"""29-DOF whole-body locomotion policy loop - controls ALL 29 joints."""
|
||||||
|
self.locomotion_counter += 1
|
||||||
|
|
||||||
|
if self.locomotion_counter == 1:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🚀 RUNNING 29-DOF LOCOMOTION POLICY (all joints active)")
|
||||||
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
|
# Get current lowstate
|
||||||
|
lowstate = self.lowstate_buffer.GetData()
|
||||||
|
if lowstate is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update remote controller from lowstate
|
||||||
|
if lowstate.wireless_remote is not None:
|
||||||
|
self.remote_controller.set(lowstate.wireless_remote)
|
||||||
|
else:
|
||||||
|
self.remote_controller.lx = 0.0
|
||||||
|
self.remote_controller.ly = 0.0
|
||||||
|
self.remote_controller.rx = 0.0
|
||||||
|
self.remote_controller.ry = 0.0
|
||||||
|
|
||||||
|
# Get ALL 29 joint positions and velocities
|
||||||
|
for i in range(29):
|
||||||
|
self.qj[i] = lowstate.motor_state[i].q
|
||||||
|
self.dqj[i] = lowstate.motor_state[i].dq
|
||||||
|
|
||||||
|
# Get IMU data
|
||||||
|
quat = lowstate.imu_state.quaternion
|
||||||
|
ang_vel = np.array(lowstate.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
|
||||||
|
if self.config.locomotion_imu_type == "torso":
|
||||||
|
waist_yaw = lowstate.motor_state[12].q
|
||||||
|
waist_yaw_omega = lowstate.motor_state[12].dq
|
||||||
|
quat, ang_vel_3d = self.locomotion_transform_imu_data(waist_yaw, waist_yaw_omega, quat, np.array([ang_vel]))
|
||||||
|
ang_vel = ang_vel_3d.flatten()
|
||||||
|
|
||||||
|
# Get velocity commands from remote controller FIRST (before phase calculation!)
|
||||||
|
if not self.simulation_mode:
|
||||||
|
# Apply deadzone (0.1) like holosoma does
|
||||||
|
ly = self.remote_controller.ly if abs(self.remote_controller.ly) > 0.1 else 0.0
|
||||||
|
lx = self.remote_controller.lx if abs(self.remote_controller.lx) > 0.1 else 0.0
|
||||||
|
rx = self.remote_controller.rx if abs(self.remote_controller.rx) > 0.1 else 0.0
|
||||||
|
|
||||||
|
self.locomotion_cmd[0] = ly # forward/backward
|
||||||
|
self.locomotion_cmd[1] = -lx # left/right (inverted)
|
||||||
|
self.locomotion_cmd[2] = -rx # yaw (inverted)
|
||||||
|
|
||||||
|
if self.locomotion_counter % 50 == 0:
|
||||||
|
logger.debug(f"29-DOF Remote - ly:{ly:.2f}, lx:{lx:.2f}, rx:{rx:.2f}")
|
||||||
|
|
||||||
|
# Create observation with correct scaling factors
|
||||||
|
gravity_orientation = self.locomotion_get_gravity_orientation(quat)
|
||||||
|
qj_obs = (self.qj - np.array(self.config.default_all_joint_angles)) * 1.0 # dof_pos: ×1.0
|
||||||
|
dqj_obs = self.dqj * 0.05 # dof_vel: ×0.05
|
||||||
|
ang_vel_scaled = ang_vel * 0.25 # base_ang_vel: ×0.25
|
||||||
|
|
||||||
|
# Zero out observations for joints missing in G1 23-DOF
|
||||||
|
# [12: waist_yaw, 14: waist_pitch, 20: left_wrist_pitch, 21: left_wrist_yaw, 27: right_wrist_pitch, 28: right_wrist_yaw]
|
||||||
|
for joint_idx in self.joints_to_freeze_23dof:
|
||||||
|
qj_obs[joint_idx] = 0.0
|
||||||
|
dqj_obs[joint_idx] = 0.0
|
||||||
|
|
||||||
|
# Update phase using holosoma's method
|
||||||
|
# Check if standing (low velocity commands)
|
||||||
|
cmd_norm = np.linalg.norm(self.locomotion_cmd[:2])
|
||||||
|
ang_cmd_norm = np.abs(self.locomotion_cmd[2])
|
||||||
|
|
||||||
|
if cmd_norm < 0.01 and ang_cmd_norm < 0.01:
|
||||||
|
# Standing still - both feet at π
|
||||||
|
self.phase_29dof[0, :] = np.pi * np.ones(2)
|
||||||
|
self.is_standing_29dof = True
|
||||||
|
elif self.is_standing_29dof:
|
||||||
|
# Resuming walking from standing - reset phase to initial state
|
||||||
|
self.phase_29dof = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||||
|
self.is_standing_29dof = False
|
||||||
|
else:
|
||||||
|
# Walking - update phase
|
||||||
|
phase_tp1 = self.phase_29dof + self.phase_dt_29dof
|
||||||
|
self.phase_29dof = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
|
||||||
|
|
||||||
|
# Compute sin/cos phase for both feet
|
||||||
|
sin_phase = np.sin(self.phase_29dof[0, :]) # shape (2,)
|
||||||
|
cos_phase = np.cos(self.phase_29dof[0, :]) # shape (2,)
|
||||||
|
|
||||||
|
# Build 100D observation vector (components in ALPHABETICAL order!)
|
||||||
|
# Joints within each 29D component stay in motor index order (0-28)
|
||||||
|
self.locomotion_obs[0:29] = self.last_unscaled_action # 1. actions (previous UNSCALED, ×1.0)
|
||||||
|
self.locomotion_obs[29:32] = ang_vel_scaled # 2. base_ang_vel (×0.25)
|
||||||
|
self.locomotion_obs[32] = self.locomotion_cmd[2] # 3. command_ang_vel (yaw, ×1.0)
|
||||||
|
self.locomotion_obs[33:35] = self.locomotion_cmd[:2] # 4. command_lin_vel (vx, vy, ×1.0)
|
||||||
|
self.locomotion_obs[35:37] = cos_phase # 5. cos_phase (2D: left, right)
|
||||||
|
self.locomotion_obs[37:66] = qj_obs # 6. dof_pos (relative, ×1.0)
|
||||||
|
self.locomotion_obs[66:95] = dqj_obs # 7. dof_vel (×0.05)
|
||||||
|
self.locomotion_obs[95:98] = gravity_orientation # 8. projected_gravity (×1.0)
|
||||||
|
self.locomotion_obs[98:100] = sin_phase # 9. sin_phase (2D: left, right)
|
||||||
|
|
||||||
|
# Get action from policy network (ONNX)
|
||||||
|
obs_input = self.locomotion_obs.reshape(1, -1).astype(np.float32)
|
||||||
|
ort_inputs = {self.policy.get_inputs()[0].name: obs_input}
|
||||||
|
ort_outs = self.policy.run(None, ort_inputs)
|
||||||
|
|
||||||
|
# Post-process ONNX output: clip to ±100, then scale by 0.25
|
||||||
|
raw_action = ort_outs[0].squeeze()
|
||||||
|
clipped_action = np.clip(raw_action, -100.0, 100.0)
|
||||||
|
|
||||||
|
# Zero out actions for joints missing in G1 23-DOF
|
||||||
|
for joint_idx in self.joints_to_freeze_23dof:
|
||||||
|
clipped_action[joint_idx] = 0.0
|
||||||
|
|
||||||
|
self.last_unscaled_action = clipped_action.copy() # Store UNSCALED for next obs
|
||||||
|
self.locomotion_action = clipped_action * 0.25 # Scale by policy_action_scale for motors
|
||||||
|
|
||||||
|
# Debug logging (first 5 iterations)
|
||||||
|
if self.locomotion_counter <= 5:
|
||||||
|
print(f"\n[29DOF Debug #{self.locomotion_counter}]")
|
||||||
|
print(f" Phase (left, right): ({self.phase_29dof[0,0]:.3f}, {self.phase_29dof[0,1]:.3f})")
|
||||||
|
print(f" Sin phase: {sin_phase}, Cos phase: {cos_phase}")
|
||||||
|
print(f" Cmd (vx, vy, yaw): ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||||
|
print(f" Obs[0:5] (last unscaled actions): {self.locomotion_obs[0:5]}")
|
||||||
|
print(f" Obs[37:42] (dof_pos): {self.locomotion_obs[37:42]}")
|
||||||
|
print(f" Raw action range: [{raw_action.min():.3f}, {raw_action.max():.3f}]")
|
||||||
|
print(f" Scaled action range: [{self.locomotion_action.min():.3f}, {self.locomotion_action.max():.3f}]")
|
||||||
|
|
||||||
|
# Transform action to target joint positions (ALL 29 joints)
|
||||||
|
target_dof_pos = np.array(self.config.default_all_joint_angles) + self.locomotion_action
|
||||||
|
|
||||||
|
if self.locomotion_counter <= 5:
|
||||||
|
print(f" Default[0:6]: {self.config.default_all_joint_angles[0:6]}")
|
||||||
|
print(f" Target pos[0:6]: {target_dof_pos[0:6]}\n")
|
||||||
|
|
||||||
|
# Send commands to ALL 29 motors
|
||||||
|
for i in range(29):
|
||||||
|
self.msg.motor_cmd[i].q = target_dof_pos[i]
|
||||||
|
self.msg.motor_cmd[i].qd = 0
|
||||||
|
self.msg.motor_cmd[i].kp = self.config.all_joint_kps[i]
|
||||||
|
self.msg.motor_cmd[i].kd = self.config.all_joint_kds[i]
|
||||||
|
self.msg.motor_cmd[i].tau = 0
|
||||||
|
|
||||||
|
# Send command
|
||||||
|
self.msg.crc = self.crc.Crc(self.msg)
|
||||||
|
self.lowcmd_publisher.Write(self.msg)
|
||||||
|
|
||||||
def groot_locomotion_run(self):
|
def groot_locomotion_run(self):
|
||||||
"""GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action."""
|
"""GR00T-style locomotion policy loop for ONNX policies - reads all 29 joints, outputs 15D action."""
|
||||||
self.locomotion_counter += 1
|
self.locomotion_counter += 1
|
||||||
@@ -1359,10 +1540,15 @@ class UnitreeG1(Robot):
|
|||||||
while self.locomotion_running:
|
while self.locomotion_running:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
# Use different run function based on policy type
|
# Route to appropriate locomotion method
|
||||||
if self.policy_type == 'onnx':
|
if hasattr(self, 'is_29dof_policy') and self.is_29dof_policy:
|
||||||
|
# 29-DOF whole-body ONNX policy: 100D → 29D
|
||||||
|
self.locomotion_29dof_run()
|
||||||
|
elif hasattr(self, 'policy_balance') and self.policy_balance is not None:
|
||||||
|
# GR00T dual-policy system: 516D → 15D
|
||||||
self.groot_locomotion_run()
|
self.groot_locomotion_run()
|
||||||
else:
|
else:
|
||||||
|
# Regular 12-DOF TorchScript or ONNX: 47D → 12D
|
||||||
self.locomotion_run()
|
self.locomotion_run()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in locomotion loop: {e}")
|
logger.error(f"Error in locomotion loop: {e}")
|
||||||
@@ -1523,6 +1709,72 @@ class UnitreeG1(Robot):
|
|||||||
logger.info("Locomotion test sequence complete! Policy is now running in background.")
|
logger.info("Locomotion test sequence complete! Policy is now running in background.")
|
||||||
logger.info("Use robot.stop_locomotion_thread() to stop the policy.")
|
logger.info("Use robot.stop_locomotion_thread() to stop the policy.")
|
||||||
|
|
||||||
|
def init_29dof_locomotion(self):
|
||||||
|
"""Initialize 29-DOF whole-body locomotion - moves all 29 joints to default pose."""
|
||||||
|
if not self.config.locomotion_control:
|
||||||
|
logger.warning("locomotion_control is False, cannot run 29-DOF init")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting 29-DOF whole-body locomotion initialization...")
|
||||||
|
|
||||||
|
# Move all joints to default position
|
||||||
|
logger.info("Moving all 29 joints to default position...")
|
||||||
|
total_time = 3.0
|
||||||
|
num_step = int(total_time / self.config.locomotion_control_dt)
|
||||||
|
|
||||||
|
default_pos = np.array(self.config.default_all_joint_angles, dtype=np.float32)
|
||||||
|
|
||||||
|
# Get current lowstate
|
||||||
|
lowstate = self.lowstate_buffer.GetData()
|
||||||
|
if lowstate is None:
|
||||||
|
logger.error("Cannot get lowstate for locomotion")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Record the current positions of all 29 joints
|
||||||
|
init_dof_pos = np.zeros(29, dtype=np.float32)
|
||||||
|
for i in range(29):
|
||||||
|
init_dof_pos[i] = lowstate.motor_state[i].q
|
||||||
|
|
||||||
|
# Move all joints to default pos
|
||||||
|
for i in range(num_step):
|
||||||
|
alpha = i / num_step
|
||||||
|
for motor_idx in range(29):
|
||||||
|
target_pos = default_pos[motor_idx]
|
||||||
|
self.msg.motor_cmd[motor_idx].q = init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||||
|
self.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.msg.motor_cmd[motor_idx].kp = self.config.all_joint_kps[motor_idx]
|
||||||
|
self.msg.motor_cmd[motor_idx].kd = self.config.all_joint_kds[motor_idx]
|
||||||
|
self.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
self.msg.crc = self.crc.Crc(self.msg)
|
||||||
|
self.lowcmd_publisher.Write(self.msg)
|
||||||
|
time.sleep(self.config.locomotion_control_dt)
|
||||||
|
logger.info("Reached default position (all 29 joints)")
|
||||||
|
|
||||||
|
# Wait 3 seconds
|
||||||
|
time.sleep(3.0)
|
||||||
|
|
||||||
|
# Hold position for 2 seconds
|
||||||
|
logger.info("Holding default position...")
|
||||||
|
hold_time = 2.0
|
||||||
|
num_steps = int(hold_time / self.config.locomotion_control_dt)
|
||||||
|
for _ in range(num_steps):
|
||||||
|
for motor_idx in range(29):
|
||||||
|
self.msg.motor_cmd[motor_idx].q = default_pos[motor_idx]
|
||||||
|
self.msg.motor_cmd[motor_idx].qd = 0
|
||||||
|
self.msg.motor_cmd[motor_idx].kp = self.config.all_joint_kps[motor_idx]
|
||||||
|
self.msg.motor_cmd[motor_idx].kd = self.config.all_joint_kds[motor_idx]
|
||||||
|
self.msg.motor_cmd[motor_idx].tau = 0
|
||||||
|
self.msg.crc = self.crc.Crc(self.msg)
|
||||||
|
self.lowcmd_publisher.Write(self.msg)
|
||||||
|
time.sleep(self.config.locomotion_control_dt)
|
||||||
|
|
||||||
|
# Start locomotion policy thread
|
||||||
|
logger.info("Starting 29-DOF locomotion policy control...")
|
||||||
|
self.start_locomotion_thread()
|
||||||
|
|
||||||
|
logger.info("29-DOF locomotion initialization complete! Policy is now running.")
|
||||||
|
logger.info("100D observations → 29D actions (ALL joints: legs + waist + arms)")
|
||||||
|
|
||||||
def init_groot_locomotion(self):
|
def init_groot_locomotion(self):
|
||||||
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
|
"""Initialize GR00T-style locomotion for ONNX policies (29 DOF, 15D actions)."""
|
||||||
if not self.config.locomotion_control:
|
if not self.config.locomotion_control:
|
||||||
|
|||||||
Reference in New Issue
Block a user