mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
Clean files and add types
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from lerobot.cameras import CameraConfig
|
from lerobot.cameras import CameraConfig
|
||||||
from lerobot.cameras.configs import ColorMode, Cv2Rotation
|
from lerobot.cameras.configs import ColorMode, Cv2Rotation
|
||||||
@@ -51,9 +52,9 @@ class Reachy2RobotConfig(RobotConfig):
|
|||||||
with_right_teleop_camera: bool = True
|
with_right_teleop_camera: bool = True
|
||||||
with_torso_camera: bool = False
|
with_torso_camera: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
# Add cameras
|
# Add cameras
|
||||||
self.cameras: dict[str, CameraConfig] = {}
|
self.cameras: Dict[str, CameraConfig] = {}
|
||||||
if self.with_left_teleop_camera:
|
if self.with_left_teleop_camera:
|
||||||
self.cameras["teleop_left"] = Reachy2CameraConfig(
|
self.cameras["teleop_left"] = Reachy2CameraConfig(
|
||||||
name="teleop",
|
name="teleop",
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
@@ -83,24 +83,27 @@ class Reachy2Robot(Robot):
|
|||||||
self.reachy: None | ReachySDK = None
|
self.reachy: None | ReachySDK = None
|
||||||
self.cameras = make_cameras_from_configs(config.cameras)
|
self.cameras = make_cameras_from_configs(config.cameras)
|
||||||
|
|
||||||
self.logs = {}
|
self.logs: Dict[str, float] = {}
|
||||||
|
|
||||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
self.joints_dict: Dict[str, str] = self._generate_joints_dict()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_features(self) -> dict:
|
def observation_features(self) -> Dict[str, Any]:
|
||||||
return {**self.motors_features, **self.camera_features}
|
return {**self.motors_features, **self.camera_features}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_features(self) -> dict:
|
def action_features(self) -> Dict[str, type]:
|
||||||
return self.motors_features
|
return self.motors_features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_features(self) -> dict[str, dict]:
|
def camera_features(self) -> Dict[str, Tuple[Optional[int], Optional[int], int]]:
|
||||||
return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras}
|
return {
|
||||||
|
cam: (self.cameras[cam].height, self.cameras[cam].width, 3)
|
||||||
|
for cam in self.cameras
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def motors_features(self) -> dict:
|
def motors_features(self) -> Dict[str, type]:
|
||||||
if self.config.with_mobile_base:
|
if self.config.with_mobile_base:
|
||||||
return {
|
return {
|
||||||
**dict.fromkeys(
|
**dict.fromkeys(
|
||||||
@@ -119,7 +122,7 @@ class Reachy2Robot(Robot):
|
|||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self.reachy.is_connected() if self.reachy is not None else False
|
return self.reachy.is_connected() if self.reachy is not None else False
|
||||||
|
|
||||||
def connect(self) -> None:
|
def connect(self, calibrate: bool = False) -> None:
|
||||||
self.reachy = ReachySDK(self.config.ip_address)
|
self.reachy = ReachySDK(self.config.ip_address)
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
print("Error connecting to Reachy 2.")
|
print("Error connecting to Reachy 2.")
|
||||||
@@ -131,8 +134,9 @@ class Reachy2Robot(Robot):
|
|||||||
self.configure()
|
self.configure()
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
self.reachy.turn_on()
|
if self.reachy is not None:
|
||||||
self.reachy.reset_default_limits()
|
self.reachy.turn_on()
|
||||||
|
self.reachy.reset_default_limits()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_calibrated(self) -> bool:
|
def is_calibrated(self) -> bool:
|
||||||
@@ -141,7 +145,7 @@ class Reachy2Robot(Robot):
|
|||||||
def calibrate(self) -> None:
|
def calibrate(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _generate_joints_dict(self) -> dict[str, str]:
|
def _generate_joints_dict(self) -> Dict[str, str]:
|
||||||
self.joints = {}
|
self.joints = {}
|
||||||
if self.config.with_neck:
|
if self.config.with_neck:
|
||||||
self.joints.update(REACHY2_NECK_JOINTS)
|
self.joints.update(REACHY2_NECK_JOINTS)
|
||||||
@@ -153,19 +157,27 @@ class Reachy2Robot(Robot):
|
|||||||
self.joints.update(REACHY2_ANTENNAS_JOINTS)
|
self.joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||||
return self.joints
|
return self.joints
|
||||||
|
|
||||||
def _get_state(self) -> dict:
|
def _get_state(self) -> Dict[str, float]:
|
||||||
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
|
if self.reachy is not None:
|
||||||
if not self.config.with_mobile_base:
|
pos_dict = {
|
||||||
return pos_dict
|
k: self.reachy.joints[v].present_position
|
||||||
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
for k, v in self.joints_dict.items()
|
||||||
return {**pos_dict, **vel_dict}
|
}
|
||||||
|
if not self.config.with_mobile_base:
|
||||||
|
return pos_dict
|
||||||
|
vel_dict = {
|
||||||
|
k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()
|
||||||
|
}
|
||||||
|
return {**pos_dict, **vel_dict}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, np.ndarray]:
|
def get_observation(self) -> Dict[str, np.ndarray]:
|
||||||
obs_dict = {}
|
obs_dict: Dict[str, Any] = {}
|
||||||
|
|
||||||
# Read Reachy 2 state
|
# Read Reachy 2 state
|
||||||
before_read_t = time.perf_counter()
|
before_read_t = time.perf_counter()
|
||||||
obs_dict = self._get_state()
|
obs_dict.update(self._get_state())
|
||||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
|
|
||||||
# Capture images from cameras
|
# Capture images from cameras
|
||||||
@@ -174,36 +186,42 @@ class Reachy2Robot(Robot):
|
|||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
def send_action(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
if not self.is_connected:
|
if self.reachy is not None:
|
||||||
raise ConnectionError()
|
if not self.is_connected:
|
||||||
|
raise ConnectionError()
|
||||||
|
|
||||||
before_write_t = time.perf_counter()
|
before_write_t = time.perf_counter()
|
||||||
|
|
||||||
vel = {}
|
vel = {}
|
||||||
for key, val in action.items():
|
for key, val in action.items():
|
||||||
if key not in self.joints_dict:
|
if key not in self.joints_dict:
|
||||||
if key not in REACHY2_VEL:
|
if key not in REACHY2_VEL:
|
||||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
raise KeyError(
|
||||||
|
f"Key '{key}' is not a valid motor key in Reachy 2."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vel[REACHY2_VEL[key]] = val
|
||||||
else:
|
else:
|
||||||
vel[REACHY2_VEL[key]] = val
|
self.reachy.joints[self.joints_dict[key]].goal_position = val
|
||||||
else:
|
|
||||||
self.reachy.joints[self.joints_dict[key]].goal_position = val
|
|
||||||
|
|
||||||
if self.config.with_mobile_base:
|
|
||||||
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
|
||||||
|
|
||||||
# We don't send the goal positions if we control Reachy 2 externally
|
|
||||||
if not self.use_external_commands:
|
|
||||||
self.reachy.send_goal_positions()
|
|
||||||
if self.config.with_mobile_base:
|
if self.config.with_mobile_base:
|
||||||
self.reachy.mobile_base.send_speed_command()
|
self.reachy.mobile_base.set_goal_speed(
|
||||||
|
vel["vx"], vel["vy"], vel["vtheta"]
|
||||||
|
)
|
||||||
|
|
||||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
# We don't send the goal positions if we control Reachy 2 externally
|
||||||
|
if not self.use_external_commands:
|
||||||
|
self.reachy.send_goal_positions()
|
||||||
|
if self.config.with_mobile_base:
|
||||||
|
self.reachy.mobile_base.send_speed_command()
|
||||||
|
|
||||||
|
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
self.reachy.turn_off_smoothly()
|
if self.reachy is not None:
|
||||||
self.reachy.disconnect()
|
self.reachy.turn_off_smoothly()
|
||||||
|
self.reachy.disconnect()
|
||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.disconnect()
|
cam.disconnect()
|
||||||
|
|||||||
Reference in New Issue
Block a user