mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 19:57:27 +00:00
refactor(utils): consolidate keyboard listener creation
This commit is contained in:
@@ -47,7 +47,6 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event, Lock
|
||||
@@ -64,8 +63,7 @@ from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
from lerobot.utils.keyboard_input import TerminalKeyListener, pynput_can_capture
|
||||
from lerobot.utils.keyboard_input import create_key_listener
|
||||
from lerobot.utils.pedal import start_pedal_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
@@ -74,15 +72,6 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
keyboard = None
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info("Could not import pynput keyboard backend: %s", e)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -177,12 +166,11 @@ class DAggerEvents:
|
||||
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
"""Initialise a keyboard listener for DAgger's 3 controls.
|
||||
|
||||
Uses the pynput global listener on X11 / trusted-macOS / Windows, and falls
|
||||
back to a display-independent terminal reader on Wayland / headless sessions
|
||||
(as long as stdin is an interactive TTY). Returns the listener (exposing
|
||||
``stop()``) or ``None`` when no keyboard backend is usable.
|
||||
Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on
|
||||
Wayland / headless TTY) is delegated to :func:`create_key_listener`. Returns the
|
||||
listener (exposing ``stop()``) or ``None`` when no keyboard backend is usable.
|
||||
"""
|
||||
# Map config key names to DAgger event names (shared by both backends).
|
||||
# Map config key names to DAgger event names.
|
||||
key_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
@@ -199,62 +187,13 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
if name == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
|
||||
if pynput_can_capture() and keyboard is not None:
|
||||
# Map pynput special keys to the same names the terminal backend emits.
|
||||
special_keys = {
|
||||
"space": keyboard.Key.space,
|
||||
"tab": keyboard.Key.tab,
|
||||
"enter": keyboard.Key.enter,
|
||||
}
|
||||
|
||||
def _resolve_key(key) -> str | None:
|
||||
"""Resolve a pynput key event to a config-comparable string."""
|
||||
if key == keyboard.Key.esc:
|
||||
return "esc"
|
||||
for name, pynput_key in special_keys.items():
|
||||
if key == pynput_key:
|
||||
return name
|
||||
if hasattr(key, "char") and key.char:
|
||||
return key.char
|
||||
return None
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
resolved = _resolve_key(key)
|
||||
if resolved is not None:
|
||||
dispatch(resolved)
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
cfg.upload,
|
||||
)
|
||||
return listener
|
||||
|
||||
if sys.stdin.isatty():
|
||||
listener = TerminalKeyListener(dispatch)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"DAgger terminal keyboard listener started — no global capture available "
|
||||
"(Wayland/headless); keep this terminal focused "
|
||||
"(pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
cfg.upload,
|
||||
)
|
||||
return listener
|
||||
|
||||
logger.warning(
|
||||
"DAgger keyboard controls disabled: no usable display (Wayland/headless) and stdin is not "
|
||||
"an interactive terminal. Run from an interactive terminal, or set the DAgger "
|
||||
"input_device to 'pedal'."
|
||||
return create_key_listener(
|
||||
dispatch,
|
||||
controls_help=(
|
||||
f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', "
|
||||
f"upload='{cfg.upload}', ESC=stop"
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
||||
|
||||
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event as ThreadingEvent, Lock
|
||||
@@ -26,8 +25,8 @@ from threading import Event as ThreadingEvent, Lock
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
from lerobot.utils.keyboard_input import TerminalKeyListener, pynput_can_capture
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from lerobot.utils.keyboard_input import create_key_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -36,15 +35,6 @@ from ..context import RolloutContext
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info("Could not import pynput keyboard backend: %s", e)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -231,9 +221,8 @@ class HighlightStrategy(RolloutStrategy):
|
||||
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
|
||||
"""Set up a keyboard listener for the save and push keys.
|
||||
|
||||
Uses the pynput global listener on X11 / trusted-macOS / Windows and falls
|
||||
back to a display-independent terminal reader on Wayland / headless
|
||||
sessions (when stdin is an interactive TTY).
|
||||
Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on
|
||||
Wayland / headless TTY) is delegated to :func:`create_key_listener`.
|
||||
"""
|
||||
save_key = self.config.save_key
|
||||
push_key = self.config.push_key
|
||||
@@ -248,34 +237,8 @@ class HighlightStrategy(RolloutStrategy):
|
||||
self._save_requested.clear()
|
||||
shutdown_event.set()
|
||||
|
||||
if pynput_can_capture() and keyboard is not None:
|
||||
|
||||
def on_press(key):
|
||||
with contextlib.suppress(Exception):
|
||||
if hasattr(key, "char") and key.char:
|
||||
dispatch(key.char)
|
||||
elif key == keyboard.Key.esc:
|
||||
dispatch("esc")
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
||||
return
|
||||
|
||||
if sys.stdin.isatty():
|
||||
self._listener = TerminalKeyListener(dispatch)
|
||||
self._listener.start()
|
||||
logger.info(
|
||||
"Terminal keyboard listener started — no global capture available "
|
||||
"(Wayland/headless); keep this terminal focused (save='%s', push='%s', ESC=stop)",
|
||||
save_key,
|
||||
push_key,
|
||||
)
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Highlight keyboard controls disabled: no usable display (Wayland/headless) and stdin "
|
||||
"is not an interactive terminal."
|
||||
self._listener = create_key_listener(
|
||||
dispatch, controls_help=f"save='{save_key}', push='{push_key}', ESC=stop"
|
||||
)
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
|
||||
@@ -307,6 +307,92 @@ class TerminalKeyListener:
|
||||
atexit.unregister(self.stop)
|
||||
|
||||
|
||||
# Map pynput key objects to the same canonical names TerminalKeyListener emits, so a
|
||||
# single dispatch works across both backends. Empty when pynput is unavailable.
|
||||
if keyboard is not None:
|
||||
_PYNPUT_KEY_NAMES = {
|
||||
keyboard.Key.right: "right",
|
||||
keyboard.Key.left: "left",
|
||||
keyboard.Key.up: "up",
|
||||
keyboard.Key.down: "down",
|
||||
keyboard.Key.esc: "esc",
|
||||
keyboard.Key.enter: "enter",
|
||||
keyboard.Key.tab: "tab",
|
||||
keyboard.Key.space: "space",
|
||||
keyboard.Key.backspace: "backspace",
|
||||
}
|
||||
else:
|
||||
_PYNPUT_KEY_NAMES = {}
|
||||
|
||||
|
||||
def _resolve_pynput_key(key) -> str | None:
|
||||
"""Resolve a pynput key event to the canonical name TerminalKeyListener also emits.
|
||||
|
||||
Special keys map through :data:`_PYNPUT_KEY_NAMES`; character keys fall back to their
|
||||
``.char`` (e.g. ``"n"``). Returns ``None`` for keys with no mapping and no character.
|
||||
"""
|
||||
name = _PYNPUT_KEY_NAMES.get(key)
|
||||
if name is not None:
|
||||
return name
|
||||
# ``or None`` keeps the historical truthy-char semantics: an empty/None char is "no key".
|
||||
return getattr(key, "char", None) or None
|
||||
|
||||
|
||||
def create_key_listener(dispatch: Callable[[str], None], *, controls_help: str = ""):
|
||||
"""Start a keyboard listener that routes resolved key names to ``dispatch``.
|
||||
|
||||
Shared backend selection used by recording and the rollout strategies:
|
||||
|
||||
* the ``pynput`` global listener on X11 / trusted-macOS / Windows (on macOS the
|
||||
listener's ``IS_TRUSTED`` flag is checked after start, and an untrusted listener is
|
||||
stopped so the terminal backend is used instead);
|
||||
* the stdlib :class:`TerminalKeyListener` on Wayland / headless sessions with a TTY;
|
||||
* ``None`` when no backend is usable (non-interactive / piped runs).
|
||||
|
||||
Both backends pass ``dispatch`` the same canonical key names ("right" / "left" / "up" /
|
||||
"down" / "esc" / "enter" / "tab" / "space" / "backspace", or a character), so one
|
||||
``dispatch`` works regardless of backend. ``controls_help`` is an optional hint
|
||||
appended to the log messages.
|
||||
|
||||
Returns the listener (exposing ``.stop()``) or ``None``.
|
||||
"""
|
||||
suffix = f" ({controls_help})" if controls_help else ""
|
||||
|
||||
if pynput_can_capture() and keyboard is not None:
|
||||
|
||||
def on_press(key):
|
||||
with contextlib.suppress(Exception):
|
||||
name = _resolve_pynput_key(key)
|
||||
if name is not None:
|
||||
dispatch(name)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
if pynput_listener_is_trusted(listener):
|
||||
logger.info("Keyboard listener started%s.", suffix)
|
||||
return listener
|
||||
# macOS without Accessibility / Input-Monitoring permission: the listener never
|
||||
# fires. Stop it and fall through to the terminal backend.
|
||||
logger.warning(
|
||||
"pynput keyboard listener is not trusted (missing macOS Accessibility / "
|
||||
"Input Monitoring permission); falling back to terminal keyboard input."
|
||||
)
|
||||
listener.stop()
|
||||
|
||||
if sys.stdin.isatty():
|
||||
listener = TerminalKeyListener(dispatch)
|
||||
listener.start()
|
||||
logger.info("Using terminal keyboard input — keep this terminal focused%s.", suffix)
|
||||
return listener
|
||||
|
||||
logger.warning(
|
||||
"Keyboard controls unavailable: no usable display (Wayland/headless) and stdin is "
|
||||
"not an interactive terminal%s.",
|
||||
suffix,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize a non-blocking keyboard listener for interactive recording controls.
|
||||
|
||||
@@ -337,6 +423,9 @@ def init_keyboard_listener():
|
||||
"stop_recording": False,
|
||||
}
|
||||
|
||||
# Accept the single-byte letter equivalents n/r/q alongside the arrow/Esc keys: the
|
||||
# letters are immune to the escape-sequence split/delay/interception that affects arrows
|
||||
# over laggy SSH/VNC links. Case-insensitive so Shift+letter still works.
|
||||
def on_key(name: str) -> None:
|
||||
key = name.lower()
|
||||
if key in ("right", "n"):
|
||||
@@ -347,50 +436,5 @@ def init_keyboard_listener():
|
||||
apply_recording_control("esc", events)
|
||||
# other keys (incl. up/down) are intentionally ignored
|
||||
|
||||
if pynput_can_capture() and keyboard is not None:
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
on_key("right")
|
||||
elif key == keyboard.Key.left:
|
||||
on_key("left")
|
||||
elif key == keyboard.Key.esc:
|
||||
on_key("esc")
|
||||
else:
|
||||
# Character keys (e.g. n/r/q) expose a ``.char``; special keys do not.
|
||||
char = getattr(key, "char", None)
|
||||
if char is not None:
|
||||
on_key(char)
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
if pynput_listener_is_trusted(listener):
|
||||
return listener, events
|
||||
# macOS without Accessibility / Input-Monitoring permission: the listener
|
||||
# never fires. Stop it and fall through to the terminal backend.
|
||||
logger.warning(
|
||||
"pynput keyboard listener is not trusted (missing macOS Accessibility / "
|
||||
"Input Monitoring permission); falling back to terminal keyboard input. "
|
||||
"Grant permission to your Python binary to restore global hotkeys."
|
||||
)
|
||||
listener.stop()
|
||||
|
||||
# Display-independent fallback (Wayland / headless-with-TTY / macOS-untrusted).
|
||||
if not sys.stdin.isatty():
|
||||
logger.warning(
|
||||
"Keyboard controls unavailable: no usable display (Wayland/headless) and "
|
||||
"stdin is not an interactive terminal. Recording will rely on the "
|
||||
"episode/reset timers (or Ctrl+C)."
|
||||
)
|
||||
return None, events
|
||||
|
||||
listener = TerminalKeyListener(on_key)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"Using terminal keyboard input (no global capture available here). "
|
||||
"Controls: Right/Left/Esc, or n=next, r=re-record, q=quit. Keep this terminal focused."
|
||||
)
|
||||
listener = create_key_listener(on_key, controls_help="Right/Left/Esc, or n=next, r=re-record, q=quit")
|
||||
return listener, events
|
||||
|
||||
@@ -32,6 +32,7 @@ import lerobot.utils.keyboard_input as ki
|
||||
from lerobot.utils.keyboard_input import (
|
||||
TerminalKeyListener,
|
||||
apply_recording_control,
|
||||
create_key_listener,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
is_wayland,
|
||||
@@ -199,3 +200,29 @@ def test_init_terminal_key_routing(monkeypatch, key, flag):
|
||||
listener, events = init_keyboard_listener()
|
||||
listener._on_key(key)
|
||||
assert events[flag] is True
|
||||
|
||||
|
||||
# --- Shared factory + pynput key resolver -----------------------------------
|
||||
def test_resolve_pynput_key_char_fallback():
|
||||
"""Unmapped keys fall back to ``.char`` (and yield None when there is none)."""
|
||||
assert ki._resolve_pynput_key(type("K", (), {"char": "s"})()) == "s"
|
||||
assert ki._resolve_pynput_key(type("K", (), {"char": None})()) is None
|
||||
assert ki._resolve_pynput_key(type("K", (), {"char": ""})()) is None # empty char -> no key
|
||||
|
||||
|
||||
def test_create_key_listener_routes_to_dispatch(monkeypatch):
|
||||
"""The terminal backend forwards canonical key names straight to ``dispatch``."""
|
||||
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
||||
_set_tty(monkeypatch, is_tty=True)
|
||||
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None)
|
||||
seen = []
|
||||
listener = create_key_listener(seen.append, controls_help="save='s'")
|
||||
assert isinstance(listener, TerminalKeyListener)
|
||||
listener._on_key("space")
|
||||
assert seen == ["space"]
|
||||
|
||||
|
||||
def test_create_key_listener_none_without_tty(monkeypatch):
|
||||
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
||||
_set_tty(monkeypatch, is_tty=False)
|
||||
assert create_key_listener(lambda name: None) is None
|
||||
|
||||
Reference in New Issue
Block a user