mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Feat/g1 improvements record sim (#2765)
This PR extends the integration of Unitree g1 with the LeRobot codebase. By converting robot state to a flat dict we can now record and replay episodes (example groot/holosoma scripts need to be adjusted as well). We also improve the simulation integration by calling .step @ _subscribe_motor_state instead of it running in a separate thread. We also add ZMQ camera to lerobot, streaming base64 images over json
This commit is contained in:
@@ -7,7 +7,7 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i
|
||||
We support both 29 and 23 DOF G1 EDU version. We introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level read/write from/to the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication over wlan, allowing for remote policy deployment as well as over eth or directly on the Orin
|
||||
- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
|
||||
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
|
||||
- **Simulation mode** for testing policies without the physical robot in mujoco
|
||||
|
||||
@@ -110,7 +110,7 @@ ssh unitree@<YOUR_ROBOT_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
|
||||
|
||||
---
|
||||
|
||||
@@ -188,7 +188,7 @@ Press `Ctrl+C` to stop the policy.
|
||||
|
||||
## Running in Simulation Mode (MuJoCo)
|
||||
|
||||
You can now test and develop policies without a physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
|
||||
You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
|
||||
|
||||
## Additional Resources
|
||||
|
||||
|
||||
@@ -111,34 +111,29 @@ class GrootLocomotionController:
|
||||
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
if not obs:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
if self.robot.remote_controller.button[0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.robot.remote_controller.button[4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
if obs["remote.buttons"][0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if obs["remote.buttons"][4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
|
||||
self.cmd[0] = self.robot.remote_controller.ly # Forward/backward
|
||||
self.cmd[1] = self.robot.remote_controller.lx * -1 # Left/right
|
||||
self.cmd[2] = self.robot.remote_controller.rx * -1 # Rotation rate
|
||||
self.cmd[0] = obs["remote.ly"] # Forward/backward
|
||||
self.cmd[1] = obs["remote.lx"] * -1 # Left/right
|
||||
self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
|
||||
|
||||
# Get joint positions and velocities
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
# Get joint positions and velocities from flat dict
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.groot_qj_all[idx] = obs[f"{name}.q"]
|
||||
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
@@ -150,8 +145,8 @@ class GrootLocomotionController:
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
@@ -219,6 +214,8 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
robot.connect()
|
||||
|
||||
# Initialize gr00T locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
@@ -234,7 +231,7 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while True:
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
groot_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@@ -126,24 +126,23 @@ class HolosomaLocomotionController:
|
||||
|
||||
def run_step(self):
|
||||
# Get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
obs = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
if not obs:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
|
||||
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||
ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
|
||||
lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
|
||||
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
|
||||
self.cmd[:] = [ly, -lx, -rx]
|
||||
|
||||
# Get joint positions and velocities
|
||||
for i in range(29):
|
||||
self.qj[i] = robot_state.motor_state[i].q
|
||||
self.dqj[i] = robot_state.motor_state[i].dq
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
self.qj[idx] = obs[f"{name}.q"]
|
||||
self.dqj[idx] = obs[f"{name}.dq"]
|
||||
|
||||
# Adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
@@ -151,8 +150,8 @@ class HolosomaLocomotionController:
|
||||
self.dqj[idx] = 0.0
|
||||
|
||||
# Express IMU data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
|
||||
ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
|
||||
gravity = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale joint positions and velocities before policy inference
|
||||
@@ -220,6 +219,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
robot.connect()
|
||||
|
||||
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
|
||||
|
||||
@@ -230,7 +230,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Run step
|
||||
while True:
|
||||
while not robot._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
holosoma_controller.run_step()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
@@ -43,6 +43,11 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
elif cfg.type == "zmq":
|
||||
from .zmq.camera_zmq import ZMQCamera
|
||||
|
||||
cameras[key] = ZMQCamera(cfg)
|
||||
|
||||
else:
|
||||
try:
|
||||
cameras[key] = cast(Camera, make_device_from_device_class(cfg))
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .camera_zmq import ZMQCamera
|
||||
from .configuration_zmq import ZMQCameraConfig
|
||||
|
||||
__all__ = ["ZMQCamera", "ZMQCameraConfig"]
|
||||
@@ -0,0 +1,235 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
ZMQCamera - Captures frames from remote cameras via ZeroMQ using JSON protocol in the
|
||||
following format:
|
||||
{
|
||||
"timestamps": {"camera_name": float},
|
||||
"images": {"camera_name": "<base64-jpeg>"}
|
||||
}
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
from .configuration_zmq import ZMQCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZMQCamera(Camera):
|
||||
"""
|
||||
Example usage:
|
||||
```python
|
||||
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
|
||||
|
||||
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
|
||||
camera = ZMQCamera(config)
|
||||
camera.connect()
|
||||
frame = camera.read()
|
||||
camera.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZMQCameraConfig):
|
||||
super().__init__(config)
|
||||
import zmq
|
||||
|
||||
self.config = config
|
||||
self.server_address = config.server_address
|
||||
self.port = config.port
|
||||
self.camera_name = config.camera_name
|
||||
self.color_mode = config.color_mode
|
||||
self.timeout_ms = config.timeout_ms
|
||||
|
||||
self.context: zmq.Context | None = None
|
||||
self.socket: zmq.Socket | None = None
|
||||
self._connected = False
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ZMQCamera({self.camera_name}@{self.server_address}:{self.port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""Connect to ZMQ camera server."""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
try:
|
||||
import zmq
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
||||
self.socket.setsockopt(zmq.CONFLATE, True)
|
||||
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
|
||||
self._connected = True
|
||||
|
||||
# Auto-detect resolution
|
||||
if self.width is None or self.height is None:
|
||||
h, w = self.read().shape[:2]
|
||||
self.height = h
|
||||
self.width = w
|
||||
logger.info(f"{self} resolution: {w}x{h}")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
if warmup:
|
||||
time.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
raise RuntimeError(f"Failed to connect to {self}: {e}") from e
|
||||
|
||||
def _cleanup(self):
|
||||
"""Clean up ZMQ resources."""
|
||||
self._connected = False
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
if self.context:
|
||||
self.context.term()
|
||||
self.context = None
|
||||
|
||||
@staticmethod
|
||||
def find_cameras() -> list[dict[str, Any]]:
|
||||
"""ZMQ cameras require manual configuration (server address/port)."""
|
||||
return []
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Read a single frame from the ZMQ camera.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame (height, width, 3)
|
||||
"""
|
||||
if not self.is_connected or self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
try:
|
||||
message = self.socket.recv_string()
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "Again":
|
||||
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
|
||||
raise
|
||||
|
||||
# Decode JSON message
|
||||
data = json.loads(message)
|
||||
|
||||
if "images" not in data:
|
||||
raise RuntimeError(f"{self} invalid message: missing 'images' key")
|
||||
|
||||
images = data["images"]
|
||||
|
||||
# Get image by camera name or first available
|
||||
if self.camera_name in images:
|
||||
img_b64 = images[self.camera_name]
|
||||
elif images:
|
||||
img_b64 = next(iter(images.values()))
|
||||
else:
|
||||
raise RuntimeError(f"{self} no images in message")
|
||||
|
||||
# Decode base64 JPEG
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
frame = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"{self} failed to decode image")
|
||||
|
||||
return frame
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
while self.stop_event and not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self.read()
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.new_frame_event.set()
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except TimeoutError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Read error: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
if self.thread and self.thread.is_alive():
|
||||
return
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
if self.stop_event:
|
||||
self.stop_event.set()
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""Read latest frame asynchronously (non-blocking)."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if not self.thread or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"{self} no frame available")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from ZMQ camera."""
|
||||
if not self.is_connected and not self.thread:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
self._stop_read_thread()
|
||||
self._cleanup()
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
__all__ = ["ZMQCameraConfig", "ColorMode"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("zmq")
|
||||
@dataclass
|
||||
class ZMQCameraConfig(CameraConfig):
|
||||
server_address: str
|
||||
port: int = 5555
|
||||
camera_name: str = "zmq_camera"
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
timeout_ms: int = 5000
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.timeout_ms <= 0:
|
||||
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
|
||||
|
||||
if not self.server_address:
|
||||
raise ValueError("`server_address` cannot be empty.")
|
||||
|
||||
if self.port <= 0 or self.port > 65535:
|
||||
raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.")
|
||||
@@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Streams camera images over ZMQ.
|
||||
Uses lerobot's OpenCVCamera for capture, encodes images to base64 and sends them over ZMQ.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
from lerobot.cameras.configs import ColorMode
|
||||
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def encode_image(image: np.ndarray, quality: int = 80) -> str:
|
||||
"""Encode RGB image to base64 JPEG string."""
|
||||
_, buffer = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
class ImageServer:
|
||||
def __init__(self, config: dict, port: int = 5555):
|
||||
self.fps = config.get("fps", 30)
|
||||
self.cameras: dict[str, OpenCVCamera] = {}
|
||||
|
||||
for name, cfg in config.get("cameras", {}).items():
|
||||
shape = cfg.get("shape", [480, 640])
|
||||
cam_config = OpenCVCameraConfig(
|
||||
index_or_path=cfg.get("device_id", 0),
|
||||
fps=self.fps,
|
||||
width=shape[1],
|
||||
height=shape[0],
|
||||
color_mode=ColorMode.RGB,
|
||||
)
|
||||
camera = OpenCVCamera(cam_config)
|
||||
camera.connect()
|
||||
self.cameras[name] = camera
|
||||
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
|
||||
|
||||
# ZMQ PUB socket
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.PUB)
|
||||
self.socket.setsockopt(zmq.SNDHWM, 20)
|
||||
self.socket.setsockopt(zmq.LINGER, 0)
|
||||
self.socket.bind(f"tcp://*:{port}")
|
||||
|
||||
logger.info(f"ImageServer running on port {port}")
|
||||
|
||||
def run(self):
|
||||
frame_count = 0
|
||||
frame_times = deque(maxlen=60)
|
||||
|
||||
try:
|
||||
while True:
|
||||
t0 = time.time()
|
||||
|
||||
# Build message
|
||||
message = {"timestamps": {}, "images": {}}
|
||||
for name, cam in self.cameras.items():
|
||||
frame = cam.read() # Returns RGB
|
||||
message["timestamps"][name] = time.time()
|
||||
message["images"][name] = encode_image(frame)
|
||||
|
||||
# Send as JSON string (suppress if buffer full)
|
||||
with contextlib.suppress(zmq.Again):
|
||||
self.socket.send_string(json.dumps(message), zmq.NOBLOCK)
|
||||
|
||||
frame_count += 1
|
||||
frame_times.append(time.time() - t0)
|
||||
|
||||
if frame_count % 60 == 0:
|
||||
logger.debug(f"FPS: {len(frame_times) / sum(frame_times):.1f}")
|
||||
|
||||
sleep = (1.0 / self.fps) - (time.time() - t0)
|
||||
if sleep > 0:
|
||||
time.sleep(sleep)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
config = {"fps": 30, "cameras": {"head_camera": {"device_id": 4, "shape": [480, 640]}}}
|
||||
ImageServer(config, port=5555).run()
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||
@@ -60,3 +62,6 @@ class UnitreeG1Config(RobotConfig):
|
||||
|
||||
# Socket config for ZMQ bridge
|
||||
robot_ip: str = "192.168.123.164" # default G1 IP
|
||||
|
||||
# Cameras (ZMQ-based remote cameras)
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -23,13 +23,8 @@ from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
|
||||
@@ -43,8 +38,6 @@ logger = logging.getLogger(__name__)
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
kTopicLowState = "rt/lowstate"
|
||||
|
||||
G1_29_Num_Motors = 29
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorState:
|
||||
@@ -66,28 +59,12 @@ class IMUState:
|
||||
# g1 observation class
|
||||
@dataclass
|
||||
class G1_29_LowState: # noqa: N801
|
||||
motor_state: list[MotorState] = field(
|
||||
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
|
||||
)
|
||||
motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex])
|
||||
imu_state: IMUState = field(default_factory=IMUState)
|
||||
wireless_remote: Any = None # Raw wireless remote data
|
||||
mode_machine: int = 0 # Robot mode
|
||||
|
||||
|
||||
class DataBuffer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_data(self):
|
||||
with self.lock:
|
||||
return self.data
|
||||
|
||||
def set_data(self, data):
|
||||
with self.lock:
|
||||
self.data = data
|
||||
|
||||
|
||||
class UnitreeG1(Robot):
|
||||
config_class = UnitreeG1Config
|
||||
name = "unitree_g1"
|
||||
@@ -117,9 +94,12 @@ class UnitreeG1(Robot):
|
||||
logger.info("Initialize UnitreeG1...")
|
||||
|
||||
self.config = config
|
||||
|
||||
self.control_dt = config.control_dt
|
||||
|
||||
# Initialize cameras config (ZMQ-based) - actual connection in connect()
|
||||
self._cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
# Import channel classes based on mode
|
||||
if config.is_simulation:
|
||||
from unitree_sdk2py.core.channel import (
|
||||
ChannelFactoryInitialize,
|
||||
@@ -133,62 +113,33 @@ class UnitreeG1(Robot):
|
||||
ChannelSubscriber,
|
||||
)
|
||||
|
||||
# connect robot
|
||||
self.ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||
self.connect()
|
||||
# Store for use in connect()
|
||||
self._ChannelFactoryInitialize = ChannelFactoryInitialize
|
||||
self._ChannelPublisher = ChannelPublisher
|
||||
self._ChannelSubscriber = ChannelSubscriber
|
||||
|
||||
# initialize direct motor control interface
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
self.lowcmd_publisher.Init()
|
||||
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
self.lowstate_subscriber.Init()
|
||||
self.lowstate_buffer = DataBuffer()
|
||||
|
||||
# initialize subscribe thread to read robot state
|
||||
# Initialize state variables
|
||||
self.sim_env = None
|
||||
self._env_wrapper = None
|
||||
self._lowstate = None
|
||||
self._shutdown_event = threading.Event()
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||
self.subscribe_thread.start()
|
||||
|
||||
while not self.is_connected:
|
||||
time.sleep(0.1)
|
||||
|
||||
# initialize hg's lowcmd msg
|
||||
self.crc = CRC()
|
||||
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||
self.msg.mode_pr = 0
|
||||
|
||||
# Wait for first state message to arrive
|
||||
lowstate = None
|
||||
while lowstate is None:
|
||||
lowstate = self.lowstate_buffer.get_data()
|
||||
if lowstate is None:
|
||||
time.sleep(0.01)
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
logger.warning("[UnitreeG1] Connected to robot.")
|
||||
self.msg.mode_machine = lowstate.mode_machine
|
||||
|
||||
# initialize all motors with unified kp/kd from config
|
||||
self.kp = np.array(config.kp, dtype=np.float32)
|
||||
self.kd = np.array(config.kd, dtype=np.float32)
|
||||
|
||||
for id in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[id].mode = 1
|
||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||
|
||||
# Initialize remote controller
|
||||
self.subscribe_thread = None
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
|
||||
# Step simulation if in simulation mode
|
||||
if self.config.is_simulation and self.sim_env is not None:
|
||||
self.sim_env.step()
|
||||
|
||||
msg = self.lowstate_subscriber.Read()
|
||||
if msg is not None:
|
||||
lowstate = G1_29_LowState()
|
||||
|
||||
# Capture motor states
|
||||
for id in range(G1_29_Num_Motors):
|
||||
# Capture motor states using jointindex
|
||||
for id in G1_29_JointIndex:
|
||||
lowstate.motor_state[id].q = msg.motor_state[id].q
|
||||
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
||||
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
||||
@@ -207,7 +158,7 @@ class UnitreeG1(Robot):
|
||||
# Capture mode_machine
|
||||
lowstate.mode_machine = msg.mode_machine
|
||||
|
||||
self.lowstate_buffer.set_data(lowstate)
|
||||
self._lowstate = lowstate
|
||||
|
||||
current_time = time.time()
|
||||
all_t_elapsed = current_time - start_time
|
||||
@@ -216,7 +167,7 @@ class UnitreeG1(Robot):
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
|
||||
def calibrate(self) -> None: # robot is already calibrated
|
||||
pass
|
||||
@@ -225,20 +176,153 @@ class UnitreeG1(Robot):
|
||||
pass
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
# Initialize DDS channel and simulation environment
|
||||
if self.config.is_simulation:
|
||||
self.ChannelFactoryInitialize(0, "lo")
|
||||
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
||||
self._ChannelFactoryInitialize(0, "lo")
|
||||
self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
||||
# Extract the actual gym env from the dict structure
|
||||
self.sim_env = self._env_wrapper["hub_env"][0].envs[0]
|
||||
else:
|
||||
self.ChannelFactoryInitialize(0)
|
||||
self._ChannelFactoryInitialize(0)
|
||||
|
||||
# Initialize direct motor control interface
|
||||
self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
self.lowcmd_publisher.Init()
|
||||
self.lowstate_subscriber = self._ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
self.lowstate_subscriber.Init()
|
||||
|
||||
# Start subscribe thread to read robot state
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||
self.subscribe_thread.start()
|
||||
|
||||
# Connect cameras
|
||||
for cam in self._cameras.values():
|
||||
if not cam.is_connected:
|
||||
cam.connect()
|
||||
|
||||
logger.info(f"Connected {len(self._cameras)} camera(s).")
|
||||
|
||||
# Initialize lowcmd message
|
||||
self.crc = CRC()
|
||||
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||
self.msg.mode_pr = 0
|
||||
|
||||
# Wait for first state message to arrive
|
||||
lowstate = None
|
||||
while lowstate is None:
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
time.sleep(0.01)
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
logger.warning("[UnitreeG1] Connected to robot.")
|
||||
self.msg.mode_machine = lowstate.mode_machine
|
||||
|
||||
# Initialize all motors with unified kp/kd from config
|
||||
self.kp = np.array(self.config.kp, dtype=np.float32)
|
||||
self.kd = np.array(self.config.kd, dtype=np.float32)
|
||||
|
||||
for id in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[id].mode = 1
|
||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||
|
||||
def disconnect(self):
|
||||
# Signal thread to stop and unblock any waits
|
||||
self._shutdown_event.set()
|
||||
self.subscribe_thread.join(timeout=2.0)
|
||||
if self.config.is_simulation:
|
||||
self.mujoco_env["hub_env"][0].envs[0].kill_sim()
|
||||
|
||||
# Wait for subscribe thread to finish
|
||||
if self.subscribe_thread is not None:
|
||||
self.subscribe_thread.join(timeout=2.0)
|
||||
if self.subscribe_thread.is_alive():
|
||||
logger.warning("Subscribe thread did not stop cleanly")
|
||||
|
||||
# Close simulation environment
|
||||
if self.config.is_simulation and self.sim_env is not None:
|
||||
try:
|
||||
# Force-kill the image publish subprocess first to avoid long waits
|
||||
if hasattr(self.sim_env, "simulator") and hasattr(self.sim_env.simulator, "sim_env"):
|
||||
sim_env_inner = self.sim_env.simulator.sim_env
|
||||
if hasattr(sim_env_inner, "image_publish_process"):
|
||||
proc = sim_env_inner.image_publish_process
|
||||
if proc.process and proc.process.is_alive():
|
||||
logger.info("Force-terminating image publish subprocess...")
|
||||
proc.stop_event.set()
|
||||
proc.process.terminate()
|
||||
proc.process.join(timeout=1)
|
||||
if proc.process.is_alive():
|
||||
proc.process.kill()
|
||||
self.sim_env.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing sim_env: {e}")
|
||||
self.sim_env = None
|
||||
self._env_wrapper = None
|
||||
|
||||
# Disconnect cameras
|
||||
for cam in self._cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
return self.lowstate_buffer.get_data()
|
||||
lowstate = self._lowstate
|
||||
if lowstate is None:
|
||||
return {}
|
||||
|
||||
obs = {}
|
||||
|
||||
# Motors - q, dq, tau for all joints
|
||||
for motor in G1_29_JointIndex:
|
||||
name = motor.name
|
||||
idx = motor.value
|
||||
obs[f"{name}.q"] = lowstate.motor_state[idx].q
|
||||
obs[f"{name}.dq"] = lowstate.motor_state[idx].dq
|
||||
obs[f"{name}.tau"] = lowstate.motor_state[idx].tau_est
|
||||
|
||||
# IMU - gyroscope
|
||||
if lowstate.imu_state.gyroscope:
|
||||
obs["imu.gyro.x"] = lowstate.imu_state.gyroscope[0]
|
||||
obs["imu.gyro.y"] = lowstate.imu_state.gyroscope[1]
|
||||
obs["imu.gyro.z"] = lowstate.imu_state.gyroscope[2]
|
||||
|
||||
# IMU - accelerometer
|
||||
if lowstate.imu_state.accelerometer:
|
||||
obs["imu.accel.x"] = lowstate.imu_state.accelerometer[0]
|
||||
obs["imu.accel.y"] = lowstate.imu_state.accelerometer[1]
|
||||
obs["imu.accel.z"] = lowstate.imu_state.accelerometer[2]
|
||||
|
||||
# IMU - quaternion
|
||||
if lowstate.imu_state.quaternion:
|
||||
obs["imu.quat.w"] = lowstate.imu_state.quaternion[0]
|
||||
obs["imu.quat.x"] = lowstate.imu_state.quaternion[1]
|
||||
obs["imu.quat.y"] = lowstate.imu_state.quaternion[2]
|
||||
obs["imu.quat.z"] = lowstate.imu_state.quaternion[3]
|
||||
|
||||
# IMU - rpy
|
||||
if lowstate.imu_state.rpy:
|
||||
obs["imu.rpy.roll"] = lowstate.imu_state.rpy[0]
|
||||
obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1]
|
||||
obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2]
|
||||
|
||||
# Controller - parse wireless_remote and add to obs
|
||||
if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24:
|
||||
self.remote_controller.set(lowstate.wireless_remote)
|
||||
obs["remote.buttons"] = self.remote_controller.button.copy()
|
||||
obs["remote.lx"] = self.remote_controller.lx
|
||||
obs["remote.ly"] = self.remote_controller.ly
|
||||
obs["remote.rx"] = self.remote_controller.rx
|
||||
obs["remote.ry"] = self.remote_controller.ry
|
||||
|
||||
# Cameras - read images from ZMQ cameras
|
||||
for cam_name, cam in self._cameras.items():
|
||||
obs[cam_name] = cam.async_read()
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
@@ -246,11 +330,15 @@ class UnitreeG1(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.lowstate_buffer.get_data() is not None
|
||||
return self._lowstate is not None
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
|
||||
|
||||
@property
|
||||
def cameras(self) -> dict:
|
||||
return self._cameras
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
@@ -293,39 +381,51 @@ class UnitreeG1(Robot):
|
||||
self,
|
||||
control_dt: float | None = None,
|
||||
default_positions: list[float] | None = None,
|
||||
) -> None: # interpolate to default position
|
||||
) -> None: # move robot to default position
|
||||
if control_dt is None:
|
||||
control_dt = self.config.control_dt
|
||||
if default_positions is None:
|
||||
default_positions = np.array(self.config.default_positions, dtype=np.float32)
|
||||
|
||||
total_time = 3.0
|
||||
num_steps = int(total_time / control_dt)
|
||||
if self.config.is_simulation and self.sim_env is not None:
|
||||
self.sim_env.reset()
|
||||
|
||||
# get current state
|
||||
robot_state = self.get_observation()
|
||||
|
||||
# record current positions
|
||||
init_dof_pos = np.zeros(29, dtype=np.float32)
|
||||
for i in range(29):
|
||||
init_dof_pos[i] = robot_state.motor_state[i].q
|
||||
|
||||
# Interpolate to default position
|
||||
for step in range(num_steps):
|
||||
start_time = time.time()
|
||||
|
||||
alpha = step / num_steps
|
||||
action_dict = {}
|
||||
for motor in G1_29_JointIndex:
|
||||
target_pos = default_positions[motor.value]
|
||||
interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha
|
||||
action_dict[f"{motor.name}.q"] = float(interp_pos)
|
||||
self.msg.motor_cmd[motor.value].q = default_positions[motor.value]
|
||||
self.msg.motor_cmd[motor.value].qd = 0
|
||||
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
|
||||
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
|
||||
self.msg.motor_cmd[motor.value].tau = 0
|
||||
self.msg.crc = self.crc.Crc(self.msg)
|
||||
self.lowcmd_publisher.Write(self.msg)
|
||||
else:
|
||||
total_time = 3.0
|
||||
num_steps = int(total_time / control_dt)
|
||||
|
||||
self.send_action(action_dict)
|
||||
# get current state
|
||||
obs = self.get_observation()
|
||||
|
||||
# Maintain constant control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
# record current positions
|
||||
init_dof_pos = np.zeros(29, dtype=np.float32)
|
||||
for motor in G1_29_JointIndex:
|
||||
init_dof_pos[motor.value] = obs[f"{motor.name}.q"]
|
||||
|
||||
# Interpolate to default position
|
||||
for step in range(num_steps):
|
||||
start_time = time.time()
|
||||
|
||||
alpha = step / num_steps
|
||||
action_dict = {}
|
||||
for motor in G1_29_JointIndex:
|
||||
target_pos = default_positions[motor.value]
|
||||
interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha
|
||||
action_dict[f"{motor.name}.q"] = float(interp_pos)
|
||||
|
||||
self.send_action(action_dict)
|
||||
|
||||
# Maintain constant control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, control_dt - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
logger.info("Reached default position")
|
||||
|
||||
@@ -74,6 +74,7 @@ from lerobot.cameras import ( # noqa: F401
|
||||
)
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
@@ -103,6 +104,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
@@ -508,6 +510,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
|
||||
# reset g1 robot
|
||||
if robot.name == "unitree_g1":
|
||||
robot.reset()
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
|
||||
@@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
Reference in New Issue
Block a user