feat(rl): port haptic follow + torque toggle from #2596 to leader intervention

This commit is contained in:
Khalil Meftah
2026-04-27 17:50:29 +02:00
parent a3cb9f5317
commit 13418dcd7b
5 changed files with 442 additions and 21 deletions
@@ -58,15 +58,38 @@ def _joint_dict(values: list[float]) -> dict[str, float]:
return {f"{name}.pos": v for name, v in zip(MOTOR_NAMES, values, strict=False)}
def _make_step(use_gripper: bool = True) -> LeaderArmInterventionStep:
def _make_step(use_gripper: bool = True, teleop_device: Any = None) -> LeaderArmInterventionStep:
return LeaderArmInterventionStep(
kinematics=_FakeKinematics(), # type: ignore[arg-type]
motor_names=MOTOR_NAMES,
end_effector_step_sizes=STEP_SIZES,
use_gripper=use_gripper,
teleop_device=teleop_device,
)
class _RecordingTeleop:
"""Minimal teleop double that records every send_action call."""
def __init__(self) -> None:
self.calls: list[dict[str, float]] = []
def send_action(self, action: dict[str, float]) -> None:
self.calls.append(dict(action))
class _RaisingTeleop:
"""Teleop double whose send_action raises an unexpected error."""
def __init__(self, exc: Exception) -> None:
self.exc = exc
self.calls = 0
def send_action(self, action: dict[str, float]) -> None:
self.calls += 1
raise self.exc
def _build_transition(
leader_joints: dict[str, float] | None,
follower_joints: dict[str, float] | None,
@@ -178,3 +201,84 @@ def test_reads_follower_from_observation_when_complementary_missing():
teleop_action = out[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
# delta = (20 - 10) * 1e-3 = 0.01, normalised by 0.025 -> 0.4
assert teleop_action["delta_x"] == pytest.approx(0.4)
# --- haptic follow ----------------------------------------------------------
def test_haptic_follow_pushes_follower_joints_to_teleop_device():
"""When teleop_device is set, follower joints should be sent to it every tick."""
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
leader["gripper.pos"] = 50.0
follower = _joint_dict([10.0, 0.0, 0.0, 0.0, 0.0])
follower["gripper.pos"] = 50.0
teleop = _RecordingTeleop()
step = _make_step(teleop_device=teleop)
step(_build_transition(leader, follower))
assert len(teleop.calls) == 1
assert teleop.calls[0] == follower
def test_haptic_follow_uses_observation_when_complementary_missing():
"""Falls back to OBSERVATION dict for haptic follow when complementary is empty."""
leader = _joint_dict([5.0, 0.0, 0.0, 0.0, 0.0])
leader["gripper.pos"] = 50.0
follower = _joint_dict([3.0, 0.0, 0.0, 0.0, 0.0])
follower["gripper.pos"] = 50.0
teleop = _RecordingTeleop()
transition = create_transition(
observation=follower,
complementary_data={"teleop_action": leader},
)
_make_step(teleop_device=teleop)(transition)
assert teleop.calls == [follower]
def test_haptic_follow_skipped_when_no_follower_joints_available():
"""No follower joints -> no haptic write (don't push stale data)."""
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
leader["gripper.pos"] = 50.0
teleop = _RecordingTeleop()
transition = _build_transition(leader, follower_joints=None)
_make_step(teleop_device=teleop)(transition)
assert teleop.calls == []
def test_haptic_follow_swallows_send_action_errors():
"""A failing teleop.send_action must not abort the action pipeline."""
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
leader["gripper.pos"] = 50.0
follower = _joint_dict([10.0, 0.0, 0.0, 0.0, 0.0])
follower["gripper.pos"] = 50.0
teleop = _RaisingTeleop(RuntimeError("bus comms fail"))
step = _make_step(teleop_device=teleop)
out = step(_build_transition(leader, follower))
assert teleop.calls == 1
# The downstream EE-delta payload must still be produced normally.
teleop_action = out[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
assert teleop_action["delta_x"] == pytest.approx(0.4)
def test_haptic_follow_disables_when_send_action_not_implemented():
"""Plain leaders (no haptic follow) opt out via NotImplementedError."""
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
leader["gripper.pos"] = 50.0
follower = _joint_dict([10.0, 0.0, 0.0, 0.0, 0.0])
follower["gripper.pos"] = 50.0
teleop = _RaisingTeleop(NotImplementedError("plain leader, no haptic follow"))
step = _make_step(teleop_device=teleop)
step(_build_transition(leader, follower))
# Tick again and confirm the step gave up rather than spamming the device.
step(_build_transition(leader, follower))
assert teleop.calls == 1
assert step.teleop_device is None