mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Try adding mobile_base velocity
This commit is contained in:
@@ -25,13 +25,12 @@ from typing import Any
|
|||||||
# from stretch_body.robot_params import RobotParams
|
# from stretch_body.robot_params import RobotParams
|
||||||
|
|
||||||
from lerobot.cameras.utils import make_cameras_from_configs
|
from lerobot.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.constants import OBS_IMAGES, OBS_STATE
|
|
||||||
|
|
||||||
from ..robot import Robot
|
from ..robot import Robot
|
||||||
from .configuration_reachy2 import Reachy2RobotConfig
|
from .configuration_reachy2 import Reachy2RobotConfig
|
||||||
|
|
||||||
# {lerobot_keys: reachy2_sdk_keys}
|
# {lerobot_keys: reachy2_sdk_keys}
|
||||||
REACHY2_MOTORS = {
|
REACHY2_JOINTS = {
|
||||||
"neck_yaw.pos": "head.neck.yaw",
|
"neck_yaw.pos": "head.neck.yaw",
|
||||||
"neck_pitch.pos": "head.neck.pitch",
|
"neck_pitch.pos": "head.neck.pitch",
|
||||||
"neck_roll.pos": "head.neck.roll",
|
"neck_roll.pos": "head.neck.roll",
|
||||||
@@ -53,9 +52,12 @@ REACHY2_MOTORS = {
|
|||||||
"l_gripper.pos": "l_arm.gripper",
|
"l_gripper.pos": "l_arm.gripper",
|
||||||
"l_antenna.pos": "head.l_antenna",
|
"l_antenna.pos": "head.l_antenna",
|
||||||
"r_antenna.pos": "head.r_antenna",
|
"r_antenna.pos": "head.r_antenna",
|
||||||
# "mobile_base.vx": "mobile_base.vx",
|
}
|
||||||
# "mobile_base.vy": "mobile_base.vy",
|
|
||||||
# "mobile_base.vtheta": "mobile_base.vtheta",
|
REACHY2_VEL = {
|
||||||
|
"mobile_base.vx": "vx",
|
||||||
|
"mobile_base.vy": "vy",
|
||||||
|
"mobile_base.vtheta": "vtheta",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -81,15 +83,6 @@ class Reachy2Robot(Robot):
|
|||||||
@property
|
@property
|
||||||
def observation_features(self) -> dict:
|
def observation_features(self) -> dict:
|
||||||
return {**self.motors_features, **self.camera_features}
|
return {**self.motors_features, **self.camera_features}
|
||||||
# return dict.fromkeys(
|
|
||||||
# REACHY2_MOTORS.keys(),
|
|
||||||
# float,
|
|
||||||
# )
|
|
||||||
# return {
|
|
||||||
# "dtype": "float32",
|
|
||||||
# "shape": len(REACHY2_MOTORS),
|
|
||||||
# "names": {"motors": list(REACHY2_MOTORS)},
|
|
||||||
# }
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_features(self) -> dict:
|
def action_features(self) -> dict:
|
||||||
@@ -100,21 +93,16 @@ class Reachy2Robot(Robot):
|
|||||||
return {
|
return {
|
||||||
cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras
|
cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras
|
||||||
}
|
}
|
||||||
# cam_ft = {}
|
|
||||||
# for cam_key, cam in self.cameras.items():
|
|
||||||
# cam_ft[cam_key] = {
|
|
||||||
# "shape": (cam.height, cam.width, cam.channels),
|
|
||||||
# "names": ["height", "width", "channels"],
|
|
||||||
# "info": None,
|
|
||||||
# }
|
|
||||||
# return cam_ft
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def motors_features(self) -> dict:
|
def motors_features(self) -> dict:
|
||||||
return dict.fromkeys(
|
return {**dict.fromkeys(
|
||||||
REACHY2_MOTORS.keys(),
|
REACHY2_JOINTS.keys(),
|
||||||
float,
|
float,
|
||||||
)
|
), **dict.fromkeys(
|
||||||
|
REACHY2_VEL.keys(),
|
||||||
|
float,
|
||||||
|
)}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
@@ -129,10 +117,6 @@ class Reachy2Robot(Robot):
|
|||||||
for cam in self.cameras.values():
|
for cam in self.cameras.values():
|
||||||
cam.connect()
|
cam.connect()
|
||||||
|
|
||||||
# if not self.is_connected:
|
|
||||||
# print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
|
||||||
# raise ConnectionError()
|
|
||||||
|
|
||||||
self.configure()
|
self.configure()
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
@@ -147,7 +131,9 @@ class Reachy2Robot(Robot):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_state(self) -> dict:
|
def _get_state(self) -> dict:
|
||||||
return {k: self.reachy.joints[v].present_position for k, v in REACHY2_MOTORS.items()}
|
pos_dict = {k: self.reachy.joints[v].present_position for k, v in REACHY2_JOINTS.items()}
|
||||||
|
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||||
|
return {**pos_dict, **vel_dict}
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, np.ndarray]:
|
def get_observation(self) -> dict[str, np.ndarray]:
|
||||||
obs_dict = {}
|
obs_dict = {}
|
||||||
@@ -157,21 +143,10 @@ class Reachy2Robot(Robot):
|
|||||||
obs_dict = self._get_state()
|
obs_dict = self._get_state()
|
||||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
|
|
||||||
# state = np.asarray(list(state.values()))
|
# Capture images from cameras
|
||||||
# obs_dict[OBS_STATE] = state
|
|
||||||
|
|
||||||
for cam_key, cam in self.cameras.items():
|
for cam_key, cam in self.cameras.items():
|
||||||
obs_dict[cam_key] = cam.async_read()
|
obs_dict[cam_key] = cam.async_read()
|
||||||
|
|
||||||
# Capture images from cameras
|
|
||||||
# for cam_key, cam in self.cameras.items():
|
|
||||||
# before_camread_t = time.perf_counter()
|
|
||||||
# frame = cam.async_read(timeout_ms=200)
|
|
||||||
# print(f"Async frame shape:", frame.shape)
|
|
||||||
|
|
||||||
# obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read(timeout_ms=200)
|
|
||||||
# self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
|
|
||||||
|
|
||||||
return obs_dict
|
return obs_dict
|
||||||
|
|
||||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -179,13 +154,22 @@ class Reachy2Robot(Robot):
|
|||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
|
|
||||||
before_write_t = time.perf_counter()
|
before_write_t = time.perf_counter()
|
||||||
|
|
||||||
|
vel = {}
|
||||||
for key, val in action.items():
|
for key, val in action.items():
|
||||||
if key not in REACHY2_MOTORS:
|
if key not in REACHY2_JOINTS:
|
||||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
if key not in REACHY2_VEL:
|
||||||
|
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||||
|
else:
|
||||||
|
vel[REACHY2_VEL[key]] = val
|
||||||
else:
|
else:
|
||||||
self.reachy.joints[REACHY2_MOTORS[key]].goal_position = val
|
self.reachy.joints[REACHY2_JOINTS[key]].goal_position = val
|
||||||
|
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
||||||
|
|
||||||
# # We don't want the teleoperator reachy2_specific to send the goal positions
|
# # We don't want the teleoperator reachy2_specific to send the goal positions
|
||||||
# self.reachy.send_goal_positions()
|
# self.reachy.send_goal_positions()
|
||||||
|
# self.reachy.send_speed_command()
|
||||||
|
|
||||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from .config_reachy2_fake_teleoperator import Reachy2FakeTeleoperatorConfig
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# {lerobot_keys: reachy2_sdk_keys}
|
# {lerobot_keys: reachy2_sdk_keys}
|
||||||
REACHY2_MOTORS = {
|
REACHY2_JOINTS = {
|
||||||
"neck_yaw.pos": "head.neck.yaw",
|
"neck_yaw.pos": "head.neck.yaw",
|
||||||
"neck_pitch.pos": "head.neck.pitch",
|
"neck_pitch.pos": "head.neck.pitch",
|
||||||
"neck_roll.pos": "head.neck.roll",
|
"neck_roll.pos": "head.neck.roll",
|
||||||
@@ -54,9 +54,12 @@ REACHY2_MOTORS = {
|
|||||||
"l_gripper.pos": "l_arm.gripper",
|
"l_gripper.pos": "l_arm.gripper",
|
||||||
"l_antenna.pos": "head.l_antenna",
|
"l_antenna.pos": "head.l_antenna",
|
||||||
"r_antenna.pos": "head.r_antenna",
|
"r_antenna.pos": "head.r_antenna",
|
||||||
# "mobile_base.vx": "mobile_base.vx",
|
}
|
||||||
# "mobile_base.vy": "mobile_base.vy",
|
|
||||||
# "mobile_base.vtheta": "mobile_base.vtheta",
|
REACHY2_VEL = {
|
||||||
|
"mobile_base.vx": "x",
|
||||||
|
"mobile_base.vy": "y",
|
||||||
|
"mobile_base.vtheta": "theta",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -75,10 +78,13 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
return dict.fromkeys(
|
return {**dict.fromkeys(
|
||||||
REACHY2_MOTORS.keys(),
|
REACHY2_JOINTS.keys(),
|
||||||
float,
|
float,
|
||||||
)
|
), **dict.fromkeys(
|
||||||
|
REACHY2_VEL.keys(),
|
||||||
|
float,
|
||||||
|
)}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feedback_features(self) -> dict[str, type]:
|
def feedback_features(self) -> dict[str, type]:
|
||||||
@@ -107,10 +113,11 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
|||||||
|
|
||||||
def get_action(self) -> dict[str, float]:
|
def get_action(self) -> dict[str, float]:
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
action = {k: self.reachy.joints[v].goal_position for k, v in REACHY2_MOTORS.items()}
|
joint_action = {k: self.reachy.joints[v].goal_position for k, v in REACHY2_JOINTS.items()}
|
||||||
|
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
return action
|
return {**joint_action, **vel_action}
|
||||||
|
|
||||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
# TODO(rcadene, aliberts): Implement force feedback
|
# TODO(rcadene, aliberts): Implement force feedback
|
||||||
|
|||||||
Reference in New Issue
Block a user