mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Clean files and add types
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.cameras.configs import ColorMode, Cv2Rotation
|
||||
@@ -51,9 +52,9 @@ class Reachy2RobotConfig(RobotConfig):
|
||||
with_right_teleop_camera: bool = True
|
||||
with_torso_camera: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Add cameras
|
||||
self.cameras: dict[str, CameraConfig] = {}
|
||||
self.cameras: Dict[str, CameraConfig] = {}
|
||||
if self.with_left_teleop_camera:
|
||||
self.cameras["teleop_left"] = Reachy2CameraConfig(
|
||||
name="teleop",
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
@@ -83,24 +83,27 @@ class Reachy2Robot(Robot):
|
||||
self.reachy: None | ReachySDK = None
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.logs = {}
|
||||
self.logs: Dict[str, float] = {}
|
||||
|
||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||
self.joints_dict: Dict[str, str] = self._generate_joints_dict()
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
def observation_features(self) -> Dict[str, Any]:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
def action_features(self) -> Dict[str, type]:
|
||||
return self.motors_features
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, dict]:
|
||||
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
|
||||
def camera_features(self) -> Dict[str, Tuple[Optional[int], Optional[int], int]]:
|
||||
return {
|
||||
cam: (self.cameras[cam].height, self.cameras[cam].width, 3)
|
||||
for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def motors_features(self) -> dict:
|
||||
def motors_features(self) -> Dict[str, type]:
|
||||
if self.config.with_mobile_base:
|
||||
return {
|
||||
**dict.fromkeys(
|
||||
@@ -119,7 +122,7 @@ class Reachy2Robot(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
def connect(self) -> None:
|
||||
def connect(self, calibrate: bool = False) -> None:
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
if not self.is_connected:
|
||||
print("Error connecting to Reachy 2.")
|
||||
@@ -131,8 +134,9 @@ class Reachy2Robot(Robot):
|
||||
self.configure()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.reachy.turn_on()
|
||||
self.reachy.reset_default_limits()
|
||||
if self.reachy is not None:
|
||||
self.reachy.turn_on()
|
||||
self.reachy.reset_default_limits()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
@@ -141,7 +145,7 @@ class Reachy2Robot(Robot):
|
||||
def calibrate(self) -> None:
|
||||
pass
|
||||
|
||||
def _generate_joints_dict(self) -> dict[str, str]:
|
||||
def _generate_joints_dict(self) -> Dict[str, str]:
|
||||
self.joints = {}
|
||||
if self.config.with_neck:
|
||||
self.joints.update(REACHY2_NECK_JOINTS)
|
||||
@@ -153,19 +157,27 @@ class Reachy2Robot(Robot):
|
||||
self.joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||
return self.joints
|
||||
|
||||
def _get_state(self) -> dict:
|
||||
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
|
||||
if not self.config.with_mobile_base:
|
||||
return pos_dict
|
||||
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
return {**pos_dict, **vel_dict}
|
||||
def _get_state(self) -> Dict[str, float]:
|
||||
if self.reachy is not None:
|
||||
pos_dict = {
|
||||
k: self.reachy.joints[v].present_position
|
||||
for k, v in self.joints_dict.items()
|
||||
}
|
||||
if not self.config.with_mobile_base:
|
||||
return pos_dict
|
||||
vel_dict = {
|
||||
k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()
|
||||
}
|
||||
return {**pos_dict, **vel_dict}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict = {}
|
||||
def get_observation(self) -> Dict[str, np.ndarray]:
|
||||
obs_dict: Dict[str, Any] = {}
|
||||
|
||||
# Read Reachy 2 state
|
||||
before_read_t = time.perf_counter()
|
||||
obs_dict = self._get_state()
|
||||
obs_dict.update(self._get_state())
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
# Capture images from cameras
|
||||
@@ -174,36 +186,42 @@ class Reachy2Robot(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
def send_action(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if self.reachy is not None:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
before_write_t = time.perf_counter()
|
||||
|
||||
vel = {}
|
||||
for key, val in action.items():
|
||||
if key not in self.joints_dict:
|
||||
if key not in REACHY2_VEL:
|
||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||
vel = {}
|
||||
for key, val in action.items():
|
||||
if key not in self.joints_dict:
|
||||
if key not in REACHY2_VEL:
|
||||
raise KeyError(
|
||||
f"Key '{key}' is not a valid motor key in Reachy 2."
|
||||
)
|
||||
else:
|
||||
vel[REACHY2_VEL[key]] = val
|
||||
else:
|
||||
vel[REACHY2_VEL[key]] = val
|
||||
else:
|
||||
self.reachy.joints[self.joints_dict[key]].goal_position = val
|
||||
self.reachy.joints[self.joints_dict[key]].goal_position = val
|
||||
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
||||
|
||||
# We don't send the goal positions if we control Reachy 2 externally
|
||||
if not self.use_external_commands:
|
||||
self.reachy.send_goal_positions()
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
self.reachy.mobile_base.set_goal_speed(
|
||||
vel["vx"], vel["vy"], vel["vtheta"]
|
||||
)
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
# We don't send the goal positions if we control Reachy 2 externally
|
||||
if not self.use_external_commands:
|
||||
self.reachy.send_goal_positions()
|
||||
if self.config.with_mobile_base:
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
return action
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.reachy.turn_off_smoothly()
|
||||
self.reachy.disconnect()
|
||||
if self.reachy is not None:
|
||||
self.reachy.turn_off_smoothly()
|
||||
self.reachy.disconnect()
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
Reference in New Issue
Block a user