mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 19:57:27 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d8491ba179 | |||
| 594a1bb81d | |||
| 264a7490ac | |||
| 6f0ba4be38 |
@@ -390,9 +390,17 @@ Set the flow of data recording using command-line arguments:
|
||||
|
||||
Control the data recording flow using keyboard shortcuts:
|
||||
|
||||
- Press **Right Arrow (`→`)**: Early stop the current episode or reset time and move to the next.
|
||||
- Press **Left Arrow (`←`)**: Cancel the current episode and re-record it.
|
||||
- Press **Escape (`ESC`)**: Immediately stop the session, encode videos, and upload the dataset.
|
||||
- Press **Right Arrow (`→`)** or **`n`**: Early stop the current episode or reset time and move to the next.
|
||||
- Press **Left Arrow (`←`)** or **`r`**: Cancel the current episode and re-record it.
|
||||
- Press **Escape (`ESC`)** or **`q`**: Immediately stop the session, encode videos, and upload the dataset.
|
||||
|
||||
<Tip>
|
||||
|
||||
These control-flow shortcuts work on **X11, Wayland, and headless/SSH** sessions. When a global keyboard backend isn't available (Wayland, a headless machine, or macOS without Accessibility permission), `lerobot-record` automatically reads the same keys from the terminal — launch it from an interactive terminal and keep it focused. You can also use the letter equivalents **`n`** (next, same as `→`), **`r`** (re-record, same as `←`) and **`q`** (quit, same as `ESC`). No `$DISPLAY` setup is required.
|
||||
|
||||
This applies to the recording control flow only. Keyboard **teleoperation** (driving the robot with the keyboard) still needs a global key backend, so it works only on an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — not on Wayland or headless sessions.
|
||||
|
||||
</Tip>
|
||||
|
||||
#### Tips for gathering data
|
||||
|
||||
@@ -406,7 +414,7 @@ If you want to dive deeper into this important topic, you can check out the [blo
|
||||
|
||||
#### Troubleshooting:
|
||||
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
- On Linux, the recording control-flow keys (arrow keys, Escape) work on X11, Wayland, and headless/SSH sessions as long as `lerobot-record` runs in an interactive terminal — no `$DISPLAY` setup is needed. If the keys have no effect, make sure you are in an interactive (TTY) terminal, not a piped/non-TTY session, and that it is focused; the letter equivalents `n` / `r` / `q` also work. Keyboard _teleoperation_ (as opposed to the recording control flow) still requires a global key backend — an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — and is unavailable on Wayland or headless machines. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ If you want to dive deeper into this important topic, you can check out the [blo
|
||||
|
||||
#### Troubleshooting:
|
||||
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
- On Linux, the recording control-flow keys (arrow keys, Escape) work on X11, Wayland, and headless/SSH sessions as long as you run the recording from an interactive terminal (keep it focused) — no `$DISPLAY` setup is needed; the letter equivalents `n` / `r` / `q` also work. Note that **keyboard teleoperation of the LeKiwi base** is different: it relies on a global key backend and therefore works only on an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — not on Wayland or headless machines. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
## Replay an episode
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.common.control_utils import predict_action
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
@@ -26,6 +26,7 @@ from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
@@ -23,6 +22,7 @@ from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.common.control_utils import predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
@@ -41,6 +41,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
@@ -39,6 +38,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.common.control_utils import predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
@@ -41,6 +41,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
@@ -36,6 +35,7 @@ from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
@@ -17,12 +17,9 @@ from __future__ import annotations
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
@@ -43,34 +40,6 @@ from lerobot.robots import Robot
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
|
||||
@cache
|
||||
def is_headless():
|
||||
"""
|
||||
Detects if the Python script is running in a headless environment (e.g., without a display).
|
||||
|
||||
This function attempts to import `pynput`, a library that requires a graphical environment.
|
||||
If the import fails, it assumes the environment is headless. The result is cached to avoid
|
||||
re-running the check.
|
||||
|
||||
Returns:
|
||||
True if the environment is determined to be headless, False otherwise.
|
||||
"""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
print(
|
||||
"Error trying to import pynput. Switching to headless mode. "
|
||||
"As a result, the video stream from the cameras won't be shown, "
|
||||
"and you won't be able to change the control flow with keyboards. "
|
||||
"For more info, see traceback below.\n"
|
||||
)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
return True
|
||||
|
||||
|
||||
def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -122,59 +91,6 @@ def predict_action(
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""
|
||||
Initializes a non-blocking keyboard listener for real-time user interaction.
|
||||
|
||||
This function sets up a listener for specific keys (right arrow, left arrow, escape) to control
|
||||
the program flow during execution, such as stopping recording or exiting loops. It gracefully
|
||||
handles headless environments where keyboard listening is not possible.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The `pynput.keyboard.Listener` instance, or `None` if in a headless environment.
|
||||
- A dictionary of event flags (e.g., `exit_early`) that are set by key presses.
|
||||
"""
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
"""
|
||||
Validates the dataset repository name against the presence of a policy configuration.
|
||||
|
||||
@@ -73,8 +73,17 @@ class EvalConfig:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||
use_async_envs: bool = True
|
||||
# Whether to record eval rollouts as a LeRobot dataset on disk.
|
||||
recording: bool = False
|
||||
# If set, push recorded eval datasets to the Hub under this repo id (one repo per task,
|
||||
# suffixed by task and env index). Requires recording=true.
|
||||
recording_repo_id: str | None = None
|
||||
# Whether the pushed recording repositories should be private.
|
||||
recording_private: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.recording_repo_id is not None and not self.recording:
|
||||
raise ValueError("eval.recording_repo_id requires eval.recording=true.")
|
||||
if self.batch_size == 0:
|
||||
self.batch_size = self._auto_batch_size()
|
||||
if self.batch_size > self.n_episodes:
|
||||
|
||||
@@ -47,8 +47,6 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event, Lock
|
||||
@@ -58,7 +56,6 @@ import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
@@ -66,7 +63,7 @@ from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
from lerobot.utils.keyboard_input import create_key_listener
|
||||
from lerobot.utils.pedal import start_pedal_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
@@ -75,19 +72,6 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -180,64 +164,36 @@ class DAggerEvents:
|
||||
|
||||
|
||||
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
"""Initialise keyboard listener with DAgger 3-key controls.
|
||||
"""Initialise a keyboard listener for DAgger's 3 controls.
|
||||
|
||||
Returns the pynput Listener (or ``None`` in headless mode or when
|
||||
pynput is unavailable).
|
||||
Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on
|
||||
Wayland / headless TTY) is delegated to :func:`create_key_listener`. Returns the
|
||||
listener (exposing ``stop()``) or ``None`` when no keyboard backend is usable.
|
||||
"""
|
||||
if not PYNPUT_AVAILABLE or is_headless():
|
||||
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
|
||||
return None
|
||||
|
||||
# Map config key names to pynput Key objects for special keys
|
||||
special_keys = {
|
||||
"space": keyboard.Key.space,
|
||||
"tab": keyboard.Key.tab,
|
||||
"enter": keyboard.Key.enter,
|
||||
}
|
||||
|
||||
def _resolve_key(key) -> str | None:
|
||||
"""Resolve a pynput key event to a config-comparable string."""
|
||||
if key == keyboard.Key.esc:
|
||||
return "esc"
|
||||
for name, pynput_key in special_keys.items():
|
||||
if key == pynput_key:
|
||||
return name
|
||||
if hasattr(key, "char") and key.char:
|
||||
return key.char
|
||||
return None
|
||||
|
||||
# Build mapping: resolved key string -> DAgger event name
|
||||
# Map config key names to DAgger event names.
|
||||
key_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
}
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
resolved = _resolve_key(key)
|
||||
if resolved is None:
|
||||
return
|
||||
if resolved == "esc":
|
||||
logger.info("Stop recording...")
|
||||
events.stop_recording.set()
|
||||
return
|
||||
if resolved in key_to_event:
|
||||
events.request_transition(key_to_event[resolved])
|
||||
if resolved == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
def dispatch(name: str) -> None:
|
||||
"""Apply a resolved key name to the DAgger events."""
|
||||
if name == "esc":
|
||||
logger.info("Stop recording...")
|
||||
events.stop_recording.set()
|
||||
return
|
||||
if name in key_to_event:
|
||||
events.request_transition(key_to_event[name])
|
||||
if name == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
cfg.upload,
|
||||
return create_key_listener(
|
||||
dispatch,
|
||||
controls_help=(
|
||||
f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', "
|
||||
f"upload='{cfg.upload}', ESC=stop"
|
||||
),
|
||||
)
|
||||
return listener
|
||||
|
||||
|
||||
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
||||
@@ -328,7 +284,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Stopping DAgger recording")
|
||||
log_say("Stopping DAgger recording", play_sounds)
|
||||
|
||||
if self._listener is not None and not is_headless():
|
||||
if self._listener is not None:
|
||||
logger.info("Stopping keyboard listener")
|
||||
self._listener.stop()
|
||||
|
||||
|
||||
@@ -35,14 +35,13 @@ import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
@@ -307,7 +306,7 @@ class EpisodicStrategy(RolloutStrategy):
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
if self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
|
||||
@@ -18,17 +18,14 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event as ThreadingEvent, Lock
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
from lerobot.utils.keyboard_input import create_key_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -37,19 +34,6 @@ from ..context import RolloutContext
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +56,6 @@ class HighlightStrategy(RolloutStrategy):
|
||||
|
||||
def __init__(self, config: HighlightStrategyConfig):
|
||||
super().__init__(config)
|
||||
require_package("pynput", extra="pynput-dep")
|
||||
self._ring: RolloutRingBuffer | None = None
|
||||
self._listener = None
|
||||
self._save_requested = ThreadingEvent()
|
||||
@@ -234,30 +217,27 @@ class HighlightStrategy(RolloutStrategy):
|
||||
logger.info("Highlight strategy teardown complete")
|
||||
|
||||
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
|
||||
"""Set up keyboard listener for save and push keys."""
|
||||
if is_headless():
|
||||
logger.warning("Headless environment — highlight keys unavailable")
|
||||
return
|
||||
"""Set up a keyboard listener for the save and push keys.
|
||||
|
||||
try:
|
||||
save_key = self.config.save_key
|
||||
push_key = self.config.push_key
|
||||
Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on
|
||||
Wayland / headless TTY) is delegated to :func:`create_key_listener`.
|
||||
"""
|
||||
save_key = self.config.save_key
|
||||
push_key = self.config.push_key
|
||||
|
||||
def on_press(key):
|
||||
with contextlib.suppress(Exception):
|
||||
if hasattr(key, "char") and key.char == save_key:
|
||||
self._save_requested.set()
|
||||
elif hasattr(key, "char") and key.char == push_key:
|
||||
self._push_requested.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
self._save_requested.clear()
|
||||
shutdown_event.set()
|
||||
def dispatch(name: str) -> None:
|
||||
"""Apply a resolved key name to the highlight events."""
|
||||
if name == save_key:
|
||||
self._save_requested.set()
|
||||
elif name == push_key:
|
||||
self._push_requested.set()
|
||||
elif name == "esc":
|
||||
self._save_requested.clear()
|
||||
shutdown_event.set()
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
||||
except ImportError:
|
||||
logger.warning("pynput not available — keyboard listener disabled")
|
||||
self._listener = create_key_listener(
|
||||
dispatch, controls_help=f"save='{save_key}', push='{push_key}', ESC=stop"
|
||||
)
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor."""
|
||||
|
||||
@@ -72,8 +72,9 @@ from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs import FeatureType, parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs import (
|
||||
check_env_attributes_and_types,
|
||||
close_envs,
|
||||
@@ -84,7 +85,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -95,6 +96,65 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict) -> dict:
|
||||
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
shape = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
return features
|
||||
|
||||
|
||||
def _build_raw_frame(
|
||||
raw_obs: dict,
|
||||
env_idx: int,
|
||||
action: np.ndarray,
|
||||
reward: float,
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if key.startswith("next."):
|
||||
continue
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
val = raw_obs[key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -105,6 +165,10 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -145,6 +209,33 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
recording_datasets: list[LeRobotDataset] | None = None
|
||||
raw_observation = None
|
||||
task_desc = ""
|
||||
if recording_dir is not None and env_features is not None:
|
||||
features = _env_features_to_dataset_features(env_features)
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
recording_datasets = []
|
||||
multi_env = env.num_envs > 1
|
||||
base_repo_id = recording_repo_id or "eval_recording"
|
||||
for i in range(env.num_envs):
|
||||
root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir)
|
||||
repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id
|
||||
recording_datasets.append(
|
||||
LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=True,
|
||||
)
|
||||
)
|
||||
raw_observation = deepcopy(observation)
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
task_desc = ""
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
@@ -162,80 +253,112 @@ def rollout(
|
||||
leave=False,
|
||||
)
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
try:
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
# Infer "task" from sub-environments (prefer natural language description).
|
||||
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
|
||||
try:
|
||||
observation["task"] = list(env.call("task_description"))
|
||||
except (AttributeError, NotImplementedError):
|
||||
# Infer "task" from sub-environments (prefer natural language description).
|
||||
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
|
||||
try:
|
||||
observation["task"] = list(env.call("task"))
|
||||
observation["task"] = list(env.call("task_description"))
|
||||
except (AttributeError, NotImplementedError):
|
||||
observation["task"] = [""] * env.num_envs
|
||||
try:
|
||||
observation["task"] = list(env.call("task"))
|
||||
except (AttributeError, NotImplementedError):
|
||||
observation["task"] = [""] * env.num_envs
|
||||
|
||||
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
||||
observation = env_preprocessor(observation)
|
||||
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
||||
observation = env_preprocessor(observation)
|
||||
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {ACTION: action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition[ACTION]
|
||||
action_transition = {ACTION: action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition[ACTION]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
|
||||
# Apply the next action.
|
||||
observation, reward, terminated, truncated, info = env.step(action_numpy)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
# Apply the next action.
|
||||
observation, reward, terminated, truncated, info = env.step(action_numpy)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available if none of the envs finished.
|
||||
if "final_info" in info:
|
||||
final_info = info["final_info"]
|
||||
if not isinstance(final_info, dict):
|
||||
raise RuntimeError(
|
||||
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
|
||||
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available if none of the envs finished.
|
||||
if "final_info" in info:
|
||||
final_info = info["final_info"]
|
||||
if not isinstance(final_info, dict):
|
||||
raise RuntimeError(
|
||||
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
|
||||
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
||||
)
|
||||
successes = final_info["is_success"].tolist()
|
||||
elif "is_success" in info:
|
||||
is_success = info["is_success"]
|
||||
successes = (
|
||||
is_success.tolist()
|
||||
if hasattr(is_success, "tolist")
|
||||
else [bool(is_success)] * env.num_envs
|
||||
)
|
||||
successes = final_info["is_success"].tolist()
|
||||
elif "is_success" in info:
|
||||
is_success = info["is_success"]
|
||||
successes = (
|
||||
is_success.tolist() if hasattr(is_success, "tolist") else [bool(is_success)] * env.num_envs
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
if recording_datasets is not None and raw_observation is not None:
|
||||
prev_done = done.copy()
|
||||
for env_idx in range(env.num_envs):
|
||||
if prev_done[env_idx]:
|
||||
continue
|
||||
frame = _build_raw_frame(
|
||||
raw_observation,
|
||||
env_idx,
|
||||
action_numpy[env_idx],
|
||||
reward[env_idx],
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_datasets[env_idx].features,
|
||||
)
|
||||
recording_datasets[env_idx].add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_datasets[env_idx].save_episode()
|
||||
raw_observation = deepcopy(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||
# and allows logging/saving (e.g., videos) to be triggered consistently.
|
||||
done = terminated | truncated | done
|
||||
if step + 1 == max_steps:
|
||||
done = np.ones_like(done, dtype=bool)
|
||||
|
||||
all_actions.append(torch.from_numpy(action_numpy))
|
||||
all_rewards.append(torch.from_numpy(reward))
|
||||
all_dones.append(torch.from_numpy(done))
|
||||
all_successes.append(torch.tensor(successes))
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
)
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||
# and allows logging/saving (e.g., videos) to be triggered consistently.
|
||||
done = terminated | truncated | done
|
||||
if step + 1 == max_steps:
|
||||
done = np.ones_like(done, dtype=bool)
|
||||
|
||||
all_actions.append(torch.from_numpy(action_numpy))
|
||||
all_rewards.append(torch.from_numpy(reward))
|
||||
all_dones.append(torch.from_numpy(done))
|
||||
all_successes.append(torch.tensor(successes))
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
finally:
|
||||
if recording_datasets is not None:
|
||||
for ds in recording_datasets:
|
||||
ds.finalize()
|
||||
if recording_repo_id is not None:
|
||||
if ds.num_episodes > 0:
|
||||
ds.push_to_hub(private=recording_private)
|
||||
else:
|
||||
logging.warning("No episodes recorded for %s — skipping push to hub.", ds.repo_id)
|
||||
|
||||
# Track the final observation.
|
||||
if return_observations:
|
||||
@@ -273,6 +396,10 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -361,6 +488,10 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -563,6 +694,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
|
||||
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
@@ -572,10 +707,15 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=False,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
recording_dir=recording_dir,
|
||||
env_features=cfg.env.features if cfg.eval.recording else None,
|
||||
recording_repo_id=cfg.eval.recording_repo_id,
|
||||
recording_private=cfg.eval.recording_private,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
@@ -618,6 +758,10 @@ def eval_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
) -> TaskMetrics:
|
||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||
|
||||
@@ -635,6 +779,10 @@ def eval_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -661,6 +809,10 @@ def run_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
):
|
||||
"""
|
||||
Run eval_one for a single (task_group, task_id, env).
|
||||
@@ -672,7 +824,13 @@ def run_one(
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
|
||||
task_recording_dir = None
|
||||
task_repo_id = None
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
if recording_repo_id is not None:
|
||||
task_repo_id = f"{recording_repo_id}_{task_group}_{task_id}"
|
||||
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
@@ -685,8 +843,12 @@ def run_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=task_recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=task_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
# ensure we always provide video_paths key to simplify accumulation
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
return task_group, task_id, metrics
|
||||
@@ -702,6 +864,10 @@ def eval_policy_all(
|
||||
n_episodes: int,
|
||||
*,
|
||||
max_episodes_rendered: int = 0,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
recording_repo_id: str | None = None,
|
||||
recording_private: bool = False,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
@@ -761,6 +927,10 @@ def eval_policy_all(
|
||||
videos_dir=videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
recording_repo_id=recording_repo_id,
|
||||
recording_private=recording_private,
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
|
||||
@@ -96,11 +96,7 @@ from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.common.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.common.control_utils import sanity_check_dataset_robot_compatibility
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
@@ -155,6 +151,7 @@ from lerobot.teleoperators.keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.keyboard_input import init_keyboard_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -508,7 +505,7 @@ def record(
|
||||
if teleop and teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.utils.import_utils import _hidapi_available, _pygame_available, require_package
|
||||
from lerobot.utils.keyboard_input import pynput_can_capture
|
||||
|
||||
from ..utils import TeleopEvents
|
||||
|
||||
@@ -123,6 +124,15 @@ class KeyboardController(InputController):
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
if not pynput_can_capture():
|
||||
logging.warning(
|
||||
"Keyboard control is unavailable in this environment. pynput cannot capture keys "
|
||||
"on Wayland or headless machines, or on macOS without Accessibility / Input "
|
||||
"Monitoring permission. Keyboard motion will be inactive."
|
||||
)
|
||||
self.running = False
|
||||
return
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
@@ -24,6 +22,7 @@ from typing import Any
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
from lerobot.utils.keyboard_input import pynput_can_capture
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
@@ -37,14 +36,10 @@ PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
logging.info("Could not import pynput keyboard backend: %s", e)
|
||||
|
||||
|
||||
class KeyboardTeleop(Teleoperator):
|
||||
@@ -88,7 +83,7 @@ class KeyboardTeleop(Teleoperator):
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
if PYNPUT_AVAILABLE:
|
||||
if PYNPUT_AVAILABLE and pynput_can_capture():
|
||||
logging.info("pynput is available - enabling local keyboard listener.")
|
||||
self.listener = keyboard.Listener(
|
||||
on_press=self._on_press,
|
||||
@@ -96,7 +91,13 @@ class KeyboardTeleop(Teleoperator):
|
||||
)
|
||||
self.listener.start()
|
||||
else:
|
||||
logging.info("pynput not available - skipping local keyboard listener.")
|
||||
logging.warning(
|
||||
"Keyboard teleoperation is unavailable in this environment. pynput can only "
|
||||
"capture key events on an X11 session (Linux), a Windows desktop, or macOS with "
|
||||
"Accessibility / Input Monitoring granted - not on Wayland or headless machines. "
|
||||
"This keyboard teleoperator will produce no actions; use an X11 session, a "
|
||||
"gamepad, or a leader-arm teleoperator instead."
|
||||
)
|
||||
self.listener = None
|
||||
|
||||
def calibrate(self) -> None:
|
||||
|
||||
@@ -0,0 +1,440 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Display-independent keyboard input for interactive controls.
|
||||
|
||||
This module centralizes everything related to *discrete* keyboard controls
|
||||
(end-episode-early, re-record, stop, and the rollout strategies' custom keys):
|
||||
|
||||
* environment detection — :func:`is_headless`, :func:`is_wayland`,
|
||||
:func:`pynput_can_capture` (the single predicate every call-site should use to
|
||||
decide whether ``pynput`` can actually capture keys here);
|
||||
* a shared key mapping — :func:`apply_recording_control`; and
|
||||
* two interchangeable backends behind one ``(listener, events)`` contract:
|
||||
the ``pynput`` global listener (X11 / trusted-macOS / Windows) and a
|
||||
standard-library :class:`TerminalKeyListener` that reads the controlling TTY
|
||||
(Wayland / headless-SSH-with-TTY / macOS without Accessibility permission).
|
||||
|
||||
NOTE: *continuous* key-state teleoperation ("hold a key to keep moving") is
|
||||
deliberately NOT served here. A terminal in cbreak mode delivers only key-down
|
||||
bytes — there is no key-release event — so the held-key model cannot be
|
||||
reproduced. Those teleoperators stay on ``pynput`` and use
|
||||
:func:`pynput_can_capture` to warn instead of silently doing nothing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import select
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .import_utils import _pynput_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# POSIX-only terminal modules (absent on Windows, where the pynput backend is used).
|
||||
if TYPE_CHECKING:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
_TERMIOS_AVAILABLE = True
|
||||
else:
|
||||
try:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
_TERMIOS_AVAILABLE = True
|
||||
except ImportError: # POSIX-only modules; unavailable on Windows
|
||||
termios = tty = None
|
||||
_TERMIOS_AVAILABLE = False
|
||||
|
||||
keyboard = None
|
||||
if _pynput_available:
|
||||
try:
|
||||
from pynput import keyboard
|
||||
except Exception as e: # e.g. no reachable X display on a headless Linux box
|
||||
logger.info("Could not import pynput keyboard backend: %s", e)
|
||||
|
||||
|
||||
@cache
|
||||
def is_headless() -> bool:
|
||||
"""Return ``True`` when no display server is available.
|
||||
|
||||
* Linux: headless when neither ``DISPLAY`` (X11) nor ``WAYLAND_DISPLAY`` is set.
|
||||
* macOS / Windows: a display is always assumed to be present. A genuinely GUI-less
|
||||
Mac/Windows CI host would be misclassified but it doesn't matter, because the
|
||||
sys.stdin.isatty() gate returns None there regardless.
|
||||
"""
|
||||
if platform.system() == "Linux":
|
||||
return not (os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
|
||||
return False
|
||||
|
||||
|
||||
@cache
|
||||
def is_wayland() -> bool:
|
||||
"""Return ``True`` when running under a Wayland session.
|
||||
|
||||
``pynput`` relies on an X11 backend. Under Wayland it still imports (XWayland
|
||||
is usually present and ``$DISPLAY`` is set) but cannot capture *global*
|
||||
hotkeys, so the documented arrow/Esc shortcuts silently do nothing. This case
|
||||
is invisible to :func:`is_headless`, hence the dedicated check.
|
||||
"""
|
||||
return os.environ.get("XDG_SESSION_TYPE", "").lower() == "wayland" or bool(
|
||||
os.environ.get("WAYLAND_DISPLAY")
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def pynput_can_capture() -> bool:
|
||||
"""Return ``True`` when a ``pynput`` global listener can actually capture keys.
|
||||
|
||||
This is the single predicate every keyboard call-site should use to choose
|
||||
between the ``pynput`` backend and a fallback. It is intentionally
|
||||
conservative:
|
||||
|
||||
* Linux: only a real X11 session (a display is present *and* it is not Wayland).
|
||||
* macOS: ``True`` here — Accessibility / Input-Monitoring permission
|
||||
(``IS_TRUSTED``) can only be confirmed at runtime *after* starting a
|
||||
listener, so :func:`init_keyboard_listener` refines this with
|
||||
:func:`pynput_listener_is_trusted`.
|
||||
* Windows: ``True`` (the low-level global hook needs no special permission).
|
||||
|
||||
Always ``False`` when ``pynput`` is not installed.
|
||||
"""
|
||||
if not _pynput_available:
|
||||
return False
|
||||
if platform.system() == "Linux":
|
||||
return not is_headless() and not is_wayland()
|
||||
return True
|
||||
|
||||
|
||||
def pynput_listener_is_trusted(listener, timeout_s: float = 1.0) -> bool:
|
||||
"""Best-effort check that a freshly started ``pynput`` listener can capture.
|
||||
|
||||
On macOS, ``pynput`` sets ``listener.IS_TRUSTED`` on its *listener thread*
|
||||
once the Quartz event tap is created; the class default is ``False``. We
|
||||
therefore wait for the thread to either flip it ``True`` (trusted) or for a
|
||||
short timeout to elapse (untrusted — it stays ``False`` forever). This biases
|
||||
toward the common trusted case (returns as soon as the flag flips) and only
|
||||
pays the full ``timeout_s`` on an already-broken untrusted machine.
|
||||
|
||||
On non-macOS backends the attribute is absent and capture is assumed to work.
|
||||
"""
|
||||
if platform.system() != "Darwin":
|
||||
return True
|
||||
deadline = time.perf_counter() + timeout_s
|
||||
while time.perf_counter() < deadline:
|
||||
if getattr(listener, "IS_TRUSTED", False):
|
||||
return True
|
||||
time.sleep(0.02)
|
||||
return bool(getattr(listener, "IS_TRUSTED", False))
|
||||
|
||||
|
||||
def apply_recording_control(control: str, events: dict) -> None:
|
||||
"""Apply a recording control-flow key press to the shared ``events`` dict.
|
||||
|
||||
Centralizes the mapping so the ``pynput`` and terminal backends behave
|
||||
identically. ``control`` is one of ``"right"`` (end the loop early), ``"left"``
|
||||
(re-record the last episode), or ``"esc"`` (stop recording).
|
||||
"""
|
||||
if control == "right":
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif control == "left":
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif control == "esc":
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
|
||||
|
||||
# Terminal arrow keys arrive as a 3-byte escape sequence whose *final* byte identifies
|
||||
# the direction. Two encodings exist depending on the terminal's cursor-key mode — CSI
|
||||
# ("ESC [ X") and SS3 ("ESC O X", common over SSH/tmux) — but both share the same final
|
||||
# byte, so this single table decodes either. Looked up by TerminalKeyListener._parse;
|
||||
# an unknown final byte yields None (sequence ignored).
|
||||
_ARROW_FINAL_BYTES = {"A": "up", "B": "down", "C": "right", "D": "left"}
|
||||
|
||||
|
||||
class TerminalKeyListener:
|
||||
"""Display-independent keyboard listener that reads keys from the controlling TTY.
|
||||
|
||||
Used as the Wayland / headless / macOS-untrusted equivalent of the ``pynput``
|
||||
listener for *discrete* controls. It puts the terminal into cbreak mode with
|
||||
echo disabled and reads bytes on a daemon thread, decoding them into logical
|
||||
key names that are passed to ``on_key``:
|
||||
|
||||
* arrow keys (``ESC [ C`` / ``ESC O C`` …) -> ``"right"`` / ``"left"`` / ``"up"`` / ``"down"``
|
||||
* a bare ``ESC`` -> ``"esc"``
|
||||
* Enter / Tab / Space / Backspace -> ``"enter"`` / ``"tab"`` / ``"space"`` / ``"backspace"``
|
||||
* any other printable byte -> that character (e.g. ``"n"``, ``"s"``)
|
||||
|
||||
Only key-down events are produced (terminals have no key-release), so this is
|
||||
suitable for discrete commands but NOT for continuous "hold-to-move" teleop.
|
||||
|
||||
The terminal is restored on :meth:`stop` and also via an ``atexit`` hook, so a
|
||||
crash or Ctrl-C never leaves the shell in a no-echo cbreak state. POSIX-only
|
||||
(``termios`` / ``tty`` / ``select``); those modules are imported lazily so this
|
||||
file stays importable on Windows (where ``pynput`` is used instead).
|
||||
"""
|
||||
|
||||
def __init__(self, on_key: Callable[[str], None]):
|
||||
self._on_key = on_key
|
||||
self._running = False
|
||||
self._thread: threading.Thread | None = None
|
||||
self._fd: int | None = None
|
||||
self._old_attrs = None
|
||||
|
||||
def _read_char(self, timeout: float) -> str | None:
|
||||
"""Return one character from stdin within ``timeout`` seconds, or ``None``."""
|
||||
if self._fd is None:
|
||||
return None
|
||||
ready, _, _ = select.select([self._fd], [], [], timeout)
|
||||
if not ready:
|
||||
return None
|
||||
try:
|
||||
data = os.read(self._fd, 1)
|
||||
except OSError:
|
||||
return None
|
||||
if not data:
|
||||
return None
|
||||
return data.decode(errors="ignore")
|
||||
|
||||
def _parse(self, ch: str) -> str | None:
|
||||
"""Decode one (possibly multi-byte) key starting at ``ch`` into a key name."""
|
||||
if ch == "\x1b":
|
||||
# Possible CSI / SS3 escape sequence (arrow keys) or a bare ESC. Use
|
||||
# short follow-up reads so a lone ESC is not mistaken for a sequence.
|
||||
ch2 = self._read_char(timeout=0.02)
|
||||
if ch2 is None:
|
||||
return "esc"
|
||||
if ch2 in ("[", "O"):
|
||||
ch3 = self._read_char(timeout=0.02)
|
||||
return _ARROW_FINAL_BYTES.get(ch3 or "")
|
||||
# Some other escape sequence (e.g. Alt+key); ignore it.
|
||||
return None
|
||||
if ch in ("\r", "\n"):
|
||||
return "enter"
|
||||
if ch == "\t":
|
||||
return "tab"
|
||||
if ch == " ":
|
||||
return "space"
|
||||
if ch in ("\x7f", "\x08"):
|
||||
return "backspace"
|
||||
if ch.isprintable():
|
||||
return ch
|
||||
return None
|
||||
|
||||
def _run(self) -> None:
|
||||
while self._running:
|
||||
ch = self._read_char(timeout=0.05)
|
||||
if ch is None:
|
||||
continue
|
||||
name = self._parse(ch)
|
||||
if name is None:
|
||||
continue
|
||||
try:
|
||||
self._on_key(name)
|
||||
except Exception as e: # never let a handler error kill the reader thread
|
||||
logger.debug("Terminal key handler error: %s", e)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Switch the terminal to cbreak mode (echo off) and read keys on a daemon thread.
|
||||
|
||||
No-op when stdin is not a TTY (piped/redirected input) or on platforms
|
||||
without ``termios`` (e.g. Windows), so non-interactive runs are unaffected.
|
||||
"""
|
||||
if not sys.stdin.isatty():
|
||||
return
|
||||
if not _TERMIOS_AVAILABLE: # POSIX-only modules (e.g. unavailable on Windows)
|
||||
logger.warning("Terminal keyboard input is not supported on this platform.")
|
||||
return
|
||||
|
||||
self._fd = sys.stdin.fileno()
|
||||
self._old_attrs = termios.tcgetattr(self._fd)
|
||||
tty.setcbreak(self._fd)
|
||||
# Explicitly disable ECHO so arrow-key escape sequences (e.g. ^[[C) are not
|
||||
# echoed as garbage into the recording terminal. (Independent of the
|
||||
# version-specific behavior of tty.setcbreak.)
|
||||
new_attrs = termios.tcgetattr(self._fd)
|
||||
new_attrs[3] &= ~termios.ECHO # index 3 == lflags
|
||||
termios.tcsetattr(self._fd, termios.TCSADRAIN, new_attrs)
|
||||
# Safety net: restore the terminal even if stop() is never reached (crash).
|
||||
atexit.register(self.stop)
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the reader thread and restore the original terminal attributes.
|
||||
|
||||
Idempotent: safe to call multiple times (e.g. explicitly and via atexit).
|
||||
"""
|
||||
self._running = False
|
||||
thread = self._thread
|
||||
if thread is not None:
|
||||
thread.join(timeout=0.5)
|
||||
self._thread = None
|
||||
if self._fd is not None and self._old_attrs is not None and _TERMIOS_AVAILABLE:
|
||||
try:
|
||||
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._old_attrs)
|
||||
finally:
|
||||
self._old_attrs = None
|
||||
with contextlib.suppress(Exception):
|
||||
atexit.unregister(self.stop)
|
||||
|
||||
|
||||
# Map pynput key objects to the same canonical names TerminalKeyListener emits, so a
|
||||
# single dispatch works across both backends. Empty when pynput is unavailable.
|
||||
if keyboard is not None:
|
||||
_PYNPUT_KEY_NAMES = {
|
||||
keyboard.Key.right: "right",
|
||||
keyboard.Key.left: "left",
|
||||
keyboard.Key.up: "up",
|
||||
keyboard.Key.down: "down",
|
||||
keyboard.Key.esc: "esc",
|
||||
keyboard.Key.enter: "enter",
|
||||
keyboard.Key.tab: "tab",
|
||||
keyboard.Key.space: "space",
|
||||
keyboard.Key.backspace: "backspace",
|
||||
}
|
||||
else:
|
||||
_PYNPUT_KEY_NAMES = {}
|
||||
|
||||
|
||||
def _resolve_pynput_key(key) -> str | None:
|
||||
"""Resolve a pynput key event to the canonical name TerminalKeyListener also emits.
|
||||
|
||||
Special keys map through :data:`_PYNPUT_KEY_NAMES`; character keys fall back to their
|
||||
``.char`` (e.g. ``"n"``). Returns ``None`` for keys with no mapping and no character.
|
||||
"""
|
||||
name = _PYNPUT_KEY_NAMES.get(key)
|
||||
if name is not None:
|
||||
return name
|
||||
# ``or None`` keeps the historical truthy-char semantics: an empty/None char is "no key".
|
||||
return getattr(key, "char", None) or None
|
||||
|
||||
|
||||
def create_key_listener(dispatch: Callable[[str], None], *, controls_help: str = ""):
|
||||
"""Start a keyboard listener that routes resolved key names to ``dispatch``.
|
||||
|
||||
Shared backend selection used by recording and the rollout strategies:
|
||||
|
||||
* the ``pynput`` global listener on X11 / trusted-macOS / Windows (on macOS the
|
||||
listener's ``IS_TRUSTED`` flag is checked after start, and an untrusted listener is
|
||||
stopped so the terminal backend is used instead);
|
||||
* the stdlib :class:`TerminalKeyListener` on Wayland / headless sessions with a TTY;
|
||||
* ``None`` when no backend is usable (non-interactive / piped runs).
|
||||
|
||||
Both backends pass ``dispatch`` the same canonical key names ("right" / "left" / "up" /
|
||||
"down" / "esc" / "enter" / "tab" / "space" / "backspace", or a character), so one
|
||||
``dispatch`` works regardless of backend. ``controls_help`` is an optional hint
|
||||
appended to the log messages.
|
||||
|
||||
Returns the listener (exposing ``.stop()``) or ``None``.
|
||||
"""
|
||||
suffix = f" ({controls_help})" if controls_help else ""
|
||||
|
||||
if pynput_can_capture() and keyboard is not None:
|
||||
|
||||
def on_press(key):
|
||||
with contextlib.suppress(Exception):
|
||||
name = _resolve_pynput_key(key)
|
||||
if name is not None:
|
||||
dispatch(name)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
if pynput_listener_is_trusted(listener):
|
||||
logger.info("Keyboard listener started%s.", suffix)
|
||||
return listener
|
||||
# macOS without Accessibility / Input-Monitoring permission: the listener never
|
||||
# fires. Stop it and fall through to the terminal backend.
|
||||
logger.warning(
|
||||
"pynput keyboard listener is not trusted (missing macOS Accessibility / "
|
||||
"Input Monitoring permission); falling back to terminal keyboard input."
|
||||
)
|
||||
listener.stop()
|
||||
|
||||
if sys.stdin.isatty():
|
||||
listener = TerminalKeyListener(dispatch)
|
||||
listener.start()
|
||||
logger.info("Using terminal keyboard input — keep this terminal focused%s.", suffix)
|
||||
return listener
|
||||
|
||||
logger.warning(
|
||||
"Keyboard controls unavailable: no usable display (Wayland/headless) and stdin is "
|
||||
"not an interactive terminal%s.",
|
||||
suffix,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize a non-blocking keyboard listener for interactive recording controls.
|
||||
|
||||
Backend selection:
|
||||
|
||||
* ``pynput`` global listener when :func:`pynput_can_capture` is true (real
|
||||
X11, macOS, Windows). On macOS the listener's ``IS_TRUSTED`` flag is checked
|
||||
after start; if the process lacks Accessibility / Input-Monitoring
|
||||
permission, the listener is stopped and the terminal backend is used.
|
||||
* a :class:`TerminalKeyListener` reading the controlling TTY when ``pynput``
|
||||
cannot capture (Wayland / headless-SSH / macOS-untrusted) *and* stdin is a TTY.
|
||||
* otherwise no listener (non-interactive / piped runs) — recording relies on
|
||||
the episode/reset timers (or Ctrl+C).
|
||||
|
||||
Both backends accept the same controls: Right/Left/Esc, plus the single-byte letter
|
||||
equivalents ``n`` (next), ``r`` (re-record) and ``q`` (quit). The letters are the most
|
||||
reliable choice over high-latency SSH/VNC links, where arrow-key escape sequences can
|
||||
be split, delayed, or intercepted by the terminal.
|
||||
|
||||
Returns:
|
||||
A tuple ``(listener, events)`` where ``listener`` exposes ``.stop()`` or is
|
||||
``None``, and ``events`` is the dict of flags (``exit_early``,
|
||||
``rerecord_episode``, ``stop_recording``) set by key presses.
|
||||
"""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
}
|
||||
|
||||
# Accept the single-byte letter equivalents n/r/q alongside the arrow/Esc keys: the
|
||||
# letters are immune to the escape-sequence split/delay/interception that affects arrows
|
||||
# over laggy SSH/VNC links. Case-insensitive so Shift+letter still works.
|
||||
def on_key(name: str) -> None:
|
||||
key = name.lower()
|
||||
if key in ("right", "n"):
|
||||
apply_recording_control("right", events)
|
||||
elif key in ("left", "r"):
|
||||
apply_recording_control("left", events)
|
||||
elif key in ("esc", "q"):
|
||||
apply_recording_control("esc", events)
|
||||
# other keys (incl. up/down) are intentionally ignored
|
||||
|
||||
listener = create_key_listener(on_key, controls_help="Right/Left/Esc, or n=next, r=re-record, q=quit")
|
||||
return listener, events
|
||||
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for the display-independent keyboard input helpers.
|
||||
|
||||
These cover the parts most likely to regress: the environment-detection decision
|
||||
table (the heart of the Wayland/headless fix), the macOS trust probe, the control
|
||||
mapping, the terminal escape-sequence parsing, and backend selection. They require
|
||||
neither ``pynput`` nor a real terminal.
|
||||
"""
|
||||
|
||||
import io
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import lerobot.utils.keyboard_input as ki
|
||||
from lerobot.utils.keyboard_input import (
|
||||
TerminalKeyListener,
|
||||
apply_recording_control,
|
||||
create_key_listener,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
is_wayland,
|
||||
pynput_can_capture,
|
||||
pynput_listener_is_trusted,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_detection_caches():
|
||||
"""The detection helpers are ``@cache``-decorated; clear around each test."""
|
||||
for fn in (is_headless, is_wayland, pynput_can_capture):
|
||||
fn.cache_clear()
|
||||
yield
|
||||
for fn in (is_headless, is_wayland, pynput_can_capture):
|
||||
fn.cache_clear()
|
||||
|
||||
|
||||
def _set_platform(monkeypatch, name):
|
||||
monkeypatch.setattr(platform, "system", lambda: name)
|
||||
|
||||
|
||||
def _set_tty(monkeypatch, is_tty):
|
||||
stdin = io.StringIO("")
|
||||
stdin.isatty = lambda: is_tty
|
||||
monkeypatch.setattr(sys, "stdin", stdin)
|
||||
|
||||
|
||||
# --- Environment detection (the core of the fix) ---------------------------
|
||||
@pytest.mark.parametrize(
|
||||
("system", "env", "expected"),
|
||||
[
|
||||
("Linux", {}, True), # no display server
|
||||
("Linux", {"DISPLAY": ":0"}, False), # X11
|
||||
("Linux", {"WAYLAND_DISPLAY": "wayland-0"}, False), # Wayland
|
||||
("Darwin", {}, False), # display always assumed present
|
||||
],
|
||||
)
|
||||
def test_is_headless(monkeypatch, system, env, expected):
|
||||
_set_platform(monkeypatch, system)
|
||||
monkeypatch.delenv("DISPLAY", raising=False)
|
||||
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
|
||||
for key, value in env.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
assert is_headless() is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env", "expected"),
|
||||
[
|
||||
({"XDG_SESSION_TYPE": "wayland"}, True),
|
||||
({"WAYLAND_DISPLAY": "wayland-0"}, True),
|
||||
({"XDG_SESSION_TYPE": "x11"}, False),
|
||||
({}, False),
|
||||
],
|
||||
)
|
||||
def test_is_wayland(monkeypatch, env, expected):
|
||||
monkeypatch.delenv("XDG_SESSION_TYPE", raising=False)
|
||||
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
|
||||
for key, value in env.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
assert is_wayland() is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("system", "env", "pynput_available", "expected"),
|
||||
[
|
||||
("Linux", {"DISPLAY": ":0"}, True, True), # X11
|
||||
("Linux", {"DISPLAY": ":0", "WAYLAND_DISPLAY": "wayland-0"}, True, False), # Wayland
|
||||
("Linux", {}, True, False), # headless
|
||||
("Darwin", {}, True, True),
|
||||
("Linux", {"DISPLAY": ":0"}, False, False), # pynput not installed
|
||||
],
|
||||
)
|
||||
def test_pynput_can_capture(monkeypatch, system, env, pynput_available, expected):
|
||||
_set_platform(monkeypatch, system)
|
||||
monkeypatch.setattr(ki, "_pynput_available", pynput_available)
|
||||
for var in ("DISPLAY", "WAYLAND_DISPLAY", "XDG_SESSION_TYPE"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
for key, value in env.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
assert pynput_can_capture() is expected
|
||||
|
||||
|
||||
# --- macOS trust probe ------------------------------------------------------
|
||||
class _FakeListener:
|
||||
def __init__(self, is_trusted):
|
||||
self.IS_TRUSTED = is_trusted
|
||||
|
||||
|
||||
def test_pynput_listener_is_trusted(monkeypatch):
|
||||
_set_platform(monkeypatch, "Linux")
|
||||
assert pynput_listener_is_trusted(_FakeListener(False)) is True # non-macOS: always assumed ok
|
||||
_set_platform(monkeypatch, "Darwin")
|
||||
assert pynput_listener_is_trusted(_FakeListener(False), timeout_s=0.05) is False
|
||||
|
||||
|
||||
# --- Control mapping --------------------------------------------------------
|
||||
def test_apply_recording_control():
|
||||
events = {"exit_early": False, "rerecord_episode": False, "stop_recording": False}
|
||||
apply_recording_control("left", events)
|
||||
assert events == {"exit_early": True, "rerecord_episode": True, "stop_recording": False}
|
||||
apply_recording_control("esc", events)
|
||||
assert events["stop_recording"] is True
|
||||
apply_recording_control("up", events) # unknown control -> no-op (no error)
|
||||
|
||||
|
||||
# --- Terminal escape-sequence parsing (the tricky bit) ----------------------
|
||||
def _drive(listener, byte_seq):
|
||||
"""Run the listener's read loop over a scripted list of bytes (no real terminal)."""
|
||||
script = list(byte_seq)
|
||||
|
||||
def fake_read(timeout):
|
||||
if script:
|
||||
return script.pop(0)
|
||||
listener._running = False
|
||||
return None
|
||||
|
||||
listener._read_char = fake_read
|
||||
listener._running = True
|
||||
listener._run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("byte_seq", "expected"),
|
||||
[
|
||||
(["\x1b", "[", "C"], ["right"]), # CSI arrow
|
||||
(["\x1b", "O", "D"], ["left"]), # SS3 arrow (e.g. over SSH/tmux)
|
||||
(["\x1b"], ["esc"]), # bare ESC
|
||||
(["\x1b", "[", "A"], ["up"]), # decoded even though the record handler ignores it
|
||||
(["n"], ["n"]), # letter passthrough
|
||||
],
|
||||
)
|
||||
def test_terminal_parsing(byte_seq, expected):
|
||||
collected = []
|
||||
_drive(TerminalKeyListener(collected.append), byte_seq)
|
||||
assert collected == expected
|
||||
|
||||
|
||||
# --- Backend selection ------------------------------------------------------
|
||||
def test_init_selects_terminal_when_pynput_cannot_capture(monkeypatch):
|
||||
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
||||
_set_tty(monkeypatch, is_tty=True)
|
||||
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None) # avoid touching termios
|
||||
listener, _ = init_keyboard_listener()
|
||||
assert isinstance(listener, TerminalKeyListener)
|
||||
|
||||
|
||||
def test_init_returns_none_without_tty(monkeypatch):
|
||||
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
||||
_set_tty(monkeypatch, is_tty=False)
|
||||
listener, _ = init_keyboard_listener()
|
||||
assert listener is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("key", "flag"),
|
||||
[("right", "exit_early"), ("r", "rerecord_episode"), ("q", "stop_recording")],
|
||||
)
|
||||
def test_init_terminal_key_routing(monkeypatch, key, flag):
|
||||
"""Arrows and their letter equivalents drive the same events (terminal backend)."""
|
||||
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
||||
_set_tty(monkeypatch, is_tty=True)
|
||||
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None)
|
||||
listener, events = init_keyboard_listener()
|
||||
listener._on_key(key)
|
||||
assert events[flag] is True
|
||||
|
||||
|
||||
# --- Shared factory + pynput key resolver -----------------------------------
|
||||
def test_resolve_pynput_key_char_fallback():
|
||||
"""Unmapped keys fall back to ``.char`` (and yield None when there is none)."""
|
||||
assert ki._resolve_pynput_key(type("K", (), {"char": "s"})()) == "s"
|
||||
assert ki._resolve_pynput_key(type("K", (), {"char": None})()) is None
|
||||
assert ki._resolve_pynput_key(type("K", (), {"char": ""})()) is None # empty char -> no key
|
||||
|
||||
|
||||
def test_create_key_listener_routes_to_dispatch(monkeypatch):
|
||||
"""The terminal backend forwards canonical key names straight to ``dispatch``."""
|
||||
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
||||
_set_tty(monkeypatch, is_tty=True)
|
||||
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None)
|
||||
seen = []
|
||||
listener = create_key_listener(seen.append, controls_help="save='s'")
|
||||
assert isinstance(listener, TerminalKeyListener)
|
||||
listener._on_key("space")
|
||||
assert seen == ["space"]
|
||||
|
||||
|
||||
def test_create_key_listener_none_without_tty(monkeypatch):
|
||||
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
|
||||
_set_tty(monkeypatch, is_tty=False)
|
||||
assert create_key_listener(lambda name: None) is None
|
||||
Reference in New Issue
Block a user