From 44cdd5ac90575ee8c3dd13ceeb190efa4bba0931 Mon Sep 17 00:00:00 2001 From: glannuzel Date: Mon, 11 Aug 2025 17:38:25 +0200 Subject: [PATCH] Call generate_joints_dict on _init_ --- src/lerobot/robots/reachy2/robot_reachy2.py | 15 ++++++++------- .../reachy2_fake_teleoperator.py | 17 +++++++++-------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 7e9575e44..481a66ca4 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -86,8 +86,7 @@ class Reachy2Robot(Robot): self.logs = {} - self.joints_dict: dict[str, str] = {} - self.generate_joints_dict() + self.joints_dict: dict[str, str] = self._generate_joints_dict() @property def observation_features(self) -> dict: @@ -142,15 +141,17 @@ class Reachy2Robot(Robot): def calibrate(self) -> None: pass - def generate_joints_dict(self) -> dict[str, str]: + def _generate_joints_dict(self) -> dict[str, str]: + self.joints = {} if self.config.with_neck: - self.joints_dict.update(REACHY2_NECK_JOINTS) + self.joints.update(REACHY2_NECK_JOINTS) if self.config.with_l_arm: - self.joints_dict.update(REACHY2_L_ARM_JOINTS) + self.joints.update(REACHY2_L_ARM_JOINTS) if self.config.with_r_arm: - self.joints_dict.update(REACHY2_R_ARM_JOINTS) + self.joints.update(REACHY2_R_ARM_JOINTS) if self.config.with_antennas: - self.joints_dict.update(REACHY2_ANTENNAS_JOINTS) + self.joints.update(REACHY2_ANTENNAS_JOINTS) + return self.joints def _get_state(self) -> dict: pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()} diff --git a/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py b/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py index 6fdf796b0..c77bc0ed4 100644 --- a/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py +++ b/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py @@ -85,18 +85,19 @@ class Reachy2FakeTeleoperator(Teleoperator): self.config = config self.reachy: None | ReachySDK = None - self.joints_dict: dict[str, str] = {} - self.generate_joints_dict() + self.joints_dict: dict[str, str] = self._generate_joints_dict() - def generate_joints_dict(self) -> dict[str, str]: + def _generate_joints_dict(self) -> dict[str, str]: + self.joints = {} if self.config.with_neck: - self.joints_dict.update(REACHY2_NECK_JOINTS) + self.joints.update(REACHY2_NECK_JOINTS) if self.config.with_l_arm: - self.joints_dict.update(REACHY2_L_ARM_JOINTS) + self.joints.update(REACHY2_L_ARM_JOINTS) if self.config.with_r_arm: - self.joints_dict.update(REACHY2_R_ARM_JOINTS) + self.joints.update(REACHY2_R_ARM_JOINTS) if self.config.with_antennas: - self.joints_dict.update(REACHY2_ANTENNAS_JOINTS) + self.joints.update(REACHY2_ANTENNAS_JOINTS) + return self.joints @property def action_features(self) -> dict[str, type]: @@ -138,7 +139,7 @@ class Reachy2FakeTeleoperator(Teleoperator): def get_action(self) -> dict[str, float]: start = time.perf_counter() - joint_action = {k: self.reachy.joints[v].goal_position for k, v in REACHY2_JOINTS.items()} + joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()} if not self.config.with_mobile_base: dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read action: {dt_ms:.1f}ms")