Call generate_joints_dict on _init_

This commit is contained in:
glannuzel
2025-08-11 17:38:25 +02:00
parent ab6bbd68a7
commit 44cdd5ac90
2 changed files with 17 additions and 15 deletions
+8 -7
View File
@@ -86,8 +86,7 @@ class Reachy2Robot(Robot):
self.logs = {}
self.joints_dict: dict[str, str] = {}
self.generate_joints_dict()
self.joints_dict: dict[str, str] = self._generate_joints_dict()
@property
def observation_features(self) -> dict:
@@ -142,15 +141,17 @@ class Reachy2Robot(Robot):
def calibrate(self) -> None:
pass
def generate_joints_dict(self) -> dict[str, str]:
def _generate_joints_dict(self) -> dict[str, str]:
self.joints = {}
if self.config.with_neck:
self.joints_dict.update(REACHY2_NECK_JOINTS)
self.joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
self.joints_dict.update(REACHY2_L_ARM_JOINTS)
self.joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
self.joints_dict.update(REACHY2_R_ARM_JOINTS)
self.joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
self.joints_dict.update(REACHY2_ANTENNAS_JOINTS)
self.joints.update(REACHY2_ANTENNAS_JOINTS)
return self.joints
def _get_state(self) -> dict:
pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()}
@@ -85,18 +85,19 @@ class Reachy2FakeTeleoperator(Teleoperator):
self.config = config
self.reachy: None | ReachySDK = None
self.joints_dict: dict[str, str] = {}
self.generate_joints_dict()
self.joints_dict: dict[str, str] = self._generate_joints_dict()
def generate_joints_dict(self) -> dict[str, str]:
def _generate_joints_dict(self) -> dict[str, str]:
self.joints = {}
if self.config.with_neck:
self.joints_dict.update(REACHY2_NECK_JOINTS)
self.joints.update(REACHY2_NECK_JOINTS)
if self.config.with_l_arm:
self.joints_dict.update(REACHY2_L_ARM_JOINTS)
self.joints.update(REACHY2_L_ARM_JOINTS)
if self.config.with_r_arm:
self.joints_dict.update(REACHY2_R_ARM_JOINTS)
self.joints.update(REACHY2_R_ARM_JOINTS)
if self.config.with_antennas:
self.joints_dict.update(REACHY2_ANTENNAS_JOINTS)
self.joints.update(REACHY2_ANTENNAS_JOINTS)
return self.joints
@property
def action_features(self) -> dict[str, type]:
@@ -138,7 +139,7 @@ class Reachy2FakeTeleoperator(Teleoperator):
def get_action(self) -> dict[str, float]:
start = time.perf_counter()
joint_action = {k: self.reachy.joints[v].goal_position for k, v in REACHY2_JOINTS.items()}
joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()}
if not self.config.with_mobile_base:
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")