mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
Fix generate_joints
This commit is contained in:
@@ -143,16 +143,16 @@ class Reachy2Robot(Robot):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _generate_joints_dict(self) -> dict[str, str]:
|
def _generate_joints_dict(self) -> dict[str, str]:
|
||||||
self.joints = {}
|
joints = {}
|
||||||
if self.config.with_neck:
|
if self.config.with_neck:
|
||||||
self.joints.update(REACHY2_NECK_JOINTS)
|
joints.update(REACHY2_NECK_JOINTS)
|
||||||
if self.config.with_l_arm:
|
if self.config.with_l_arm:
|
||||||
self.joints.update(REACHY2_L_ARM_JOINTS)
|
joints.update(REACHY2_L_ARM_JOINTS)
|
||||||
if self.config.with_r_arm:
|
if self.config.with_r_arm:
|
||||||
self.joints.update(REACHY2_R_ARM_JOINTS)
|
joints.update(REACHY2_R_ARM_JOINTS)
|
||||||
if self.config.with_antennas:
|
if self.config.with_antennas:
|
||||||
self.joints.update(REACHY2_ANTENNAS_JOINTS)
|
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||||
return self.joints
|
return joints
|
||||||
|
|
||||||
def _get_state(self) -> dict[str, float]:
|
def _get_state(self) -> dict[str, float]:
|
||||||
if self.reachy is not None:
|
if self.reachy is not None:
|
||||||
|
|||||||
@@ -87,16 +87,16 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
|||||||
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
self.joints_dict: dict[str, str] = self._generate_joints_dict()
|
||||||
|
|
||||||
def _generate_joints_dict(self) -> dict[str, str]:
|
def _generate_joints_dict(self) -> dict[str, str]:
|
||||||
self.joints = {}
|
joints = {}
|
||||||
if self.config.with_neck:
|
if self.config.with_neck:
|
||||||
self.joints.update(REACHY2_NECK_JOINTS)
|
joints.update(REACHY2_NECK_JOINTS)
|
||||||
if self.config.with_l_arm:
|
if self.config.with_l_arm:
|
||||||
self.joints.update(REACHY2_L_ARM_JOINTS)
|
joints.update(REACHY2_L_ARM_JOINTS)
|
||||||
if self.config.with_r_arm:
|
if self.config.with_r_arm:
|
||||||
self.joints.update(REACHY2_R_ARM_JOINTS)
|
joints.update(REACHY2_R_ARM_JOINTS)
|
||||||
if self.config.with_antennas:
|
if self.config.with_antennas:
|
||||||
self.joints.update(REACHY2_ANTENNAS_JOINTS)
|
joints.update(REACHY2_ANTENNAS_JOINTS)
|
||||||
return self.joints
|
return joints
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_features(self) -> dict[str, type]:
|
def action_features(self) -> dict[str, type]:
|
||||||
@@ -125,7 +125,6 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
|||||||
def connect(self, calibrate: bool = True) -> None:
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
self.reachy = ReachySDK(self.config.ip_address)
|
self.reachy = ReachySDK(self.config.ip_address)
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
print("Error connecting to Reachy 2.")
|
|
||||||
raise ConnectionError()
|
raise ConnectionError()
|
||||||
logger.info(f"{self} connected.")
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
@@ -141,12 +140,20 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
|||||||
|
|
||||||
def get_action(self) -> dict[str, float]:
|
def get_action(self) -> dict[str, float]:
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
|
|
||||||
if not self.config.with_mobile_base:
|
if self.reachy and self.is_connected:
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
joint_action = {
|
||||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
k: self.reachy.joints_dict[v].goal_position
|
||||||
return joint_action
|
for k, v in self.joints_dict.items()
|
||||||
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
|
}
|
||||||
|
if not self.config.with_mobile_base:
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
|
return joint_action
|
||||||
|
vel_action = {
|
||||||
|
k: self.reachy.mobile_base.last_cmd_vel[v]
|
||||||
|
for k, v in REACHY2_VEL.items()
|
||||||
|
}
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
return {**joint_action, **vel_action}
|
return {**joint_action, **vel_action}
|
||||||
@@ -156,4 +163,5 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
def disconnect(self) -> None:
|
||||||
self.reachy.disconnect()
|
if self.reachy and self.is_connected:
|
||||||
|
self.reachy.disconnect()
|
||||||
|
|||||||
Reference in New Issue
Block a user