From b4e454c0ff3637d918e7aab41ee4b4096d79cd90 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 25 Jun 2026 10:58:39 +0200 Subject: [PATCH] feat(utils): display-independent keyboard controls for recording (Wayland / headless / macOS) (#3875) * feat(utils): headless keyboard control * refactor(utils): consolidate keyboard listener creation * fix(rollout): remove import require guard for pynput --------- Co-authored-by: Leo Toff Co-authored-by: Stefano Maestri Co-authored-by: Sahil Chande <85823961+SahilChande@users.noreply.github.com> Co-authored-by: Vinayak Agarwal <63502278+Vinayak-Agarwal-2004@users.noreply.github.com> Co-authored-by: Abdul Rahim Mirani --- docs/source/il_robots.mdx | 16 +- docs/source/lekiwi.mdx | 2 +- examples/lekiwi/evaluate.py | 3 +- examples/lekiwi/record.py | 2 +- examples/phone_to_so100/evaluate.py | 3 +- examples/phone_to_so100/record.py | 2 +- examples/so100_to_so100_EE/evaluate.py | 3 +- examples/so100_to_so100_EE/record.py | 2 +- src/lerobot/common/control_utils.py | 84 ---- src/lerobot/rollout/strategies/dagger.py | 90 +--- src/lerobot/rollout/strategies/episodic.py | 5 +- src/lerobot/rollout/strategies/highlight.py | 58 +-- src/lerobot/scripts/lerobot_record.py | 9 +- .../teleoperators/gamepad/gamepad_utils.py | 10 + .../teleoperators/keyboard/teleop_keyboard.py | 21 +- src/lerobot/utils/keyboard_input.py | 440 ++++++++++++++++++ tests/utils/test_keyboard_input.py | 228 +++++++++ 17 files changed, 758 insertions(+), 220 deletions(-) create mode 100644 src/lerobot/utils/keyboard_input.py create mode 100644 tests/utils/test_keyboard_input.py diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 53ae5af82..6a820e0db 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -390,9 +390,17 @@ Set the flow of data recording using command-line arguments: Control the data recording flow using keyboard shortcuts: -- Press **Right Arrow (`→`)**: Early stop the current episode or reset time and move to the next. -- Press **Left Arrow (`←`)**: Cancel the current episode and re-record it. -- Press **Escape (`ESC`)**: Immediately stop the session, encode videos, and upload the dataset. +- Press **Right Arrow (`→`)** or **`n`**: Early stop the current episode or reset time and move to the next. +- Press **Left Arrow (`←`)** or **`r`**: Cancel the current episode and re-record it. +- Press **Escape (`ESC`)** or **`q`**: Immediately stop the session, encode videos, and upload the dataset. + + + +These control-flow shortcuts work on **X11, Wayland, and headless/SSH** sessions. When a global keyboard backend isn't available (Wayland, a headless machine, or macOS without Accessibility permission), `lerobot-record` automatically reads the same keys from the terminal — launch it from an interactive terminal and keep it focused. You can also use the letter equivalents **`n`** (next, same as `→`), **`r`** (re-record, same as `←`) and **`q`** (quit, same as `ESC`). No `$DISPLAY` setup is required. + +This applies to the recording control flow only. Keyboard **teleoperation** (driving the robot with the keyboard) still needs a global key backend, so it works only on an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — not on Wayland or headless sessions. + + #### Tips for gathering data @@ -406,7 +414,7 @@ If you want to dive deeper into this important topic, you can check out the [blo #### Troubleshooting: -- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). +- On Linux, the recording control-flow keys (arrow keys, Escape) work on X11, Wayland, and headless/SSH sessions as long as `lerobot-record` runs in an interactive terminal — no `$DISPLAY` setup is needed. If the keys have no effect, make sure you are in an interactive (TTY) terminal, not a piped/non-TTY session, and that it is focused; the letter equivalents `n` / `r` / `q` also work. Keyboard _teleoperation_ (as opposed to the recording control flow) still requires a global key backend — an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — and is unavailable on Wayland or headless machines. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). ## Visualize a dataset diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index 7e7c1a680..739073b65 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -319,7 +319,7 @@ If you want to dive deeper into this important topic, you can check out the [blo #### Troubleshooting: -- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). +- On Linux, the recording control-flow keys (arrow keys, Escape) work on X11, Wayland, and headless/SSH sessions as long as you run the recording from an interactive terminal (keep it focused) — no `$DISPLAY` setup is needed; the letter equivalents `n` / `r` / `q` also work. Note that **keyboard teleoperation of the LeKiwi base** is different: it relies on a global key backend and therefore works only on an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — not on Wayland or headless machines. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). ## Replay an episode diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 3ddcd1f14..13bb6ac28 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -17,7 +17,7 @@ import logging import time -from lerobot.common.control_utils import init_keyboard_listener, predict_action +from lerobot.common.control_utils import predict_action from lerobot.datasets import LeRobotDataset from lerobot.policies import make_pre_post_processors from lerobot.policies.act import ACTPolicy @@ -26,6 +26,7 @@ from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 2c581f5ff..f62a9eb49 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.common.control_utils import init_keyboard_listener from lerobot.datasets import LeRobotDataset from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig @@ -23,6 +22,7 @@ from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import hw_to_dataset_features +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index e859123d0..d1fb4de67 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -18,7 +18,7 @@ import logging import time from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.common.control_utils import init_keyboard_listener, predict_action +from lerobot.common.control_utils import predict_action from lerobot.configs import FeatureType, PolicyFeature from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.model.kinematics import RobotKinematics @@ -41,6 +41,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 87b8e49fd..612e94ab9 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -15,7 +15,6 @@ # limitations under the License. from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.common.control_utils import init_keyboard_listener from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.model.kinematics import RobotKinematics from lerobot.processor import ( @@ -39,6 +38,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.types import RobotAction, RobotObservation from lerobot.utils.feature_utils import combine_feature_dicts +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index 63def68d0..2a2022623 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -18,7 +18,7 @@ import logging import time from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.common.control_utils import init_keyboard_listener, predict_action +from lerobot.common.control_utils import predict_action from lerobot.configs import FeatureType, PolicyFeature from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.model.kinematics import RobotKinematics @@ -41,6 +41,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index a0b92da3b..3706ee4f5 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -16,7 +16,6 @@ from lerobot.cameras.opencv import OpenCVCameraConfig -from lerobot.common.control_utils import init_keyboard_listener from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features from lerobot.model.kinematics import RobotKinematics from lerobot.processor import ( @@ -36,6 +35,7 @@ from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.types import RobotAction, RobotObservation from lerobot.utils.feature_utils import combine_feature_dicts +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/src/lerobot/common/control_utils.py b/src/lerobot/common/control_utils.py index ddaf77d26..e3130643d 100644 --- a/src/lerobot/common/control_utils.py +++ b/src/lerobot/common/control_utils.py @@ -17,12 +17,9 @@ from __future__ import annotations ######################################################################################## # Utilities ######################################################################################## -import logging import time -import traceback from contextlib import nullcontext from copy import copy -from functools import cache from typing import TYPE_CHECKING, Any import numpy as np @@ -43,34 +40,6 @@ from lerobot.robots import Robot from lerobot.types import PolicyAction -@cache -def is_headless(): - """ - Detects if the Python script is running in a headless environment (e.g., without a display). - - This function attempts to import `pynput`, a library that requires a graphical environment. - If the import fails, it assumes the environment is headless. The result is cached to avoid - re-running the check. - - Returns: - True if the environment is determined to be headless, False otherwise. - """ - try: - import pynput # noqa - - return False - except Exception: - print( - "Error trying to import pynput. Switching to headless mode. " - "As a result, the video stream from the cameras won't be shown, " - "and you won't be able to change the control flow with keyboards. " - "For more info, see traceback below.\n" - ) - traceback.print_exc() - print() - return True - - def predict_action( observation: dict[str, np.ndarray], policy: PreTrainedPolicy, @@ -122,59 +91,6 @@ def predict_action( return action -def init_keyboard_listener(): - """ - Initializes a non-blocking keyboard listener for real-time user interaction. - - This function sets up a listener for specific keys (right arrow, left arrow, escape) to control - the program flow during execution, such as stopping recording or exiting loops. It gracefully - handles headless environments where keyboard listening is not possible. - - Returns: - A tuple containing: - - The `pynput.keyboard.Listener` instance, or `None` if in a headless environment. - - A dictionary of event flags (e.g., `exit_early`) that are set by key presses. - """ - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. - events = {} - events["exit_early"] = False - events["rerecord_episode"] = False - events["stop_recording"] = False - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - listener = None - return listener, events - - # Only import pynput if not in a headless environment - from pynput import keyboard - - def on_press(key): - try: - if key == keyboard.Key.right: - print("Right arrow key pressed. Exiting loop...") - events["exit_early"] = True - elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") - events["stop_recording"] = True - events["exit_early"] = True - except Exception as e: - print(f"Error handling key press: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - - return listener, events - - def sanity_check_dataset_name(repo_id, policy_cfg): """ Validates the dataset repository name against the presence of a policy configuration. diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 8791a5502..21d1e8e98 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -47,8 +47,6 @@ from __future__ import annotations import contextlib import enum import logging -import os -import sys import time from concurrent.futures import Future, ThreadPoolExecutor from threading import Event, Lock @@ -58,7 +56,6 @@ import numpy as np from lerobot.common.control_utils import ( follower_smooth_move_to, - is_headless, teleop_smooth_move_to, teleop_supports_feedback, ) @@ -66,7 +63,7 @@ from lerobot.datasets import VideoEncodingManager from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame -from lerobot.utils.import_utils import _pynput_available +from lerobot.utils.keyboard_input import create_key_listener from lerobot.utils.pedal import start_pedal_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -75,19 +72,6 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon from ..context import RolloutContext from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action -PYNPUT_AVAILABLE = _pynput_available -keyboard = None -if PYNPUT_AVAILABLE: - try: - if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): - logging.info("No DISPLAY set. Skipping pynput import.") - PYNPUT_AVAILABLE = False - else: - from pynput import keyboard - except Exception as e: - PYNPUT_AVAILABLE = False - logging.info(f"Could not import pynput: {e}") - logger = logging.getLogger(__name__) @@ -180,64 +164,36 @@ class DAggerEvents: def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): - """Initialise keyboard listener with DAgger 3-key controls. + """Initialise a keyboard listener for DAgger's 3 controls. - Returns the pynput Listener (or ``None`` in headless mode or when - pynput is unavailable). + Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on + Wayland / headless TTY) is delegated to :func:`create_key_listener`. Returns the + listener (exposing ``stop()``) or ``None`` when no keyboard backend is usable. """ - if not PYNPUT_AVAILABLE or is_headless(): - logger.warning("Headless environment or pynput unavailable — keyboard controls disabled") - return None - - # Map config key names to pynput Key objects for special keys - special_keys = { - "space": keyboard.Key.space, - "tab": keyboard.Key.tab, - "enter": keyboard.Key.enter, - } - - def _resolve_key(key) -> str | None: - """Resolve a pynput key event to a config-comparable string.""" - if key == keyboard.Key.esc: - return "esc" - for name, pynput_key in special_keys.items(): - if key == pynput_key: - return name - if hasattr(key, "char") and key.char: - return key.char - return None - - # Build mapping: resolved key string -> DAgger event name + # Map config key names to DAgger event names. key_to_event = { cfg.pause_resume: "pause_resume", cfg.correction: "correction", } - def on_press(key): - try: - resolved = _resolve_key(key) - if resolved is None: - return - if resolved == "esc": - logger.info("Stop recording...") - events.stop_recording.set() - return - if resolved in key_to_event: - events.request_transition(key_to_event[resolved]) - if resolved == cfg.upload: - events.upload_requested.set() - except Exception as e: - logger.debug("Key error: %s", e) + def dispatch(name: str) -> None: + """Apply a resolved key name to the DAgger events.""" + if name == "esc": + logger.info("Stop recording...") + events.stop_recording.set() + return + if name in key_to_event: + events.request_transition(key_to_event[name]) + if name == cfg.upload: + events.upload_requested.set() - listener = keyboard.Listener(on_press=on_press) - listener.start() - logger.info( - "DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)", - cfg.pause_resume, - cfg.correction, - cfg.upload, + return create_key_listener( + dispatch, + controls_help=( + f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', " + f"upload='{cfg.upload}', ESC=stop" + ), ) - return listener def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig): @@ -328,7 +284,7 @@ class DAggerStrategy(RolloutStrategy): logger.info("Stopping DAgger recording") log_say("Stopping DAgger recording", play_sounds) - if self._listener is not None and not is_headless(): + if self._listener is not None: logger.info("Stopping keyboard listener") self._listener.stop() diff --git a/src/lerobot/rollout/strategies/episodic.py b/src/lerobot/rollout/strategies/episodic.py index e925fb2ea..e70e66787 100644 --- a/src/lerobot/rollout/strategies/episodic.py +++ b/src/lerobot/rollout/strategies/episodic.py @@ -35,14 +35,13 @@ import time from lerobot.common.control_utils import ( follower_smooth_move_to, - init_keyboard_listener, - is_headless, teleop_smooth_move_to, teleop_supports_feedback, ) from lerobot.datasets import VideoEncodingManager from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import log_rerun_data @@ -307,7 +306,7 @@ class EpisodicStrategy(RolloutStrategy): log_say("Stop recording", play_sounds, blocking=True) - if not is_headless() and self._listener is not None: + if self._listener is not None: self._listener.stop() if ctx.data.dataset is not None: diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index baff70da7..385a9e2b6 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -18,17 +18,14 @@ from __future__ import annotations import contextlib import logging -import os -import sys import time from concurrent.futures import Future, ThreadPoolExecutor from threading import Event as ThreadingEvent, Lock -from lerobot.common.control_utils import is_headless from lerobot.datasets import VideoEncodingManager from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame -from lerobot.utils.import_utils import _pynput_available, require_package +from lerobot.utils.keyboard_input import create_key_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -37,19 +34,6 @@ from ..context import RolloutContext from ..ring_buffer import RolloutRingBuffer from .core import RolloutStrategy, safe_push_to_hub, send_next_action -PYNPUT_AVAILABLE = _pynput_available -keyboard = None -if PYNPUT_AVAILABLE: - try: - if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): - logging.info("No DISPLAY set. Skipping pynput import.") - PYNPUT_AVAILABLE = False - else: - from pynput import keyboard - except Exception as e: - PYNPUT_AVAILABLE = False - logging.info(f"Could not import pynput: {e}") - logger = logging.getLogger(__name__) @@ -72,7 +56,6 @@ class HighlightStrategy(RolloutStrategy): def __init__(self, config: HighlightStrategyConfig): super().__init__(config) - require_package("pynput", extra="pynput-dep") self._ring: RolloutRingBuffer | None = None self._listener = None self._save_requested = ThreadingEvent() @@ -234,30 +217,27 @@ class HighlightStrategy(RolloutStrategy): logger.info("Highlight strategy teardown complete") def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None: - """Set up keyboard listener for save and push keys.""" - if is_headless(): - logger.warning("Headless environment — highlight keys unavailable") - return + """Set up a keyboard listener for the save and push keys. - try: - save_key = self.config.save_key - push_key = self.config.push_key + Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on + Wayland / headless TTY) is delegated to :func:`create_key_listener`. + """ + save_key = self.config.save_key + push_key = self.config.push_key - def on_press(key): - with contextlib.suppress(Exception): - if hasattr(key, "char") and key.char == save_key: - self._save_requested.set() - elif hasattr(key, "char") and key.char == push_key: - self._push_requested.set() - elif key == keyboard.Key.esc: - self._save_requested.clear() - shutdown_event.set() + def dispatch(name: str) -> None: + """Apply a resolved key name to the highlight events.""" + if name == save_key: + self._save_requested.set() + elif name == push_key: + self._push_requested.set() + elif name == "esc": + self._save_requested.clear() + shutdown_event.set() - self._listener = keyboard.Listener(on_press=on_press) - self._listener.start() - logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key) - except ImportError: - logger.warning("pynput not available — keyboard listener disabled") + self._listener = create_key_listener( + dispatch, controls_help=f"save='{save_key}', push='{push_key}', ESC=stop" + ) def _background_push(self, dataset, cfg) -> None: """Queue a Hub push on the single-worker executor.""" diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 0deb54b90..4d5518c7c 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -96,11 +96,7 @@ from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.reachy2_camera import Reachy2CameraConfig # noqa: F401 from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401 -from lerobot.common.control_utils import ( - init_keyboard_listener, - is_headless, - sanity_check_dataset_robot_compatibility, -) +from lerobot.common.control_utils import sanity_check_dataset_robot_compatibility from lerobot.configs import parser from lerobot.configs.dataset import DatasetRecordConfig from lerobot.datasets import ( @@ -155,6 +151,7 @@ from lerobot.teleoperators.keyboard import KeyboardTeleop from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts from lerobot.utils.import_utils import register_third_party_plugins +from lerobot.utils.keyboard_input import init_keyboard_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import ( init_logging, @@ -508,7 +505,7 @@ def record( if teleop and teleop.is_connected: teleop.disconnect() - if not is_headless() and listener: + if listener is not None: listener.stop() if cfg.dataset.push_to_hub: diff --git a/src/lerobot/teleoperators/gamepad/gamepad_utils.py b/src/lerobot/teleoperators/gamepad/gamepad_utils.py index c1531ca84..22dbb7cca 100644 --- a/src/lerobot/teleoperators/gamepad/gamepad_utils.py +++ b/src/lerobot/teleoperators/gamepad/gamepad_utils.py @@ -18,6 +18,7 @@ import logging from typing import TYPE_CHECKING from lerobot.utils.import_utils import _hidapi_available, _pygame_available, require_package +from lerobot.utils.keyboard_input import pynput_can_capture from ..utils import TeleopEvents @@ -123,6 +124,15 @@ class KeyboardController(InputController): def start(self): """Start the keyboard listener.""" + if not pynput_can_capture(): + logging.warning( + "Keyboard control is unavailable in this environment. pynput cannot capture keys " + "on Wayland or headless machines, or on macOS without Accessibility / Input " + "Monitoring permission. Keyboard motion will be inactive." + ) + self.running = False + return + from pynput import keyboard def on_press(key): diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 801789bcb..872cc7a26 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -import os -import sys import time from queue import Queue from typing import Any @@ -24,6 +22,7 @@ from typing import Any from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.import_utils import _pynput_available, require_package +from lerobot.utils.keyboard_input import pynput_can_capture from ..teleoperator import Teleoperator from ..utils import TeleopEvents @@ -37,14 +36,10 @@ PYNPUT_AVAILABLE = _pynput_available keyboard = None if PYNPUT_AVAILABLE: try: - if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): - logging.info("No DISPLAY set. Skipping pynput import.") - PYNPUT_AVAILABLE = False - else: - from pynput import keyboard + from pynput import keyboard except Exception as e: PYNPUT_AVAILABLE = False - logging.info(f"Could not import pynput: {e}") + logging.info("Could not import pynput keyboard backend: %s", e) class KeyboardTeleop(Teleoperator): @@ -88,7 +83,7 @@ class KeyboardTeleop(Teleoperator): @check_if_already_connected def connect(self) -> None: - if PYNPUT_AVAILABLE: + if PYNPUT_AVAILABLE and pynput_can_capture(): logging.info("pynput is available - enabling local keyboard listener.") self.listener = keyboard.Listener( on_press=self._on_press, @@ -96,7 +91,13 @@ class KeyboardTeleop(Teleoperator): ) self.listener.start() else: - logging.info("pynput not available - skipping local keyboard listener.") + logging.warning( + "Keyboard teleoperation is unavailable in this environment. pynput can only " + "capture key events on an X11 session (Linux), a Windows desktop, or macOS with " + "Accessibility / Input Monitoring granted - not on Wayland or headless machines. " + "This keyboard teleoperator will produce no actions; use an X11 session, a " + "gamepad, or a leader-arm teleoperator instead." + ) self.listener = None def calibrate(self) -> None: diff --git a/src/lerobot/utils/keyboard_input.py b/src/lerobot/utils/keyboard_input.py new file mode 100644 index 000000000..00c0f53ec --- /dev/null +++ b/src/lerobot/utils/keyboard_input.py @@ -0,0 +1,440 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Display-independent keyboard input for interactive controls. + +This module centralizes everything related to *discrete* keyboard controls +(end-episode-early, re-record, stop, and the rollout strategies' custom keys): + +* environment detection — :func:`is_headless`, :func:`is_wayland`, + :func:`pynput_can_capture` (the single predicate every call-site should use to + decide whether ``pynput`` can actually capture keys here); +* a shared key mapping — :func:`apply_recording_control`; and +* two interchangeable backends behind one ``(listener, events)`` contract: + the ``pynput`` global listener (X11 / trusted-macOS / Windows) and a + standard-library :class:`TerminalKeyListener` that reads the controlling TTY + (Wayland / headless-SSH-with-TTY / macOS without Accessibility permission). + +NOTE: *continuous* key-state teleoperation ("hold a key to keep moving") is +deliberately NOT served here. A terminal in cbreak mode delivers only key-down +bytes — there is no key-release event — so the held-key model cannot be +reproduced. Those teleoperators stay on ``pynput`` and use +:func:`pynput_can_capture` to warn instead of silently doing nothing. +""" + +from __future__ import annotations + +import atexit +import contextlib +import logging +import os +import platform +import select +import sys +import threading +import time +from collections.abc import Callable +from functools import cache +from typing import TYPE_CHECKING + +from .import_utils import _pynput_available + +logger = logging.getLogger(__name__) + +# POSIX-only terminal modules (absent on Windows, where the pynput backend is used). +if TYPE_CHECKING: + import termios + import tty + + _TERMIOS_AVAILABLE = True +else: + try: + import termios + import tty + + _TERMIOS_AVAILABLE = True + except ImportError: # POSIX-only modules; unavailable on Windows + termios = tty = None + _TERMIOS_AVAILABLE = False + +keyboard = None +if _pynput_available: + try: + from pynput import keyboard + except Exception as e: # e.g. no reachable X display on a headless Linux box + logger.info("Could not import pynput keyboard backend: %s", e) + + +@cache +def is_headless() -> bool: + """Return ``True`` when no display server is available. + + * Linux: headless when neither ``DISPLAY`` (X11) nor ``WAYLAND_DISPLAY`` is set. + * macOS / Windows: a display is always assumed to be present. A genuinely GUI-less + Mac/Windows CI host would be misclassified but it doesn't matter, because the + sys.stdin.isatty() gate returns None there regardless. + """ + if platform.system() == "Linux": + return not (os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")) + return False + + +@cache +def is_wayland() -> bool: + """Return ``True`` when running under a Wayland session. + + ``pynput`` relies on an X11 backend. Under Wayland it still imports (XWayland + is usually present and ``$DISPLAY`` is set) but cannot capture *global* + hotkeys, so the documented arrow/Esc shortcuts silently do nothing. This case + is invisible to :func:`is_headless`, hence the dedicated check. + """ + return os.environ.get("XDG_SESSION_TYPE", "").lower() == "wayland" or bool( + os.environ.get("WAYLAND_DISPLAY") + ) + + +@cache +def pynput_can_capture() -> bool: + """Return ``True`` when a ``pynput`` global listener can actually capture keys. + + This is the single predicate every keyboard call-site should use to choose + between the ``pynput`` backend and a fallback. It is intentionally + conservative: + + * Linux: only a real X11 session (a display is present *and* it is not Wayland). + * macOS: ``True`` here — Accessibility / Input-Monitoring permission + (``IS_TRUSTED``) can only be confirmed at runtime *after* starting a + listener, so :func:`init_keyboard_listener` refines this with + :func:`pynput_listener_is_trusted`. + * Windows: ``True`` (the low-level global hook needs no special permission). + + Always ``False`` when ``pynput`` is not installed. + """ + if not _pynput_available: + return False + if platform.system() == "Linux": + return not is_headless() and not is_wayland() + return True + + +def pynput_listener_is_trusted(listener, timeout_s: float = 1.0) -> bool: + """Best-effort check that a freshly started ``pynput`` listener can capture. + + On macOS, ``pynput`` sets ``listener.IS_TRUSTED`` on its *listener thread* + once the Quartz event tap is created; the class default is ``False``. We + therefore wait for the thread to either flip it ``True`` (trusted) or for a + short timeout to elapse (untrusted — it stays ``False`` forever). This biases + toward the common trusted case (returns as soon as the flag flips) and only + pays the full ``timeout_s`` on an already-broken untrusted machine. + + On non-macOS backends the attribute is absent and capture is assumed to work. + """ + if platform.system() != "Darwin": + return True + deadline = time.perf_counter() + timeout_s + while time.perf_counter() < deadline: + if getattr(listener, "IS_TRUSTED", False): + return True + time.sleep(0.02) + return bool(getattr(listener, "IS_TRUSTED", False)) + + +def apply_recording_control(control: str, events: dict) -> None: + """Apply a recording control-flow key press to the shared ``events`` dict. + + Centralizes the mapping so the ``pynput`` and terminal backends behave + identically. ``control`` is one of ``"right"`` (end the loop early), ``"left"`` + (re-record the last episode), or ``"esc"`` (stop recording). + """ + if control == "right": + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + elif control == "left": + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif control == "esc": + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + + +# Terminal arrow keys arrive as a 3-byte escape sequence whose *final* byte identifies +# the direction. Two encodings exist depending on the terminal's cursor-key mode — CSI +# ("ESC [ X") and SS3 ("ESC O X", common over SSH/tmux) — but both share the same final +# byte, so this single table decodes either. Looked up by TerminalKeyListener._parse; +# an unknown final byte yields None (sequence ignored). +_ARROW_FINAL_BYTES = {"A": "up", "B": "down", "C": "right", "D": "left"} + + +class TerminalKeyListener: + """Display-independent keyboard listener that reads keys from the controlling TTY. + + Used as the Wayland / headless / macOS-untrusted equivalent of the ``pynput`` + listener for *discrete* controls. It puts the terminal into cbreak mode with + echo disabled and reads bytes on a daemon thread, decoding them into logical + key names that are passed to ``on_key``: + + * arrow keys (``ESC [ C`` / ``ESC O C`` …) -> ``"right"`` / ``"left"`` / ``"up"`` / ``"down"`` + * a bare ``ESC`` -> ``"esc"`` + * Enter / Tab / Space / Backspace -> ``"enter"`` / ``"tab"`` / ``"space"`` / ``"backspace"`` + * any other printable byte -> that character (e.g. ``"n"``, ``"s"``) + + Only key-down events are produced (terminals have no key-release), so this is + suitable for discrete commands but NOT for continuous "hold-to-move" teleop. + + The terminal is restored on :meth:`stop` and also via an ``atexit`` hook, so a + crash or Ctrl-C never leaves the shell in a no-echo cbreak state. POSIX-only + (``termios`` / ``tty`` / ``select``); those modules are imported lazily so this + file stays importable on Windows (where ``pynput`` is used instead). + """ + + def __init__(self, on_key: Callable[[str], None]): + self._on_key = on_key + self._running = False + self._thread: threading.Thread | None = None + self._fd: int | None = None + self._old_attrs = None + + def _read_char(self, timeout: float) -> str | None: + """Return one character from stdin within ``timeout`` seconds, or ``None``.""" + if self._fd is None: + return None + ready, _, _ = select.select([self._fd], [], [], timeout) + if not ready: + return None + try: + data = os.read(self._fd, 1) + except OSError: + return None + if not data: + return None + return data.decode(errors="ignore") + + def _parse(self, ch: str) -> str | None: + """Decode one (possibly multi-byte) key starting at ``ch`` into a key name.""" + if ch == "\x1b": + # Possible CSI / SS3 escape sequence (arrow keys) or a bare ESC. Use + # short follow-up reads so a lone ESC is not mistaken for a sequence. + ch2 = self._read_char(timeout=0.02) + if ch2 is None: + return "esc" + if ch2 in ("[", "O"): + ch3 = self._read_char(timeout=0.02) + return _ARROW_FINAL_BYTES.get(ch3 or "") + # Some other escape sequence (e.g. Alt+key); ignore it. + return None + if ch in ("\r", "\n"): + return "enter" + if ch == "\t": + return "tab" + if ch == " ": + return "space" + if ch in ("\x7f", "\x08"): + return "backspace" + if ch.isprintable(): + return ch + return None + + def _run(self) -> None: + while self._running: + ch = self._read_char(timeout=0.05) + if ch is None: + continue + name = self._parse(ch) + if name is None: + continue + try: + self._on_key(name) + except Exception as e: # never let a handler error kill the reader thread + logger.debug("Terminal key handler error: %s", e) + + def start(self) -> None: + """Switch the terminal to cbreak mode (echo off) and read keys on a daemon thread. + + No-op when stdin is not a TTY (piped/redirected input) or on platforms + without ``termios`` (e.g. Windows), so non-interactive runs are unaffected. + """ + if not sys.stdin.isatty(): + return + if not _TERMIOS_AVAILABLE: # POSIX-only modules (e.g. unavailable on Windows) + logger.warning("Terminal keyboard input is not supported on this platform.") + return + + self._fd = sys.stdin.fileno() + self._old_attrs = termios.tcgetattr(self._fd) + tty.setcbreak(self._fd) + # Explicitly disable ECHO so arrow-key escape sequences (e.g. ^[[C) are not + # echoed as garbage into the recording terminal. (Independent of the + # version-specific behavior of tty.setcbreak.) + new_attrs = termios.tcgetattr(self._fd) + new_attrs[3] &= ~termios.ECHO # index 3 == lflags + termios.tcsetattr(self._fd, termios.TCSADRAIN, new_attrs) + # Safety net: restore the terminal even if stop() is never reached (crash). + atexit.register(self.stop) + + self._running = True + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def stop(self) -> None: + """Stop the reader thread and restore the original terminal attributes. + + Idempotent: safe to call multiple times (e.g. explicitly and via atexit). + """ + self._running = False + thread = self._thread + if thread is not None: + thread.join(timeout=0.5) + self._thread = None + if self._fd is not None and self._old_attrs is not None and _TERMIOS_AVAILABLE: + try: + termios.tcsetattr(self._fd, termios.TCSADRAIN, self._old_attrs) + finally: + self._old_attrs = None + with contextlib.suppress(Exception): + atexit.unregister(self.stop) + + +# Map pynput key objects to the same canonical names TerminalKeyListener emits, so a +# single dispatch works across both backends. Empty when pynput is unavailable. +if keyboard is not None: + _PYNPUT_KEY_NAMES = { + keyboard.Key.right: "right", + keyboard.Key.left: "left", + keyboard.Key.up: "up", + keyboard.Key.down: "down", + keyboard.Key.esc: "esc", + keyboard.Key.enter: "enter", + keyboard.Key.tab: "tab", + keyboard.Key.space: "space", + keyboard.Key.backspace: "backspace", + } +else: + _PYNPUT_KEY_NAMES = {} + + +def _resolve_pynput_key(key) -> str | None: + """Resolve a pynput key event to the canonical name TerminalKeyListener also emits. + + Special keys map through :data:`_PYNPUT_KEY_NAMES`; character keys fall back to their + ``.char`` (e.g. ``"n"``). Returns ``None`` for keys with no mapping and no character. + """ + name = _PYNPUT_KEY_NAMES.get(key) + if name is not None: + return name + # ``or None`` keeps the historical truthy-char semantics: an empty/None char is "no key". + return getattr(key, "char", None) or None + + +def create_key_listener(dispatch: Callable[[str], None], *, controls_help: str = ""): + """Start a keyboard listener that routes resolved key names to ``dispatch``. + + Shared backend selection used by recording and the rollout strategies: + + * the ``pynput`` global listener on X11 / trusted-macOS / Windows (on macOS the + listener's ``IS_TRUSTED`` flag is checked after start, and an untrusted listener is + stopped so the terminal backend is used instead); + * the stdlib :class:`TerminalKeyListener` on Wayland / headless sessions with a TTY; + * ``None`` when no backend is usable (non-interactive / piped runs). + + Both backends pass ``dispatch`` the same canonical key names ("right" / "left" / "up" / + "down" / "esc" / "enter" / "tab" / "space" / "backspace", or a character), so one + ``dispatch`` works regardless of backend. ``controls_help`` is an optional hint + appended to the log messages. + + Returns the listener (exposing ``.stop()``) or ``None``. + """ + suffix = f" ({controls_help})" if controls_help else "" + + if pynput_can_capture() and keyboard is not None: + + def on_press(key): + with contextlib.suppress(Exception): + name = _resolve_pynput_key(key) + if name is not None: + dispatch(name) + + listener = keyboard.Listener(on_press=on_press) + listener.start() + if pynput_listener_is_trusted(listener): + logger.info("Keyboard listener started%s.", suffix) + return listener + # macOS without Accessibility / Input-Monitoring permission: the listener never + # fires. Stop it and fall through to the terminal backend. + logger.warning( + "pynput keyboard listener is not trusted (missing macOS Accessibility / " + "Input Monitoring permission); falling back to terminal keyboard input." + ) + listener.stop() + + if sys.stdin.isatty(): + listener = TerminalKeyListener(dispatch) + listener.start() + logger.info("Using terminal keyboard input — keep this terminal focused%s.", suffix) + return listener + + logger.warning( + "Keyboard controls unavailable: no usable display (Wayland/headless) and stdin is " + "not an interactive terminal%s.", + suffix, + ) + return None + + +def init_keyboard_listener(): + """Initialize a non-blocking keyboard listener for interactive recording controls. + + Backend selection: + + * ``pynput`` global listener when :func:`pynput_can_capture` is true (real + X11, macOS, Windows). On macOS the listener's ``IS_TRUSTED`` flag is checked + after start; if the process lacks Accessibility / Input-Monitoring + permission, the listener is stopped and the terminal backend is used. + * a :class:`TerminalKeyListener` reading the controlling TTY when ``pynput`` + cannot capture (Wayland / headless-SSH / macOS-untrusted) *and* stdin is a TTY. + * otherwise no listener (non-interactive / piped runs) — recording relies on + the episode/reset timers (or Ctrl+C). + + Both backends accept the same controls: Right/Left/Esc, plus the single-byte letter + equivalents ``n`` (next), ``r`` (re-record) and ``q`` (quit). The letters are the most + reliable choice over high-latency SSH/VNC links, where arrow-key escape sequences can + be split, delayed, or intercepted by the terminal. + + Returns: + A tuple ``(listener, events)`` where ``listener`` exposes ``.stop()`` or is + ``None``, and ``events`` is the dict of flags (``exit_early``, + ``rerecord_episode``, ``stop_recording``) set by key presses. + """ + events = { + "exit_early": False, + "rerecord_episode": False, + "stop_recording": False, + } + + # Accept the single-byte letter equivalents n/r/q alongside the arrow/Esc keys: the + # letters are immune to the escape-sequence split/delay/interception that affects arrows + # over laggy SSH/VNC links. Case-insensitive so Shift+letter still works. + def on_key(name: str) -> None: + key = name.lower() + if key in ("right", "n"): + apply_recording_control("right", events) + elif key in ("left", "r"): + apply_recording_control("left", events) + elif key in ("esc", "q"): + apply_recording_control("esc", events) + # other keys (incl. up/down) are intentionally ignored + + listener = create_key_listener(on_key, controls_help="Right/Left/Esc, or n=next, r=re-record, q=quit") + return listener, events diff --git a/tests/utils/test_keyboard_input.py b/tests/utils/test_keyboard_input.py new file mode 100644 index 000000000..2f0dee889 --- /dev/null +++ b/tests/utils/test_keyboard_input.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the display-independent keyboard input helpers. + +These cover the parts most likely to regress: the environment-detection decision +table (the heart of the Wayland/headless fix), the macOS trust probe, the control +mapping, the terminal escape-sequence parsing, and backend selection. They require +neither ``pynput`` nor a real terminal. +""" + +import io +import platform +import sys + +import pytest + +import lerobot.utils.keyboard_input as ki +from lerobot.utils.keyboard_input import ( + TerminalKeyListener, + apply_recording_control, + create_key_listener, + init_keyboard_listener, + is_headless, + is_wayland, + pynput_can_capture, + pynput_listener_is_trusted, +) + + +@pytest.fixture(autouse=True) +def _clear_detection_caches(): + """The detection helpers are ``@cache``-decorated; clear around each test.""" + for fn in (is_headless, is_wayland, pynput_can_capture): + fn.cache_clear() + yield + for fn in (is_headless, is_wayland, pynput_can_capture): + fn.cache_clear() + + +def _set_platform(monkeypatch, name): + monkeypatch.setattr(platform, "system", lambda: name) + + +def _set_tty(monkeypatch, is_tty): + stdin = io.StringIO("") + stdin.isatty = lambda: is_tty + monkeypatch.setattr(sys, "stdin", stdin) + + +# --- Environment detection (the core of the fix) --------------------------- +@pytest.mark.parametrize( + ("system", "env", "expected"), + [ + ("Linux", {}, True), # no display server + ("Linux", {"DISPLAY": ":0"}, False), # X11 + ("Linux", {"WAYLAND_DISPLAY": "wayland-0"}, False), # Wayland + ("Darwin", {}, False), # display always assumed present + ], +) +def test_is_headless(monkeypatch, system, env, expected): + _set_platform(monkeypatch, system) + monkeypatch.delenv("DISPLAY", raising=False) + monkeypatch.delenv("WAYLAND_DISPLAY", raising=False) + for key, value in env.items(): + monkeypatch.setenv(key, value) + assert is_headless() is expected + + +@pytest.mark.parametrize( + ("env", "expected"), + [ + ({"XDG_SESSION_TYPE": "wayland"}, True), + ({"WAYLAND_DISPLAY": "wayland-0"}, True), + ({"XDG_SESSION_TYPE": "x11"}, False), + ({}, False), + ], +) +def test_is_wayland(monkeypatch, env, expected): + monkeypatch.delenv("XDG_SESSION_TYPE", raising=False) + monkeypatch.delenv("WAYLAND_DISPLAY", raising=False) + for key, value in env.items(): + monkeypatch.setenv(key, value) + assert is_wayland() is expected + + +@pytest.mark.parametrize( + ("system", "env", "pynput_available", "expected"), + [ + ("Linux", {"DISPLAY": ":0"}, True, True), # X11 + ("Linux", {"DISPLAY": ":0", "WAYLAND_DISPLAY": "wayland-0"}, True, False), # Wayland + ("Linux", {}, True, False), # headless + ("Darwin", {}, True, True), + ("Linux", {"DISPLAY": ":0"}, False, False), # pynput not installed + ], +) +def test_pynput_can_capture(monkeypatch, system, env, pynput_available, expected): + _set_platform(monkeypatch, system) + monkeypatch.setattr(ki, "_pynput_available", pynput_available) + for var in ("DISPLAY", "WAYLAND_DISPLAY", "XDG_SESSION_TYPE"): + monkeypatch.delenv(var, raising=False) + for key, value in env.items(): + monkeypatch.setenv(key, value) + assert pynput_can_capture() is expected + + +# --- macOS trust probe ------------------------------------------------------ +class _FakeListener: + def __init__(self, is_trusted): + self.IS_TRUSTED = is_trusted + + +def test_pynput_listener_is_trusted(monkeypatch): + _set_platform(monkeypatch, "Linux") + assert pynput_listener_is_trusted(_FakeListener(False)) is True # non-macOS: always assumed ok + _set_platform(monkeypatch, "Darwin") + assert pynput_listener_is_trusted(_FakeListener(False), timeout_s=0.05) is False + + +# --- Control mapping -------------------------------------------------------- +def test_apply_recording_control(): + events = {"exit_early": False, "rerecord_episode": False, "stop_recording": False} + apply_recording_control("left", events) + assert events == {"exit_early": True, "rerecord_episode": True, "stop_recording": False} + apply_recording_control("esc", events) + assert events["stop_recording"] is True + apply_recording_control("up", events) # unknown control -> no-op (no error) + + +# --- Terminal escape-sequence parsing (the tricky bit) ---------------------- +def _drive(listener, byte_seq): + """Run the listener's read loop over a scripted list of bytes (no real terminal).""" + script = list(byte_seq) + + def fake_read(timeout): + if script: + return script.pop(0) + listener._running = False + return None + + listener._read_char = fake_read + listener._running = True + listener._run() + + +@pytest.mark.parametrize( + ("byte_seq", "expected"), + [ + (["\x1b", "[", "C"], ["right"]), # CSI arrow + (["\x1b", "O", "D"], ["left"]), # SS3 arrow (e.g. over SSH/tmux) + (["\x1b"], ["esc"]), # bare ESC + (["\x1b", "[", "A"], ["up"]), # decoded even though the record handler ignores it + (["n"], ["n"]), # letter passthrough + ], +) +def test_terminal_parsing(byte_seq, expected): + collected = [] + _drive(TerminalKeyListener(collected.append), byte_seq) + assert collected == expected + + +# --- Backend selection ------------------------------------------------------ +def test_init_selects_terminal_when_pynput_cannot_capture(monkeypatch): + monkeypatch.setattr(ki, "pynput_can_capture", lambda: False) + _set_tty(monkeypatch, is_tty=True) + monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None) # avoid touching termios + listener, _ = init_keyboard_listener() + assert isinstance(listener, TerminalKeyListener) + + +def test_init_returns_none_without_tty(monkeypatch): + monkeypatch.setattr(ki, "pynput_can_capture", lambda: False) + _set_tty(monkeypatch, is_tty=False) + listener, _ = init_keyboard_listener() + assert listener is None + + +@pytest.mark.parametrize( + ("key", "flag"), + [("right", "exit_early"), ("r", "rerecord_episode"), ("q", "stop_recording")], +) +def test_init_terminal_key_routing(monkeypatch, key, flag): + """Arrows and their letter equivalents drive the same events (terminal backend).""" + monkeypatch.setattr(ki, "pynput_can_capture", lambda: False) + _set_tty(monkeypatch, is_tty=True) + monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None) + listener, events = init_keyboard_listener() + listener._on_key(key) + assert events[flag] is True + + +# --- Shared factory + pynput key resolver ----------------------------------- +def test_resolve_pynput_key_char_fallback(): + """Unmapped keys fall back to ``.char`` (and yield None when there is none).""" + assert ki._resolve_pynput_key(type("K", (), {"char": "s"})()) == "s" + assert ki._resolve_pynput_key(type("K", (), {"char": None})()) is None + assert ki._resolve_pynput_key(type("K", (), {"char": ""})()) is None # empty char -> no key + + +def test_create_key_listener_routes_to_dispatch(monkeypatch): + """The terminal backend forwards canonical key names straight to ``dispatch``.""" + monkeypatch.setattr(ki, "pynput_can_capture", lambda: False) + _set_tty(monkeypatch, is_tty=True) + monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None) + seen = [] + listener = create_key_listener(seen.append, controls_help="save='s'") + assert isinstance(listener, TerminalKeyListener) + listener._on_key("space") + assert seen == ["space"] + + +def test_create_key_listener_none_without_tty(monkeypatch): + monkeypatch.setattr(ki, "pynput_can_capture", lambda: False) + _set_tty(monkeypatch, is_tty=False) + assert create_key_listener(lambda name: None) is None