Update send_action test

This commit is contained in:
glannuzel
2025-08-26 16:31:52 +02:00
parent 31117fcf94
commit ecca518072
2 changed files with 47 additions and 19 deletions
@@ -142,18 +142,12 @@ class Reachy2FakeTeleoperator(Teleoperator):
start = time.perf_counter() start = time.perf_counter()
if self.reachy and self.is_connected: if self.reachy and self.is_connected:
joint_action = { joint_action = {k: self.reachy.joints_dict[v].goal_position for k, v in self.joints_dict.items()}
k: self.reachy.joints_dict[v].goal_position
for k, v in self.joints_dict.items()
}
if not self.config.with_mobile_base: if not self.config.with_mobile_base:
dt_ms = (time.perf_counter() - start) * 1e3 dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms") logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return joint_action return joint_action
vel_action = { vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
k: self.reachy.mobile_base.last_cmd_vel[v]
for k, v in REACHY2_VEL.items()
}
dt_ms = (time.perf_counter() - start) * 1e3 dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms") logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return {**joint_action, **vel_action} return {**joint_action, **vel_action}
+45 -11
View File
@@ -14,18 +14,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
import numpy as np import numpy as np
import pytest
from lerobot.robots.reachy2 import ( from lerobot.robots.reachy2 import (
Reachy2Robot, Reachy2Robot,
Reachy2RobotConfig, Reachy2RobotConfig,
) )
# {lerobot_keys: reachy2_sdk_keys} # {lerobot_keys: reachy2_sdk_keys}
REACHY2_JOINTS = { REACHY2_JOINTS = {
"neck_yaw.pos": "head.neck.yaw", "neck_yaw.pos": "head.neck.yaw",
@@ -68,6 +66,30 @@ PARAMS = [
def _make_reachy2_sdk_mock(): def _make_reachy2_sdk_mock():
class JointSpy:
__slots__ = (
"present_position",
"_goal_position",
"set_calls",
"set_values",
"_on_set",
)
def __init__(self, present_position=0.0, initial_goal=None, on_set=None):
self.present_position = present_position
self._goal_position = initial_goal
self._on_set = on_set
@property
def goal_position(self):
return self._goal_position
@goal_position.setter
def goal_position(self, v):
self._goal_position = v
if self._on_set:
self._on_set()
r = MagicMock(name="ReachySDKMock") r = MagicMock(name="ReachySDKMock")
r.is_connected.return_value = True r.is_connected.return_value = True
@@ -77,9 +99,19 @@ def _make_reachy2_sdk_mock():
def _disconnect(): def _disconnect():
r.is_connected.return_value = False r.is_connected.return_value = False
# Global counter of goal_position sets
r._goal_position_set_total = 0
def _on_any_goal_set():
r._goal_position_set_total += 1
# Mock joints with some dummy positions # Mock joints with some dummy positions
joints = { joints = {
k: MagicMock(present_position=i, goal_position=i + 0.1) k: JointSpy(
present_position=float(i),
initial_goal=float(i) + 0.1,
on_set=_on_any_goal_set,
)
for i, k in enumerate(REACHY2_JOINTS.values()) for i, k in enumerate(REACHY2_JOINTS.values())
} }
r.joints = joints r.joints = joints
@@ -160,19 +192,15 @@ def test_get_observation(reachy2):
reachy2.connect() reachy2.connect()
obs = reachy2.get_observation() obs = reachy2.get_observation()
expected_keys = {m for m in reachy2.joints_dict.keys()} expected_keys = set(reachy2.joints_dict)
expected_keys.update( expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base
)
expected_keys.update(reachy2.cameras.keys()) expected_keys.update(reachy2.cameras.keys())
assert set(obs.keys()) == expected_keys assert set(obs.keys()) == expected_keys
print(obs) print(obs)
for motor in reachy2.joints_dict.keys(): for motor in reachy2.joints_dict.keys():
assert ( assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
)
if reachy2.config.with_mobile_base: if reachy2.config.with_mobile_base:
for vel in REACHY2_VEL.keys(): for vel in REACHY2_VEL.keys():
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]] assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
@@ -189,6 +217,12 @@ def test_send_action(reachy2):
assert returned == action assert returned == action
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
for motor in reachy2.joints_dict.keys():
expected_pos = action[motor]
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
assert real_pos == expected_pos
if reachy2.config.with_mobile_base: if reachy2.config.with_mobile_base:
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)] goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed) reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)