From ecca5180720eb70fcec33796a041f259e9fc559a Mon Sep 17 00:00:00 2001 From: glannuzel Date: Tue, 26 Aug 2025 16:31:52 +0200 Subject: [PATCH] Update send_action test --- .../reachy2_fake_teleoperator.py | 10 +--- tests/robots/test_reachy2.py | 56 +++++++++++++++---- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py b/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py index f9418dc58..9de428697 100644 --- a/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py +++ b/src/lerobot/teleoperators/reachy2_fake_teleoperator/reachy2_fake_teleoperator.py @@ -142,18 +142,12 @@ class Reachy2FakeTeleoperator(Teleoperator): start = time.perf_counter() if self.reachy and self.is_connected: - joint_action = { - k: self.reachy.joints_dict[v].goal_position - for k, v in self.joints_dict.items() - } + joint_action = {k: self.reachy.joints_dict[v].goal_position for k, v in self.joints_dict.items()} if not self.config.with_mobile_base: dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read action: {dt_ms:.1f}ms") return joint_action - vel_action = { - k: self.reachy.mobile_base.last_cmd_vel[v] - for k, v in REACHY2_VEL.items() - } + vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()} dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read action: {dt_ms:.1f}ms") return {**joint_action, **vel_action} diff --git a/tests/robots/test_reachy2.py b/tests/robots/test_reachy2.py index a2c40c5bf..669376ca3 100644 --- a/tests/robots/test_reachy2.py +++ b/tests/robots/test_reachy2.py @@ -14,18 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager from unittest.mock import MagicMock, patch -import pytest import numpy as np +import pytest from lerobot.robots.reachy2 import ( Reachy2Robot, Reachy2RobotConfig, ) - # {lerobot_keys: reachy2_sdk_keys} REACHY2_JOINTS = { "neck_yaw.pos": "head.neck.yaw", @@ -68,6 +66,30 @@ PARAMS = [ def _make_reachy2_sdk_mock(): + class JointSpy: + __slots__ = ( + "present_position", + "_goal_position", + "set_calls", + "set_values", + "_on_set", + ) + + def __init__(self, present_position=0.0, initial_goal=None, on_set=None): + self.present_position = present_position + self._goal_position = initial_goal + self._on_set = on_set + + @property + def goal_position(self): + return self._goal_position + + @goal_position.setter + def goal_position(self, v): + self._goal_position = v + if self._on_set: + self._on_set() + r = MagicMock(name="ReachySDKMock") r.is_connected.return_value = True @@ -77,9 +99,19 @@ def _make_reachy2_sdk_mock(): def _disconnect(): r.is_connected.return_value = False + # Global counter of goal_position sets + r._goal_position_set_total = 0 + + def _on_any_goal_set(): + r._goal_position_set_total += 1 + # Mock joints with some dummy positions joints = { - k: MagicMock(present_position=i, goal_position=i + 0.1) + k: JointSpy( + present_position=float(i), + initial_goal=float(i) + 0.1, + on_set=_on_any_goal_set, + ) for i, k in enumerate(REACHY2_JOINTS.values()) } r.joints = joints @@ -160,19 +192,15 @@ def test_get_observation(reachy2): reachy2.connect() obs = reachy2.get_observation() - expected_keys = {m for m in reachy2.joints_dict.keys()} - expected_keys.update( - f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base - ) + expected_keys = set(reachy2.joints_dict) + expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base) expected_keys.update(reachy2.cameras.keys()) assert set(obs.keys()) == expected_keys print(obs) for motor in reachy2.joints_dict.keys(): - assert ( - obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position - ) + assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position if reachy2.config.with_mobile_base: for vel in REACHY2_VEL.keys(): assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]] @@ -189,6 +217,12 @@ def test_send_action(reachy2): assert returned == action + assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict) + for motor in reachy2.joints_dict.keys(): + expected_pos = action[motor] + real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position + assert real_pos == expected_pos + if reachy2.config.with_mobile_base: goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)] reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)