From 60c342ad2d6d6f31ca63bc0b62792b83d70aab0c Mon Sep 17 00:00:00 2001 From: glannuzel Date: Tue, 19 Aug 2025 15:26:41 +0200 Subject: [PATCH] Clean files and add types --- .../robots/reachy2/configuration_reachy2.py | 5 +- src/lerobot/robots/reachy2/robot_reachy2.py | 104 ++++++++++-------- 2 files changed, 64 insertions(+), 45 deletions(-) diff --git a/src/lerobot/robots/reachy2/configuration_reachy2.py b/src/lerobot/robots/reachy2/configuration_reachy2.py index 6fd76fde2..300e05071 100644 --- a/src/lerobot/robots/reachy2/configuration_reachy2.py +++ b/src/lerobot/robots/reachy2/configuration_reachy2.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass +from typing import Dict from lerobot.cameras import CameraConfig from lerobot.cameras.configs import ColorMode, Cv2Rotation @@ -51,9 +52,9 @@ class Reachy2RobotConfig(RobotConfig): with_right_teleop_camera: bool = True with_torso_camera: bool = False - def __post_init__(self): + def __post_init__(self) -> None: # Add cameras - self.cameras: dict[str, CameraConfig] = {} + self.cameras: Dict[str, CameraConfig] = {} if self.with_left_teleop_camera: self.cameras["teleop_left"] = Reachy2CameraConfig( name="teleop", diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 0d4e96bc5..7f6984c21 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -15,7 +15,7 @@ # limitations under the License. import time -from typing import Any +from typing import Any, Dict, Optional, Tuple import numpy as np from lerobot.cameras.utils import make_cameras_from_configs @@ -83,24 +83,27 @@ class Reachy2Robot(Robot): self.reachy: None | ReachySDK = None self.cameras = make_cameras_from_configs(config.cameras) - self.logs = {} + self.logs: Dict[str, float] = {} - self.joints_dict: dict[str, str] = self._generate_joints_dict() + self.joints_dict: Dict[str, str] = self._generate_joints_dict() @property - def observation_features(self) -> dict: + def observation_features(self) -> Dict[str, Any]: return {**self.motors_features, **self.camera_features} @property - def action_features(self) -> dict: + def action_features(self) -> Dict[str, type]: return self.motors_features @property - def camera_features(self) -> dict[str, dict]: - return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras} + def camera_features(self) -> Dict[str, Tuple[Optional[int], Optional[int], int]]: + return { + cam: (self.cameras[cam].height, self.cameras[cam].width, 3) + for cam in self.cameras + } @property - def motors_features(self) -> dict: + def motors_features(self) -> Dict[str, type]: if self.config.with_mobile_base: return { **dict.fromkeys( @@ -119,7 +122,7 @@ class Reachy2Robot(Robot): def is_connected(self) -> bool: return self.reachy.is_connected() if self.reachy is not None else False - def connect(self) -> None: + def connect(self, calibrate: bool = False) -> None: self.reachy = ReachySDK(self.config.ip_address) if not self.is_connected: print("Error connecting to Reachy 2.") @@ -131,8 +134,9 @@ class Reachy2Robot(Robot): self.configure() def configure(self) -> None: - self.reachy.turn_on() - self.reachy.reset_default_limits() + if self.reachy is not None: + self.reachy.turn_on() + self.reachy.reset_default_limits() @property def is_calibrated(self) -> bool: @@ -141,7 +145,7 @@ class Reachy2Robot(Robot): def calibrate(self) -> None: pass - def _generate_joints_dict(self) -> dict[str, str]: + def _generate_joints_dict(self) -> Dict[str, str]: self.joints = {} if self.config.with_neck: self.joints.update(REACHY2_NECK_JOINTS) @@ -153,19 +157,27 @@ class Reachy2Robot(Robot): self.joints.update(REACHY2_ANTENNAS_JOINTS) return self.joints - def _get_state(self) -> dict: - pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()} - if not self.config.with_mobile_base: - return pos_dict - vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()} - return {**pos_dict, **vel_dict} + def _get_state(self) -> Dict[str, float]: + if self.reachy is not None: + pos_dict = { + k: self.reachy.joints[v].present_position + for k, v in self.joints_dict.items() + } + if not self.config.with_mobile_base: + return pos_dict + vel_dict = { + k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items() + } + return {**pos_dict, **vel_dict} + else: + return {} - def get_observation(self) -> dict[str, np.ndarray]: - obs_dict = {} + def get_observation(self) -> Dict[str, np.ndarray]: + obs_dict: Dict[str, Any] = {} # Read Reachy 2 state before_read_t = time.perf_counter() - obs_dict = self._get_state() + obs_dict.update(self._get_state()) self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t # Capture images from cameras @@ -174,36 +186,42 @@ class Reachy2Robot(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - if not self.is_connected: - raise ConnectionError() + def send_action(self, action: Dict[str, Any]) -> Dict[str, Any]: + if self.reachy is not None: + if not self.is_connected: + raise ConnectionError() - before_write_t = time.perf_counter() + before_write_t = time.perf_counter() - vel = {} - for key, val in action.items(): - if key not in self.joints_dict: - if key not in REACHY2_VEL: - raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.") + vel = {} + for key, val in action.items(): + if key not in self.joints_dict: + if key not in REACHY2_VEL: + raise KeyError( + f"Key '{key}' is not a valid motor key in Reachy 2." + ) + else: + vel[REACHY2_VEL[key]] = val else: - vel[REACHY2_VEL[key]] = val - else: - self.reachy.joints[self.joints_dict[key]].goal_position = val + self.reachy.joints[self.joints_dict[key]].goal_position = val - if self.config.with_mobile_base: - self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"]) - - # We don't send the goal positions if we control Reachy 2 externally - if not self.use_external_commands: - self.reachy.send_goal_positions() if self.config.with_mobile_base: - self.reachy.mobile_base.send_speed_command() + self.reachy.mobile_base.set_goal_speed( + vel["vx"], vel["vy"], vel["vtheta"] + ) - self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t + # We don't send the goal positions if we control Reachy 2 externally + if not self.use_external_commands: + self.reachy.send_goal_positions() + if self.config.with_mobile_base: + self.reachy.mobile_base.send_speed_command() + + self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t return action def disconnect(self) -> None: - self.reachy.turn_off_smoothly() - self.reachy.disconnect() + if self.reachy is not None: + self.reachy.turn_off_smoothly() + self.reachy.disconnect() for cam in self.cameras.values(): cam.disconnect()