mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
Try adding mobile_base velocity
This commit is contained in:
@@ -25,13 +25,12 @@ from typing import Any
|
||||
# from stretch_body.robot_params import RobotParams
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
from ..robot import Robot
|
||||
from .configuration_reachy2 import Reachy2RobotConfig
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_MOTORS = {
|
||||
REACHY2_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
@@ -53,9 +52,12 @@ REACHY2_MOTORS = {
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
# "mobile_base.vx": "mobile_base.vx",
|
||||
# "mobile_base.vy": "mobile_base.vy",
|
||||
# "mobile_base.vtheta": "mobile_base.vtheta",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "vx",
|
||||
"mobile_base.vy": "vy",
|
||||
"mobile_base.vtheta": "vtheta",
|
||||
}
|
||||
|
||||
|
||||
@@ -81,15 +83,6 @@ class Reachy2Robot(Robot):
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return {**self.motors_features, **self.camera_features}
|
||||
# return dict.fromkeys(
|
||||
# REACHY2_MOTORS.keys(),
|
||||
# float,
|
||||
# )
|
||||
# return {
|
||||
# "dtype": "float32",
|
||||
# "shape": len(REACHY2_MOTORS),
|
||||
# "names": {"motors": list(REACHY2_MOTORS)},
|
||||
# }
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
@@ -100,21 +93,16 @@ class Reachy2Robot(Robot):
|
||||
return {
|
||||
cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
# cam_ft = {}
|
||||
# for cam_key, cam in self.cameras.items():
|
||||
# cam_ft[cam_key] = {
|
||||
# "shape": (cam.height, cam.width, cam.channels),
|
||||
# "names": ["height", "width", "channels"],
|
||||
# "info": None,
|
||||
# }
|
||||
# return cam_ft
|
||||
|
||||
@property
|
||||
def motors_features(self) -> dict:
|
||||
return dict.fromkeys(
|
||||
REACHY2_MOTORS.keys(),
|
||||
return {**dict.fromkeys(
|
||||
REACHY2_JOINTS.keys(),
|
||||
float,
|
||||
)
|
||||
), **dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
)}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
@@ -129,10 +117,6 @@ class Reachy2Robot(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
# if not self.is_connected:
|
||||
# print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
# raise ConnectionError()
|
||||
|
||||
self.configure()
|
||||
|
||||
def configure(self) -> None:
|
||||
@@ -147,7 +131,9 @@ class Reachy2Robot(Robot):
|
||||
pass
|
||||
|
||||
def _get_state(self) -> dict:
|
||||
return {k: self.reachy.joints[v].present_position for k, v in REACHY2_MOTORS.items()}
|
||||
pos_dict = {k: self.reachy.joints[v].present_position for k, v in REACHY2_JOINTS.items()}
|
||||
vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()}
|
||||
return {**pos_dict, **vel_dict}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict = {}
|
||||
@@ -157,21 +143,10 @@ class Reachy2Robot(Robot):
|
||||
obs_dict = self._get_state()
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
# state = np.asarray(list(state.values()))
|
||||
# obs_dict[OBS_STATE] = state
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
|
||||
# Capture images from cameras
|
||||
# for cam_key, cam in self.cameras.items():
|
||||
# before_camread_t = time.perf_counter()
|
||||
# frame = cam.async_read(timeout_ms=200)
|
||||
# print(f"Async frame shape:", frame.shape)
|
||||
|
||||
# obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read(timeout_ms=200)
|
||||
# self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -179,13 +154,22 @@ class Reachy2Robot(Robot):
|
||||
raise ConnectionError()
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
|
||||
vel = {}
|
||||
for key, val in action.items():
|
||||
if key not in REACHY2_MOTORS:
|
||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||
if key not in REACHY2_JOINTS:
|
||||
if key not in REACHY2_VEL:
|
||||
raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.")
|
||||
else:
|
||||
vel[REACHY2_VEL[key]] = val
|
||||
else:
|
||||
self.reachy.joints[REACHY2_MOTORS[key]].goal_position = val
|
||||
self.reachy.joints[REACHY2_JOINTS[key]].goal_position = val
|
||||
self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"])
|
||||
|
||||
# # We don't want the teleoperator reachy2_specific to send the goal positions
|
||||
# self.reachy.send_goal_positions()
|
||||
# self.reachy.send_speed_command()
|
||||
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
return action
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ from .config_reachy2_fake_teleoperator import Reachy2FakeTeleoperatorConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# {lerobot_keys: reachy2_sdk_keys}
|
||||
REACHY2_MOTORS = {
|
||||
REACHY2_JOINTS = {
|
||||
"neck_yaw.pos": "head.neck.yaw",
|
||||
"neck_pitch.pos": "head.neck.pitch",
|
||||
"neck_roll.pos": "head.neck.roll",
|
||||
@@ -54,9 +54,12 @@ REACHY2_MOTORS = {
|
||||
"l_gripper.pos": "l_arm.gripper",
|
||||
"l_antenna.pos": "head.l_antenna",
|
||||
"r_antenna.pos": "head.r_antenna",
|
||||
# "mobile_base.vx": "mobile_base.vx",
|
||||
# "mobile_base.vy": "mobile_base.vy",
|
||||
# "mobile_base.vtheta": "mobile_base.vtheta",
|
||||
}
|
||||
|
||||
REACHY2_VEL = {
|
||||
"mobile_base.vx": "x",
|
||||
"mobile_base.vy": "y",
|
||||
"mobile_base.vtheta": "theta",
|
||||
}
|
||||
|
||||
|
||||
@@ -75,10 +78,13 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return dict.fromkeys(
|
||||
REACHY2_MOTORS.keys(),
|
||||
return {**dict.fromkeys(
|
||||
REACHY2_JOINTS.keys(),
|
||||
float,
|
||||
)
|
||||
), **dict.fromkeys(
|
||||
REACHY2_VEL.keys(),
|
||||
float,
|
||||
)}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
@@ -107,10 +113,11 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
action = {k: self.reachy.joints[v].goal_position for k, v in REACHY2_MOTORS.items()}
|
||||
joint_action = {k: self.reachy.joints[v].goal_position for k, v in REACHY2_JOINTS.items()}
|
||||
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
return {**joint_action, **vel_action}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO(rcadene, aliberts): Implement force feedback
|
||||
|
||||
Reference in New Issue
Block a user