mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 19:57:27 +00:00
229 lines
8.2 KiB
Python
229 lines
8.2 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Unit tests for the display-independent keyboard input helpers.
|
|
|
|
These cover the parts most likely to regress: the environment-detection decision
|
|
table (the heart of the Wayland/headless fix), the macOS trust probe, the control
|
|
mapping, the terminal escape-sequence parsing, and backend selection. They require
|
|
neither ``pynput`` nor a real terminal.
|
|
"""
|
|
|
|
import io
|
|
import platform
|
|
import sys
|
|
|
|
import pytest
|
|
|
|
import lerobot.utils.keyboard_input as ki
|
|
from lerobot.utils.keyboard_input import (
|
|
TerminalKeyListener,
|
|
apply_recording_control,
|
|
create_key_listener,
|
|
init_keyboard_listener,
|
|
is_headless,
|
|
is_wayland,
|
|
pynput_can_capture,
|
|
pynput_listener_is_trusted,
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _clear_detection_caches():
|
|
"""The detection helpers are ``@cache``-decorated; clear around each test."""
|
|
for fn in (is_headless, is_wayland, pynput_can_capture):
|
|
fn.cache_clear()
|
|
yield
|
|
for fn in (is_headless, is_wayland, pynput_can_capture):
|
|
fn.cache_clear()
|
|
|
|
|
|
def _set_platform(monkeypatch, name):
|
|
monkeypatch.setattr(platform, "system", lambda: name)
|
|
|
|
|
|
def _set_tty(monkeypatch, is_tty):
|
|
stdin = io.StringIO("")
|
|
stdin.isatty = lambda: is_tty
|
|
monkeypatch.setattr(sys, "stdin", stdin)
|
|
|
|
|
|
# --- Environment detection (the core of the fix) ---------------------------
|
|
@pytest.mark.parametrize(
|
|
("system", "env", "expected"),
|
|
[
|
|
("Linux", {}, True), # no display server
|
|
("Linux", {"DISPLAY": ":0"}, False), # X11
|
|
("Linux", {"WAYLAND_DISPLAY": "wayland-0"}, False), # Wayland
|
|
("Darwin", {}, False), # display always assumed present
|
|
],
|
|
)
|
|
def test_is_headless(monkeypatch, system, env, expected):
|
|
_set_platform(monkeypatch, system)
|
|
monkeypatch.delenv("DISPLAY", raising=False)
|
|
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
|
|
for key, value in env.items():
|
|
monkeypatch.setenv(key, value)
|
|
assert is_headless() is expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("env", "expected"),
|
|
[
|
|
({"XDG_SESSION_TYPE": "wayland"}, True),
|
|
({"WAYLAND_DISPLAY": "wayland-0"}, True),
|
|
({"XDG_SESSION_TYPE": "x11"}, False),
|
|
({}, False),
|
|
],
|
|
)
|
|
def test_is_wayland(monkeypatch, env, expected):
|
|
monkeypatch.delenv("XDG_SESSION_TYPE", raising=False)
|
|
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
|
|
for key, value in env.items():
|
|
monkeypatch.setenv(key, value)
|
|
assert is_wayland() is expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("system", "env", "pynput_available", "expected"),
|
|
[
|
|
("Linux", {"DISPLAY": ":0"}, True, True), # X11
|
|
("Linux", {"DISPLAY": ":0", "WAYLAND_DISPLAY": "wayland-0"}, True, False), # Wayland
|
|
("Linux", {}, True, False), # headless
|
|
("Darwin", {}, True, True),
|
|
("Linux", {"DISPLAY": ":0"}, False, False), # pynput not installed
|
|
],
|
|
)
|
|
def test_pynput_can_capture(monkeypatch, system, env, pynput_available, expected):
|
|
_set_platform(monkeypatch, system)
|
|
monkeypatch.setattr(ki, "_pynput_available", pynput_available)
|
|
for var in ("DISPLAY", "WAYLAND_DISPLAY", "XDG_SESSION_TYPE"):
|
|
monkeypatch.delenv(var, raising=False)
|
|
for key, value in env.items():
|
|
monkeypatch.setenv(key, value)
|
|
assert pynput_can_capture() is expected
|
|
|
|
|
|
# --- macOS trust probe ------------------------------------------------------
|
|
class _FakeListener:
|
|
def __init__(self, is_trusted):
|
|
self.IS_TRUSTED = is_trusted
|
|
|
|
|
|
def test_pynput_listener_is_trusted(monkeypatch):
|
|
_set_platform(monkeypatch, "Linux")
|
|
assert pynput_listener_is_trusted(_FakeListener(False)) is True # non-macOS: always assumed ok
|
|
_set_platform(monkeypatch, "Darwin")
|
|
assert pynput_listener_is_trusted(_FakeListener(False), timeout_s=0.05) is False
|
|
|
|
|
|
# --- Control mapping --------------------------------------------------------
|
|
def test_apply_recording_control():
|
|
events = {"exit_early": False, "rerecord_episode": False, "stop_recording": False}
|
|
apply_recording_control("left", events)
|
|
assert events == {"exit_early": True, "rerecord_episode": True, "stop_recording": False}
|
|
apply_recording_control("esc", events)
|
|
assert events["stop_recording"] is True
|
|
apply_recording_control("up", events) # unknown control -> no-op (no error)
|
|
|
|
|
|
# --- Terminal escape-sequence parsing (the tricky bit) ----------------------
|
|
def _drive(listener, byte_seq):
|
|
"""Run the listener's read loop over a scripted list of bytes (no real terminal)."""
|
|
script = list(byte_seq)
|
|
|
|
def fake_read(timeout):
|
|
if script:
|
|
return script.pop(0)
|
|
listener._running = False
|
|
return None
|
|
|
|
listener._read_char = fake_read
|
|
listener._running = True
|
|
listener._run()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("byte_seq", "expected"),
|
|
[
|
|
(["\x1b", "[", "C"], ["right"]), # CSI arrow
|
|
(["\x1b", "O", "D"], ["left"]), # SS3 arrow (e.g. over SSH/tmux)
|
|
(["\x1b"], ["esc"]), # bare ESC
|
|
(["\x1b", "[", "A"], ["up"]), # decoded even though the record handler ignores it
|
|
(["n"], ["n"]), # letter passthrough
|
|
],
|
|
)
|
|
def test_terminal_parsing(byte_seq, expected):
|
|
collected = []
|
|
_drive(TerminalKeyListener(collected.append), byte_seq)
|
|
assert collected == expected
|
|
|
|
|
|
# --- Backend selection ------------------------------------------------------
|
|
def test_init_selects_terminal_when_pynput_cannot_capture(monkeypatch):
|
|
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
|
_set_tty(monkeypatch, is_tty=True)
|
|
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None) # avoid touching termios
|
|
listener, _ = init_keyboard_listener()
|
|
assert isinstance(listener, TerminalKeyListener)
|
|
|
|
|
|
def test_init_returns_none_without_tty(monkeypatch):
|
|
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
|
_set_tty(monkeypatch, is_tty=False)
|
|
listener, _ = init_keyboard_listener()
|
|
assert listener is None
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("key", "flag"),
|
|
[("right", "exit_early"), ("r", "rerecord_episode"), ("q", "stop_recording")],
|
|
)
|
|
def test_init_terminal_key_routing(monkeypatch, key, flag):
|
|
"""Arrows and their letter equivalents drive the same events (terminal backend)."""
|
|
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
|
_set_tty(monkeypatch, is_tty=True)
|
|
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None)
|
|
listener, events = init_keyboard_listener()
|
|
listener._on_key(key)
|
|
assert events[flag] is True
|
|
|
|
|
|
# --- Shared factory + pynput key resolver -----------------------------------
|
|
def test_resolve_pynput_key_char_fallback():
|
|
"""Unmapped keys fall back to ``.char`` (and yield None when there is none)."""
|
|
assert ki._resolve_pynput_key(type("K", (), {"char": "s"})()) == "s"
|
|
assert ki._resolve_pynput_key(type("K", (), {"char": None})()) is None
|
|
assert ki._resolve_pynput_key(type("K", (), {"char": ""})()) is None # empty char -> no key
|
|
|
|
|
|
def test_create_key_listener_routes_to_dispatch(monkeypatch):
|
|
"""The terminal backend forwards canonical key names straight to ``dispatch``."""
|
|
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
|
_set_tty(monkeypatch, is_tty=True)
|
|
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None)
|
|
seen = []
|
|
listener = create_key_listener(seen.append, controls_help="save='s'")
|
|
assert isinstance(listener, TerminalKeyListener)
|
|
listener._on_key("space")
|
|
assert seen == ["space"]
|
|
|
|
|
|
def test_create_key_listener_none_without_tty(monkeypatch):
|
|
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
|
_set_tty(monkeypatch, is_tty=False)
|
|
assert create_key_listener(lambda name: None) is None
|