From 594a1bb81d96d8df6ef329dd68edd317a0439c4e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Jun 2026 17:47:25 +0200 Subject: [PATCH] refactor(utils): consolidate keyboard listener creation --- src/lerobot/rollout/strategies/dagger.py | 83 ++---------- src/lerobot/rollout/strategies/highlight.py | 49 +------ src/lerobot/utils/keyboard_input.py | 136 +++++++++++++------- tests/utils/test_keyboard_input.py | 27 ++++ 4 files changed, 134 insertions(+), 161 deletions(-) diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 71f3c163c..21d1e8e98 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -47,7 +47,6 @@ from __future__ import annotations import contextlib import enum import logging -import sys import time from concurrent.futures import Future, ThreadPoolExecutor from threading import Event, Lock @@ -64,8 +63,7 @@ from lerobot.datasets import VideoEncodingManager from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame -from lerobot.utils.import_utils import _pynput_available -from lerobot.utils.keyboard_input import TerminalKeyListener, pynput_can_capture +from lerobot.utils.keyboard_input import create_key_listener from lerobot.utils.pedal import start_pedal_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -74,15 +72,6 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon from ..context import RolloutContext from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action -keyboard = None -PYNPUT_AVAILABLE = _pynput_available -if PYNPUT_AVAILABLE: - try: - from pynput import keyboard - except Exception as e: - PYNPUT_AVAILABLE = False - logging.info("Could not import pynput keyboard backend: %s", e) - logger = logging.getLogger(__name__) @@ -177,12 +166,11 @@ class DAggerEvents: def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): """Initialise a keyboard listener for DAgger's 3 controls. - Uses the pynput global listener on X11 / trusted-macOS / Windows, and falls - back to a display-independent terminal reader on Wayland / headless sessions - (as long as stdin is an interactive TTY). Returns the listener (exposing - ``stop()``) or ``None`` when no keyboard backend is usable. + Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on + Wayland / headless TTY) is delegated to :func:`create_key_listener`. Returns the + listener (exposing ``stop()``) or ``None`` when no keyboard backend is usable. """ - # Map config key names to DAgger event names (shared by both backends). + # Map config key names to DAgger event names. key_to_event = { cfg.pause_resume: "pause_resume", cfg.correction: "correction", @@ -199,62 +187,13 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): if name == cfg.upload: events.upload_requested.set() - if pynput_can_capture() and keyboard is not None: - # Map pynput special keys to the same names the terminal backend emits. - special_keys = { - "space": keyboard.Key.space, - "tab": keyboard.Key.tab, - "enter": keyboard.Key.enter, - } - - def _resolve_key(key) -> str | None: - """Resolve a pynput key event to a config-comparable string.""" - if key == keyboard.Key.esc: - return "esc" - for name, pynput_key in special_keys.items(): - if key == pynput_key: - return name - if hasattr(key, "char") and key.char: - return key.char - return None - - def on_press(key): - try: - resolved = _resolve_key(key) - if resolved is not None: - dispatch(resolved) - except Exception as e: - logger.debug("Key error: %s", e) - - listener = keyboard.Listener(on_press=on_press) - listener.start() - logger.info( - "DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)", - cfg.pause_resume, - cfg.correction, - cfg.upload, - ) - return listener - - if sys.stdin.isatty(): - listener = TerminalKeyListener(dispatch) - listener.start() - logger.info( - "DAgger terminal keyboard listener started — no global capture available " - "(Wayland/headless); keep this terminal focused " - "(pause_resume='%s', correction='%s', upload='%s', ESC=stop)", - cfg.pause_resume, - cfg.correction, - cfg.upload, - ) - return listener - - logger.warning( - "DAgger keyboard controls disabled: no usable display (Wayland/headless) and stdin is not " - "an interactive terminal. Run from an interactive terminal, or set the DAgger " - "input_device to 'pedal'." + return create_key_listener( + dispatch, + controls_help=( + f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', " + f"upload='{cfg.upload}', ESC=stop" + ), ) - return None def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig): diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index e258dbd6f..da7ae2fbc 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -18,7 +18,6 @@ from __future__ import annotations import contextlib import logging -import sys import time from concurrent.futures import Future, ThreadPoolExecutor from threading import Event as ThreadingEvent, Lock @@ -26,8 +25,8 @@ from threading import Event as ThreadingEvent, Lock from lerobot.datasets import VideoEncodingManager from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame -from lerobot.utils.import_utils import _pynput_available, require_package -from lerobot.utils.keyboard_input import TerminalKeyListener, pynput_can_capture +from lerobot.utils.import_utils import require_package +from lerobot.utils.keyboard_input import create_key_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -36,15 +35,6 @@ from ..context import RolloutContext from ..ring_buffer import RolloutRingBuffer from .core import RolloutStrategy, safe_push_to_hub, send_next_action -PYNPUT_AVAILABLE = _pynput_available -keyboard = None -if PYNPUT_AVAILABLE: - try: - from pynput import keyboard - except Exception as e: - PYNPUT_AVAILABLE = False - logging.info("Could not import pynput keyboard backend: %s", e) - logger = logging.getLogger(__name__) @@ -231,9 +221,8 @@ class HighlightStrategy(RolloutStrategy): def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None: """Set up a keyboard listener for the save and push keys. - Uses the pynput global listener on X11 / trusted-macOS / Windows and falls - back to a display-independent terminal reader on Wayland / headless - sessions (when stdin is an interactive TTY). + Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on + Wayland / headless TTY) is delegated to :func:`create_key_listener`. """ save_key = self.config.save_key push_key = self.config.push_key @@ -248,34 +237,8 @@ class HighlightStrategy(RolloutStrategy): self._save_requested.clear() shutdown_event.set() - if pynput_can_capture() and keyboard is not None: - - def on_press(key): - with contextlib.suppress(Exception): - if hasattr(key, "char") and key.char: - dispatch(key.char) - elif key == keyboard.Key.esc: - dispatch("esc") - - self._listener = keyboard.Listener(on_press=on_press) - self._listener.start() - logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key) - return - - if sys.stdin.isatty(): - self._listener = TerminalKeyListener(dispatch) - self._listener.start() - logger.info( - "Terminal keyboard listener started — no global capture available " - "(Wayland/headless); keep this terminal focused (save='%s', push='%s', ESC=stop)", - save_key, - push_key, - ) - return - - logger.warning( - "Highlight keyboard controls disabled: no usable display (Wayland/headless) and stdin " - "is not an interactive terminal." + self._listener = create_key_listener( + dispatch, controls_help=f"save='{save_key}', push='{push_key}', ESC=stop" ) def _background_push(self, dataset, cfg) -> None: diff --git a/src/lerobot/utils/keyboard_input.py b/src/lerobot/utils/keyboard_input.py index 83e980add..00c0f53ec 100644 --- a/src/lerobot/utils/keyboard_input.py +++ b/src/lerobot/utils/keyboard_input.py @@ -307,6 +307,92 @@ class TerminalKeyListener: atexit.unregister(self.stop) +# Map pynput key objects to the same canonical names TerminalKeyListener emits, so a +# single dispatch works across both backends. Empty when pynput is unavailable. +if keyboard is not None: + _PYNPUT_KEY_NAMES = { + keyboard.Key.right: "right", + keyboard.Key.left: "left", + keyboard.Key.up: "up", + keyboard.Key.down: "down", + keyboard.Key.esc: "esc", + keyboard.Key.enter: "enter", + keyboard.Key.tab: "tab", + keyboard.Key.space: "space", + keyboard.Key.backspace: "backspace", + } +else: + _PYNPUT_KEY_NAMES = {} + + +def _resolve_pynput_key(key) -> str | None: + """Resolve a pynput key event to the canonical name TerminalKeyListener also emits. + + Special keys map through :data:`_PYNPUT_KEY_NAMES`; character keys fall back to their + ``.char`` (e.g. ``"n"``). Returns ``None`` for keys with no mapping and no character. + """ + name = _PYNPUT_KEY_NAMES.get(key) + if name is not None: + return name + # ``or None`` keeps the historical truthy-char semantics: an empty/None char is "no key". + return getattr(key, "char", None) or None + + +def create_key_listener(dispatch: Callable[[str], None], *, controls_help: str = ""): + """Start a keyboard listener that routes resolved key names to ``dispatch``. + + Shared backend selection used by recording and the rollout strategies: + + * the ``pynput`` global listener on X11 / trusted-macOS / Windows (on macOS the + listener's ``IS_TRUSTED`` flag is checked after start, and an untrusted listener is + stopped so the terminal backend is used instead); + * the stdlib :class:`TerminalKeyListener` on Wayland / headless sessions with a TTY; + * ``None`` when no backend is usable (non-interactive / piped runs). + + Both backends pass ``dispatch`` the same canonical key names ("right" / "left" / "up" / + "down" / "esc" / "enter" / "tab" / "space" / "backspace", or a character), so one + ``dispatch`` works regardless of backend. ``controls_help`` is an optional hint + appended to the log messages. + + Returns the listener (exposing ``.stop()``) or ``None``. + """ + suffix = f" ({controls_help})" if controls_help else "" + + if pynput_can_capture() and keyboard is not None: + + def on_press(key): + with contextlib.suppress(Exception): + name = _resolve_pynput_key(key) + if name is not None: + dispatch(name) + + listener = keyboard.Listener(on_press=on_press) + listener.start() + if pynput_listener_is_trusted(listener): + logger.info("Keyboard listener started%s.", suffix) + return listener + # macOS without Accessibility / Input-Monitoring permission: the listener never + # fires. Stop it and fall through to the terminal backend. + logger.warning( + "pynput keyboard listener is not trusted (missing macOS Accessibility / " + "Input Monitoring permission); falling back to terminal keyboard input." + ) + listener.stop() + + if sys.stdin.isatty(): + listener = TerminalKeyListener(dispatch) + listener.start() + logger.info("Using terminal keyboard input — keep this terminal focused%s.", suffix) + return listener + + logger.warning( + "Keyboard controls unavailable: no usable display (Wayland/headless) and stdin is " + "not an interactive terminal%s.", + suffix, + ) + return None + + def init_keyboard_listener(): """Initialize a non-blocking keyboard listener for interactive recording controls. @@ -337,6 +423,9 @@ def init_keyboard_listener(): "stop_recording": False, } + # Accept the single-byte letter equivalents n/r/q alongside the arrow/Esc keys: the + # letters are immune to the escape-sequence split/delay/interception that affects arrows + # over laggy SSH/VNC links. Case-insensitive so Shift+letter still works. def on_key(name: str) -> None: key = name.lower() if key in ("right", "n"): @@ -347,50 +436,5 @@ def init_keyboard_listener(): apply_recording_control("esc", events) # other keys (incl. up/down) are intentionally ignored - if pynput_can_capture() and keyboard is not None: - - def on_press(key): - try: - if key == keyboard.Key.right: - on_key("right") - elif key == keyboard.Key.left: - on_key("left") - elif key == keyboard.Key.esc: - on_key("esc") - else: - # Character keys (e.g. n/r/q) expose a ``.char``; special keys do not. - char = getattr(key, "char", None) - if char is not None: - on_key(char) - except Exception as e: - print(f"Error handling key press: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - if pynput_listener_is_trusted(listener): - return listener, events - # macOS without Accessibility / Input-Monitoring permission: the listener - # never fires. Stop it and fall through to the terminal backend. - logger.warning( - "pynput keyboard listener is not trusted (missing macOS Accessibility / " - "Input Monitoring permission); falling back to terminal keyboard input. " - "Grant permission to your Python binary to restore global hotkeys." - ) - listener.stop() - - # Display-independent fallback (Wayland / headless-with-TTY / macOS-untrusted). - if not sys.stdin.isatty(): - logger.warning( - "Keyboard controls unavailable: no usable display (Wayland/headless) and " - "stdin is not an interactive terminal. Recording will rely on the " - "episode/reset timers (or Ctrl+C)." - ) - return None, events - - listener = TerminalKeyListener(on_key) - listener.start() - logger.info( - "Using terminal keyboard input (no global capture available here). " - "Controls: Right/Left/Esc, or n=next, r=re-record, q=quit. Keep this terminal focused." - ) + listener = create_key_listener(on_key, controls_help="Right/Left/Esc, or n=next, r=re-record, q=quit") return listener, events diff --git a/tests/utils/test_keyboard_input.py b/tests/utils/test_keyboard_input.py index fd644302c..2f0dee889 100644 --- a/tests/utils/test_keyboard_input.py +++ b/tests/utils/test_keyboard_input.py @@ -32,6 +32,7 @@ import lerobot.utils.keyboard_input as ki from lerobot.utils.keyboard_input import ( TerminalKeyListener, apply_recording_control, + create_key_listener, init_keyboard_listener, is_headless, is_wayland, @@ -199,3 +200,29 @@ def test_init_terminal_key_routing(monkeypatch, key, flag): listener, events = init_keyboard_listener() listener._on_key(key) assert events[flag] is True + + +# --- Shared factory + pynput key resolver ----------------------------------- +def test_resolve_pynput_key_char_fallback(): + """Unmapped keys fall back to ``.char`` (and yield None when there is none).""" + assert ki._resolve_pynput_key(type("K", (), {"char": "s"})()) == "s" + assert ki._resolve_pynput_key(type("K", (), {"char": None})()) is None + assert ki._resolve_pynput_key(type("K", (), {"char": ""})()) is None # empty char -> no key + + +def test_create_key_listener_routes_to_dispatch(monkeypatch): + """The terminal backend forwards canonical key names straight to ``dispatch``.""" + monkeypatch.setattr(ki, "pynput_can_capture", lambda: False) + _set_tty(monkeypatch, is_tty=True) + monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None) + seen = [] + listener = create_key_listener(seen.append, controls_help="save='s'") + assert isinstance(listener, TerminalKeyListener) + listener._on_key("space") + assert seen == ["space"] + + +def test_create_key_listener_none_without_tty(monkeypatch): + monkeypatch.setattr(ki, "pynput_can_capture", lambda: False) + _set_tty(monkeypatch, is_tty=False) + assert create_key_listener(lambda name: None) is None