Clean files and add types

This commit is contained in:
glannuzel
2025-08-19 15:26:41 +02:00
parent 6e6031bb37
commit 60c342ad2d
2 changed files with 64 additions and 45 deletions
@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict
from lerobot.cameras import CameraConfig from lerobot.cameras import CameraConfig
from lerobot.cameras.configs import ColorMode, Cv2Rotation from lerobot.cameras.configs import ColorMode, Cv2Rotation
@@ -51,9 +52,9 @@ class Reachy2RobotConfig(RobotConfig):
with_right_teleop_camera: bool = True with_right_teleop_camera: bool = True
with_torso_camera: bool = False with_torso_camera: bool = False
def __post_init__(self): def __post_init__(self) -> None:
# Add cameras # Add cameras
self.cameras: dict[str, CameraConfig] = {} self.cameras: Dict[str, CameraConfig] = {}
if self.with_left_teleop_camera: if self.with_left_teleop_camera:
self.cameras["teleop_left"] = Reachy2CameraConfig( self.cameras["teleop_left"] = Reachy2CameraConfig(
name="teleop", name="teleop",
+61 -43
View File
@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import time import time
from typing import Any from typing import Any, Dict, Optional, Tuple
import numpy as np import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs from lerobot.cameras.utils import make_cameras_from_configs
@@ -83,24 +83,27 @@ class Reachy2Robot(Robot):
self.reachy: None | ReachySDK = None self.reachy: None | ReachySDK = None
self.cameras = make_cameras_from_configs(config.cameras) self.cameras = make_cameras_from_configs(config.cameras)
self.logs = {} self.logs: Dict[str, float] = {}
self.joints_dict: dict[str, str] = self._generate_joints_dict() self.joints_dict: Dict[str, str] = self._generate_joints_dict()
@property @property
def observation_features(self) -> dict: def observation_features(self) -> Dict[str, Any]:
return {**self.motors_features, **self.camera_features} return {**self.motors_features, **self.camera_features}
@property @property
def action_features(self) -> dict: def action_features(self) -> Dict[str, type]:
return self.motors_features return self.motors_features
@property @property
def camera_features(self) -> dict[str, dict]: def camera_features(self) -> Dict[str, Tuple[Optional[int], Optional[int], int]]:
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras} return {
cam: (self.cameras[cam].height, self.cameras[cam].width, 3)
for cam in self.cameras
}
@property @property
def motors_features(self) -> dict: def motors_features(self) -> Dict[str, type]:
if self.config.with_mobile_base: if self.config.with_mobile_base:
return { return {
**dict.fromkeys( **dict.fromkeys(
@@ -119,7 +122,7 @@ class Reachy2Robot(Robot):
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False return self.reachy.is_connected() if self.reachy is not None else False
def connect(self) -> None: def connect(self, calibrate: bool = False) -> None:
self.reachy = ReachySDK(self.config.ip_address) self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected: if not self.is_connected:
print("Error connecting to Reachy 2.") print("Error connecting to Reachy 2.")
@@ -131,8 +134,9 @@ class Reachy2Robot(Robot):
self.configure() self.configure()
def configure(self) -> None: def configure(self) -> None:
self.reachy.turn_on() if self.reachy is not None:
self.reachy.reset_default_limits() self.reachy.turn_on()
self.reachy.reset_default_limits()
@property @property
def is_calibrated(self) -> bool: def is_calibrated(self) -> bool:
@@ -141,7 +145,7 @@ class Reachy2Robot(Robot):
def calibrate(self) -> None: def calibrate(self) -> None:
pass pass
def _generate_joints_dict(self) -> dict[str, str]: def _generate_joints_dict(self) -> Dict[str, str]:
self.joints = {} self.joints = {}
if self.config.with_neck: if self.config.with_neck:
self.joints.update(REACHY2_NECK_JOINTS) self.joints.update(REACHY2_NECK_JOINTS)
@@ -153,19 +157,27 @@ class Reachy2Robot(Robot):
self.joints.update(REACHY2_ANTENNAS_JOINTS) self.joints.update(REACHY2_ANTENNAS_JOINTS)
return self.joints return self.joints
def _get_state(self) -> dict: def _get_state(self) -> Dict[str, float]:
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()} if self.reachy is not None:
if not self.config.with_mobile_base: pos_dict = {
return pos_dict k: self.reachy.joints[v].present_position
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()} for k, v in self.joints_dict.items()
return {**pos_dict, **vel_dict} }
if not self.config.with_mobile_base:
return pos_dict
vel_dict = {
k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()
}
return {**pos_dict, **vel_dict}
else:
return {}
def get_observation(self) -> dict[str, np.ndarray]: def get_observation(self) -> Dict[str, np.ndarray]:
obs_dict = {} obs_dict: Dict[str, Any] = {}
# Read Reachy 2 state # Read Reachy 2 state
before_read_t = time.perf_counter() before_read_t = time.perf_counter()
obs_dict = self._get_state() obs_dict.update(self._get_state())
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
# Capture images from cameras # Capture images from cameras
@@ -174,36 +186,42 @@ class Reachy2Robot(Robot):
return obs_dict return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]: def send_action(self, action: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_connected: if self.reachy is not None:
raise ConnectionError() if not self.is_connected:
raise ConnectionError()
before_write_t = time.perf_counter() before_write_t = time.perf_counter()
vel = {} vel = {}
for key, val in action.items(): for key, val in action.items():
if key not in self.joints_dict: if key not in self.joints_dict:
if key not in REACHY2_VEL: if key not in REACHY2_VEL:
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.") raise KeyError(
f"Key '{key}' is not a valid motor key in Reachy 2."
)
else:
vel[REACHY2_VEL[key]] = val
else: else:
vel[REACHY2_VEL[key]] = val self.reachy.joints[self.joints_dict[key]].goal_position = val
else:
self.reachy.joints[self.joints_dict[key]].goal_position = val
if self.config.with_mobile_base:
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
# We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base: if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command() self.reachy.mobile_base.set_goal_speed(
vel["vx"], vel["vy"], vel["vtheta"]
)
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t # We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
return action return action
def disconnect(self) -> None: def disconnect(self) -> None:
self.reachy.turn_off_smoothly() if self.reachy is not None:
self.reachy.disconnect() self.reachy.turn_off_smoothly()
self.reachy.disconnect()
for cam in self.cameras.values(): for cam in self.cameras.values():
cam.disconnect() cam.disconnect()