From 60215547701b71965116d11677a3208ad28b5332 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Thu, 7 May 2026 11:12:02 +0200 Subject: [PATCH] chore(rollout): nice collored cli --- src/lerobot/rollout/strategies/base.py | 7 +- src/lerobot/rollout/strategies/core.py | 12 + src/lerobot/rollout/strategies/dagger.py | 37 ++- src/lerobot/rollout/strategies/display.py | 263 ++++++++++++++++++++ src/lerobot/rollout/strategies/highlight.py | 24 +- src/lerobot/rollout/strategies/sentry.py | 10 +- 6 files changed, 336 insertions(+), 17 deletions(-) create mode 100644 src/lerobot/rollout/strategies/display.py diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index e47b65209..9d737be94 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -23,6 +23,7 @@ from lerobot.utils.robot_utils import precise_sleep from ..context import RolloutContext from .core import RolloutStrategy, send_next_action +from .display import BaseDisplay logger = logging.getLogger(__name__) @@ -38,6 +39,8 @@ class BaseStrategy(RolloutStrategy): """Initialise the inference engine.""" self._init_engine(ctx) logger.info("Base strategy ready") + self._display = BaseDisplay(duration=ctx.runtime.cfg.duration) + self._display.show_banner() def run(self, ctx: RolloutContext) -> None: """Run the autonomous control loop until shutdown or duration expires.""" @@ -72,9 +75,7 @@ class BaseStrategy(RolloutStrategy): if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) else: - logger.warning( - f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" - ) + self._warn_slow_loop(dt, control_interval, cfg.fps) def teardown(self, ctx: RolloutContext) -> None: """Disconnect hardware and stop inference.""" diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index 9c897522f..c9ee7d4c5 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -33,6 +33,7 @@ from ..inference import InferenceEngine if TYPE_CHECKING: from ..configs import RolloutStrategyConfig from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext + from .display import RolloutStatusDisplay logger = logging.getLogger(__name__) @@ -51,6 +52,17 @@ class RolloutStrategy(abc.ABC): self._interpolator: ActionInterpolator | None = None self._warmup_flushed: bool = False self._cached_obs_processed: dict | None = None + self._display: RolloutStatusDisplay | None = None + + def _warn_slow_loop(self, dt: float, control_interval: float, fps: float) -> None: + """Warn when the control loop runs slower than the target FPS.""" + if dt > control_interval: + logger.warning( + "Control loop running slower (%.1f Hz) than target (%.0f Hz). " + "Possible causes: camera FPS not keeping up, slow policy inference, CPU starvation.", + 1 / dt, + fps, + ) def _init_engine(self, ctx: RolloutContext) -> None: """Attach the inference engine and action interpolator, then start the backend. diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 1bd6d9bb0..df170291e 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -71,6 +71,7 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon from ..context import RolloutContext from ..robot_wrapper import ThreadSafeRobot from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action +from .display import DAggerDisplay PYNPUT_AVAILABLE = _pynput_available keyboard = None @@ -286,7 +287,7 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): listener = keyboard.Listener(on_press=on_press) listener.start() - logger.info( + logger.debug( "DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)", cfg.pause_resume, cfg.correction, @@ -370,6 +371,28 @@ class DAggerStrategy(RolloutStrategy): self._episode_duration_s, ) + if self.config.input_device == "keyboard": + kb = self.config.keyboard + pause_key, correction_key, upload_key = ( + kb.pause_resume.upper(), + kb.correction.upper(), + kb.upload.upper(), + ) + else: + pb = self.config.pedal + pause_key, correction_key, upload_key = pb.pause_resume, pb.correction, pb.upload + + self._display = DAggerDisplay( + record_autonomous=self.config.record_autonomous, + num_episodes=self.config.num_episodes, + episode_duration_s=self._episode_duration_s, + input_device=self.config.input_device, + pause_key=pause_key, + correction_key=correction_key, + upload_key=upload_key, + ) + self._display.show_banner() + def run(self, ctx: RolloutContext) -> None: """Run DAgger episodes with human-in-the-loop intervention.""" if self.config.record_autonomous: @@ -442,6 +465,7 @@ class DAggerStrategy(RolloutStrategy): interpolator.reset() events.reset() engine.resume() + self._display.show_state(DAggerPhase.AUTONOMOUS) last_action: dict[str, Any] | None = None record_tick = 0 @@ -472,6 +496,7 @@ class DAggerStrategy(RolloutStrategy): ctx, last_action, ) + self._display.show_state(new_phase) if new_phase == DAggerPhase.AUTONOMOUS: last_action = None @@ -556,9 +581,7 @@ class DAggerStrategy(RolloutStrategy): if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) else: - logger.warning( - f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" - ) + self._warn_slow_loop(dt, control_interval, cfg.fps) finally: logger.info("DAgger continuous control loop ended — pausing engine") @@ -599,6 +622,7 @@ class DAggerStrategy(RolloutStrategy): interpolator.reset() events.reset() engine.resume() + self._display.show_state(DAggerPhase.AUTONOMOUS) last_action: dict[str, Any] | None = None start_time = time.perf_counter() @@ -633,6 +657,7 @@ class DAggerStrategy(RolloutStrategy): ctx, last_action, ) + self._display.show_state(new_phase) if new_phase == DAggerPhase.AUTONOMOUS: last_action = None @@ -705,9 +730,7 @@ class DAggerStrategy(RolloutStrategy): if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) else: - logger.warning( - f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" - ) + self._warn_slow_loop(dt, control_interval, cfg.fps) finally: logger.info("DAgger corrections-only loop ended — pausing engine") diff --git a/src/lerobot/rollout/strategies/display.py b/src/lerobot/rollout/strategies/display.py new file mode 100644 index 000000000..de27efeec --- /dev/null +++ b/src/lerobot/rollout/strategies/display.py @@ -0,0 +1,263 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Console status display for rollout strategies. + +One subclass per strategy — static states/controls are declared as class +constants; runtime-dependent values are passed to ``__init__``. + +In each strategy's ``setup()``: + + self._display = DAggerDisplay( + record_autonomous=self.config.record_autonomous, + num_episodes=self.config.num_episodes, + episode_duration_s=self._episode_duration_s, + input_device=self.config.input_device, + pause_key="SPACE", + correction_key="TAB", + upload_key="ENTER", + ) + self._display.show_banner() + +On each state transition: + + self._display.show_state("correcting") +""" + +from __future__ import annotations + +import enum +import sys +from dataclasses import dataclass + + +def _supports_color() -> bool: + return hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + + +class _C: + """ANSI escape codes.""" + + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + GREEN = "\033[1;92m" + YELLOW = "\033[1;93m" + RED = "\033[1;91m" + CYAN = "\033[1;96m" + WHITE = "\033[1;97m" + GRAY = "\033[2;37m" + + +@dataclass +class StateConfig: + """One named rollout state. + + ``key`` must match the string passed to ``RolloutStatusDisplay.show_state()``. + """ + + key: str + emoji: str + label: str + description: str + color: str = _C.WHITE + + +@dataclass +class ControlConfig: + """One keyboard/pedal binding shown in the startup banner.""" + + key: str + description: str + + +# --------------------------------------------------------------------------- +# Base display class +# --------------------------------------------------------------------------- + + +class RolloutStatusDisplay: + """Unified console status display. Subclass once per strategy.""" + + def __init__( + self, + strategy: str, + states: list[StateConfig], + controls: list[ControlConfig], + info: list[str] | None = None, + ) -> None: + self.strategy = strategy + self._states = {s.key: s for s in states} + self._controls = controls + self._info = info or [] + self._use_color = _supports_color() + + def _c(self, code: str, text: str) -> str: + if not self._use_color: + return text + return f"{code}{text}{_C.RESET}" + + def show_banner(self) -> None: + """Print startup banner: strategy name, states, controls, config info.""" + width = 62 + sep = self._c(_C.BOLD, "═" * width) + + print(f"\n{sep}") + print(self._c(_C.BOLD, f" lerobot-rollout │ {self.strategy}")) + + if self._states: + print() + for state in self._states.values(): + label = self._c(state.color, f"{state.label:<14}") + desc = self._c(_C.GRAY, state.description) + print(f" {state.emoji} {label} {desc}") + + if self._controls: + print() + key_width = max(len(c.key) for c in self._controls) + for ctrl in self._controls: + key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]") + print(f" {key_str} {ctrl.description}") + + if self._info: + print() + for item in self._info: + print(f" {item}") + + print(f"{sep}\n") + + def show_state(self, state_key: str | enum.Enum) -> None: + """Print the current state and available controls - call this on every transition.""" + key = state_key.value if isinstance(state_key, enum.Enum) else state_key + state = self._states.get(key) + if state is None: + return + label = self._c(state.color, f"{state.label:<14}") + desc = self._c(_C.GRAY, state.description) + print(f"\n {state.emoji} {label} {desc}\n") + + if self._controls: + key_width = max(len(c.key) for c in self._controls) + for ctrl in self._controls: + key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]") + print(f" {key_str} {ctrl.description}") + print() + + +# --------------------------------------------------------------------------- +# One display subclass per strategy +# --------------------------------------------------------------------------- + + +class BaseDisplay(RolloutStatusDisplay): + """Status display for the base (eval-only, no recording) strategy.""" + + _STATES = [StateConfig("running", "🟢", "RUNNING", "autonomous rollout — no recording", _C.GREEN)] + _CONTROLS = [ControlConfig("Ctrl+C", "stop session")] + + def __init__(self, duration: float = 0) -> None: + info = ["No recording — evaluation only."] + if duration > 0: + info.append(f"Duration: {duration:.0f}s") + super().__init__("base", self._STATES, self._CONTROLS, info) + + +class SentryDisplay(RolloutStatusDisplay): + """Status display for the sentry (continuous autonomous recording) strategy.""" + + _STATES = [StateConfig("recording", "🟢", "RECORDING", "continuous autonomous recording", _C.GREEN)] + _CONTROLS = [ControlConfig("Ctrl+C", "stop session")] + + def __init__(self, episode_duration_s: float, upload_every_n_episodes: int) -> None: + info = [ + f"Episode rotation: ~{episode_duration_s:.0f}s | " + f"Upload every {upload_every_n_episodes} episodes", + ] + super().__init__("sentry", self._STATES, self._CONTROLS, info) + + +class HighlightDisplay(RolloutStatusDisplay): + """Status display for the highlight (ring-buffer on-demand save) strategy.""" + + def __init__(self, ring_buffer_seconds: float, save_key: str, push_key: str) -> None: + states = [ + StateConfig( + "buffering", + "⚪", + "BUFFERING", + f"ring buffer active — last {ring_buffer_seconds:.0f}s captured", + _C.WHITE, + ), + StateConfig("recording", "🔴", "RECORDING", "live recording — press [s] to save episode", _C.RED), + ] + controls = [ + ControlConfig(save_key, "BUFFERING ↔ RECORDING start recording / save episode"), + ControlConfig(push_key, "push dataset to Hub (background)"), + ControlConfig("ESC", "stop session"), + ] + super().__init__("highlight", states, controls) + + +class DAggerDisplay(RolloutStatusDisplay): + """Status display for the dagger (human-in-the-loop) strategy.""" + + _PAUSED_STATE = StateConfig("paused", "🟡", "PAUSED", "holding last position — awaiting input", _C.YELLOW) + _CORRECTING_STATE = StateConfig( + "correcting", "🔴", "CORRECTING", "human teleop active — recording correction", _C.RED + ) + + def __init__( + self, + record_autonomous: bool, + num_episodes: int, + episode_duration_s: float, + input_device: str, + pause_key: str, + correction_key: str, + upload_key: str, + ) -> None: + mode = "continuous recording" if record_autonomous else "corrections only" + auto_desc = "policy running — recording" if record_autonomous else "policy running — no recording" + states = [ + StateConfig("autonomous", "🟢", "AUTONOMOUS", auto_desc, _C.GREEN), + self._PAUSED_STATE, + self._CORRECTING_STATE, + ] + controls = [ + ControlConfig(pause_key, "AUTONOMOUS ↔ PAUSED pause / resume policy"), + ControlConfig(correction_key, "PAUSED ↔ CORRECTING start / stop correction"), + ControlConfig(upload_key, "push dataset to Hub"), + ControlConfig("ESC", "stop session"), + ] + info = [f"Target: {num_episodes} episodes | Input: {input_device}"] + if record_autonomous: + info.append(f"Episode rotation: ~{episode_duration_s:.0f}s") + super().__init__(f"dagger [{mode}]", states, controls, info) + + +if __name__ == "__main__": + dagger_display = DAggerDisplay( + record_autonomous=False, + num_episodes=20, + episode_duration_s=30, + input_device="keyboard", + pause_key="SPACE", + correction_key="TAB", + upload_key="ENTER", + ) + dagger_display.show_banner() + dagger_display.show_state("paused") + dagger_display.show_state("correcting") + dagger_display.show_state("paused") + dagger_display.show_state("autonomous") diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index baff70da7..5666e2194 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -17,6 +17,7 @@ from __future__ import annotations import contextlib +import enum import logging import os import sys @@ -36,6 +37,7 @@ from ..configs import HighlightStrategyConfig from ..context import RolloutContext from ..ring_buffer import RolloutRingBuffer from .core import RolloutStrategy, safe_push_to_hub, send_next_action +from .display import HighlightDisplay PYNPUT_AVAILABLE = _pynput_available keyboard = None @@ -53,6 +55,13 @@ if PYNPUT_AVAILABLE: logger = logging.getLogger(__name__) +class HighlightPhase(enum.Enum): + """Observable phases of a Highlight session.""" + + BUFFERING = "buffering" # Ring buffer accumulating frames, not recording + RECORDING = "recording" # Live recording active + + class HighlightStrategy(RolloutStrategy): """Autonomous rollout with on-demand recording via ring buffer. @@ -105,6 +114,13 @@ class HighlightStrategy(RolloutStrategy): self.config.save_key, self.config.push_key, ) + self._display = HighlightDisplay( + ring_buffer_seconds=self.config.ring_buffer_seconds, + save_key=self.config.save_key, + push_key=self.config.push_key, + ) + self._display.show_banner() + self._display.show_state(HighlightPhase.BUFFERING) def run(self, ctx: RolloutContext) -> None: """Run the autonomous loop, buffering frames and recording on demand.""" @@ -162,6 +178,7 @@ class HighlightStrategy(RolloutStrategy): for buffered_frame in ring.drain(): dataset.add_frame(buffered_frame) self._recording_live.set() + self._display.show_state(HighlightPhase.RECORDING) else: dataset.add_frame(frame) with self._episode_lock: @@ -172,6 +189,7 @@ class HighlightStrategy(RolloutStrategy): play_sounds, ) self._recording_live.clear() + self._display.show_state(HighlightPhase.BUFFERING) continue # frame already consumed — skip ring.append if self._push_requested.is_set(): @@ -188,9 +206,7 @@ class HighlightStrategy(RolloutStrategy): if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) else: - logger.warning( - f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" - ) + self._warn_slow_loop(dt, control_interval, cfg.fps) finally: logger.info("Highlight control loop ended") @@ -255,7 +271,7 @@ class HighlightStrategy(RolloutStrategy): self._listener = keyboard.Listener(on_press=on_press) self._listener.start() - logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key) + logger.debug("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key) except ImportError: logger.warning("pynput not available — keyboard listener disabled") diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index 61e38aa68..89bd999b8 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -32,6 +32,7 @@ from lerobot.utils.utils import log_say from ..configs import SentryStrategyConfig from ..context import RolloutContext from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action +from .display import SentryDisplay logger = logging.getLogger(__name__) @@ -79,6 +80,11 @@ class SentryStrategy(RolloutStrategy): self._episode_duration_s, self.config.upload_every_n_episodes, ) + self._display = SentryDisplay( + episode_duration_s=self._episode_duration_s, + upload_every_n_episodes=self.config.upload_every_n_episodes, + ) + self._display.show_banner() def run(self, ctx: RolloutContext) -> None: """Run the continuous recording loop with automatic episode rotation.""" @@ -160,9 +166,7 @@ class SentryStrategy(RolloutStrategy): if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) else: - logger.warning( - f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" - ) + self._warn_slow_loop(dt, control_interval, cfg.fps) finally: logger.info("Sentry control loop ended — saving final episode")