mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
Feat/g1 improvements record sim (#2765)
This PR extends the integration of Unitree g1 with the LeRobot codebase. By converting robot state to a flat dict we can now record and replay episodes (example groot/holosoma scripts need to be adjusted as well). We also improve the simulation integration by calling .step @ _subscribe_motor_state instead of it running in a separate thread. We also add ZMQ camera to lerobot, streaming base64 images over json
This commit is contained in:
@@ -111,34 +111,29 @@ class GrootLocomotionController:
|
||||
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
if not obs:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
if self.robot.remote_controller.button[0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.robot.remote_controller.button[4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
if obs["remote.buttons"][0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if obs["remote.buttons"][4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
|
||||
self.cmd[0] = self.robot.remote_controller.ly # Forward/backward
|
||||
self.cmd[1] = self.robot.remote_controller.lx * -1 # Left/right
|
||||
self.cmd[2] = self.robot.remote_controller.rx * -1 # Rotation rate
|
||||
self.cmd[0] = obs["remote.ly"] # Forward/backward
|
||||
self.cmd[1] = obs["remote.lx"] * -1 # Left/right
|
||||
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
|
||||
|
||||
# Get joint positions and velocities
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
# Get joint positions and velocities from flat dict
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.groot_qj_all[idx] = obs[f"{name}.q"]
|
||||
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
@@ -150,8 +145,8 @@ class GrootLocomotionController:
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
@@ -219,6 +214,8 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
robot.connect()
|
||||
|
||||
# Initialize gr00T locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
@@ -234,7 +231,7 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while True:
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
groot_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@@ -126,24 +126,23 @@ class HolosomaLocomotionController:
|
||||
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
if not obs:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
|
||||
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
|
||||
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
|
||||
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
|
||||
self.cmd[:] = [ly, -lx, -rx]
|
||||
|
||||
# Get joint positions and velocities
|
||||
for i in range(29):
|
||||
self.qj[i] = robot_state.motor_state[i].q
|
||||
self.dqj[i] = robot_state.motor_state[i].dq
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.qj[idx] = obs[f"{name}.q"]
|
||||
self.dqj[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
@@ -151,8 +150,8 @@ class HolosomaLocomotionController:
|
||||
self.dqj[idx] = 0.0
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
@@ -220,6 +219,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
robot.connect()
|
||||
|
||||
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
|
||||
|
||||
@@ -230,7 +230,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while True:
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
holosoma_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
Reference in New Issue
Block a user