Clean files and add types

This commit is contained in:
glannuzel
2025-08-19 15:26:41 +02:00
parent 6e6031bb37
commit 60c342ad2d
2 changed files with 64 additions and 45 deletions
@@ -13,6 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Dict
from lerobot.cameras import CameraConfig
from lerobot.cameras.configs import ColorMode, Cv2Rotation
@@ -51,9 +52,9 @@ class Reachy2RobotConfig(RobotConfig):
with_right_teleop_camera: bool = True
with_torso_camera: bool = False
def __post_init__(self):
def __post_init__(self) -> None:
# Add cameras
self.cameras: dict[str, CameraConfig] = {}
self.cameras: Dict[str, CameraConfig] = {}
if self.with_left_teleop_camera:
self.cameras["teleop_left"] = Reachy2CameraConfig(
name="teleop",
+61 -43
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import time
from typing import Any
from typing import Any, Dict, Optional, Tuple
import numpy as np
from lerobot.cameras.utils import make_cameras_from_configs
@@ -83,24 +83,27 @@ class Reachy2Robot(Robot):
self.reachy: None | ReachySDK = None
self.cameras = make_cameras_from_configs(config.cameras)
self.logs = {}
self.logs: Dict[str, float] = {}
self.joints_dict: dict[str, str] = self._generate_joints_dict()
self.joints_dict: Dict[str, str] = self._generate_joints_dict()
@property
def observation_features(self) -> dict:
def observation_features(self) -> Dict[str, Any]:
return {**self.motors_features, **self.camera_features}
@property
def action_features(self) -> dict:
def action_features(self) -> Dict[str, type]:
return self.motors_features
@property
def camera_features(self) -> dict[str, dict]:
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
def camera_features(self) -> Dict[str, Tuple[Optional[int], Optional[int], int]]:
return {
cam: (self.cameras[cam].height, self.cameras[cam].width, 3)
for cam in self.cameras
}
@property
def motors_features(self) -> dict:
def motors_features(self) -> Dict[str, type]:
if self.config.with_mobile_base:
return {
**dict.fromkeys(
@@ -119,7 +122,7 @@ class Reachy2Robot(Robot):
def is_connected(self) -> bool:
return self.reachy.is_connected() if self.reachy is not None else False
def connect(self) -> None:
def connect(self, calibrate: bool = False) -> None:
self.reachy = ReachySDK(self.config.ip_address)
if not self.is_connected:
print("Error connecting to Reachy 2.")
@@ -131,8 +134,9 @@ class Reachy2Robot(Robot):
self.configure()
def configure(self) -> None:
self.reachy.turn_on()
self.reachy.reset_default_limits()
if self.reachy is not None:
self.reachy.turn_on()
self.reachy.reset_default_limits()
@property
def is_calibrated(self) -> bool:
@@ -141,7 +145,7 @@ class Reachy2Robot(Robot):
def calibrate(self) -> None:
pass
def _generate_joints_dict(self) -> dict[str, str]:
def _generate_joints_dict(self) -> Dict[str, str]:
self.joints = {}
if self.config.with_neck:
self.joints.update(REACHY2_NECK_JOINTS)
@@ -153,19 +157,27 @@ class Reachy2Robot(Robot):
self.joints.update(REACHY2_ANTENNAS_JOINTS)
return self.joints
def _get_state(self) -> dict:
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
return pos_dict
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
return {**pos_dict, **vel_dict}
def _get_state(self) -> Dict[str, float]:
if self.reachy is not None:
pos_dict = {
k: self.reachy.joints[v].present_position
for k, v in self.joints_dict.items()
}
if not self.config.with_mobile_base:
return pos_dict
vel_dict = {
k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()
}
return {**pos_dict, **vel_dict}
else:
return {}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict = {}
def get_observation(self) -> Dict[str, np.ndarray]:
obs_dict: Dict[str, Any] = {}
# Read Reachy 2 state
before_read_t = time.perf_counter()
obs_dict = self._get_state()
obs_dict.update(self._get_state())
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
# Capture images from cameras
@@ -174,36 +186,42 @@ class Reachy2Robot(Robot):
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
if not self.is_connected:
raise ConnectionError()
def send_action(self, action: Dict[str, Any]) -> Dict[str, Any]:
if self.reachy is not None:
if not self.is_connected:
raise ConnectionError()
before_write_t = time.perf_counter()
before_write_t = time.perf_counter()
vel = {}
for key, val in action.items():
if key not in self.joints_dict:
if key not in REACHY2_VEL:
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
vel = {}
for key, val in action.items():
if key not in self.joints_dict:
if key not in REACHY2_VEL:
raise KeyError(
f"Key '{key}' is not a valid motor key in Reachy 2."
)
else:
vel[REACHY2_VEL[key]] = val
else:
vel[REACHY2_VEL[key]] = val
else:
self.reachy.joints[self.joints_dict[key]].goal_position = val
self.reachy.joints[self.joints_dict[key]].goal_position = val
if self.config.with_mobile_base:
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
# We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command()
self.reachy.mobile_base.set_goal_speed(
vel["vx"], vel["vy"], vel["vtheta"]
)
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
# We don't send the goal positions if we control Reachy 2 externally
if not self.use_external_commands:
self.reachy.send_goal_positions()
if self.config.with_mobile_base:
self.reachy.mobile_base.send_speed_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
return action
def disconnect(self) -> None:
self.reachy.turn_off_smoothly()
self.reachy.disconnect()
if self.reachy is not None:
self.reachy.turn_off_smoothly()
self.reachy.disconnect()
for cam in self.cameras.values():
cam.disconnect()