mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
chore(rollout): nice collored cli
This commit is contained in:
@@ -23,6 +23,7 @@ from lerobot.utils.robot_utils import precise_sleep
|
|||||||
|
|
||||||
from ..context import RolloutContext
|
from ..context import RolloutContext
|
||||||
from .core import RolloutStrategy, send_next_action
|
from .core import RolloutStrategy, send_next_action
|
||||||
|
from .display import BaseDisplay
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,6 +39,8 @@ class BaseStrategy(RolloutStrategy):
|
|||||||
"""Initialise the inference engine."""
|
"""Initialise the inference engine."""
|
||||||
self._init_engine(ctx)
|
self._init_engine(ctx)
|
||||||
logger.info("Base strategy ready")
|
logger.info("Base strategy ready")
|
||||||
|
self._display = BaseDisplay(duration=ctx.runtime.cfg.duration)
|
||||||
|
self._display.show_banner()
|
||||||
|
|
||||||
def run(self, ctx: RolloutContext) -> None:
|
def run(self, ctx: RolloutContext) -> None:
|
||||||
"""Run the autonomous control loop until shutdown or duration expires."""
|
"""Run the autonomous control loop until shutdown or duration expires."""
|
||||||
@@ -72,9 +75,7 @@ class BaseStrategy(RolloutStrategy):
|
|||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
|
||||||
)
|
|
||||||
|
|
||||||
def teardown(self, ctx: RolloutContext) -> None:
|
def teardown(self, ctx: RolloutContext) -> None:
|
||||||
"""Disconnect hardware and stop inference."""
|
"""Disconnect hardware and stop inference."""
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from ..inference import InferenceEngine
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..configs import RolloutStrategyConfig
|
from ..configs import RolloutStrategyConfig
|
||||||
from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext
|
from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext
|
||||||
|
from .display import RolloutStatusDisplay
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,6 +52,17 @@ class RolloutStrategy(abc.ABC):
|
|||||||
self._interpolator: ActionInterpolator | None = None
|
self._interpolator: ActionInterpolator | None = None
|
||||||
self._warmup_flushed: bool = False
|
self._warmup_flushed: bool = False
|
||||||
self._cached_obs_processed: dict | None = None
|
self._cached_obs_processed: dict | None = None
|
||||||
|
self._display: RolloutStatusDisplay | None = None
|
||||||
|
|
||||||
|
def _warn_slow_loop(self, dt: float, control_interval: float, fps: float) -> None:
|
||||||
|
"""Warn when the control loop runs slower than the target FPS."""
|
||||||
|
if dt > control_interval:
|
||||||
|
logger.warning(
|
||||||
|
"Control loop running slower (%.1f Hz) than target (%.0f Hz). "
|
||||||
|
"Possible causes: camera FPS not keeping up, slow policy inference, CPU starvation.",
|
||||||
|
1 / dt,
|
||||||
|
fps,
|
||||||
|
)
|
||||||
|
|
||||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||||
"""Attach the inference engine and action interpolator, then start the backend.
|
"""Attach the inference engine and action interpolator, then start the backend.
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon
|
|||||||
from ..context import RolloutContext
|
from ..context import RolloutContext
|
||||||
from ..robot_wrapper import ThreadSafeRobot
|
from ..robot_wrapper import ThreadSafeRobot
|
||||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||||
|
from .display import DAggerDisplay
|
||||||
|
|
||||||
PYNPUT_AVAILABLE = _pynput_available
|
PYNPUT_AVAILABLE = _pynput_available
|
||||||
keyboard = None
|
keyboard = None
|
||||||
@@ -286,7 +287,7 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
|||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
listener.start()
|
listener.start()
|
||||||
logger.info(
|
logger.debug(
|
||||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||||
cfg.pause_resume,
|
cfg.pause_resume,
|
||||||
cfg.correction,
|
cfg.correction,
|
||||||
@@ -370,6 +371,28 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
self._episode_duration_s,
|
self._episode_duration_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.input_device == "keyboard":
|
||||||
|
kb = self.config.keyboard
|
||||||
|
pause_key, correction_key, upload_key = (
|
||||||
|
kb.pause_resume.upper(),
|
||||||
|
kb.correction.upper(),
|
||||||
|
kb.upload.upper(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pb = self.config.pedal
|
||||||
|
pause_key, correction_key, upload_key = pb.pause_resume, pb.correction, pb.upload
|
||||||
|
|
||||||
|
self._display = DAggerDisplay(
|
||||||
|
record_autonomous=self.config.record_autonomous,
|
||||||
|
num_episodes=self.config.num_episodes,
|
||||||
|
episode_duration_s=self._episode_duration_s,
|
||||||
|
input_device=self.config.input_device,
|
||||||
|
pause_key=pause_key,
|
||||||
|
correction_key=correction_key,
|
||||||
|
upload_key=upload_key,
|
||||||
|
)
|
||||||
|
self._display.show_banner()
|
||||||
|
|
||||||
def run(self, ctx: RolloutContext) -> None:
|
def run(self, ctx: RolloutContext) -> None:
|
||||||
"""Run DAgger episodes with human-in-the-loop intervention."""
|
"""Run DAgger episodes with human-in-the-loop intervention."""
|
||||||
if self.config.record_autonomous:
|
if self.config.record_autonomous:
|
||||||
@@ -442,6 +465,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
interpolator.reset()
|
interpolator.reset()
|
||||||
events.reset()
|
events.reset()
|
||||||
engine.resume()
|
engine.resume()
|
||||||
|
self._display.show_state(DAggerPhase.AUTONOMOUS)
|
||||||
|
|
||||||
last_action: dict[str, Any] | None = None
|
last_action: dict[str, Any] | None = None
|
||||||
record_tick = 0
|
record_tick = 0
|
||||||
@@ -472,6 +496,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
ctx,
|
ctx,
|
||||||
last_action,
|
last_action,
|
||||||
)
|
)
|
||||||
|
self._display.show_state(new_phase)
|
||||||
if new_phase == DAggerPhase.AUTONOMOUS:
|
if new_phase == DAggerPhase.AUTONOMOUS:
|
||||||
last_action = None
|
last_action = None
|
||||||
|
|
||||||
@@ -556,9 +581,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("DAgger continuous control loop ended — pausing engine")
|
logger.info("DAgger continuous control loop ended — pausing engine")
|
||||||
@@ -599,6 +622,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
interpolator.reset()
|
interpolator.reset()
|
||||||
events.reset()
|
events.reset()
|
||||||
engine.resume()
|
engine.resume()
|
||||||
|
self._display.show_state(DAggerPhase.AUTONOMOUS)
|
||||||
|
|
||||||
last_action: dict[str, Any] | None = None
|
last_action: dict[str, Any] | None = None
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -633,6 +657,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
ctx,
|
ctx,
|
||||||
last_action,
|
last_action,
|
||||||
)
|
)
|
||||||
|
self._display.show_state(new_phase)
|
||||||
if new_phase == DAggerPhase.AUTONOMOUS:
|
if new_phase == DAggerPhase.AUTONOMOUS:
|
||||||
last_action = None
|
last_action = None
|
||||||
|
|
||||||
@@ -705,9 +730,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("DAgger corrections-only loop ended — pausing engine")
|
logger.info("DAgger corrections-only loop ended — pausing engine")
|
||||||
|
|||||||
@@ -0,0 +1,263 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Console status display for rollout strategies.
|
||||||
|
|
||||||
|
One subclass per strategy — static states/controls are declared as class
|
||||||
|
constants; runtime-dependent values are passed to ``__init__``.
|
||||||
|
|
||||||
|
In each strategy's ``setup()``:
|
||||||
|
|
||||||
|
self._display = DAggerDisplay(
|
||||||
|
record_autonomous=self.config.record_autonomous,
|
||||||
|
num_episodes=self.config.num_episodes,
|
||||||
|
episode_duration_s=self._episode_duration_s,
|
||||||
|
input_device=self.config.input_device,
|
||||||
|
pause_key="SPACE",
|
||||||
|
correction_key="TAB",
|
||||||
|
upload_key="ENTER",
|
||||||
|
)
|
||||||
|
self._display.show_banner()
|
||||||
|
|
||||||
|
On each state transition:
|
||||||
|
|
||||||
|
self._display.show_state("correcting")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_color() -> bool:
|
||||||
|
return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
|
||||||
|
|
||||||
|
|
||||||
|
class _C:
|
||||||
|
"""ANSI escape codes."""
|
||||||
|
|
||||||
|
RESET = "\033[0m"
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
DIM = "\033[2m"
|
||||||
|
GREEN = "\033[1;92m"
|
||||||
|
YELLOW = "\033[1;93m"
|
||||||
|
RED = "\033[1;91m"
|
||||||
|
CYAN = "\033[1;96m"
|
||||||
|
WHITE = "\033[1;97m"
|
||||||
|
GRAY = "\033[2;37m"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StateConfig:
|
||||||
|
"""One named rollout state.
|
||||||
|
|
||||||
|
``key`` must match the string passed to ``RolloutStatusDisplay.show_state()``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
emoji: str
|
||||||
|
label: str
|
||||||
|
description: str
|
||||||
|
color: str = _C.WHITE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ControlConfig:
|
||||||
|
"""One keyboard/pedal binding shown in the startup banner."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base display class
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class RolloutStatusDisplay:
|
||||||
|
"""Unified console status display. Subclass once per strategy."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
strategy: str,
|
||||||
|
states: list[StateConfig],
|
||||||
|
controls: list[ControlConfig],
|
||||||
|
info: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.strategy = strategy
|
||||||
|
self._states = {s.key: s for s in states}
|
||||||
|
self._controls = controls
|
||||||
|
self._info = info or []
|
||||||
|
self._use_color = _supports_color()
|
||||||
|
|
||||||
|
def _c(self, code: str, text: str) -> str:
|
||||||
|
if not self._use_color:
|
||||||
|
return text
|
||||||
|
return f"{code}{text}{_C.RESET}"
|
||||||
|
|
||||||
|
def show_banner(self) -> None:
|
||||||
|
"""Print startup banner: strategy name, states, controls, config info."""
|
||||||
|
width = 62
|
||||||
|
sep = self._c(_C.BOLD, "═" * width)
|
||||||
|
|
||||||
|
print(f"\n{sep}")
|
||||||
|
print(self._c(_C.BOLD, f" lerobot-rollout │ {self.strategy}"))
|
||||||
|
|
||||||
|
if self._states:
|
||||||
|
print()
|
||||||
|
for state in self._states.values():
|
||||||
|
label = self._c(state.color, f"{state.label:<14}")
|
||||||
|
desc = self._c(_C.GRAY, state.description)
|
||||||
|
print(f" {state.emoji} {label} {desc}")
|
||||||
|
|
||||||
|
if self._controls:
|
||||||
|
print()
|
||||||
|
key_width = max(len(c.key) for c in self._controls)
|
||||||
|
for ctrl in self._controls:
|
||||||
|
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
|
||||||
|
print(f" {key_str} {ctrl.description}")
|
||||||
|
|
||||||
|
if self._info:
|
||||||
|
print()
|
||||||
|
for item in self._info:
|
||||||
|
print(f" {item}")
|
||||||
|
|
||||||
|
print(f"{sep}\n")
|
||||||
|
|
||||||
|
def show_state(self, state_key: str | enum.Enum) -> None:
|
||||||
|
"""Print the current state and available controls - call this on every transition."""
|
||||||
|
key = state_key.value if isinstance(state_key, enum.Enum) else state_key
|
||||||
|
state = self._states.get(key)
|
||||||
|
if state is None:
|
||||||
|
return
|
||||||
|
label = self._c(state.color, f"{state.label:<14}")
|
||||||
|
desc = self._c(_C.GRAY, state.description)
|
||||||
|
print(f"\n {state.emoji} {label} {desc}\n")
|
||||||
|
|
||||||
|
if self._controls:
|
||||||
|
key_width = max(len(c.key) for c in self._controls)
|
||||||
|
for ctrl in self._controls:
|
||||||
|
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
|
||||||
|
print(f" {key_str} {ctrl.description}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# One display subclass per strategy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDisplay(RolloutStatusDisplay):
|
||||||
|
"""Status display for the base (eval-only, no recording) strategy."""
|
||||||
|
|
||||||
|
_STATES = [StateConfig("running", "🟢", "RUNNING", "autonomous rollout — no recording", _C.GREEN)]
|
||||||
|
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
|
||||||
|
|
||||||
|
def __init__(self, duration: float = 0) -> None:
|
||||||
|
info = ["No recording — evaluation only."]
|
||||||
|
if duration > 0:
|
||||||
|
info.append(f"Duration: {duration:.0f}s")
|
||||||
|
super().__init__("base", self._STATES, self._CONTROLS, info)
|
||||||
|
|
||||||
|
|
||||||
|
class SentryDisplay(RolloutStatusDisplay):
|
||||||
|
"""Status display for the sentry (continuous autonomous recording) strategy."""
|
||||||
|
|
||||||
|
_STATES = [StateConfig("recording", "🟢", "RECORDING", "continuous autonomous recording", _C.GREEN)]
|
||||||
|
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
|
||||||
|
|
||||||
|
def __init__(self, episode_duration_s: float, upload_every_n_episodes: int) -> None:
|
||||||
|
info = [
|
||||||
|
f"Episode rotation: ~{episode_duration_s:.0f}s | "
|
||||||
|
f"Upload every {upload_every_n_episodes} episodes",
|
||||||
|
]
|
||||||
|
super().__init__("sentry", self._STATES, self._CONTROLS, info)
|
||||||
|
|
||||||
|
|
||||||
|
class HighlightDisplay(RolloutStatusDisplay):
|
||||||
|
"""Status display for the highlight (ring-buffer on-demand save) strategy."""
|
||||||
|
|
||||||
|
def __init__(self, ring_buffer_seconds: float, save_key: str, push_key: str) -> None:
|
||||||
|
states = [
|
||||||
|
StateConfig(
|
||||||
|
"buffering",
|
||||||
|
"⚪",
|
||||||
|
"BUFFERING",
|
||||||
|
f"ring buffer active — last {ring_buffer_seconds:.0f}s captured",
|
||||||
|
_C.WHITE,
|
||||||
|
),
|
||||||
|
StateConfig("recording", "🔴", "RECORDING", "live recording — press [s] to save episode", _C.RED),
|
||||||
|
]
|
||||||
|
controls = [
|
||||||
|
ControlConfig(save_key, "BUFFERING ↔ RECORDING start recording / save episode"),
|
||||||
|
ControlConfig(push_key, "push dataset to Hub (background)"),
|
||||||
|
ControlConfig("ESC", "stop session"),
|
||||||
|
]
|
||||||
|
super().__init__("highlight", states, controls)
|
||||||
|
|
||||||
|
|
||||||
|
class DAggerDisplay(RolloutStatusDisplay):
|
||||||
|
"""Status display for the dagger (human-in-the-loop) strategy."""
|
||||||
|
|
||||||
|
_PAUSED_STATE = StateConfig("paused", "🟡", "PAUSED", "holding last position — awaiting input", _C.YELLOW)
|
||||||
|
_CORRECTING_STATE = StateConfig(
|
||||||
|
"correcting", "🔴", "CORRECTING", "human teleop active — recording correction", _C.RED
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
record_autonomous: bool,
|
||||||
|
num_episodes: int,
|
||||||
|
episode_duration_s: float,
|
||||||
|
input_device: str,
|
||||||
|
pause_key: str,
|
||||||
|
correction_key: str,
|
||||||
|
upload_key: str,
|
||||||
|
) -> None:
|
||||||
|
mode = "continuous recording" if record_autonomous else "corrections only"
|
||||||
|
auto_desc = "policy running — recording" if record_autonomous else "policy running — no recording"
|
||||||
|
states = [
|
||||||
|
StateConfig("autonomous", "🟢", "AUTONOMOUS", auto_desc, _C.GREEN),
|
||||||
|
self._PAUSED_STATE,
|
||||||
|
self._CORRECTING_STATE,
|
||||||
|
]
|
||||||
|
controls = [
|
||||||
|
ControlConfig(pause_key, "AUTONOMOUS ↔ PAUSED pause / resume policy"),
|
||||||
|
ControlConfig(correction_key, "PAUSED ↔ CORRECTING start / stop correction"),
|
||||||
|
ControlConfig(upload_key, "push dataset to Hub"),
|
||||||
|
ControlConfig("ESC", "stop session"),
|
||||||
|
]
|
||||||
|
info = [f"Target: {num_episodes} episodes | Input: {input_device}"]
|
||||||
|
if record_autonomous:
|
||||||
|
info.append(f"Episode rotation: ~{episode_duration_s:.0f}s")
|
||||||
|
super().__init__(f"dagger [{mode}]", states, controls, info)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dagger_display = DAggerDisplay(
|
||||||
|
record_autonomous=False,
|
||||||
|
num_episodes=20,
|
||||||
|
episode_duration_s=30,
|
||||||
|
input_device="keyboard",
|
||||||
|
pause_key="SPACE",
|
||||||
|
correction_key="TAB",
|
||||||
|
upload_key="ENTER",
|
||||||
|
)
|
||||||
|
dagger_display.show_banner()
|
||||||
|
dagger_display.show_state("paused")
|
||||||
|
dagger_display.show_state("correcting")
|
||||||
|
dagger_display.show_state("paused")
|
||||||
|
dagger_display.show_state("autonomous")
|
||||||
@@ -17,6 +17,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -36,6 +37,7 @@ from ..configs import HighlightStrategyConfig
|
|||||||
from ..context import RolloutContext
|
from ..context import RolloutContext
|
||||||
from ..ring_buffer import RolloutRingBuffer
|
from ..ring_buffer import RolloutRingBuffer
|
||||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||||
|
from .display import HighlightDisplay
|
||||||
|
|
||||||
PYNPUT_AVAILABLE = _pynput_available
|
PYNPUT_AVAILABLE = _pynput_available
|
||||||
keyboard = None
|
keyboard = None
|
||||||
@@ -53,6 +55,13 @@ if PYNPUT_AVAILABLE:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HighlightPhase(enum.Enum):
|
||||||
|
"""Observable phases of a Highlight session."""
|
||||||
|
|
||||||
|
BUFFERING = "buffering" # Ring buffer accumulating frames, not recording
|
||||||
|
RECORDING = "recording" # Live recording active
|
||||||
|
|
||||||
|
|
||||||
class HighlightStrategy(RolloutStrategy):
|
class HighlightStrategy(RolloutStrategy):
|
||||||
"""Autonomous rollout with on-demand recording via ring buffer.
|
"""Autonomous rollout with on-demand recording via ring buffer.
|
||||||
|
|
||||||
@@ -105,6 +114,13 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
self.config.save_key,
|
self.config.save_key,
|
||||||
self.config.push_key,
|
self.config.push_key,
|
||||||
)
|
)
|
||||||
|
self._display = HighlightDisplay(
|
||||||
|
ring_buffer_seconds=self.config.ring_buffer_seconds,
|
||||||
|
save_key=self.config.save_key,
|
||||||
|
push_key=self.config.push_key,
|
||||||
|
)
|
||||||
|
self._display.show_banner()
|
||||||
|
self._display.show_state(HighlightPhase.BUFFERING)
|
||||||
|
|
||||||
def run(self, ctx: RolloutContext) -> None:
|
def run(self, ctx: RolloutContext) -> None:
|
||||||
"""Run the autonomous loop, buffering frames and recording on demand."""
|
"""Run the autonomous loop, buffering frames and recording on demand."""
|
||||||
@@ -162,6 +178,7 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
for buffered_frame in ring.drain():
|
for buffered_frame in ring.drain():
|
||||||
dataset.add_frame(buffered_frame)
|
dataset.add_frame(buffered_frame)
|
||||||
self._recording_live.set()
|
self._recording_live.set()
|
||||||
|
self._display.show_state(HighlightPhase.RECORDING)
|
||||||
else:
|
else:
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
with self._episode_lock:
|
with self._episode_lock:
|
||||||
@@ -172,6 +189,7 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
play_sounds,
|
play_sounds,
|
||||||
)
|
)
|
||||||
self._recording_live.clear()
|
self._recording_live.clear()
|
||||||
|
self._display.show_state(HighlightPhase.BUFFERING)
|
||||||
continue # frame already consumed — skip ring.append
|
continue # frame already consumed — skip ring.append
|
||||||
|
|
||||||
if self._push_requested.is_set():
|
if self._push_requested.is_set():
|
||||||
@@ -188,9 +206,7 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("Highlight control loop ended")
|
logger.info("Highlight control loop ended")
|
||||||
@@ -255,7 +271,7 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
|
|
||||||
self._listener = keyboard.Listener(on_press=on_press)
|
self._listener = keyboard.Listener(on_press=on_press)
|
||||||
self._listener.start()
|
self._listener.start()
|
||||||
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
logger.debug("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("pynput not available — keyboard listener disabled")
|
logger.warning("pynput not available — keyboard listener disabled")
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from lerobot.utils.utils import log_say
|
|||||||
from ..configs import SentryStrategyConfig
|
from ..configs import SentryStrategyConfig
|
||||||
from ..context import RolloutContext
|
from ..context import RolloutContext
|
||||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||||
|
from .display import SentryDisplay
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -79,6 +80,11 @@ class SentryStrategy(RolloutStrategy):
|
|||||||
self._episode_duration_s,
|
self._episode_duration_s,
|
||||||
self.config.upload_every_n_episodes,
|
self.config.upload_every_n_episodes,
|
||||||
)
|
)
|
||||||
|
self._display = SentryDisplay(
|
||||||
|
episode_duration_s=self._episode_duration_s,
|
||||||
|
upload_every_n_episodes=self.config.upload_every_n_episodes,
|
||||||
|
)
|
||||||
|
self._display.show_banner()
|
||||||
|
|
||||||
def run(self, ctx: RolloutContext) -> None:
|
def run(self, ctx: RolloutContext) -> None:
|
||||||
"""Run the continuous recording loop with automatic episode rotation."""
|
"""Run the continuous recording loop with automatic episode rotation."""
|
||||||
@@ -160,9 +166,7 @@ class SentryStrategy(RolloutStrategy):
|
|||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("Sentry control loop ended — saving final episode")
|
logger.info("Sentry control loop ended — saving final episode")
|
||||||
|
|||||||
Reference in New Issue
Block a user