mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
feat(rl): port haptic follow + torque toggle from #2596 to leader intervention
This commit is contained in:
@@ -58,15 +58,38 @@ def _joint_dict(values: list[float]) -> dict[str, float]:
|
||||
return {f"{name}.pos": v for name, v in zip(MOTOR_NAMES, values, strict=False)}
|
||||
|
||||
|
||||
def _make_step(use_gripper: bool = True) -> LeaderArmInterventionStep:
|
||||
def _make_step(use_gripper: bool = True, teleop_device: Any = None) -> LeaderArmInterventionStep:
|
||||
return LeaderArmInterventionStep(
|
||||
kinematics=_FakeKinematics(), # type: ignore[arg-type]
|
||||
motor_names=MOTOR_NAMES,
|
||||
end_effector_step_sizes=STEP_SIZES,
|
||||
use_gripper=use_gripper,
|
||||
teleop_device=teleop_device,
|
||||
)
|
||||
|
||||
|
||||
class _RecordingTeleop:
|
||||
"""Minimal teleop double that records every send_action call."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, float]] = []
|
||||
|
||||
def send_action(self, action: dict[str, float]) -> None:
|
||||
self.calls.append(dict(action))
|
||||
|
||||
|
||||
class _RaisingTeleop:
|
||||
"""Teleop double whose send_action raises an unexpected error."""
|
||||
|
||||
def __init__(self, exc: Exception) -> None:
|
||||
self.exc = exc
|
||||
self.calls = 0
|
||||
|
||||
def send_action(self, action: dict[str, float]) -> None:
|
||||
self.calls += 1
|
||||
raise self.exc
|
||||
|
||||
|
||||
def _build_transition(
|
||||
leader_joints: dict[str, float] | None,
|
||||
follower_joints: dict[str, float] | None,
|
||||
@@ -178,3 +201,84 @@ def test_reads_follower_from_observation_when_complementary_missing():
|
||||
teleop_action = out[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
# delta = (20 - 10) * 1e-3 = 0.01, normalised by 0.025 -> 0.4
|
||||
assert teleop_action["delta_x"] == pytest.approx(0.4)
|
||||
|
||||
|
||||
# --- haptic follow ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_haptic_follow_pushes_follower_joints_to_teleop_device():
|
||||
"""When teleop_device is set, follower joints should be sent to it every tick."""
|
||||
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
|
||||
leader["gripper.pos"] = 50.0
|
||||
follower = _joint_dict([10.0, 0.0, 0.0, 0.0, 0.0])
|
||||
follower["gripper.pos"] = 50.0
|
||||
teleop = _RecordingTeleop()
|
||||
|
||||
step = _make_step(teleop_device=teleop)
|
||||
step(_build_transition(leader, follower))
|
||||
|
||||
assert len(teleop.calls) == 1
|
||||
assert teleop.calls[0] == follower
|
||||
|
||||
|
||||
def test_haptic_follow_uses_observation_when_complementary_missing():
|
||||
"""Falls back to OBSERVATION dict for haptic follow when complementary is empty."""
|
||||
leader = _joint_dict([5.0, 0.0, 0.0, 0.0, 0.0])
|
||||
leader["gripper.pos"] = 50.0
|
||||
follower = _joint_dict([3.0, 0.0, 0.0, 0.0, 0.0])
|
||||
follower["gripper.pos"] = 50.0
|
||||
teleop = _RecordingTeleop()
|
||||
|
||||
transition = create_transition(
|
||||
observation=follower,
|
||||
complementary_data={"teleop_action": leader},
|
||||
)
|
||||
_make_step(teleop_device=teleop)(transition)
|
||||
|
||||
assert teleop.calls == [follower]
|
||||
|
||||
|
||||
def test_haptic_follow_skipped_when_no_follower_joints_available():
|
||||
"""No follower joints -> no haptic write (don't push stale data)."""
|
||||
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
|
||||
leader["gripper.pos"] = 50.0
|
||||
teleop = _RecordingTeleop()
|
||||
|
||||
transition = _build_transition(leader, follower_joints=None)
|
||||
_make_step(teleop_device=teleop)(transition)
|
||||
|
||||
assert teleop.calls == []
|
||||
|
||||
|
||||
def test_haptic_follow_swallows_send_action_errors():
|
||||
"""A failing teleop.send_action must not abort the action pipeline."""
|
||||
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
|
||||
leader["gripper.pos"] = 50.0
|
||||
follower = _joint_dict([10.0, 0.0, 0.0, 0.0, 0.0])
|
||||
follower["gripper.pos"] = 50.0
|
||||
teleop = _RaisingTeleop(RuntimeError("bus comms fail"))
|
||||
|
||||
step = _make_step(teleop_device=teleop)
|
||||
out = step(_build_transition(leader, follower))
|
||||
|
||||
assert teleop.calls == 1
|
||||
# The downstream EE-delta payload must still be produced normally.
|
||||
teleop_action = out[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
assert teleop_action["delta_x"] == pytest.approx(0.4)
|
||||
|
||||
|
||||
def test_haptic_follow_disables_when_send_action_not_implemented():
|
||||
"""Plain leaders (no haptic follow) opt out via NotImplementedError."""
|
||||
leader = _joint_dict([20.0, 0.0, 0.0, 0.0, 0.0])
|
||||
leader["gripper.pos"] = 50.0
|
||||
follower = _joint_dict([10.0, 0.0, 0.0, 0.0, 0.0])
|
||||
follower["gripper.pos"] = 50.0
|
||||
teleop = _RaisingTeleop(NotImplementedError("plain leader, no haptic follow"))
|
||||
|
||||
step = _make_step(teleop_device=teleop)
|
||||
step(_build_transition(leader, follower))
|
||||
# Tick again and confirm the step gave up rather than spamming the device.
|
||||
step(_build_transition(leader, follower))
|
||||
|
||||
assert teleop.calls == 1
|
||||
assert step.teleop_device is None
|
||||
|
||||
Reference in New Issue
Block a user