Merge branch 'main' into feature/add-multitask-dit

This commit is contained in:
Bryson Jones
2026-01-12 09:13:46 -08:00
committed by GitHub
25 changed files with 844 additions and 395 deletions
+34 -19
View File
@@ -38,6 +38,7 @@ docker run --rm -it \
start_rviz:=true start_sdk_server:=true mujoco:=true start_rviz:=true start_sdk_server:=true mujoco:=true
``` ```
> [!NOTE]
> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance: > If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance:
> >
> ``` > ```
@@ -141,7 +142,7 @@ If you choose this option but still want to use the VR teleoperation application
First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command: First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command:
```bash ```bash
python -m lerobot.record \ lerobot-record \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--robot.id=r2-0000 \ --robot.id=r2-0000 \
@@ -150,6 +151,7 @@ python -m lerobot.record \
--teleop.type=reachy2_teleoperator \ --teleop.type=reachy2_teleoperator \
--teleop.ip_address=192.168.0.200 \ --teleop.ip_address=192.168.0.200 \
--teleop.with_mobile_base=false \ --teleop.with_mobile_base=false \
--robot.with_torso_camera=true \
--dataset.repo_id=pollen_robotics/record_test \ --dataset.repo_id=pollen_robotics/record_test \
--dataset.single_task="Reachy 2 recording test" \ --dataset.single_task="Reachy 2 recording test" \
--dataset.num_episodes=1 \ --dataset.num_episodes=1 \
@@ -165,7 +167,7 @@ python -m lerobot.record \
**Extended setup overview (all options included):** **Extended setup overview (all options included):**
```bash ```bash
python -m lerobot.record \ lerobot-record \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--robot.use_external_commands=true \ --robot.use_external_commands=true \
@@ -177,6 +179,8 @@ python -m lerobot.record \
--robot.with_left_teleop_camera=true \ --robot.with_left_teleop_camera=true \
--robot.with_right_teleop_camera=true \ --robot.with_right_teleop_camera=true \
--robot.with_torso_camera=false \ --robot.with_torso_camera=false \
--robot.camera_width=640 \
--robot.camera_height=480 \
--robot.disable_torque_on_disconnect=false \ --robot.disable_torque_on_disconnect=false \
--robot.max_relative_target=5.0 \ --robot.max_relative_target=5.0 \
--teleop.type=reachy2_teleoperator \ --teleop.type=reachy2_teleoperator \
@@ -212,9 +216,10 @@ Must be set to true if a compliant Reachy 2 is used to control another one.
From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies. From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies.
To avoid this, you can exclude specific parts from recording and replay using: To avoid this, you can exclude specific parts from recording and replay using:
```` ```bash
--robot.with_<part>=false --robot.with_<part>=false
```, ```
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
It determine whether the corresponding part is recorded in the observations. True if not set. It determine whether the corresponding part is recorded in the observations. True if not set.
@@ -222,49 +227,60 @@ By default, **all parts are recorded**.
The same per-part mechanism is available in `reachy2_teleoperator` as well. The same per-part mechanism is available in `reachy2_teleoperator` as well.
```` ```bash
--teleop.with\_<part> --teleop.with\_<part>
``` ```
with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. with `<part>` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`.
Determine whether the corresponding part is recorded in the actions. True if not set. Determine whether the corresponding part is recorded in the actions. True if not set.
> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator. > **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator.
For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`. > For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`.
##### Use the relevant cameras ##### Use the relevant cameras
You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with: You can do the same for **cameras**. Enable or disable each camera with default parameters using:
```bash
--robot.with_left_teleop_camera=<true|false> \
--robot.with_right_teleop_camera=<true|false> \
--robot.with_torso_camera=<true|false>
``` ```
--robot.with_left_teleop_camera=<true|false> By default, no camera is recorded, all camera arguments are set to `false`.
--robot.with_right_teleop_camera=<true|false> If you want to, you can use custom `width` and `height` parameters for Reachy 2's cameras using the `--robot.camera_width` & `--robot.camera_height` argument:
--robot.with_torso_camera=<true|false>
```` ```bash
--robot.camera_width=1920 \
--robot.camera_height=1080
```
This will change the resolution of all 3 default robot cameras (enabled by the above bool arguments).
If you want, you can add additional cameras other than the ones in the robot as usual with:
```bash
--robot.cameras="{ extra: {type: opencv, index_or_path: 42, width: 640, height: 480, fps: 30}}" \
```
## Step 2: Replay ## Step 2: Replay
Make sure the robot is configured with the same parts as the dataset: Make sure the robot is configured with the same parts as the dataset:
```bash ```bash
python -m lerobot.replay \ lerobot-replay \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--robot.use_external_commands=false \ --robot.use_external_commands=false \
--robot.with_mobile_base=false \ --robot.with_mobile_base=false \
--dataset.repo_id=pollen_robotics/record_test \ --dataset.repo_id=pollen_robotics/record_test \
--dataset.episode=0 --dataset.episode=0
--display_data=true ```
````
## Step 3: Train ## Step 3: Train
```bash ```bash
python -m lerobot.scripts.train \ lerobot-train \
--dataset.repo_id=pollen_robotics/record_test \ --dataset.repo_id=pollen_robotics/record_test \
--policy.type=act \ --policy.type=act \
--output_dir=outputs/train/reachy2_test \ --output_dir=outputs/train/reachy2_test \
@@ -277,10 +293,9 @@ python -m lerobot.scripts.train \
## Step 4: Evaluate ## Step 4: Evaluate
```bash ```bash
python -m lerobot.record \ lerobot-eval \
--robot.type=reachy2 \ --robot.type=reachy2 \
--robot.ip_address=192.168.0.200 \ --robot.ip_address=192.168.0.200 \
--display_data=false \
--dataset.repo_id=pollen_robotics/eval_record_test \ --dataset.repo_id=pollen_robotics/eval_record_test \
--dataset.single_task="Evaluate reachy2 policy" \ --dataset.single_task="Evaluate reachy2 policy" \
--dataset.num_episodes=10 \ --dataset.num_episodes=10 \
+3 -3
View File
@@ -7,7 +7,7 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i
We support both 29 and 23 DOF G1 EDU version. We introduce: We support both 29 and 23 DOF G1 EDU version. We introduce:
- **`unitree g1` robot class, handling low level read/write from/to the humanoid** - **`unitree g1` robot class, handling low level read/write from/to the humanoid**
- **ZMQ socket bridge** for remote communication over wlan, allowing for remote policy deployment as well as over eth or directly on the Orin - **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot
- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma - **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma
- **Simulation mode** for testing policies without the physical robot in mujoco - **Simulation mode** for testing policies without the physical robot in mujoco
@@ -110,7 +110,7 @@ ssh unitree@<YOUR_ROBOT_IP>
# Password: 123 # Password: 123
``` ```
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`). Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address.
--- ---
@@ -188,7 +188,7 @@ Press `Ctrl+C` to stop the policy.
## Running in Simulation Mode (MuJoCo) ## Running in Simulation Mode (MuJoCo)
You can now test and develop policies without a physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config.
## Additional Resources ## Additional Resources
+18 -21
View File
@@ -111,34 +111,29 @@ class GrootLocomotionController:
def run_step(self): def run_step(self):
# Get current observation # Get current observation
robot_state = self.robot.get_observation() obs = self.robot.get_observation()
if robot_state is None: if not obs:
return return
# Get command from remote controller # Get command from remote controller
if robot_state.wireless_remote is not None: if obs["remote.buttons"][0]: # R1 - raise waist
self.robot.remote_controller.set(robot_state.wireless_remote)
if self.robot.remote_controller.button[0]: # R1 - raise waist
self.groot_height_cmd += 0.001 self.groot_height_cmd += 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
if self.robot.remote_controller.button[4]: # R2 - lower waist if obs["remote.buttons"][4]: # R2 - lower waist
self.groot_height_cmd -= 0.001 self.groot_height_cmd -= 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
self.cmd[0] = self.robot.remote_controller.ly # Forward/backward self.cmd[0] = obs["remote.ly"] # Forward/backward
self.cmd[1] = self.robot.remote_controller.lx * -1 # Left/right self.cmd[1] = obs["remote.lx"] * -1 # Left/right
self.cmd[2] = self.robot.remote_controller.rx * -1 # Rotation rate self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate
# Get joint positions and velocities # Get joint positions and velocities from flat dict
for i in range(29): for motor in G1_29_JointIndex:
self.groot_qj_all[i] = robot_state.motor_state[i].q name = motor.name
self.groot_dqj_all[i] = robot_state.motor_state[i].dq idx = motor.value
self.groot_qj_all[idx] = obs[f"{name}.q"]
self.groot_dqj_all[idx] = obs[f"{name}.dq"]
# Adapt observation for g1_23dof # Adapt observation for g1_23dof
for idx in MISSING_JOINTS: for idx in MISSING_JOINTS:
@@ -150,8 +145,8 @@ class GrootLocomotionController:
dqj_obs = self.groot_dqj_all.copy() dqj_obs = self.groot_dqj_all.copy()
# Express IMU data in gravity frame of reference # Express IMU data in gravity frame of reference
quat = robot_state.imu_state.quaternion quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat) gravity_orientation = self.robot.get_gravity_orientation(quat)
# Scale joint positions and velocities before policy inference # Scale joint positions and velocities before policy inference
@@ -219,6 +214,8 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
config = UnitreeG1Config() config = UnitreeG1Config()
robot = UnitreeG1(config) robot = UnitreeG1(config)
robot.connect()
# Initialize gr00T locomotion controller # Initialize gr00T locomotion controller
groot_controller = GrootLocomotionController( groot_controller = GrootLocomotionController(
policy_balance=policy_balance, policy_balance=policy_balance,
@@ -234,7 +231,7 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None:
logger.info("Press Ctrl+C to stop") logger.info("Press Ctrl+C to stop")
# Run step # Run step
while True: while not robot._shutdown_event.is_set():
start_time = time.time() start_time = time.time()
groot_controller.run_step() groot_controller.run_step()
elapsed = time.time() - start_time elapsed = time.time() - start_time
+14 -14
View File
@@ -126,24 +126,23 @@ class HolosomaLocomotionController:
def run_step(self): def run_step(self):
# Get current observation # Get current observation
robot_state = self.robot.get_observation() obs = self.robot.get_observation()
if robot_state is None: if not obs:
return return
# Get command from remote controller # Get command from remote controller
if robot_state.wireless_remote is not None: ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0
self.robot.remote_controller.set(robot_state.wireless_remote) lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0
rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
self.cmd[:] = [ly, -lx, -rx] self.cmd[:] = [ly, -lx, -rx]
# Get joint positions and velocities # Get joint positions and velocities
for i in range(29): for motor in G1_29_JointIndex:
self.qj[i] = robot_state.motor_state[i].q name = motor.name
self.dqj[i] = robot_state.motor_state[i].dq idx = motor.value
self.qj[idx] = obs[f"{name}.q"]
self.dqj[idx] = obs[f"{name}.dq"]
# Adapt observation for g1_23dof # Adapt observation for g1_23dof
for idx in MISSING_JOINTS: for idx in MISSING_JOINTS:
@@ -151,8 +150,8 @@ class HolosomaLocomotionController:
self.dqj[idx] = 0.0 self.dqj[idx] = 0.0
# Express IMU data in gravity frame of reference # Express IMU data in gravity frame of reference
quat = robot_state.imu_state.quaternion quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]]
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32)
gravity = self.robot.get_gravity_orientation(quat) gravity = self.robot.get_gravity_orientation(quat)
# Scale joint positions and velocities before policy inference # Scale joint positions and velocities before policy inference
@@ -220,6 +219,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -
# Initialize robot # Initialize robot
config = UnitreeG1Config() config = UnitreeG1Config()
robot = UnitreeG1(config) robot = UnitreeG1(config)
robot.connect()
holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd) holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd)
@@ -230,7 +230,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -
logger.info("Press Ctrl+C to stop") logger.info("Press Ctrl+C to stop")
# Run step # Run step
while True: while not robot._shutdown_event.is_set():
start_time = time.time() start_time = time.time()
holosoma_controller.run_step() holosoma_controller.run_step()
elapsed = time.time() - start_time elapsed = time.time() - start_time
+1 -1
View File
@@ -111,7 +111,7 @@ unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0", "pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0,<2.0.0" "onnxruntime>=1.16.0,<2.0.0"
] ]
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"] reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"]
kinematics = ["lerobot[placo-dep]"] kinematics = ["lerobot[placo-dep]"]
intelrealsense = [ intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'", "pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
@@ -35,18 +35,19 @@ class Reachy2CameraConfig(CameraConfig):
name="teleop", name="teleop",
image_type="left", image_type="left",
ip_address="192.168.0.200", # IP address of the robot ip_address="192.168.0.200", # IP address of the robot
fps=15, port=50065, # Port of the camera server
width=640, width=640,
height=480, height=480,
fps=30, # Not configurable for Reachy 2 cameras
color_mode=ColorMode.RGB, color_mode=ColorMode.RGB,
) # Left teleop camera, 640x480 @ 15FPS ) # Left teleop camera, 640x480 @ 30FPS
``` ```
Attributes: Attributes:
name: Name of the camera device. Can be "teleop" or "depth". name: Name of the camera device. Can be "teleop" or "depth".
image_type: Type of image stream. For "teleop" camera, can be "left" or "right". image_type: Type of image stream. For "teleop" camera, can be "left" or "right".
For "depth" camera, can be "rgb" or "depth". (depth is not supported yet) For "depth" camera, can be "rgb" or "depth". (depth is not supported yet)
fps: Requested frames per second for the color stream. fps: Requested frames per second for the color stream. Not configurable for Reachy 2 cameras.
width: Requested frame width in pixels for the color stream. width: Requested frame width in pixels for the color stream.
height: Requested frame height in pixels for the color stream. height: Requested frame height in pixels for the color stream.
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB. color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
@@ -62,7 +63,6 @@ class Reachy2CameraConfig(CameraConfig):
color_mode: ColorMode = ColorMode.RGB color_mode: ColorMode = ColorMode.RGB
ip_address: str | None = "localhost" ip_address: str | None = "localhost"
port: int = 50065 port: int = 50065
# use_depth: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.name not in ["teleop", "depth"]: if self.name not in ["teleop", "depth"]:
@@ -16,12 +16,13 @@
Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager. Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager.
""" """
from __future__ import annotations
import logging import logging
import os import os
import platform import platform
import time import time
from threading import Event, Lock, Thread from typing import TYPE_CHECKING, Any
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
@@ -30,10 +31,19 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy import numpy as np # type: ignore # TODO: add type stubs for numpy
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk from lerobot.utils.import_utils import _reachy2_sdk_available
CameraManager,
) if TYPE_CHECKING or _reachy2_sdk_available:
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
else:
CameraManager = None
class CameraView:
LEFT = 0
RIGHT = 1
from lerobot.utils.errors import DeviceNotConnectedError from lerobot.utils.errors import DeviceNotConnectedError
@@ -69,17 +79,10 @@ class Reachy2Camera(Camera):
self.config = config self.config = config
self.fps = config.fps
self.color_mode = config.color_mode self.color_mode = config.color_mode
self.cam_manager: CameraManager | None = None self.cam_manager: CameraManager | None = None
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})" return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})"
@@ -100,44 +103,23 @@ class Reachy2Camera(Camera):
def connect(self, warmup: bool = True) -> None: def connect(self, warmup: bool = True) -> None:
""" """
Connects to the Reachy2 CameraManager as specified in the configuration. Connects to the Reachy2 CameraManager as specified in the configuration.
Raises:
DeviceNotConnectedError: If the camera is not connected.
""" """
self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port) self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port)
if self.cam_manager is None:
raise DeviceNotConnectedError(f"Could not connect to {self}.")
self.cam_manager.initialize_cameras() self.cam_manager.initialize_cameras()
logger.info(f"{self} connected.") logger.info(f"{self} connected.")
@staticmethod @staticmethod
def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]: def find_cameras() -> list[dict[str, Any]]:
""" """
Detects available Reachy 2 cameras. Detection not implemented for Reachy2 cameras.
Returns:
List[Dict[str, Any]]: A list of dictionaries,
where each dictionary contains 'name', 'stereo',
and the default profile properties (width, height, fps).
""" """
initialized_cameras = [] raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.")
camera_manager = CameraManager(host=ip_address, port=port)
for camera in [camera_manager.teleop, camera_manager.depth]:
if camera is None:
continue
height, width, _, _, _, _, _ = camera.get_parameters()
camera_info = {
"name": camera._cam_info.name,
"stereo": camera._cam_info.stereo,
"default_profile": {
"width": width,
"height": height,
"fps": 30,
},
}
initialized_cameras.append(camera_info)
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
""" """
@@ -155,26 +137,32 @@ class Reachy2Camera(Camera):
(height, width, channels), using the specified or default (height, width, channels), using the specified or default
color mode and applying any configured rotation. color mode and applying any configured rotation.
""" """
start_time = time.perf_counter()
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.") raise DeviceNotConnectedError(f"{self} is not connected.")
start_time = time.perf_counter() if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8) frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
else:
if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.config.image_type == "left": if self.config.image_type == "left":
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0] frame = self.cam_manager.teleop.get_frame(
CameraView.LEFT, size=(self.config.width, self.config.height)
)[0]
elif self.config.image_type == "right": elif self.config.image_type == "right":
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0] frame = self.cam_manager.teleop.get_frame(
CameraView.RIGHT, size=(self.config.width, self.config.height)
)[0]
elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"): elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.config.image_type == "depth": if self.config.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()[0] frame = self.cam_manager.depth.get_depth_frame()[0]
elif self.config.image_type == "rgb": elif self.config.image_type == "rgb":
frame = self.cam_manager.depth.get_frame(size=(640, 480))[0] frame = self.cam_manager.depth.get_frame(size=(self.config.width, self.config.height))[0]
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
if frame is None: if frame is None:
return np.empty((0, 0, 3), dtype=np.uint8) return np.empty((0, 0, 3), dtype=np.uint8)
@@ -187,63 +175,11 @@ class Reachy2Camera(Camera):
return frame return frame
def _read_loop(self) -> None:
"""
Internal loop run by the background thread for asynchronous reading.
On each iteration:
1. Reads a color frame
2. Stores result in latest_frame (thread-safe)
3. Sets new_frame_event to notify listeners
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
with self.frame_lock:
self.latest_frame = color_image
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except Exception as e:
logger.warning(f"Error reading frame in background thread for {self}: {e}")
def _start_read_thread(self) -> None:
"""Starts or restarts the background read thread if it's not running."""
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=0.1)
if self.stop_event is not None:
self.stop_event.set()
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
""" """
Reads the latest available frame asynchronously. Reads the latest available frame.
This method retrieves the most recent frame captured by the background This method retrieves the most recent frame available in Reachy 2's low-level software.
read thread. It does not block waiting for the camera hardware directly,
but may wait up to timeout_ms for the background thread to provide a frame.
Args: Args:
timeout_ms (float): Maximum time in milliseconds to wait for a frame timeout_ms (float): Maximum time in milliseconds to wait for a frame
@@ -261,22 +197,10 @@ class Reachy2Camera(Camera):
if not self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.") raise DeviceNotConnectedError(f"{self} is not connected.")
if self.thread is None or not self.thread.is_alive(): frame = self.read()
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
thread_alive = self.thread is not None and self.thread.is_alive()
raise TimeoutError(
f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
f"Read thread alive: {thread_alive}."
)
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None: if frame is None:
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") raise RuntimeError(f"Internal error: No frame available for {self}.")
return frame return frame
@@ -287,12 +211,9 @@ class Reachy2Camera(Camera):
Raises: Raises:
DeviceNotConnectedError: If the camera is already disconnected. DeviceNotConnectedError: If the camera is already disconnected.
""" """
if not self.is_connected and self.thread is None: if not self.is_connected:
raise DeviceNotConnectedError(f"{self} not connected.") raise DeviceNotConnectedError(f"{self} not connected.")
if self.thread is not None:
self._stop_read_thread()
if self.cam_manager is not None: if self.cam_manager is not None:
self.cam_manager.disconnect() self.cam_manager.disconnect()
+5
View File
@@ -43,6 +43,11 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
cameras[key] = Reachy2Camera(cfg) cameras[key] = Reachy2Camera(cfg)
elif cfg.type == "zmq":
from .zmq.camera_zmq import ZMQCamera
cameras[key] = ZMQCamera(cfg)
else: else:
try: try:
cameras[key] = cast(Camera, make_device_from_device_class(cfg)) cameras[key] = cast(Camera, make_device_from_device_class(cfg))
+20
View File
@@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .camera_zmq import ZMQCamera
from .configuration_zmq import ZMQCameraConfig
__all__ = ["ZMQCamera", "ZMQCameraConfig"]
+235
View File
@@ -0,0 +1,235 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ZMQCamera - Captures frames from remote cameras via ZeroMQ using JSON protocol in the
following format:
{
"timestamps": {"camera_name": float},
"images": {"camera_name": "<base64-jpeg>"}
}
"""
import base64
import json
import logging
import time
from threading import Event, Lock, Thread
from typing import Any
import cv2
import numpy as np
from numpy.typing import NDArray
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
from .configuration_zmq import ZMQCameraConfig
logger = logging.getLogger(__name__)
class ZMQCamera(Camera):
"""
Example usage:
```python
from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig
config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera")
camera = ZMQCamera(config)
camera.connect()
frame = camera.read()
camera.disconnect()
```
"""
def __init__(self, config: ZMQCameraConfig):
super().__init__(config)
import zmq
self.config = config
self.server_address = config.server_address
self.port = config.port
self.camera_name = config.camera_name
self.color_mode = config.color_mode
self.timeout_ms = config.timeout_ms
self.context: zmq.Context | None = None
self.socket: zmq.Socket | None = None
self._connected = False
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
return f"ZMQCamera({self.camera_name}@{self.server_address}:{self.port})"
@property
def is_connected(self) -> bool:
return self._connected and self.context is not None and self.socket is not None
def connect(self, warmup: bool = True) -> None:
"""Connect to ZMQ camera server."""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
logger.info(f"Connecting to {self}...")
try:
import zmq
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
self.socket.setsockopt(zmq.CONFLATE, True)
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
self._connected = True
# Auto-detect resolution
if self.width is None or self.height is None:
h, w = self.read().shape[:2]
self.height = h
self.width = w
logger.info(f"{self} resolution: {w}x{h}")
logger.info(f"{self} connected.")
if warmup:
time.sleep(0.1)
except Exception as e:
self._cleanup()
raise RuntimeError(f"Failed to connect to {self}: {e}") from e
def _cleanup(self):
"""Clean up ZMQ resources."""
self._connected = False
if self.socket:
self.socket.close()
self.socket = None
if self.context:
self.context.term()
self.context = None
@staticmethod
def find_cameras() -> list[dict[str, Any]]:
"""ZMQ cameras require manual configuration (server address/port)."""
return []
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
"""
Read a single frame from the ZMQ camera.
Returns:
np.ndarray: Decoded frame (height, width, 3)
"""
if not self.is_connected or self.socket is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
try:
message = self.socket.recv_string()
except Exception as e:
if type(e).__name__ == "Again":
raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e
raise
# Decode JSON message
data = json.loads(message)
if "images" not in data:
raise RuntimeError(f"{self} invalid message: missing 'images' key")
images = data["images"]
# Get image by camera name or first available
if self.camera_name in images:
img_b64 = images[self.camera_name]
elif images:
img_b64 = next(iter(images.values()))
else:
raise RuntimeError(f"{self} no images in message")
# Decode base64 JPEG
img_bytes = base64.b64decode(img_b64)
frame = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
if frame is None:
raise RuntimeError(f"{self} failed to decode image")
return frame
def _read_loop(self) -> None:
while self.stop_event and not self.stop_event.is_set():
try:
frame = self.read()
with self.frame_lock:
self.latest_frame = frame
self.new_frame_event.set()
except DeviceNotConnectedError:
break
except TimeoutError:
pass
except Exception as e:
logger.warning(f"Read error: {e}")
def _start_read_thread(self) -> None:
if self.thread and self.thread.is_alive():
return
self.stop_event = Event()
self.thread = Thread(target=self._read_loop, daemon=True)
self.thread.start()
def _stop_read_thread(self) -> None:
if self.stop_event:
self.stop_event.set()
if self.thread and self.thread.is_alive():
self.thread.join(timeout=2.0)
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
"""Read latest frame asynchronously (non-blocking)."""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
if not self.thread or not self.thread.is_alive():
self._start_read_thread()
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms")
with self.frame_lock:
frame = self.latest_frame
self.new_frame_event.clear()
if frame is None:
raise RuntimeError(f"{self} no frame available")
return frame
def disconnect(self) -> None:
"""Disconnect from ZMQ camera."""
if not self.is_connected and not self.thread:
raise DeviceNotConnectedError(f"{self} not connected.")
self._stop_read_thread()
self._cleanup()
logger.info(f"{self} disconnected.")
@@ -0,0 +1,46 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
__all__ = ["ZMQCameraConfig", "ColorMode"]
@CameraConfig.register_subclass("zmq")
@dataclass
class ZMQCameraConfig(CameraConfig):
server_address: str
port: int = 5555
camera_name: str = "zmq_camera"
color_mode: ColorMode = ColorMode.RGB
timeout_ms: int = 5000
def __post_init__(self) -> None:
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
)
if self.timeout_ms <= 0:
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
if not self.server_address:
raise ValueError("`server_address` cannot be empty.")
if self.port <= 0 or self.port > 65535:
raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.")
+114
View File
@@ -0,0 +1,114 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Streams camera images over ZMQ.
Uses lerobot's OpenCVCamera for capture, encodes images to base64 and sends them over ZMQ.
"""
import base64
import contextlib
import json
import logging
import time
from collections import deque
import cv2
import numpy as np
import zmq
from lerobot.cameras.configs import ColorMode
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
logger = logging.getLogger(__name__)
def encode_image(image: np.ndarray, quality: int = 80) -> str:
"""Encode RGB image to base64 JPEG string."""
_, buffer = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
return base64.b64encode(buffer).decode("utf-8")
class ImageServer:
def __init__(self, config: dict, port: int = 5555):
self.fps = config.get("fps", 30)
self.cameras: dict[str, OpenCVCamera] = {}
for name, cfg in config.get("cameras", {}).items():
shape = cfg.get("shape", [480, 640])
cam_config = OpenCVCameraConfig(
index_or_path=cfg.get("device_id", 0),
fps=self.fps,
width=shape[1],
height=shape[0],
color_mode=ColorMode.RGB,
)
camera = OpenCVCamera(cam_config)
camera.connect()
self.cameras[name] = camera
logger.info(f"Camera {name}: {shape[1]}x{shape[0]}")
# ZMQ PUB socket
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.setsockopt(zmq.SNDHWM, 20)
self.socket.setsockopt(zmq.LINGER, 0)
self.socket.bind(f"tcp://*:{port}")
logger.info(f"ImageServer running on port {port}")
def run(self):
frame_count = 0
frame_times = deque(maxlen=60)
try:
while True:
t0 = time.time()
# Build message
message = {"timestamps": {}, "images": {}}
for name, cam in self.cameras.items():
frame = cam.read() # Returns RGB
message["timestamps"][name] = time.time()
message["images"][name] = encode_image(frame)
# Send as JSON string (suppress if buffer full)
with contextlib.suppress(zmq.Again):
self.socket.send_string(json.dumps(message), zmq.NOBLOCK)
frame_count += 1
frame_times.append(time.time() - t0)
if frame_count % 60 == 0:
logger.debug(f"FPS: {len(frame_times) / sum(frame_times):.1f}")
sleep = (1.0 / self.fps) - (time.time() - t0)
if sleep > 0:
time.sleep(sleep)
except KeyboardInterrupt:
pass
finally:
for cam in self.cameras.values():
cam.disconnect()
self.socket.close()
self.context.term()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
config = {"fps": 30, "cameras": {"head_camera": {"device_id": 4, "shape": [480, 640]}}}
ImageServer(config, port=5555).run()
@@ -30,6 +30,8 @@ class Reachy2RobotConfig(RobotConfig):
# IP address of the Reachy 2 robot # IP address of the Reachy 2 robot
ip_address: str | None = "localhost" ip_address: str | None = "localhost"
# Port of the Reachy 2 robot
port: int = 50065
# If True, turn_off_smoothly() will be sent to the robot before disconnecting. # If True, turn_off_smoothly() will be sent to the robot before disconnecting.
disable_torque_on_disconnect: bool = False disable_torque_on_disconnect: bool = False
@@ -51,11 +53,16 @@ class Reachy2RobotConfig(RobotConfig):
# Robot cameras # Robot cameras
# Set to True if you want to use the corresponding cameras in the observations. # Set to True if you want to use the corresponding cameras in the observations.
# By default, only the teleop cameras are used. # By default, no camera is used.
with_left_teleop_camera: bool = True with_left_teleop_camera: bool = False
with_right_teleop_camera: bool = True with_right_teleop_camera: bool = False
with_torso_camera: bool = False with_torso_camera: bool = False
# Camera parameters
camera_width: int = 640
camera_height: int = 480
# For cameras other than the 3 default Reachy 2 cameras.
cameras: dict[str, CameraConfig] = field(default_factory=dict) cameras: dict[str, CameraConfig] = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
@@ -65,9 +72,10 @@ class Reachy2RobotConfig(RobotConfig):
name="teleop", name="teleop",
image_type="left", image_type="left",
ip_address=self.ip_address, ip_address=self.ip_address,
fps=15, port=self.port,
width=640, width=self.camera_width,
height=480, height=self.camera_height,
fps=30, # Not configurable for Reachy 2 cameras
color_mode=ColorMode.RGB, color_mode=ColorMode.RGB,
) )
if self.with_right_teleop_camera: if self.with_right_teleop_camera:
@@ -75,9 +83,10 @@ class Reachy2RobotConfig(RobotConfig):
name="teleop", name="teleop",
image_type="right", image_type="right",
ip_address=self.ip_address, ip_address=self.ip_address,
fps=15, port=self.port,
width=640, width=self.camera_width,
height=480, height=self.camera_height,
fps=30, # Not configurable for Reachy 2 cameras
color_mode=ColorMode.RGB, color_mode=ColorMode.RGB,
) )
if self.with_torso_camera: if self.with_torso_camera:
@@ -85,9 +94,10 @@ class Reachy2RobotConfig(RobotConfig):
name="depth", name="depth",
image_type="rgb", image_type="rgb",
ip_address=self.ip_address, ip_address=self.ip_address,
fps=15, port=self.port,
width=640, width=self.camera_width,
height=480, height=self.camera_height,
fps=30, # Not configurable for Reachy 2 cameras
color_mode=ColorMode.RGB, color_mode=ColorMode.RGB,
) )
+8 -2
View File
@@ -13,19 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import time import time
from typing import Any from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
from reachy2_sdk import ReachySDK
from lerobot.cameras.utils import make_cameras_from_configs from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.utils.import_utils import _reachy2_sdk_available
from ..robot import Robot from ..robot import Robot
from ..utils import ensure_safe_goal_position from ..utils import ensure_safe_goal_position
from .configuration_reachy2 import Reachy2RobotConfig from .configuration_reachy2 import Reachy2RobotConfig
if TYPE_CHECKING or _reachy2_sdk_available:
from reachy2_sdk import ReachySDK
else:
ReachySDK = None
# {lerobot_keys: reachy2_sdk_keys} # {lerobot_keys: reachy2_sdk_keys}
REACHY2_NECK_JOINTS = { REACHY2_NECK_JOINTS = {
"neck_yaw.pos": "head.neck.yaw", "neck_yaw.pos": "head.neck.yaw",
@@ -16,6 +16,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
from ..config import RobotConfig from ..config import RobotConfig
_GAINS: dict[str, dict[str, list[float]]] = { _GAINS: dict[str, dict[str, list[float]]] = {
@@ -60,3 +62,6 @@ class UnitreeG1Config(RobotConfig):
# Socket config for ZMQ bridge # Socket config for ZMQ bridge
robot_ip: str = "192.168.123.164" # default G1 IP robot_ip: str = "192.168.123.164" # default G1 IP
# Cameras (ZMQ-based remote cameras)
cameras: dict[str, CameraConfig] = field(default_factory=dict)
+185 -85
View File
@@ -23,13 +23,8 @@ from functools import cached_property
from typing import Any from typing import Any
import numpy as np import numpy as np
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
LowCmd_ as hg_LowCmd,
LowState_ as hg_LowState,
)
from unitree_sdk2py.utils.crc import CRC
from lerobot.cameras.utils import make_cameras_from_configs
from lerobot.envs.factory import make_env from lerobot.envs.factory import make_env
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
@@ -43,8 +38,6 @@ logger = logging.getLogger(__name__)
kTopicLowCommand_Debug = "rt/lowcmd" kTopicLowCommand_Debug = "rt/lowcmd"
kTopicLowState = "rt/lowstate" kTopicLowState = "rt/lowstate"
G1_29_Num_Motors = 29
@dataclass @dataclass
class MotorState: class MotorState:
@@ -66,28 +59,12 @@ class IMUState:
# g1 observation class # g1 observation class
@dataclass @dataclass
class G1_29_LowState: # noqa: N801 class G1_29_LowState: # noqa: N801
motor_state: list[MotorState] = field( motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex])
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
)
imu_state: IMUState = field(default_factory=IMUState) imu_state: IMUState = field(default_factory=IMUState)
wireless_remote: Any = None # Raw wireless remote data wireless_remote: Any = None # Raw wireless remote data
mode_machine: int = 0 # Robot mode mode_machine: int = 0 # Robot mode
class DataBuffer:
def __init__(self):
self.data = None
self.lock = threading.Lock()
def get_data(self):
with self.lock:
return self.data
def set_data(self, data):
with self.lock:
self.data = data
class UnitreeG1(Robot): class UnitreeG1(Robot):
config_class = UnitreeG1Config config_class = UnitreeG1Config
name = "unitree_g1" name = "unitree_g1"
@@ -117,9 +94,12 @@ class UnitreeG1(Robot):
logger.info("Initialize UnitreeG1...") logger.info("Initialize UnitreeG1...")
self.config = config self.config = config
self.control_dt = config.control_dt self.control_dt = config.control_dt
# Initialize cameras config (ZMQ-based) - actual connection in connect()
self._cameras = make_cameras_from_configs(config.cameras)
# Import channel classes based on mode
if config.is_simulation: if config.is_simulation:
from unitree_sdk2py.core.channel import ( from unitree_sdk2py.core.channel import (
ChannelFactoryInitialize, ChannelFactoryInitialize,
@@ -133,62 +113,33 @@ class UnitreeG1(Robot):
ChannelSubscriber, ChannelSubscriber,
) )
# connect robot # Store for use in connect()
self.ChannelFactoryInitialize = ChannelFactoryInitialize self._ChannelFactoryInitialize = ChannelFactoryInitialize
self.connect() self._ChannelPublisher = ChannelPublisher
self._ChannelSubscriber = ChannelSubscriber
# initialize direct motor control interface # Initialize state variables
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd) self.sim_env = None
self.lowcmd_publisher.Init() self._env_wrapper = None
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState) self._lowstate = None
self.lowstate_subscriber.Init()
self.lowstate_buffer = DataBuffer()
# initialize subscribe thread to read robot state
self._shutdown_event = threading.Event() self._shutdown_event = threading.Event()
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state) self.subscribe_thread = None
self.subscribe_thread.start()
while not self.is_connected:
time.sleep(0.1)
# initialize hg's lowcmd msg
self.crc = CRC()
self.msg = unitree_hg_msg_dds__LowCmd_()
self.msg.mode_pr = 0
# Wait for first state message to arrive
lowstate = None
while lowstate is None:
lowstate = self.lowstate_buffer.get_data()
if lowstate is None:
time.sleep(0.01)
logger.warning("[UnitreeG1] Waiting for robot state...")
logger.warning("[UnitreeG1] Connected to robot.")
self.msg.mode_machine = lowstate.mode_machine
# initialize all motors with unified kp/kd from config
self.kp = np.array(config.kp, dtype=np.float32)
self.kd = np.array(config.kd, dtype=np.float32)
for id in G1_29_JointIndex:
self.msg.motor_cmd[id].mode = 1
self.msg.motor_cmd[id].kp = self.kp[id.value]
self.msg.motor_cmd[id].kd = self.kd[id.value]
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
# Initialize remote controller
self.remote_controller = self.RemoteController() self.remote_controller = self.RemoteController()
def _subscribe_motor_state(self): # polls robot state @ 250Hz def _subscribe_motor_state(self): # polls robot state @ 250Hz
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
start_time = time.time() start_time = time.time()
# Step simulation if in simulation mode
if self.config.is_simulation and self.sim_env is not None:
self.sim_env.step()
msg = self.lowstate_subscriber.Read() msg = self.lowstate_subscriber.Read()
if msg is not None: if msg is not None:
lowstate = G1_29_LowState() lowstate = G1_29_LowState()
# Capture motor states # Capture motor states using jointindex
for id in range(G1_29_Num_Motors): for id in G1_29_JointIndex:
lowstate.motor_state[id].q = msg.motor_state[id].q lowstate.motor_state[id].q = msg.motor_state[id].q
lowstate.motor_state[id].dq = msg.motor_state[id].dq lowstate.motor_state[id].dq = msg.motor_state[id].dq
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
@@ -207,7 +158,7 @@ class UnitreeG1(Robot):
# Capture mode_machine # Capture mode_machine
lowstate.mode_machine = msg.mode_machine lowstate.mode_machine = msg.mode_machine
self.lowstate_buffer.set_data(lowstate) self._lowstate = lowstate
current_time = time.time() current_time = time.time()
all_t_elapsed = current_time - start_time all_t_elapsed = current_time - start_time
@@ -216,7 +167,7 @@ class UnitreeG1(Robot):
@cached_property @cached_property
def action_features(self) -> dict[str, type]: def action_features(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex} return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
def calibrate(self) -> None: # robot is already calibrated def calibrate(self) -> None: # robot is already calibrated
pass pass
@@ -225,20 +176,153 @@ class UnitreeG1(Robot):
pass pass
def connect(self, calibrate: bool = True) -> None: # connect to DDS def connect(self, calibrate: bool = True) -> None: # connect to DDS
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
LowCmd_ as hg_LowCmd,
LowState_ as hg_LowState,
)
from unitree_sdk2py.utils.crc import CRC
# Initialize DDS channel and simulation environment
if self.config.is_simulation: if self.config.is_simulation:
self.ChannelFactoryInitialize(0, "lo") self._ChannelFactoryInitialize(0, "lo")
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
# Extract the actual gym env from the dict structure
self.sim_env = self._env_wrapper["hub_env"][0].envs[0]
else: else:
self.ChannelFactoryInitialize(0) self._ChannelFactoryInitialize(0)
# Initialize direct motor control interface
self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
self.lowcmd_publisher.Init()
self.lowstate_subscriber = self._ChannelSubscriber(kTopicLowState, hg_LowState)
self.lowstate_subscriber.Init()
# Start subscribe thread to read robot state
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
self.subscribe_thread.start()
# Connect cameras
for cam in self._cameras.values():
if not cam.is_connected:
cam.connect()
logger.info(f"Connected {len(self._cameras)} camera(s).")
# Initialize lowcmd message
self.crc = CRC()
self.msg = unitree_hg_msg_dds__LowCmd_()
self.msg.mode_pr = 0
# Wait for first state message to arrive
lowstate = None
while lowstate is None:
lowstate = self._lowstate
if lowstate is None:
time.sleep(0.01)
logger.warning("[UnitreeG1] Waiting for robot state...")
logger.warning("[UnitreeG1] Connected to robot.")
self.msg.mode_machine = lowstate.mode_machine
# Initialize all motors with unified kp/kd from config
self.kp = np.array(self.config.kp, dtype=np.float32)
self.kd = np.array(self.config.kd, dtype=np.float32)
for id in G1_29_JointIndex:
self.msg.motor_cmd[id].mode = 1
self.msg.motor_cmd[id].kp = self.kp[id.value]
self.msg.motor_cmd[id].kd = self.kd[id.value]
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
def disconnect(self): def disconnect(self):
# Signal thread to stop and unblock any waits
self._shutdown_event.set() self._shutdown_event.set()
# Wait for subscribe thread to finish
if self.subscribe_thread is not None:
self.subscribe_thread.join(timeout=2.0) self.subscribe_thread.join(timeout=2.0)
if self.config.is_simulation: if self.subscribe_thread.is_alive():
self.mujoco_env["hub_env"][0].envs[0].kill_sim() logger.warning("Subscribe thread did not stop cleanly")
# Close simulation environment
if self.config.is_simulation and self.sim_env is not None:
try:
# Force-kill the image publish subprocess first to avoid long waits
if hasattr(self.sim_env, "simulator") and hasattr(self.sim_env.simulator, "sim_env"):
sim_env_inner = self.sim_env.simulator.sim_env
if hasattr(sim_env_inner, "image_publish_process"):
proc = sim_env_inner.image_publish_process
if proc.process and proc.process.is_alive():
logger.info("Force-terminating image publish subprocess...")
proc.stop_event.set()
proc.process.terminate()
proc.process.join(timeout=1)
if proc.process.is_alive():
proc.process.kill()
self.sim_env.close()
except Exception as e:
logger.warning(f"Error closing sim_env: {e}")
self.sim_env = None
self._env_wrapper = None
# Disconnect cameras
for cam in self._cameras.values():
cam.disconnect()
def get_observation(self) -> dict[str, Any]: def get_observation(self) -> dict[str, Any]:
return self.lowstate_buffer.get_data() lowstate = self._lowstate
if lowstate is None:
return {}
obs = {}
# Motors - q, dq, tau for all joints
for motor in G1_29_JointIndex:
name = motor.name
idx = motor.value
obs[f"{name}.q"] = lowstate.motor_state[idx].q
obs[f"{name}.dq"] = lowstate.motor_state[idx].dq
obs[f"{name}.tau"] = lowstate.motor_state[idx].tau_est
# IMU - gyroscope
if lowstate.imu_state.gyroscope:
obs["imu.gyro.x"] = lowstate.imu_state.gyroscope[0]
obs["imu.gyro.y"] = lowstate.imu_state.gyroscope[1]
obs["imu.gyro.z"] = lowstate.imu_state.gyroscope[2]
# IMU - accelerometer
if lowstate.imu_state.accelerometer:
obs["imu.accel.x"] = lowstate.imu_state.accelerometer[0]
obs["imu.accel.y"] = lowstate.imu_state.accelerometer[1]
obs["imu.accel.z"] = lowstate.imu_state.accelerometer[2]
# IMU - quaternion
if lowstate.imu_state.quaternion:
obs["imu.quat.w"] = lowstate.imu_state.quaternion[0]
obs["imu.quat.x"] = lowstate.imu_state.quaternion[1]
obs["imu.quat.y"] = lowstate.imu_state.quaternion[2]
obs["imu.quat.z"] = lowstate.imu_state.quaternion[3]
# IMU - rpy
if lowstate.imu_state.rpy:
obs["imu.rpy.roll"] = lowstate.imu_state.rpy[0]
obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1]
obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2]
# Controller - parse wireless_remote and add to obs
if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24:
self.remote_controller.set(lowstate.wireless_remote)
obs["remote.buttons"] = self.remote_controller.button.copy()
obs["remote.lx"] = self.remote_controller.lx
obs["remote.ly"] = self.remote_controller.ly
obs["remote.rx"] = self.remote_controller.rx
obs["remote.ry"] = self.remote_controller.ry
# Cameras - read images from ZMQ cameras
for cam_name, cam in self._cameras.items():
obs[cam_name] = cam.async_read()
return obs
@property @property
def is_calibrated(self) -> bool: def is_calibrated(self) -> bool:
@@ -246,11 +330,15 @@ class UnitreeG1(Robot):
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.lowstate_buffer.get_data() is not None return self._lowstate is not None
@property @property
def _motors_ft(self) -> dict[str, type]: def _motors_ft(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex} return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex}
@property
def cameras(self) -> dict:
return self._cameras
@property @property
def _cameras_ft(self) -> dict[str, tuple]: def _cameras_ft(self) -> dict[str, tuple]:
@@ -293,22 +381,34 @@ class UnitreeG1(Robot):
self, self,
control_dt: float | None = None, control_dt: float | None = None,
default_positions: list[float] | None = None, default_positions: list[float] | None = None,
) -> None: # interpolate to default position ) -> None: # move robot to default position
if control_dt is None: if control_dt is None:
control_dt = self.config.control_dt control_dt = self.config.control_dt
if default_positions is None: if default_positions is None:
default_positions = np.array(self.config.default_positions, dtype=np.float32) default_positions = np.array(self.config.default_positions, dtype=np.float32)
if self.config.is_simulation and self.sim_env is not None:
self.sim_env.reset()
for motor in G1_29_JointIndex:
self.msg.motor_cmd[motor.value].q = default_positions[motor.value]
self.msg.motor_cmd[motor.value].qd = 0
self.msg.motor_cmd[motor.value].kp = self.kp[motor.value]
self.msg.motor_cmd[motor.value].kd = self.kd[motor.value]
self.msg.motor_cmd[motor.value].tau = 0
self.msg.crc = self.crc.Crc(self.msg)
self.lowcmd_publisher.Write(self.msg)
else:
total_time = 3.0 total_time = 3.0
num_steps = int(total_time / control_dt) num_steps = int(total_time / control_dt)
# get current state # get current state
robot_state = self.get_observation() obs = self.get_observation()
# record current positions # record current positions
init_dof_pos = np.zeros(29, dtype=np.float32) init_dof_pos = np.zeros(29, dtype=np.float32)
for i in range(29): for motor in G1_29_JointIndex:
init_dof_pos[i] = robot_state.motor_state[i].q init_dof_pos[motor.value] = obs[f"{motor.name}.q"]
# Interpolate to default position # Interpolate to default position
for step in range(num_steps): for step in range(num_steps):
+10
View File
@@ -73,7 +73,9 @@ from lerobot.cameras import ( # noqa: F401
CameraConfig, # noqa: F401 CameraConfig, # noqa: F401
) )
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.reachy2_camera.configuration_reachy2_camera import Reachy2CameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.image_writer import safe_stop_image_writer
@@ -102,7 +104,9 @@ from lerobot.robots import ( # noqa: F401
koch_follower, koch_follower,
make_robot_from_config, make_robot_from_config,
omx_follower, omx_follower,
reachy2,
so_follower, so_follower,
unitree_g1,
) )
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
Teleoperator, Teleoperator,
@@ -112,6 +116,7 @@ from lerobot.teleoperators import ( # noqa: F401
koch_leader, koch_leader,
make_teleoperator_from_config, make_teleoperator_from_config,
omx_leader, omx_leader,
reachy2_teleoperator,
so_leader, so_leader,
) )
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
@@ -508,6 +513,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
): ):
log_say("Reset the environment", cfg.play_sounds) log_say("Reset the environment", cfg.play_sounds)
# reset g1 robot
if robot.name == "unitree_g1":
robot.reset()
record_loop( record_loop(
robot=robot, robot=robot,
events=events, events=events,
+2
View File
@@ -59,7 +59,9 @@ from lerobot.robots import ( # noqa: F401
koch_follower, koch_follower,
make_robot_from_config, make_robot_from_config,
omx_follower, omx_follower,
reachy2,
so_follower, so_follower,
unitree_g1,
) )
from lerobot.utils.constants import ACTION from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.import_utils import register_third_party_plugins
@@ -76,6 +76,7 @@ from lerobot.robots import ( # noqa: F401
koch_follower, koch_follower,
make_robot_from_config, make_robot_from_config,
omx_follower, omx_follower,
reachy2,
so_follower, so_follower,
) )
from lerobot.teleoperators import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401
@@ -88,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
koch_leader, koch_leader,
make_teleoperator_from_config, make_teleoperator_from_config,
omx_leader, omx_leader,
reachy2_teleoperator,
so_leader, so_leader,
) )
from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.import_utils import register_third_party_plugins
@@ -13,11 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import logging import logging
import time import time
from typing import TYPE_CHECKING
from reachy2_sdk import ReachySDK from lerobot.utils.import_utils import _reachy2_sdk_available
if TYPE_CHECKING or _reachy2_sdk_available:
from reachy2_sdk import ReachySDK
else:
ReachySDK = None
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..teleoperator import Teleoperator from ..teleoperator import Teleoperator
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
@@ -75,6 +84,7 @@ class Reachy2Teleoperator(Teleoperator):
def __init__(self, config: Reachy2TeleoperatorConfig): def __init__(self, config: Reachy2TeleoperatorConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.reachy: None | ReachySDK = None self.reachy: None | ReachySDK = None
@@ -117,9 +127,13 @@ class Reachy2Teleoperator(Teleoperator):
return self.reachy.is_connected() if self.reachy is not None else False return self.reachy.is_connected() if self.reachy is not None else False
def connect(self, calibrate: bool = True) -> None: def connect(self, calibrate: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.reachy = ReachySDK(self.config.ip_address) self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected: if not self.is_connected:
raise ConnectionError() raise DeviceNotConnectedError()
logger.info(f"{self} connected.") logger.info(f"{self} connected.")
@property @property
@@ -135,19 +149,20 @@ class Reachy2Teleoperator(Teleoperator):
def get_action(self) -> dict[str, float]: def get_action(self) -> dict[str, float]:
start = time.perf_counter() start = time.perf_counter()
if self.reachy and self.is_connected: if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
joint_action: dict[str, float] = {}
vel_action: dict[str, float] = {}
if self.config.use_present_position: if self.config.use_present_position:
joint_action = { joint_action = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()
}
else: else:
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()} joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base: if not self.config.with_mobile_base:
dt_ms = (time.perf_counter() - start) * 1e3 dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms") logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return joint_action return joint_action
if self.config.use_present_position: if self.config.use_present_position:
vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()} vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
else: else:
@@ -160,5 +175,5 @@ class Reachy2Teleoperator(Teleoperator):
raise NotImplementedError raise NotImplementedError
def disconnect(self) -> None: def disconnect(self) -> None:
if self.reachy and self.is_connected: if self.is_connected:
self.reachy.disconnect() self.reachy.disconnect()
+1
View File
@@ -64,6 +64,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
_transformers_available = is_package_available("transformers") _transformers_available = is_package_available("transformers")
_peft_available = is_package_available("peft") _peft_available = is_package_available("peft")
_scipy_available = is_package_available("scipy") _scipy_available = is_package_available("scipy")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
def make_device_from_device_class(config: ChoiceRegistry) -> Any: def make_device_from_device_class(config: ChoiceRegistry) -> Any:
+2 -12
View File
@@ -20,6 +20,8 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
pytest.importorskip("reachy2_sdk")
from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig
from lerobot.utils.errors import DeviceNotConnectedError from lerobot.utils.errors import DeviceNotConnectedError
@@ -127,24 +129,12 @@ def test_async_read(camera):
try: try:
img = camera.async_read() img = camera.async_read()
assert camera.thread is not None
assert camera.thread.is_alive()
assert isinstance(img, np.ndarray) assert isinstance(img, np.ndarray)
finally: finally:
if camera.is_connected: if camera.is_connected:
camera.disconnect() camera.disconnect()
def test_async_read_timeout(camera):
camera.connect()
try:
with pytest.raises(TimeoutError):
camera.async_read(timeout_ms=0)
finally:
if camera.is_connected:
camera.disconnect()
def test_read_before_connect(camera): def test_read_before_connect(camera):
with pytest.raises(DeviceNotConnectedError): with pytest.raises(DeviceNotConnectedError):
_ = camera.read() _ = camera.read()
-1
View File
@@ -28,7 +28,6 @@ pytest_plugins = [
"tests.fixtures.files", "tests.fixtures.files",
"tests.fixtures.hub", "tests.fixtures.hub",
"tests.fixtures.optimizers", "tests.fixtures.optimizers",
"tests.plugins.reachy2_sdk",
] ]
-46
View File
@@ -1,46 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import types
from unittest.mock import MagicMock
def _install_reachy2_sdk_stub():
sdk = types.ModuleType("reachy2_sdk")
sdk.__path__ = []
sdk.ReachySDK = MagicMock(name="ReachySDK")
media = types.ModuleType("reachy2_sdk.media")
media.__path__ = []
camera = types.ModuleType("reachy2_sdk.media.camera")
camera.CameraView = MagicMock(name="CameraView")
camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager")
camera_manager.CameraManager = MagicMock(name="CameraManager")
sdk.media = media
media.camera = camera
media.camera_manager = camera_manager
# Register in sys.modules
sys.modules.setdefault("reachy2_sdk", sdk)
sys.modules.setdefault("reachy2_sdk.media", media)
sys.modules.setdefault("reachy2_sdk.media.camera", camera)
sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager)
def pytest_sessionstart(session):
_install_reachy2_sdk_stub()
+2
View File
@@ -19,6 +19,8 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
pytest.importorskip("reachy2_sdk")
from lerobot.robots.reachy2 import ( from lerobot.robots.reachy2 import (
REACHY2_ANTENNAS_JOINTS, REACHY2_ANTENNAS_JOINTS,
REACHY2_L_ARM_JOINTS, REACHY2_L_ARM_JOINTS,