Compare commits

...

4 Commits

Author SHA1 Message Date
Steven Palma d8491ba179 fix(rollout): remove import require guard for pynput 2026-06-24 18:40:35 +02:00
Steven Palma 594a1bb81d refactor(utils): consolidate keyboard listener creation 2026-06-24 17:47:25 +02:00
Steven Palma 264a7490ac feat(utils): headless keyboard control 2026-06-24 17:23:51 +02:00
Khalil Meftah 6f0ba4be38 Record eval rollouts as LeRobot datasets (#3825)
* feat(eval): record eval rollouts as raw LeRobot datasets

- Record raw env observations inline during rollout(), before
preprocess_observation() transforms them. Uses LeRobotDataset.create()
with add_frame()/save_episode().

- Supports vectorized envs: each env in the batch records independently,
with save_episode() called per env on termination. Each task gets its
own dataset under output_dir/recordings/{task_group}_{task_id}/.

Enabled via --eval.recording=true; disabled by default.

* fix(eval): use FeatureType enum comparison instead of string value

* refactor(eval): per-env datasets recording, no double reset

- Extract _infer_shape_from_obs() to reduce nesting in feature conversion
- Move dataset creation into rollout() using its own env.reset() observation,
  eliminating the extra reset in run_one()
- Replace deepcopy with _shallow_copy_obs() for raw observation stashing
- Support batch_size > 1: each parallel env records to its own dataset
  (single env skips the env_0/ nesting for simplicity)
- One-time warning for env_features keys missing from observations
- Pass recording_dir + env_features through the call chain instead of
  a pre-built recording_dataset object

* refactor(eval): remove shape inference and shallow copy helpers

* feat(eval): optionally push recorded eval datasets to the Hub

* fix(eval): address review comments

- Wrap rollout loop in try/finally so finalize() runs on crash/interrupt
- Guard push_to_hub with num_episodes > 0 to avoid pushing empty datasets
- Hoist loop-invariant multi_env and base_repo_id out of creation loop
2026-06-23 14:03:57 +02:00
19 changed files with 1006 additions and 289 deletions
+12 -4
View File
@@ -390,9 +390,17 @@ Set the flow of data recording using command-line arguments:
Control the data recording flow using keyboard shortcuts:
- Press **Right Arrow (`→`)**: Early stop the current episode or reset time and move to the next.
- Press **Left Arrow (`←`)**: Cancel the current episode and re-record it.
- Press **Escape (`ESC`)**: Immediately stop the session, encode videos, and upload the dataset.
- Press **Right Arrow (`→`)** or **`n`**: Early stop the current episode or reset time and move to the next.
- Press **Left Arrow (`←`)** or **`r`**: Cancel the current episode and re-record it.
- Press **Escape (`ESC`)** or **`q`**: Immediately stop the session, encode videos, and upload the dataset.
<Tip>
These control-flow shortcuts work on **X11, Wayland, and headless/SSH** sessions. When a global keyboard backend isn't available (Wayland, a headless machine, or macOS without Accessibility permission), `lerobot-record` automatically reads the same keys from the terminal — launch it from an interactive terminal and keep it focused. You can also use the letter equivalents **`n`** (next, same as `→`), **`r`** (re-record, same as `←`) and **`q`** (quit, same as `ESC`). No `$DISPLAY` setup is required.
This applies to the recording control flow only. Keyboard **teleoperation** (driving the robot with the keyboard) still needs a global key backend, so it works only on an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — not on Wayland or headless sessions.
</Tip>
#### Tips for gathering data
@@ -406,7 +414,7 @@ If you want to dive deeper into this important topic, you can check out the [blo
#### Troubleshooting:
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
- On Linux, the recording control-flow keys (arrow keys, Escape) work on X11, Wayland, and headless/SSH sessions as long as `lerobot-record` runs in an interactive terminal — no `$DISPLAY` setup is needed. If the keys have no effect, make sure you are in an interactive (TTY) terminal, not a piped/non-TTY session, and that it is focused; the letter equivalents `n` / `r` / `q` also work. Keyboard _teleoperation_ (as opposed to the recording control flow) still requires a global key backend — an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — and is unavailable on Wayland or headless machines. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
## Visualize a dataset
+1 -1
View File
@@ -319,7 +319,7 @@ If you want to dive deeper into this important topic, you can check out the [blo
#### Troubleshooting:
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
- On Linux, the recording control-flow keys (arrow keys, Escape) work on X11, Wayland, and headless/SSH sessions as long as you run the recording from an interactive terminal (keep it focused) — no `$DISPLAY` setup is needed; the letter equivalents `n` / `r` / `q` also work. Note that **keyboard teleoperation of the LeKiwi base** is different: it relies on a global key backend and therefore works only on an X11 session, a Windows desktop, or macOS with Accessibility/Input Monitoring granted — not on Wayland or headless machines. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
## Replay an episode
+2 -1
View File
@@ -17,7 +17,7 @@
import logging
import time
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.common.control_utils import predict_action
from lerobot.datasets import LeRobotDataset
from lerobot.policies import make_pre_post_processors
from lerobot.policies.act import ACTPolicy
@@ -26,6 +26,7 @@ from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+1 -1
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets import LeRobotDataset
from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
@@ -23,6 +22,7 @@ from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -1
View File
@@ -18,7 +18,7 @@ import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.common.control_utils import predict_action
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
@@ -41,6 +41,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+1 -1
View File
@@ -15,7 +15,6 @@
# limitations under the License.
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
@@ -39,6 +38,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -1
View File
@@ -18,7 +18,7 @@ import logging
import time
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener, predict_action
from lerobot.common.control_utils import predict_action
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
@@ -41,6 +41,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+1 -1
View File
@@ -16,7 +16,6 @@
from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
@@ -36,6 +35,7 @@ from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
-84
View File
@@ -17,12 +17,9 @@ from __future__ import annotations
########################################################################################
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
from functools import cache
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -43,34 +40,6 @@ from lerobot.robots import Robot
from lerobot.types import PolicyAction
@cache
def is_headless():
"""
Detects if the Python script is running in a headless environment (e.g., without a display).
This function attempts to import `pynput`, a library that requires a graphical environment.
If the import fails, it assumes the environment is headless. The result is cached to avoid
re-running the check.
Returns:
True if the environment is determined to be headless, False otherwise.
"""
try:
import pynput # noqa
return False
except Exception:
print(
"Error trying to import pynput. Switching to headless mode. "
"As a result, the video stream from the cameras won't be shown, "
"and you won't be able to change the control flow with keyboards. "
"For more info, see traceback below.\n"
)
traceback.print_exc()
print()
return True
def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
@@ -122,59 +91,6 @@ def predict_action(
return action
def init_keyboard_listener():
"""
Initializes a non-blocking keyboard listener for real-time user interaction.
This function sets up a listener for specific keys (right arrow, left arrow, escape) to control
the program flow during execution, such as stopping recording or exiting loops. It gracefully
handles headless environments where keyboard listening is not possible.
Returns:
A tuple containing:
- The `pynput.keyboard.Listener` instance, or `None` if in a headless environment.
- A dictionary of event flags (e.g., `exit_early`) that are set by key presses.
"""
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def sanity_check_dataset_name(repo_id, policy_cfg):
"""
Validates the dataset repository name against the presence of a policy configuration.
+9
View File
@@ -73,8 +73,17 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
use_async_envs: bool = True
# Whether to record eval rollouts as a LeRobot dataset on disk.
recording: bool = False
# If set, push recorded eval datasets to the Hub under this repo id (one repo per task,
# suffixed by task and env index). Requires recording=true.
recording_repo_id: str | None = None
# Whether the pushed recording repositories should be private.
recording_private: bool = False
def __post_init__(self) -> None:
if self.recording_repo_id is not None and not self.recording:
raise ValueError("eval.recording_repo_id requires eval.recording=true.")
if self.batch_size == 0:
self.batch_size = self._auto_batch_size()
if self.batch_size > self.n_episodes:
+23 -67
View File
@@ -47,8 +47,6 @@ from __future__ import annotations
import contextlib
import enum
import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event, Lock
@@ -58,7 +56,6 @@ import numpy as np
from lerobot.common.control_utils import (
follower_smooth_move_to,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
@@ -66,7 +63,7 @@ from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
from lerobot.utils.keyboard_input import create_key_listener
from lerobot.utils.pedal import start_pedal_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
@@ -75,19 +72,6 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon
from ..context import RolloutContext
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
@@ -180,64 +164,36 @@ class DAggerEvents:
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
"""Initialise keyboard listener with DAgger 3-key controls.
"""Initialise a keyboard listener for DAgger's 3 controls.
Returns the pynput Listener (or ``None`` in headless mode or when
pynput is unavailable).
Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on
Wayland / headless TTY) is delegated to :func:`create_key_listener`. Returns the
listener (exposing ``stop()``) or ``None`` when no keyboard backend is usable.
"""
if not PYNPUT_AVAILABLE or is_headless():
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
return None
# Map config key names to pynput Key objects for special keys
special_keys = {
"space": keyboard.Key.space,
"tab": keyboard.Key.tab,
"enter": keyboard.Key.enter,
}
def _resolve_key(key) -> str | None:
"""Resolve a pynput key event to a config-comparable string."""
if key == keyboard.Key.esc:
return "esc"
for name, pynput_key in special_keys.items():
if key == pynput_key:
return name
if hasattr(key, "char") and key.char:
return key.char
return None
# Build mapping: resolved key string -> DAgger event name
# Map config key names to DAgger event names.
key_to_event = {
cfg.pause_resume: "pause_resume",
cfg.correction: "correction",
}
def on_press(key):
try:
resolved = _resolve_key(key)
if resolved is None:
return
if resolved == "esc":
logger.info("Stop recording...")
events.stop_recording.set()
return
if resolved in key_to_event:
events.request_transition(key_to_event[resolved])
if resolved == cfg.upload:
events.upload_requested.set()
except Exception as e:
logger.debug("Key error: %s", e)
def dispatch(name: str) -> None:
"""Apply a resolved key name to the DAgger events."""
if name == "esc":
logger.info("Stop recording...")
events.stop_recording.set()
return
if name in key_to_event:
events.request_transition(key_to_event[name])
if name == cfg.upload:
events.upload_requested.set()
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.info(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
cfg.upload,
return create_key_listener(
dispatch,
controls_help=(
f"pause_resume='{cfg.pause_resume}', correction='{cfg.correction}', "
f"upload='{cfg.upload}', ESC=stop"
),
)
return listener
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
@@ -328,7 +284,7 @@ class DAggerStrategy(RolloutStrategy):
logger.info("Stopping DAgger recording")
log_say("Stopping DAgger recording", play_sounds)
if self._listener is not None and not is_headless():
if self._listener is not None:
logger.info("Stopping keyboard listener")
self._listener.stop()
+2 -3
View File
@@ -35,14 +35,13 @@ import time
from lerobot.common.control_utils import (
follower_smooth_move_to,
init_keyboard_listener,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import log_rerun_data
@@ -307,7 +306,7 @@ class EpisodicStrategy(RolloutStrategy):
log_say("Stop recording", play_sounds, blocking=True)
if not is_headless() and self._listener is not None:
if self._listener is not None:
self._listener.stop()
if ctx.data.dataset is not None:
+19 -39
View File
@@ -18,17 +18,14 @@ from __future__ import annotations
import contextlib
import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event as ThreadingEvent, Lock
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available, require_package
from lerobot.utils.keyboard_input import create_key_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
@@ -37,19 +34,6 @@ from ..context import RolloutContext
from ..ring_buffer import RolloutRingBuffer
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logger = logging.getLogger(__name__)
@@ -72,7 +56,6 @@ class HighlightStrategy(RolloutStrategy):
def __init__(self, config: HighlightStrategyConfig):
super().__init__(config)
require_package("pynput", extra="pynput-dep")
self._ring: RolloutRingBuffer | None = None
self._listener = None
self._save_requested = ThreadingEvent()
@@ -234,30 +217,27 @@ class HighlightStrategy(RolloutStrategy):
logger.info("Highlight strategy teardown complete")
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
"""Set up keyboard listener for save and push keys."""
if is_headless():
logger.warning("Headless environment — highlight keys unavailable")
return
"""Set up a keyboard listener for the save and push keys.
try:
save_key = self.config.save_key
push_key = self.config.push_key
Backend selection (pynput on X11 / trusted-macOS / Windows, a terminal reader on
Wayland / headless TTY) is delegated to :func:`create_key_listener`.
"""
save_key = self.config.save_key
push_key = self.config.push_key
def on_press(key):
with contextlib.suppress(Exception):
if hasattr(key, "char") and key.char == save_key:
self._save_requested.set()
elif hasattr(key, "char") and key.char == push_key:
self._push_requested.set()
elif key == keyboard.Key.esc:
self._save_requested.clear()
shutdown_event.set()
def dispatch(name: str) -> None:
"""Apply a resolved key name to the highlight events."""
if name == save_key:
self._save_requested.set()
elif name == push_key:
self._push_requested.set()
elif name == "esc":
self._save_requested.clear()
shutdown_event.set()
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
except ImportError:
logger.warning("pynput not available — keyboard listener disabled")
self._listener = create_key_listener(
dispatch, controls_help=f"save='{save_key}', push='{push_key}', ESC=stop"
)
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor."""
+239 -69
View File
@@ -72,8 +72,9 @@ from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
from lerobot.configs import parser
from lerobot.configs import FeatureType, parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs import (
check_env_attributes_and_types,
close_envs,
@@ -84,7 +85,7 @@ from lerobot.envs import (
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
@@ -95,6 +96,65 @@ from lerobot.utils.utils import (
)
def _env_features_to_dataset_features(env_features: dict) -> dict:
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
features = {}
for key, ft in env_features.items():
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
else:
features[key] = {"dtype": "float32", "shape": shape, "names": None}
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
return features
def _build_raw_frame(
raw_obs: dict,
env_idx: int,
action: np.ndarray,
reward: float,
success: bool,
done: bool,
task: str,
env_features: dict,
) -> dict:
"""Build a dataset frame from raw env observations for one env index.
Keys in the frame match the keys in env_features so they align with the
dataset schema created by _env_features_to_dataset_features().
"""
frame: dict[str, Any] = {}
for key in env_features:
if key == ACTION:
continue
if key.startswith("next."):
continue
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
for cam_name, img in raw_obs["pixels"].items():
candidate = f"{OBS_IMAGES}.{cam_name}"
if candidate == key:
frame[key] = img[env_idx]
if key in frame:
continue
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
frame[key] = raw_obs["pixels"][env_idx]
continue
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
val = raw_obs[key][env_idx]
if val.dtype == np.float64:
val = val.astype(np.float32)
frame[key] = val
frame[ACTION] = action
frame["next.reward"] = np.atleast_1d(np.float32(reward))
frame["next.success"] = np.atleast_1d(np.bool_(success))
frame["next.done"] = np.atleast_1d(np.bool_(done))
frame["task"] = task
return frame
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
@@ -105,6 +165,10 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
@@ -145,6 +209,33 @@ def rollout(
if render_callback is not None:
render_callback(env)
recording_datasets: list[LeRobotDataset] | None = None
raw_observation = None
task_desc = ""
if recording_dir is not None and env_features is not None:
features = _env_features_to_dataset_features(env_features)
fps = env.unwrapped.metadata.get("render_fps", 30)
recording_datasets = []
multi_env = env.num_envs > 1
base_repo_id = recording_repo_id or "eval_recording"
for i in range(env.num_envs):
root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir)
repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id
recording_datasets.append(
LeRobotDataset.create(
repo_id=repo_id,
fps=fps,
features=features,
root=root,
use_videos=True,
)
)
raw_observation = deepcopy(observation)
try:
task_desc = list(env.call("task_description"))[0]
except (AttributeError, NotImplementedError):
task_desc = ""
all_observations = []
all_actions = []
all_rewards = []
@@ -162,80 +253,112 @@ def rollout(
leave=False,
)
check_env_attributes_and_types(env)
while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
try:
while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
# Infer "task" from sub-environments (prefer natural language description).
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
try:
observation["task"] = list(env.call("task_description"))
except (AttributeError, NotImplementedError):
# Infer "task" from sub-environments (prefer natural language description).
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
try:
observation["task"] = list(env.call("task"))
observation["task"] = list(env.call("task_description"))
except (AttributeError, NotImplementedError):
observation["task"] = [""] * env.num_envs
try:
observation["task"] = list(env.call("task"))
except (AttributeError, NotImplementedError):
observation["task"] = [""] * env.num_envs
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
action_transition = {ACTION: action}
action_transition = env_postprocessor(action_transition)
action = action_transition[ACTION]
action_transition = {ACTION: action}
action_transition = env_postprocessor(action_transition)
action = action_transition[ACTION]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
observation, reward, terminated, truncated, info = env.step(action_numpy)
if render_callback is not None:
render_callback(env)
# Apply the next action.
observation, reward, terminated, truncated, info = env.step(action_numpy)
if render_callback is not None:
render_callback(env)
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available if none of the envs finished.
if "final_info" in info:
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available if none of the envs finished.
if "final_info" in info:
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
)
successes = final_info["is_success"].tolist()
elif "is_success" in info:
is_success = info["is_success"]
successes = (
is_success.tolist()
if hasattr(is_success, "tolist")
else [bool(is_success)] * env.num_envs
)
successes = final_info["is_success"].tolist()
elif "is_success" in info:
is_success = info["is_success"]
successes = (
is_success.tolist() if hasattr(is_success, "tolist") else [bool(is_success)] * env.num_envs
else:
successes = [False] * env.num_envs
if recording_datasets is not None and raw_observation is not None:
prev_done = done.copy()
for env_idx in range(env.num_envs):
if prev_done[env_idx]:
continue
frame = _build_raw_frame(
raw_observation,
env_idx,
action_numpy[env_idx],
reward[env_idx],
successes[env_idx],
bool(terminated[env_idx] | truncated[env_idx]),
task_desc,
recording_datasets[env_idx].features,
)
recording_datasets[env_idx].add_frame(frame)
if terminated[env_idx] or truncated[env_idx]:
recording_datasets[env_idx].save_episode()
raw_observation = deepcopy(observation)
# Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit.
# This ensures that the rollout always terminates cleanly at `max_steps`,
# and allows logging/saving (e.g., videos) to be triggered consistently.
done = terminated | truncated | done
if step + 1 == max_steps:
done = np.ones_like(done, dtype=bool)
all_actions.append(torch.from_numpy(action_numpy))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
else:
successes = [False] * env.num_envs
# Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit.
# This ensures that the rollout always terminates cleanly at `max_steps`,
# and allows logging/saving (e.g., videos) to be triggered consistently.
done = terminated | truncated | done
if step + 1 == max_steps:
done = np.ones_like(done, dtype=bool)
all_actions.append(torch.from_numpy(action_numpy))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
finally:
if recording_datasets is not None:
for ds in recording_datasets:
ds.finalize()
if recording_repo_id is not None:
if ds.num_episodes > 0:
ds.push_to_hub(private=recording_private)
else:
logging.warning("No episodes recorded for %s — skipping push to hub.", ds.repo_id)
# Track the final observation.
if return_observations:
@@ -273,6 +396,10 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
) -> dict:
"""
Args:
@@ -361,6 +488,10 @@ def eval_policy(
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
recording_dir=recording_dir,
env_features=env_features,
recording_repo_id=recording_repo_id,
recording_private=recording_private,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -563,6 +694,10 @@ def eval_main(cfg: EvalPipelineConfig):
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
max_episodes_rendered = 0 if cfg.eval.recording else 10
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
@@ -572,10 +707,15 @@ def eval_main(cfg: EvalPipelineConfig):
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=False,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
recording_dir=recording_dir,
env_features=cfg.env.features if cfg.eval.recording else None,
recording_repo_id=cfg.eval.recording_repo_id,
recording_private=cfg.eval.recording_private,
)
print("Overall Aggregated Metrics:")
print(info["overall"])
@@ -618,6 +758,10 @@ def eval_one(
videos_dir: Path | None,
return_episode_data: bool,
start_seed: int | None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
) -> TaskMetrics:
"""Evaluates one task_id of one suite using the provided vec env."""
@@ -635,6 +779,10 @@ def eval_one(
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dir=recording_dir,
env_features=env_features,
recording_repo_id=recording_repo_id,
recording_private=recording_private,
)
per_episode = task_result["per_episode"]
@@ -661,6 +809,10 @@ def run_one(
videos_dir: Path | None,
return_episode_data: bool,
start_seed: int | None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
):
"""
Run eval_one for a single (task_group, task_id, env).
@@ -672,7 +824,13 @@ def run_one(
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
task_videos_dir.mkdir(parents=True, exist_ok=True)
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
task_recording_dir = None
task_repo_id = None
if recording_dir is not None and env_features is not None:
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
if recording_repo_id is not None:
task_repo_id = f"{recording_repo_id}_{task_group}_{task_id}"
metrics = eval_one(
env,
policy=policy,
@@ -685,8 +843,12 @@ def run_one(
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dir=task_recording_dir,
env_features=env_features,
recording_repo_id=task_repo_id,
recording_private=recording_private,
)
# ensure we always provide video_paths key to simplify accumulation
if max_episodes_rendered > 0:
metrics.setdefault("video_paths", [])
return task_group, task_id, metrics
@@ -702,6 +864,10 @@ def eval_policy_all(
n_episodes: int,
*,
max_episodes_rendered: int = 0,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
@@ -761,6 +927,10 @@ def eval_policy_all(
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dir=recording_dir,
env_features=env_features,
recording_repo_id=recording_repo_id,
recording_private=recording_private,
)
if max_parallel_tasks <= 1:
+3 -6
View File
@@ -96,11 +96,7 @@ from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig # noqa: F401
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.common.control_utils import (
init_keyboard_listener,
is_headless,
sanity_check_dataset_robot_compatibility,
)
from lerobot.common.control_utils import sanity_check_dataset_robot_compatibility
from lerobot.configs import parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.datasets import (
@@ -155,6 +151,7 @@ from lerobot.teleoperators.keyboard import KeyboardTeleop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.keyboard_input import init_keyboard_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import (
init_logging,
@@ -508,7 +505,7 @@ def record(
if teleop and teleop.is_connected:
teleop.disconnect()
if not is_headless() and listener:
if listener is not None:
listener.stop()
if cfg.dataset.push_to_hub:
@@ -18,6 +18,7 @@ import logging
from typing import TYPE_CHECKING
from lerobot.utils.import_utils import _hidapi_available, _pygame_available, require_package
from lerobot.utils.keyboard_input import pynput_can_capture
from ..utils import TeleopEvents
@@ -123,6 +124,15 @@ class KeyboardController(InputController):
def start(self):
"""Start the keyboard listener."""
if not pynput_can_capture():
logging.warning(
"Keyboard control is unavailable in this environment. pynput cannot capture keys "
"on Wayland or headless machines, or on macOS without Accessibility / Input "
"Monitoring permission. Keyboard motion will be inactive."
)
self.running = False
return
from pynput import keyboard
def on_press(key):
@@ -15,8 +15,6 @@
# limitations under the License.
import logging
import os
import sys
import time
from queue import Queue
from typing import Any
@@ -24,6 +22,7 @@ from typing import Any
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import _pynput_available, require_package
from lerobot.utils.keyboard_input import pynput_can_capture
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
@@ -37,14 +36,10 @@ PYNPUT_AVAILABLE = _pynput_available
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
PYNPUT_AVAILABLE = False
else:
from pynput import keyboard
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
logging.info("Could not import pynput keyboard backend: %s", e)
class KeyboardTeleop(Teleoperator):
@@ -88,7 +83,7 @@ class KeyboardTeleop(Teleoperator):
@check_if_already_connected
def connect(self) -> None:
if PYNPUT_AVAILABLE:
if PYNPUT_AVAILABLE and pynput_can_capture():
logging.info("pynput is available - enabling local keyboard listener.")
self.listener = keyboard.Listener(
on_press=self._on_press,
@@ -96,7 +91,13 @@ class KeyboardTeleop(Teleoperator):
)
self.listener.start()
else:
logging.info("pynput not available - skipping local keyboard listener.")
logging.warning(
"Keyboard teleoperation is unavailable in this environment. pynput can only "
"capture key events on an X11 session (Linux), a Windows desktop, or macOS with "
"Accessibility / Input Monitoring granted - not on Wayland or headless machines. "
"This keyboard teleoperator will produce no actions; use an X11 session, a "
"gamepad, or a leader-arm teleoperator instead."
)
self.listener = None
def calibrate(self) -> None:
+440
View File
@@ -0,0 +1,440 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Display-independent keyboard input for interactive controls.
This module centralizes everything related to *discrete* keyboard controls
(end-episode-early, re-record, stop, and the rollout strategies' custom keys):
* environment detection :func:`is_headless`, :func:`is_wayland`,
:func:`pynput_can_capture` (the single predicate every call-site should use to
decide whether ``pynput`` can actually capture keys here);
* a shared key mapping :func:`apply_recording_control`; and
* two interchangeable backends behind one ``(listener, events)`` contract:
the ``pynput`` global listener (X11 / trusted-macOS / Windows) and a
standard-library :class:`TerminalKeyListener` that reads the controlling TTY
(Wayland / headless-SSH-with-TTY / macOS without Accessibility permission).
NOTE: *continuous* key-state teleoperation ("hold a key to keep moving") is
deliberately NOT served here. A terminal in cbreak mode delivers only key-down
bytes there is no key-release event so the held-key model cannot be
reproduced. Those teleoperators stay on ``pynput`` and use
:func:`pynput_can_capture` to warn instead of silently doing nothing.
"""
from __future__ import annotations
import atexit
import contextlib
import logging
import os
import platform
import select
import sys
import threading
import time
from collections.abc import Callable
from functools import cache
from typing import TYPE_CHECKING
from .import_utils import _pynput_available
logger = logging.getLogger(__name__)
# POSIX-only terminal modules (absent on Windows, where the pynput backend is used).
if TYPE_CHECKING:
import termios
import tty
_TERMIOS_AVAILABLE = True
else:
try:
import termios
import tty
_TERMIOS_AVAILABLE = True
except ImportError: # POSIX-only modules; unavailable on Windows
termios = tty = None
_TERMIOS_AVAILABLE = False
keyboard = None
if _pynput_available:
try:
from pynput import keyboard
except Exception as e: # e.g. no reachable X display on a headless Linux box
logger.info("Could not import pynput keyboard backend: %s", e)
@cache
def is_headless() -> bool:
"""Return ``True`` when no display server is available.
* Linux: headless when neither ``DISPLAY`` (X11) nor ``WAYLAND_DISPLAY`` is set.
* macOS / Windows: a display is always assumed to be present. A genuinely GUI-less
Mac/Windows CI host would be misclassified but it doesn't matter, because the
sys.stdin.isatty() gate returns None there regardless.
"""
if platform.system() == "Linux":
return not (os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
return False
@cache
def is_wayland() -> bool:
"""Return ``True`` when running under a Wayland session.
``pynput`` relies on an X11 backend. Under Wayland it still imports (XWayland
is usually present and ``$DISPLAY`` is set) but cannot capture *global*
hotkeys, so the documented arrow/Esc shortcuts silently do nothing. This case
is invisible to :func:`is_headless`, hence the dedicated check.
"""
return os.environ.get("XDG_SESSION_TYPE", "").lower() == "wayland" or bool(
os.environ.get("WAYLAND_DISPLAY")
)
@cache
def pynput_can_capture() -> bool:
"""Return ``True`` when a ``pynput`` global listener can actually capture keys.
This is the single predicate every keyboard call-site should use to choose
between the ``pynput`` backend and a fallback. It is intentionally
conservative:
* Linux: only a real X11 session (a display is present *and* it is not Wayland).
* macOS: ``True`` here Accessibility / Input-Monitoring permission
(``IS_TRUSTED``) can only be confirmed at runtime *after* starting a
listener, so :func:`init_keyboard_listener` refines this with
:func:`pynput_listener_is_trusted`.
* Windows: ``True`` (the low-level global hook needs no special permission).
Always ``False`` when ``pynput`` is not installed.
"""
if not _pynput_available:
return False
if platform.system() == "Linux":
return not is_headless() and not is_wayland()
return True
def pynput_listener_is_trusted(listener, timeout_s: float = 1.0) -> bool:
"""Best-effort check that a freshly started ``pynput`` listener can capture.
On macOS, ``pynput`` sets ``listener.IS_TRUSTED`` on its *listener thread*
once the Quartz event tap is created; the class default is ``False``. We
therefore wait for the thread to either flip it ``True`` (trusted) or for a
short timeout to elapse (untrusted it stays ``False`` forever). This biases
toward the common trusted case (returns as soon as the flag flips) and only
pays the full ``timeout_s`` on an already-broken untrusted machine.
On non-macOS backends the attribute is absent and capture is assumed to work.
"""
if platform.system() != "Darwin":
return True
deadline = time.perf_counter() + timeout_s
while time.perf_counter() < deadline:
if getattr(listener, "IS_TRUSTED", False):
return True
time.sleep(0.02)
return bool(getattr(listener, "IS_TRUSTED", False))
def apply_recording_control(control: str, events: dict) -> None:
"""Apply a recording control-flow key press to the shared ``events`` dict.
Centralizes the mapping so the ``pynput`` and terminal backends behave
identically. ``control`` is one of ``"right"`` (end the loop early), ``"left"``
(re-record the last episode), or ``"esc"`` (stop recording).
"""
if control == "right":
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif control == "left":
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif control == "esc":
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
# Terminal arrow keys arrive as a 3-byte escape sequence whose *final* byte identifies
# the direction. Two encodings exist depending on the terminal's cursor-key mode — CSI
# ("ESC [ X") and SS3 ("ESC O X", common over SSH/tmux) — but both share the same final
# byte, so this single table decodes either. Looked up by TerminalKeyListener._parse;
# an unknown final byte yields None (sequence ignored).
_ARROW_FINAL_BYTES = {"A": "up", "B": "down", "C": "right", "D": "left"}
class TerminalKeyListener:
"""Display-independent keyboard listener that reads keys from the controlling TTY.
Used as the Wayland / headless / macOS-untrusted equivalent of the ``pynput``
listener for *discrete* controls. It puts the terminal into cbreak mode with
echo disabled and reads bytes on a daemon thread, decoding them into logical
key names that are passed to ``on_key``:
* arrow keys (``ESC [ C`` / ``ESC O C`` ) -> ``"right"`` / ``"left"`` / ``"up"`` / ``"down"``
* a bare ``ESC`` -> ``"esc"``
* Enter / Tab / Space / Backspace -> ``"enter"`` / ``"tab"`` / ``"space"`` / ``"backspace"``
* any other printable byte -> that character (e.g. ``"n"``, ``"s"``)
Only key-down events are produced (terminals have no key-release), so this is
suitable for discrete commands but NOT for continuous "hold-to-move" teleop.
The terminal is restored on :meth:`stop` and also via an ``atexit`` hook, so a
crash or Ctrl-C never leaves the shell in a no-echo cbreak state. POSIX-only
(``termios`` / ``tty`` / ``select``); those modules are imported lazily so this
file stays importable on Windows (where ``pynput`` is used instead).
"""
def __init__(self, on_key: Callable[[str], None]):
self._on_key = on_key
self._running = False
self._thread: threading.Thread | None = None
self._fd: int | None = None
self._old_attrs = None
def _read_char(self, timeout: float) -> str | None:
"""Return one character from stdin within ``timeout`` seconds, or ``None``."""
if self._fd is None:
return None
ready, _, _ = select.select([self._fd], [], [], timeout)
if not ready:
return None
try:
data = os.read(self._fd, 1)
except OSError:
return None
if not data:
return None
return data.decode(errors="ignore")
def _parse(self, ch: str) -> str | None:
"""Decode one (possibly multi-byte) key starting at ``ch`` into a key name."""
if ch == "\x1b":
# Possible CSI / SS3 escape sequence (arrow keys) or a bare ESC. Use
# short follow-up reads so a lone ESC is not mistaken for a sequence.
ch2 = self._read_char(timeout=0.02)
if ch2 is None:
return "esc"
if ch2 in ("[", "O"):
ch3 = self._read_char(timeout=0.02)
return _ARROW_FINAL_BYTES.get(ch3 or "")
# Some other escape sequence (e.g. Alt+key); ignore it.
return None
if ch in ("\r", "\n"):
return "enter"
if ch == "\t":
return "tab"
if ch == " ":
return "space"
if ch in ("\x7f", "\x08"):
return "backspace"
if ch.isprintable():
return ch
return None
def _run(self) -> None:
while self._running:
ch = self._read_char(timeout=0.05)
if ch is None:
continue
name = self._parse(ch)
if name is None:
continue
try:
self._on_key(name)
except Exception as e: # never let a handler error kill the reader thread
logger.debug("Terminal key handler error: %s", e)
def start(self) -> None:
"""Switch the terminal to cbreak mode (echo off) and read keys on a daemon thread.
No-op when stdin is not a TTY (piped/redirected input) or on platforms
without ``termios`` (e.g. Windows), so non-interactive runs are unaffected.
"""
if not sys.stdin.isatty():
return
if not _TERMIOS_AVAILABLE: # POSIX-only modules (e.g. unavailable on Windows)
logger.warning("Terminal keyboard input is not supported on this platform.")
return
self._fd = sys.stdin.fileno()
self._old_attrs = termios.tcgetattr(self._fd)
tty.setcbreak(self._fd)
# Explicitly disable ECHO so arrow-key escape sequences (e.g. ^[[C) are not
# echoed as garbage into the recording terminal. (Independent of the
# version-specific behavior of tty.setcbreak.)
new_attrs = termios.tcgetattr(self._fd)
new_attrs[3] &= ~termios.ECHO # index 3 == lflags
termios.tcsetattr(self._fd, termios.TCSADRAIN, new_attrs)
# Safety net: restore the terminal even if stop() is never reached (crash).
atexit.register(self.stop)
self._running = True
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the reader thread and restore the original terminal attributes.
Idempotent: safe to call multiple times (e.g. explicitly and via atexit).
"""
self._running = False
thread = self._thread
if thread is not None:
thread.join(timeout=0.5)
self._thread = None
if self._fd is not None and self._old_attrs is not None and _TERMIOS_AVAILABLE:
try:
termios.tcsetattr(self._fd, termios.TCSADRAIN, self._old_attrs)
finally:
self._old_attrs = None
with contextlib.suppress(Exception):
atexit.unregister(self.stop)
# Map pynput key objects to the same canonical names TerminalKeyListener emits, so a
# single dispatch works across both backends. Empty when pynput is unavailable.
if keyboard is not None:
_PYNPUT_KEY_NAMES = {
keyboard.Key.right: "right",
keyboard.Key.left: "left",
keyboard.Key.up: "up",
keyboard.Key.down: "down",
keyboard.Key.esc: "esc",
keyboard.Key.enter: "enter",
keyboard.Key.tab: "tab",
keyboard.Key.space: "space",
keyboard.Key.backspace: "backspace",
}
else:
_PYNPUT_KEY_NAMES = {}
def _resolve_pynput_key(key) -> str | None:
"""Resolve a pynput key event to the canonical name TerminalKeyListener also emits.
Special keys map through :data:`_PYNPUT_KEY_NAMES`; character keys fall back to their
``.char`` (e.g. ``"n"``). Returns ``None`` for keys with no mapping and no character.
"""
name = _PYNPUT_KEY_NAMES.get(key)
if name is not None:
return name
# ``or None`` keeps the historical truthy-char semantics: an empty/None char is "no key".
return getattr(key, "char", None) or None
def create_key_listener(dispatch: Callable[[str], None], *, controls_help: str = ""):
"""Start a keyboard listener that routes resolved key names to ``dispatch``.
Shared backend selection used by recording and the rollout strategies:
* the ``pynput`` global listener on X11 / trusted-macOS / Windows (on macOS the
listener's ``IS_TRUSTED`` flag is checked after start, and an untrusted listener is
stopped so the terminal backend is used instead);
* the stdlib :class:`TerminalKeyListener` on Wayland / headless sessions with a TTY;
* ``None`` when no backend is usable (non-interactive / piped runs).
Both backends pass ``dispatch`` the same canonical key names ("right" / "left" / "up" /
"down" / "esc" / "enter" / "tab" / "space" / "backspace", or a character), so one
``dispatch`` works regardless of backend. ``controls_help`` is an optional hint
appended to the log messages.
Returns the listener (exposing ``.stop()``) or ``None``.
"""
suffix = f" ({controls_help})" if controls_help else ""
if pynput_can_capture() and keyboard is not None:
def on_press(key):
with contextlib.suppress(Exception):
name = _resolve_pynput_key(key)
if name is not None:
dispatch(name)
listener = keyboard.Listener(on_press=on_press)
listener.start()
if pynput_listener_is_trusted(listener):
logger.info("Keyboard listener started%s.", suffix)
return listener
# macOS without Accessibility / Input-Monitoring permission: the listener never
# fires. Stop it and fall through to the terminal backend.
logger.warning(
"pynput keyboard listener is not trusted (missing macOS Accessibility / "
"Input Monitoring permission); falling back to terminal keyboard input."
)
listener.stop()
if sys.stdin.isatty():
listener = TerminalKeyListener(dispatch)
listener.start()
logger.info("Using terminal keyboard input — keep this terminal focused%s.", suffix)
return listener
logger.warning(
"Keyboard controls unavailable: no usable display (Wayland/headless) and stdin is "
"not an interactive terminal%s.",
suffix,
)
return None
def init_keyboard_listener():
"""Initialize a non-blocking keyboard listener for interactive recording controls.
Backend selection:
* ``pynput`` global listener when :func:`pynput_can_capture` is true (real
X11, macOS, Windows). On macOS the listener's ``IS_TRUSTED`` flag is checked
after start; if the process lacks Accessibility / Input-Monitoring
permission, the listener is stopped and the terminal backend is used.
* a :class:`TerminalKeyListener` reading the controlling TTY when ``pynput``
cannot capture (Wayland / headless-SSH / macOS-untrusted) *and* stdin is a TTY.
* otherwise no listener (non-interactive / piped runs) recording relies on
the episode/reset timers (or Ctrl+C).
Both backends accept the same controls: Right/Left/Esc, plus the single-byte letter
equivalents ``n`` (next), ``r`` (re-record) and ``q`` (quit). The letters are the most
reliable choice over high-latency SSH/VNC links, where arrow-key escape sequences can
be split, delayed, or intercepted by the terminal.
Returns:
A tuple ``(listener, events)`` where ``listener`` exposes ``.stop()`` or is
``None``, and ``events`` is the dict of flags (``exit_early``,
``rerecord_episode``, ``stop_recording``) set by key presses.
"""
events = {
"exit_early": False,
"rerecord_episode": False,
"stop_recording": False,
}
# Accept the single-byte letter equivalents n/r/q alongside the arrow/Esc keys: the
# letters are immune to the escape-sequence split/delay/interception that affects arrows
# over laggy SSH/VNC links. Case-insensitive so Shift+letter still works.
def on_key(name: str) -> None:
key = name.lower()
if key in ("right", "n"):
apply_recording_control("right", events)
elif key in ("left", "r"):
apply_recording_control("left", events)
elif key in ("esc", "q"):
apply_recording_control("esc", events)
# other keys (incl. up/down) are intentionally ignored
listener = create_key_listener(on_key, controls_help="Right/Left/Esc, or n=next, r=re-record, q=quit")
return listener, events
+228
View File
@@ -0,0 +1,228 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the display-independent keyboard input helpers.
These cover the parts most likely to regress: the environment-detection decision
table (the heart of the Wayland/headless fix), the macOS trust probe, the control
mapping, the terminal escape-sequence parsing, and backend selection. They require
neither ``pynput`` nor a real terminal.
"""
import io
import platform
import sys
import pytest
import lerobot.utils.keyboard_input as ki
from lerobot.utils.keyboard_input import (
TerminalKeyListener,
apply_recording_control,
create_key_listener,
init_keyboard_listener,
is_headless,
is_wayland,
pynput_can_capture,
pynput_listener_is_trusted,
)
@pytest.fixture(autouse=True)
def _clear_detection_caches():
"""The detection helpers are ``@cache``-decorated; clear around each test."""
for fn in (is_headless, is_wayland, pynput_can_capture):
fn.cache_clear()
yield
for fn in (is_headless, is_wayland, pynput_can_capture):
fn.cache_clear()
def _set_platform(monkeypatch, name):
monkeypatch.setattr(platform, "system", lambda: name)
def _set_tty(monkeypatch, is_tty):
stdin = io.StringIO("")
stdin.isatty = lambda: is_tty
monkeypatch.setattr(sys, "stdin", stdin)
# --- Environment detection (the core of the fix) ---------------------------
@pytest.mark.parametrize(
("system", "env", "expected"),
[
("Linux", {}, True), # no display server
("Linux", {"DISPLAY": ":0"}, False), # X11
("Linux", {"WAYLAND_DISPLAY": "wayland-0"}, False), # Wayland
("Darwin", {}, False), # display always assumed present
],
)
def test_is_headless(monkeypatch, system, env, expected):
_set_platform(monkeypatch, system)
monkeypatch.delenv("DISPLAY", raising=False)
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
for key, value in env.items():
monkeypatch.setenv(key, value)
assert is_headless() is expected
@pytest.mark.parametrize(
("env", "expected"),
[
({"XDG_SESSION_TYPE": "wayland"}, True),
({"WAYLAND_DISPLAY": "wayland-0"}, True),
({"XDG_SESSION_TYPE": "x11"}, False),
({}, False),
],
)
def test_is_wayland(monkeypatch, env, expected):
monkeypatch.delenv("XDG_SESSION_TYPE", raising=False)
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
for key, value in env.items():
monkeypatch.setenv(key, value)
assert is_wayland() is expected
@pytest.mark.parametrize(
("system", "env", "pynput_available", "expected"),
[
("Linux", {"DISPLAY": ":0"}, True, True), # X11
("Linux", {"DISPLAY": ":0", "WAYLAND_DISPLAY": "wayland-0"}, True, False), # Wayland
("Linux", {}, True, False), # headless
("Darwin", {}, True, True),
("Linux", {"DISPLAY": ":0"}, False, False), # pynput not installed
],
)
def test_pynput_can_capture(monkeypatch, system, env, pynput_available, expected):
_set_platform(monkeypatch, system)
monkeypatch.setattr(ki, "_pynput_available", pynput_available)
for var in ("DISPLAY", "WAYLAND_DISPLAY", "XDG_SESSION_TYPE"):
monkeypatch.delenv(var, raising=False)
for key, value in env.items():
monkeypatch.setenv(key, value)
assert pynput_can_capture() is expected
# --- macOS trust probe ------------------------------------------------------
class _FakeListener:
def __init__(self, is_trusted):
self.IS_TRUSTED = is_trusted
def test_pynput_listener_is_trusted(monkeypatch):
_set_platform(monkeypatch, "Linux")
assert pynput_listener_is_trusted(_FakeListener(False)) is True # non-macOS: always assumed ok
_set_platform(monkeypatch, "Darwin")
assert pynput_listener_is_trusted(_FakeListener(False), timeout_s=0.05) is False
# --- Control mapping --------------------------------------------------------
def test_apply_recording_control():
events = {"exit_early": False, "rerecord_episode": False, "stop_recording": False}
apply_recording_control("left", events)
assert events == {"exit_early": True, "rerecord_episode": True, "stop_recording": False}
apply_recording_control("esc", events)
assert events["stop_recording"] is True
apply_recording_control("up", events) # unknown control -> no-op (no error)
# --- Terminal escape-sequence parsing (the tricky bit) ----------------------
def _drive(listener, byte_seq):
"""Run the listener's read loop over a scripted list of bytes (no real terminal)."""
script = list(byte_seq)
def fake_read(timeout):
if script:
return script.pop(0)
listener._running = False
return None
listener._read_char = fake_read
listener._running = True
listener._run()
@pytest.mark.parametrize(
("byte_seq", "expected"),
[
(["\x1b", "[", "C"], ["right"]), # CSI arrow
(["\x1b", "O", "D"], ["left"]), # SS3 arrow (e.g. over SSH/tmux)
(["\x1b"], ["esc"]), # bare ESC
(["\x1b", "[", "A"], ["up"]), # decoded even though the record handler ignores it
(["n"], ["n"]), # letter passthrough
],
)
def test_terminal_parsing(byte_seq, expected):
collected = []
_drive(TerminalKeyListener(collected.append), byte_seq)
assert collected == expected
# --- Backend selection ------------------------------------------------------
def test_init_selects_terminal_when_pynput_cannot_capture(monkeypatch):
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
_set_tty(monkeypatch, is_tty=True)
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None) # avoid touching termios
listener, _ = init_keyboard_listener()
assert isinstance(listener, TerminalKeyListener)
def test_init_returns_none_without_tty(monkeypatch):
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
_set_tty(monkeypatch, is_tty=False)
listener, _ = init_keyboard_listener()
assert listener is None
@pytest.mark.parametrize(
("key", "flag"),
[("right", "exit_early"), ("r", "rerecord_episode"), ("q", "stop_recording")],
)
def test_init_terminal_key_routing(monkeypatch, key, flag):
"""Arrows and their letter equivalents drive the same events (terminal backend)."""
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
_set_tty(monkeypatch, is_tty=True)
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None)
listener, events = init_keyboard_listener()
listener._on_key(key)
assert events[flag] is True
# --- Shared factory + pynput key resolver -----------------------------------
def test_resolve_pynput_key_char_fallback():
"""Unmapped keys fall back to ``.char`` (and yield None when there is none)."""
assert ki._resolve_pynput_key(type("K", (), {"char": "s"})()) == "s"
assert ki._resolve_pynput_key(type("K", (), {"char": None})()) is None
assert ki._resolve_pynput_key(type("K", (), {"char": ""})()) is None # empty char -> no key
def test_create_key_listener_routes_to_dispatch(monkeypatch):
"""The terminal backend forwards canonical key names straight to ``dispatch``."""
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
_set_tty(monkeypatch, is_tty=True)
monkeypatch.setattr(TerminalKeyListener, "start", lambda self: None)
seen = []
listener = create_key_listener(seen.append, controls_help="save='s'")
assert isinstance(listener, TerminalKeyListener)
listener._on_key("space")
assert seen == ["space"]
def test_create_key_listener_none_without_tty(monkeypatch):
monkeypatch.setattr(ki, "pynput_can_capture", lambda: False)
_set_tty(monkeypatch, is_tty=False)
assert create_key_listener(lambda name: None) is None