diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index bf60f3014..7d250692b 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -143,16 +143,16 @@ class Reachy2Robot(Robot): pass def _generate_joints_dict(self) -> dict[str, str]: - self.joints = {} + joints = {} if self.config.with_neck: - self.joints.update(REACHY2_NECK_JOINTS) + joints.update(REACHY2_NECK_JOINTS) if self.config.with_l_arm: - self.joints.update(REACHY2_L_ARM_JOINTS) + joints.update(REACHY2_L_ARM_JOINTS) if self.config.with_r_arm: - self.joints.update(REACHY2_R_ARM_JOINTS) + joints.update(REACHY2_R_ARM_JOINTS) if self.config.with_antennas: - self.joints.update(REACHY2_ANTENNAS_JOINTS) - return self.joints + joints.update(REACHY2_ANTENNAS_JOINTS) + return joints def _get_state(self) -> dict[str, float]: if self.reachy is not None: diff --git a/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py b/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py index b790e9f4d..f9418dc58 100644 --- a/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py +++ b/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py @@ -87,16 +87,16 @@ class Reachy2FakeTeleoperator(Teleoperator): self.joints_dict: dict[str, str] = self._generate_joints_dict() def _generate_joints_dict(self) -> dict[str, str]: - self.joints = {} + joints = {} if self.config.with_neck: - self.joints.update(REACHY2_NECK_JOINTS) + joints.update(REACHY2_NECK_JOINTS) if self.config.with_l_arm: - self.joints.update(REACHY2_L_ARM_JOINTS) + joints.update(REACHY2_L_ARM_JOINTS) if self.config.with_r_arm: - self.joints.update(REACHY2_R_ARM_JOINTS) + joints.update(REACHY2_R_ARM_JOINTS) if self.config.with_antennas: - self.joints.update(REACHY2_ANTENNAS_JOINTS) - return self.joints + joints.update(REACHY2_ANTENNAS_JOINTS) + return joints @property def action_features(self) -> dict[str, type]: @@ -125,7 +125,6 @@ class Reachy2FakeTeleoperator(Teleoperator): def connect(self, calibrate: bool = True) -> None: self.reachy = ReachySDK(self.config.ip_address) if not self.is_connected: - print("Error connecting to Reachy 2.") raise ConnectionError() logger.info(f"{self} connected.") @@ -141,12 +140,20 @@ class Reachy2FakeTeleoperator(Teleoperator): def get_action(self) -> dict[str, float]: start = time.perf_counter() - joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()} - if not self.config.with_mobile_base: - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read action: {dt_ms:.1f}ms") - return joint_action - vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()} + + if self.reachy and self.is_connected: + joint_action = { + k: self.reachy.joints_dict[v].goal_position + for k, v in self.joints_dict.items() + } + if not self.config.with_mobile_base: + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return joint_action + vel_action = { + k: self.reachy.mobile_base.last_cmd_vel[v] + for k, v in REACHY2_VEL.items() + } dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read action: {dt_ms:.1f}ms") return {**joint_action, **vel_action} @@ -156,4 +163,5 @@ class Reachy2FakeTeleoperator(Teleoperator): raise NotImplementedError def disconnect(self) -> None: - self.reachy.disconnect() + if self.reachy and self.is_connected: + self.reachy.disconnect()