mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Update send_action test
This commit is contained in:
@@ -142,18 +142,12 @@ class Reachy2FakeTeleoperator(Teleoperator):
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
if self.reachy and self.is_connected:
|
if self.reachy and self.is_connected:
|
||||||
joint_action = {
|
joint_action = {k: self.reachy.joints_dict[v].goal_position for k, v in self.joints_dict.items()}
|
||||||
k: self.reachy.joints_dict[v].goal_position
|
|
||||||
for k, v in self.joints_dict.items()
|
|
||||||
}
|
|
||||||
if not self.config.with_mobile_base:
|
if not self.config.with_mobile_base:
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
return joint_action
|
return joint_action
|
||||||
vel_action = {
|
vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()}
|
||||||
k: self.reachy.mobile_base.last_cmd_vel[v]
|
|
||||||
for k, v in REACHY2_VEL.items()
|
|
||||||
}
|
|
||||||
dt_ms = (time.perf_counter() - start) * 1e3
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
return {**joint_action, **vel_action}
|
return {**joint_action, **vel_action}
|
||||||
|
|||||||
@@ -14,18 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
from lerobot.robots.reachy2 import (
|
from lerobot.robots.reachy2 import (
|
||||||
Reachy2Robot,
|
Reachy2Robot,
|
||||||
Reachy2RobotConfig,
|
Reachy2RobotConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# {lerobot_keys: reachy2_sdk_keys}
|
# {lerobot_keys: reachy2_sdk_keys}
|
||||||
REACHY2_JOINTS = {
|
REACHY2_JOINTS = {
|
||||||
"neck_yaw.pos": "head.neck.yaw",
|
"neck_yaw.pos": "head.neck.yaw",
|
||||||
@@ -68,6 +66,30 @@ PARAMS = [
|
|||||||
|
|
||||||
|
|
||||||
def _make_reachy2_sdk_mock():
|
def _make_reachy2_sdk_mock():
|
||||||
|
class JointSpy:
|
||||||
|
__slots__ = (
|
||||||
|
"present_position",
|
||||||
|
"_goal_position",
|
||||||
|
"set_calls",
|
||||||
|
"set_values",
|
||||||
|
"_on_set",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, present_position=0.0, initial_goal=None, on_set=None):
|
||||||
|
self.present_position = present_position
|
||||||
|
self._goal_position = initial_goal
|
||||||
|
self._on_set = on_set
|
||||||
|
|
||||||
|
@property
|
||||||
|
def goal_position(self):
|
||||||
|
return self._goal_position
|
||||||
|
|
||||||
|
@goal_position.setter
|
||||||
|
def goal_position(self, v):
|
||||||
|
self._goal_position = v
|
||||||
|
if self._on_set:
|
||||||
|
self._on_set()
|
||||||
|
|
||||||
r = MagicMock(name="ReachySDKMock")
|
r = MagicMock(name="ReachySDKMock")
|
||||||
r.is_connected.return_value = True
|
r.is_connected.return_value = True
|
||||||
|
|
||||||
@@ -77,9 +99,19 @@ def _make_reachy2_sdk_mock():
|
|||||||
def _disconnect():
|
def _disconnect():
|
||||||
r.is_connected.return_value = False
|
r.is_connected.return_value = False
|
||||||
|
|
||||||
|
# Global counter of goal_position sets
|
||||||
|
r._goal_position_set_total = 0
|
||||||
|
|
||||||
|
def _on_any_goal_set():
|
||||||
|
r._goal_position_set_total += 1
|
||||||
|
|
||||||
# Mock joints with some dummy positions
|
# Mock joints with some dummy positions
|
||||||
joints = {
|
joints = {
|
||||||
k: MagicMock(present_position=i, goal_position=i + 0.1)
|
k: JointSpy(
|
||||||
|
present_position=float(i),
|
||||||
|
initial_goal=float(i) + 0.1,
|
||||||
|
on_set=_on_any_goal_set,
|
||||||
|
)
|
||||||
for i, k in enumerate(REACHY2_JOINTS.values())
|
for i, k in enumerate(REACHY2_JOINTS.values())
|
||||||
}
|
}
|
||||||
r.joints = joints
|
r.joints = joints
|
||||||
@@ -160,19 +192,15 @@ def test_get_observation(reachy2):
|
|||||||
reachy2.connect()
|
reachy2.connect()
|
||||||
obs = reachy2.get_observation()
|
obs = reachy2.get_observation()
|
||||||
|
|
||||||
expected_keys = {m for m in reachy2.joints_dict.keys()}
|
expected_keys = set(reachy2.joints_dict)
|
||||||
expected_keys.update(
|
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||||
f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base
|
|
||||||
)
|
|
||||||
expected_keys.update(reachy2.cameras.keys())
|
expected_keys.update(reachy2.cameras.keys())
|
||||||
assert set(obs.keys()) == expected_keys
|
assert set(obs.keys()) == expected_keys
|
||||||
|
|
||||||
print(obs)
|
print(obs)
|
||||||
|
|
||||||
for motor in reachy2.joints_dict.keys():
|
for motor in reachy2.joints_dict.keys():
|
||||||
assert (
|
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||||
obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
|
||||||
)
|
|
||||||
if reachy2.config.with_mobile_base:
|
if reachy2.config.with_mobile_base:
|
||||||
for vel in REACHY2_VEL.keys():
|
for vel in REACHY2_VEL.keys():
|
||||||
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||||
@@ -189,6 +217,12 @@ def test_send_action(reachy2):
|
|||||||
|
|
||||||
assert returned == action
|
assert returned == action
|
||||||
|
|
||||||
|
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
|
||||||
|
for motor in reachy2.joints_dict.keys():
|
||||||
|
expected_pos = action[motor]
|
||||||
|
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||||
|
assert real_pos == expected_pos
|
||||||
|
|
||||||
if reachy2.config.with_mobile_base:
|
if reachy2.config.with_mobile_base:
|
||||||
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
|
goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)]
|
||||||
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
|
reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed)
|
||||||
|
|||||||
Reference in New Issue
Block a user