From 35bb2c7459d920c0628189655de5cb7bdd2a020f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 17 Apr 2026 15:55:03 +0200 Subject: [PATCH] simplify dagger --- src/lerobot/rollout/__init__.py | 4 + src/lerobot/rollout/configs.py | 63 ++- src/lerobot/rollout/context.py | 12 + src/lerobot/rollout/strategies/dagger.py | 679 +++++++++++++---------- 4 files changed, 448 insertions(+), 310 deletions(-) diff --git a/src/lerobot/rollout/__init__.py b/src/lerobot/rollout/__init__.py index 896d6f91a..8d7a90d43 100644 --- a/src/lerobot/rollout/__init__.py +++ b/src/lerobot/rollout/__init__.py @@ -16,6 +16,8 @@ from .configs import ( BaseStrategyConfig, + DAggerKeyboardConfig, + DAggerPedalConfig, DAggerStrategyConfig, DatasetRecordConfig, HighlightStrategyConfig, @@ -39,6 +41,8 @@ from .strategies import RolloutStrategy, create_strategy __all__ = [ "BaseStrategyConfig", + "DAggerKeyboardConfig", + "DAggerPedalConfig", "DAggerStrategyConfig", "HighlightStrategyConfig", "InferenceEngine", diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 66a91a31e..7d5417299 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -66,7 +66,7 @@ class SentryStrategyConfig(RolloutStrategyConfig): uploaded in the background every ``upload_every_n_episodes`` episodes. """ - episode_duration_s: float = 120.0 + episode_duration_s: float = 20.0 upload_every_n_episodes: int = 5 @@ -87,6 +87,32 @@ class HighlightStrategyConfig(RolloutStrategyConfig): push_key: str = "h" +@dataclass +class DAggerKeyboardConfig: + """Keyboard key bindings for DAgger controls. + + Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or + special key names (``"space"``). + """ + + pause_resume: str = "space" + correction: str = "c" + upload: str = "h" + + +@dataclass +class DAggerPedalConfig: + """Foot pedal configuration for DAgger controls. + + Pedal codes are evdev key code strings (e.g. ``"KEY_A"``). + """ + + device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + pause_resume: str = "KEY_A" + correction: str = "KEY_B" + upload: str = "KEY_C" + + @RolloutStrategyConfig.register_subclass("dagger") @dataclass class DAggerStrategyConfig(RolloutStrategyConfig): @@ -95,19 +121,30 @@ class DAggerStrategyConfig(RolloutStrategyConfig): Alternates between autonomous policy execution and human intervention. Intervention frames are tagged with ``intervention=True``. + Input is controlled via either a keyboard or foot pedal, selected by + ``input_device``. Each device exposes three actions: + + 1. **pause_resume** — toggle policy execution on/off. + 2. **correction** — toggle human correction recording. + 3. **upload** — push dataset to hub on demand (corrections-only mode). + When ``record_autonomous=True`` (default) both autonomous and correction - frames are recorded — this requires streaming encoding so the policy - loop never blocks on disk I/O. Set to ``False`` to record only the - human-correction windows; encoding can then happen between phases. + frames are recorded with sentry-like time-based episode rotation and + background uploading. Set to ``False`` to record only the human-correction + windows, where each correction becomes its own episode. """ - episode_time_s: float = 120.0 - num_episodes: int = 50 - play_sounds: bool = True - calibrate: bool = False - log_hz: bool = True - hz_log_interval_s: float = 2.0 - record_autonomous: bool = True + episode_time_s: float = 20.0 + num_episodes: int = 10 + record_autonomous: bool = False + upload_every_n_episodes: int = 5 + input_device: str = "keyboard" + keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig) + pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig) + + def __post_init__(self): + if self.input_device not in ("keyboard", "pedal"): + raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'") # --------------------------------------------------------------------------- @@ -160,9 +197,7 @@ class RolloutConfig: if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None: raise ValueError("DAgger strategy requires --teleop.type to be set") - needs_dataset = isinstance( - self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig) - ) + needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig)) if needs_dataset and (self.dataset is None or not self.dataset.repo_id): raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set") diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index 7ed8ae784..d202647f9 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -234,6 +234,18 @@ def build_rollout_context( teleop = make_teleoperator_from_config(cfg.teleop) teleop.connect() + # DAgger requires teleop with motor control capabilities (enable_torque, + # disable_torque, write_goal_positions). + if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None: + required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions") + missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))] + if missing: + teleop.disconnect() + raise ValueError( + f"DAgger strategy requires a teleoperator with motor control methods " + f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}" + ) + # --- 4. Features + action-key reconciliation --------------------- all_obs_features = robot.observation_features observation_features_hw = { diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 66f822db4..1d7297031 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -18,13 +18,20 @@ Implements the RaC paradigm (Recovery and Correction) for interactive imitation learning. Alternates between autonomous policy execution and human intervention via teleoperator. -Keyboard Controls: - SPACE - Pause policy (robot holds position, no recording) - c - Take control (start correction, recording resumes) - p - Resume policy after pause/correction - -> - End episode (save and continue) - <- - Re-record episode - ESC - Stop recording and push to hub +Input is controlled via either a keyboard or foot pedal, selected by +the ``input_device`` config field. Each device exposes three actions: + + 1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED). + 2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING). + 3. **upload** — Push dataset to hub on demand (corrections-only mode). + ESC (keyboard only) — Stop session. + +Recording Modes: + ``record_autonomous=True``: Sentry-like continuous recording with + time-based episode rotation. Both autonomous and correction + frames are recorded; corrections tagged ``intervention=True``. + ``record_autonomous=False``: Only correction windows are recorded. + Each correction (start to stop) becomes one episode. """ from __future__ import annotations @@ -32,7 +39,10 @@ from __future__ import annotations import contextlib import enum import logging +import os +import sys import time +from concurrent.futures import Future, ThreadPoolExecutor from threading import Event, Lock from typing import Any @@ -40,19 +50,31 @@ import numpy as np from lerobot.common.control_utils import is_headless from lerobot.datasets import VideoEncodingManager -from lerobot.processor import RobotProcessorPipeline from lerobot.teleoperators import Teleoperator from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.import_utils import _pynput_available from lerobot.utils.pedal import start_pedal_listener from lerobot.utils.robot_utils import precise_sleep -from lerobot.utils.utils import log_say -from ..configs import DAggerStrategyConfig +from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig from ..context import RolloutContext from ..robot_wrapper import ThreadSafeRobot from .core import RolloutStrategy, send_next_action +PYNPUT_AVAILABLE = _pynput_available +keyboard = None +if PYNPUT_AVAILABLE: + try: + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + logging.info("No DISPLAY set. Skipping pynput import.") + PYNPUT_AVAILABLE = False + else: + from pynput import keyboard + except Exception as e: + PYNPUT_AVAILABLE = False + logging.info(f"Could not import pynput: {e}") + logger = logging.getLogger(__name__) @@ -64,22 +86,22 @@ logger = logging.getLogger(__name__) class DAggerPhase(enum.Enum): """Observable phases of a DAgger episode.""" - AUTONOMOUS = "autonomous" # Policy driving, recording autonomous frames - PAUSED = "paused" # Engine paused, teleop aligned, awaiting takeover/resume + AUTONOMOUS = "autonomous" # Policy driving + PAUSED = "paused" # Engine paused, teleop aligned, awaiting input CORRECTING = "correcting" # Human driving via teleop, recording interventions -# Valid (current_phase, event) → next_phase +# Valid (current_phase, event) -> next_phase _DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = { - (DAggerPhase.AUTONOMOUS, "pause"): DAggerPhase.PAUSED, - (DAggerPhase.PAUSED, "takeover"): DAggerPhase.CORRECTING, - (DAggerPhase.PAUSED, "resume"): DAggerPhase.AUTONOMOUS, - (DAggerPhase.CORRECTING, "resume"): DAggerPhase.AUTONOMOUS, + (DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED, + (DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS, + (DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING, + (DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED, } class DAggerEvents: - """Thread-safe container for DAgger keyboard/pedal events. + """Thread-safe container for DAgger input device events. The keyboard/pedal threads write transition requests; the main loop consumes them. @@ -90,16 +112,9 @@ class DAggerEvents: self._phase = DAggerPhase.AUTONOMOUS self._pending_transition: str | None = None - # Episode-level flags written by keyboard/pedal threads, consumed by - # the main loop. ``threading.Event`` gives us atomic set/clear/check - # semantics without taking ``self._lock``. - self.exit_early = Event() - self.rerecord_episode = Event() + # Session-level flags self.stop_recording = Event() - - # Reset-phase flags (simpler lifecycle, shared between threads). - self.in_reset = Event() - self.start_next_episode = Event() + self.upload_requested = Event() # -- Thread-safe phase access ------------------------------------------ @@ -138,13 +153,12 @@ class DAggerEvents: self._phase = new_phase return old_phase, new_phase - def reset_for_episode(self) -> None: - """Reset all transient state at the start of an episode.""" + def reset(self) -> None: + """Reset all transient state for a fresh session.""" with self._lock: self._phase = DAggerPhase.AUTONOMOUS self._pending_transition = None - self.exit_early.clear() - self.rerecord_episode.clear() + self.upload_requested.clear() # --------------------------------------------------------------------------- @@ -152,29 +166,15 @@ class DAggerEvents: # --------------------------------------------------------------------------- -def _teleop_has_motor_control(teleop: Teleoperator) -> bool: - return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions")) - - -def _teleop_disable_torque(teleop: Teleoperator) -> None: - if hasattr(teleop, "disable_torque"): - teleop.disable_torque() - - -def _teleop_enable_torque(teleop: Teleoperator) -> None: - if hasattr(teleop, "enable_torque"): - teleop.enable_torque() - - def _teleop_smooth_move_to( teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50 ) -> None: - """Smoothly move teleop to target position if motor control is available.""" - if not _teleop_has_motor_control(teleop): - logger.warning("Teleop does not support motor control — cannot mirror robot position") - return + """Smoothly move teleop to target position via linear interpolation. - _teleop_enable_torque(teleop) + The teleoperator is guaranteed to have motor control methods + (validated at context build time). + """ + teleop.enable_torque() current = teleop.get_action() steps = max(int(duration_s * fps), 1) @@ -190,103 +190,58 @@ def _teleop_smooth_move_to( time.sleep(1 / fps) -def _reset_loop( - robot: ThreadSafeRobot, - teleop: Teleoperator, - events: DAggerEvents, - fps: int, - teleop_action_processor: RobotProcessorPipeline, - robot_action_processor: RobotProcessorPipeline, -) -> None: - """Reset period where the human repositions the environment.""" - logger.info("RESET — press any key to enable teleoperation") - - events.in_reset.set() - events.start_next_episode.clear() - - obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} - _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) - - while not events.start_next_episode.is_set() and not events.stop_recording.is_set(): - precise_sleep(0.05) - - if events.stop_recording.is_set(): - return - - events.start_next_episode.clear() - _teleop_disable_torque(teleop) - logger.info("Teleop enabled — press any key to start episode") - - while not events.start_next_episode.is_set() and not events.stop_recording.is_set(): - loop_start = time.perf_counter() - obs = robot.get_observation() - action = teleop.get_action() - processed_teleop = teleop_action_processor((action, obs)) - robot_action_to_send = robot_action_processor((processed_teleop, obs)) - robot.send_action(robot_action_to_send) - precise_sleep(1 / fps - (time.perf_counter() - loop_start)) - - events.in_reset.clear() - events.start_next_episode.clear() - events.reset_for_episode() +# --------------------------------------------------------------------------- +# Input device handlers +# --------------------------------------------------------------------------- -def _init_dagger_keyboard(events: DAggerEvents): - """Initialise keyboard listener with DAgger/HIL controls. +def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig): + """Initialise keyboard listener with DAgger 3-key controls. - Returns the pynput Listener (or ``None`` in headless mode). + Returns the pynput Listener (or ``None`` in headless mode or when + pynput is unavailable). """ - if is_headless(): - logger.warning("Headless environment — keyboard controls unavailable") + if not PYNPUT_AVAILABLE or is_headless(): + logger.warning("Headless environment or pynput unavailable — keyboard controls disabled") return None - from pynput import keyboard + # Map config key names to pynput Key objects for special keys + special_keys = { + "space": keyboard.Key.space, + "tab": keyboard.Key.tab, + "enter": keyboard.Key.enter, + } + + def _resolve_key(key) -> str | None: + """Resolve a pynput key event to a config-comparable string.""" + if key == keyboard.Key.esc: + return "esc" + for name, pynput_key in special_keys.items(): + if key == pynput_key: + return name + if hasattr(key, "char") and key.char: + return key.char + return None + + # Build mapping: resolved key string -> DAgger event name + key_to_event = { + cfg.pause_resume: "pause_resume", + cfg.correction: "correction", + } def on_press(key): try: - if events.in_reset.is_set(): - if ( - key in [keyboard.Key.space, keyboard.Key.right] - or hasattr(key, "char") - and key.char == "c" - ): - events.start_next_episode.set() - elif key == keyboard.Key.esc: - events.stop_recording.set() - events.start_next_episode.set() + resolved = _resolve_key(key) + if resolved is None: return - - phase = events.phase - if key == keyboard.Key.space and phase == DAggerPhase.AUTONOMOUS: - logger.info("PAUSED — press 'c' to take control or 'p' to resume policy") - events.request_transition("pause") - elif hasattr(key, "char") and key.char == "c" and phase == DAggerPhase.PAUSED: - logger.info("Taking control...") - events.request_transition("takeover") - elif ( - hasattr(key, "char") - and key.char == "p" - and phase - in ( - DAggerPhase.PAUSED, - DAggerPhase.CORRECTING, - ) - ): - logger.info("Resuming policy...") - events.request_transition("resume") - - elif key == keyboard.Key.right: - logger.info("End episode") - events.exit_early.set() - elif key == keyboard.Key.left: - logger.info("Re-record episode") - events.rerecord_episode.set() - events.exit_early.set() - elif key == keyboard.Key.esc: + if resolved == "esc": logger.info("Stop recording...") events.stop_recording.set() - events.exit_early.set() + return + if resolved in key_to_event: + events.request_transition(key_to_event[resolved]) + if resolved == cfg.upload: + events.upload_requested.set() except Exception as e: logger.debug("Key error: %s", e) @@ -295,27 +250,23 @@ def _init_dagger_keyboard(events: DAggerEvents): return listener -_DAGGER_PEDAL_KEYS = ("KEY_A", "KEY_C") +def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig): + """Initialise foot pedal listener with DAgger 3-pedal controls. - -def _dagger_pedal_callback(events: DAggerEvents): - """Build the pedal key-press handler for DAgger's state machine.""" + Returns the pedal listener thread (or ``None`` if evdev is unavailable). + """ + code_to_event = { + cfg.pause_resume: "pause_resume", + cfg.correction: "correction", + } def on_press(code: str) -> None: - if code not in _DAGGER_PEDAL_KEYS: - return - if events.in_reset.is_set(): - events.start_next_episode.set() - return - phase = events.phase - if phase == DAggerPhase.CORRECTING: - events.request_transition("resume") - elif phase == DAggerPhase.PAUSED: - events.request_transition("takeover") - elif phase == DAggerPhase.AUTONOMOUS: - events.request_transition("pause") + if code in code_to_event: + events.request_transition(code_to_event[code]) + if code == cfg.upload: + events.upload_requested.set() - return on_press + return start_pedal_listener(on_press, device_path=cfg.device_path) # --------------------------------------------------------------------------- @@ -328,12 +279,14 @@ class DAggerStrategy(RolloutStrategy): State machine:: - AUTONOMOUS --(SPACE)--> PAUSED --(c)--> CORRECTING --(p)--> AUTONOMOUS - --(p)--> AUTONOMOUS + AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED + --(key1)--> AUTONOMOUS - Intervention frames are tagged with ``intervention=True`` (bool) in - the dataset; autonomous frames with ``intervention=False``. When - ``record_autonomous=False`` only corrections are recorded. + Recording modes: + ``record_autonomous=True``: Sentry-like continuous recording with + time-based episode rotation. Intervention frames tagged True. + ``record_autonomous=False``: Only correction windows recorded. + Each correction = one episode. Upload on demand via key3. """ config: DAggerStrategyConfig @@ -341,71 +294,51 @@ class DAggerStrategy(RolloutStrategy): def __init__(self, config: DAggerStrategyConfig): super().__init__(config) self._listener = None + self._pedal_thread = None self._events = DAggerEvents() + self._push_executor: ThreadPoolExecutor | None = None + self._pending_push: Future | None = None + self._needs_push = Event() + self._episode_lock = Lock() def setup(self, ctx: RolloutContext) -> None: - """Initialise the inference engine, keyboard listener, and pedal handler.""" + """Initialise the inference engine and input device listener.""" self._init_engine(ctx) + self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push") - self._listener = _init_dagger_keyboard(self._events) - start_pedal_listener(_dagger_pedal_callback(self._events)) + if self.config.input_device == "keyboard": + self._listener = _init_dagger_keyboard(self._events, self.config.keyboard) + else: + self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal) + record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only" logger.info( - "DAgger strategy ready (episodes=%d, episode_time=%.0fs, record_autonomous=%s)", + "DAgger strategy ready (input=%s, episodes=%d, record=%s)", + self.config.input_device, self.config.num_episodes, - self.config.episode_time_s, - self.config.record_autonomous, + record_mode, ) - logger.info("Controls: SPACE=pause, c=take control, p=resume, ->=end, <-=redo, ESC=stop") def run(self, ctx: RolloutContext) -> None: """Run DAgger episodes with human-in-the-loop intervention.""" - dataset = ctx.data.dataset - events = self._events - teleop = ctx.hardware.teleop - - with VideoEncodingManager(dataset): - try: - recorded = 0 - while recorded < self.config.num_episodes and not events.stop_recording.is_set(): - log_say(f"Episode {dataset.num_episodes}", self.config.play_sounds) - - self._run_episode(ctx) - - if events.rerecord_episode.is_set(): - log_say("Re-recording", self.config.play_sounds) - events.rerecord_episode.clear() - events.exit_early.clear() - dataset.clear_episode_buffer() - continue - - dataset.save_episode() - recorded += 1 - - if recorded < self.config.num_episodes and not events.stop_recording.is_set(): - _reset_loop( - ctx.hardware.robot_wrapper, - teleop, - events, - int(ctx.runtime.cfg.fps), - ctx.processors.teleop_action_processor, - ctx.processors.robot_action_processor, - ) - - finally: - with contextlib.suppress(Exception): - dataset.save_episode() + if self.config.record_autonomous: + self._run_continuous(ctx) + else: + self._run_corrections_only(ctx) def teardown(self, ctx: RolloutContext) -> None: """Stop listeners, finalise the dataset, and disconnect hardware.""" - log_say("Stop recording", self.config.play_sounds, blocking=True) - if self._listener is not None and not is_headless(): self._listener.stop() + # Flush any queued/running push cleanly + if self._push_executor is not None: + self._push_executor.shutdown(wait=True) + self._push_executor = None + if ctx.data.dataset is not None: ctx.data.dataset.finalize() - if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: + if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: ctx.data.dataset.push_to_hub( tags=ctx.runtime.cfg.dataset.tags, private=ctx.runtime.cfg.dataset.private, @@ -415,11 +348,17 @@ class DAggerStrategy(RolloutStrategy): logger.info("DAgger strategy teardown complete") # ------------------------------------------------------------------ - # Episode rollout (state machine) + # Continuous recording mode (record_autonomous=True) # ------------------------------------------------------------------ - def _run_episode(self, ctx: RolloutContext) -> None: - """Run a single DAgger episode with the HIL state machine.""" + def _run_continuous(self, ctx: RolloutContext) -> None: + """Sentry-like continuous recording with intervention tagging. + + Episodes are auto-rotated every ``episode_time_s`` seconds and + uploaded in the background every ``upload_every_n_episodes`` episodes. + Both autonomous and correction frames are recorded; corrections are + tagged with ``intervention=True``. + """ engine = self._engine cfg = ctx.runtime.cfg robot = ctx.hardware.robot_wrapper @@ -427,111 +366,231 @@ class DAggerStrategy(RolloutStrategy): dataset = ctx.data.dataset events = self._events interpolator = self._interpolator + features = ctx.data.dataset_features control_interval = interpolator.get_control_interval(cfg.fps) - stream_online = bool(cfg.dataset.streaming_encoding) if cfg.dataset else False record_stride = max(1, cfg.interpolation_multiplier) - record_autonomous = self.config.record_autonomous - - features = ctx.data.dataset_features + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task engine.reset() interpolator.reset() - events.reset_for_episode() - _teleop_disable_torque(teleop) - - last_action: dict[str, Any] | None = None - frame_buffer: list[dict] = [] - task_str = cfg.dataset.single_task if cfg.dataset else cfg.task - - timestamp = 0.0 - record_tick = 0 - start_t = time.perf_counter() - + events.reset() + teleop.disable_torque() engine.resume() - while timestamp < self.config.episode_time_s: - loop_start = time.perf_counter() + last_action: dict[str, Any] | None = None + record_tick = 0 + episode_start = time.perf_counter() + start_time = time.perf_counter() + episodes_since_push = 0 - if events.exit_early.is_set(): - events.exit_early.clear() - break + with VideoEncodingManager(dataset): + try: + while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set(): + loop_start = time.perf_counter() - transition = events.consume_transition() - if transition is not None: - old_phase, new_phase = transition - self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop) - last_action = None + if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: + break - phase = events.phase + # Process transitions + transition = events.consume_transition() + if transition is not None: + old_phase, new_phase = transition + self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop) + last_action = None - obs = robot.get_observation() - obs_processed = ctx.processors.robot_observation_processor(obs) - obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + phase = events.phase + obs = robot.get_observation() + obs_processed = ctx.processors.robot_observation_processor(obs) + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) - # --- CORRECTING: human teleop control --- - if phase == DAggerPhase.CORRECTING: - teleop_action = teleop.get_action() - processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs)) - robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs)) - robot.send_action(robot_action_to_send) - action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION) - if record_tick % record_stride == 0: - frame = { - **obs_frame, - **action_frame, - "task": task_str, - "intervention": np.array([True], dtype=bool), - } - if stream_online: - dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 - - # --- PAUSED: hold position --- - elif phase == DAggerPhase.PAUSED: - if last_action: - robot.send_action(last_action) - - # --- AUTONOMOUS: policy control --- - else: - engine.notify_observation(obs_processed) - - if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): - timestamp = time.perf_counter() - start_t - continue - - action_dict = send_next_action(obs_processed, obs, ctx, interpolator) - - if action_dict is not None: - last_action = ctx.processors.robot_action_processor((action_dict, obs)) - action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) - if record_autonomous and record_tick % record_stride == 0: - frame = { - **obs_frame, - **action_frame, - "task": task_str, - "intervention": np.array([False], dtype=bool), - } - if stream_online: + # --- CORRECTING: human teleop control --- + if phase == DAggerPhase.CORRECTING: + teleop_action = teleop.get_action() + processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs)) + robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs)) + robot.send_action(robot_action_to_send) + last_action = robot_action_to_send + action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION) + if record_tick % record_stride == 0: + frame = { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([True], dtype=bool), + } dataset.add_frame(frame) - else: - frame_buffer.append(frame) - record_tick += 1 + record_tick += 1 - dt = time.perf_counter() - loop_start - if (sleep_t := control_interval - dt) > 0: - precise_sleep(sleep_t) - timestamp = time.perf_counter() - start_t + # --- PAUSED: hold position --- + elif phase == DAggerPhase.PAUSED: + if last_action: + robot.send_action(last_action) - # End of episode: pause engine, disable teleop, flush buffer - engine.pause() - _teleop_disable_torque(teleop) + # --- AUTONOMOUS: policy control --- + else: + engine.notify_observation(obs_processed) - if not stream_online: - for frame in frame_buffer: - dataset.add_frame(frame) + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): + continue + + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) + if action_dict is not None: + last_action = ctx.processors.robot_action_processor((action_dict, obs)) + action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) + if record_tick % record_stride == 0: + frame = { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([False], dtype=bool), + } + dataset.add_frame(frame) + record_tick += 1 + + # Sentry-like episode rotation + elapsed = time.perf_counter() - episode_start + if elapsed >= self.config.episode_time_s: + with self._episode_lock: + dataset.save_episode() + episodes_since_push += 1 + self._needs_push.set() + logger.info("Episode saved (total: %d)", dataset.num_episodes) + + if episodes_since_push >= self.config.upload_every_n_episodes: + self._background_push(dataset, cfg) + episodes_since_push = 0 + + episode_start = time.perf_counter() + + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + + finally: + engine.pause() + teleop.disable_torque() + with contextlib.suppress(Exception): + with self._episode_lock: + dataset.save_episode() + self._needs_push.set() + + # ------------------------------------------------------------------ + # Corrections-only mode (record_autonomous=False) + # ------------------------------------------------------------------ + + def _run_corrections_only(self, ctx: RolloutContext) -> None: + """Record only human correction windows. Each correction = one episode. + + The policy runs autonomously without recording. When the user + pauses and starts a correction, frames are recorded with + ``intervention=True``. Stopping the correction saves the episode. + The dataset can be uploaded on demand via the upload key/pedal. + """ + engine = self._engine + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + teleop = ctx.hardware.teleop + dataset = ctx.data.dataset + events = self._events + interpolator = self._interpolator + features = ctx.data.dataset_features + + control_interval = interpolator.get_control_interval(cfg.fps) + record_stride = max(1, cfg.interpolation_multiplier) + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task + + engine.reset() + interpolator.reset() + events.reset() + teleop.disable_torque() + engine.resume() + + last_action: dict[str, Any] | None = None + record_tick = 0 + recorded = 0 + + with VideoEncodingManager(dataset): + try: + while ( + recorded < self.config.num_episodes + and not events.stop_recording.is_set() + and not ctx.runtime.shutdown_event.is_set() + ): + loop_start = time.perf_counter() + + # Process transitions + transition = events.consume_transition() + if transition is not None: + old_phase, new_phase = transition + self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop) + last_action = None + + # Correction ended -> save episode (blocking if not streaming) + if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED: + with self._episode_lock: + dataset.save_episode() + recorded += 1 + self._needs_push.set() + logger.info("Episode %d saved", recorded) + + # On-demand upload + if events.upload_requested.is_set(): + events.upload_requested.clear() + self._background_push(dataset, cfg) + + phase = events.phase + obs = robot.get_observation() + obs_processed = ctx.processors.robot_observation_processor(obs) + + # --- CORRECTING: human teleop control + recording --- + if phase == DAggerPhase.CORRECTING: + teleop_action = teleop.get_action() + processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs)) + robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs)) + robot.send_action(robot_action_to_send) + last_action = robot_action_to_send + + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION) + if record_tick % record_stride == 0: + dataset.add_frame( + { + **obs_frame, + **action_frame, + "task": task_str, + "intervention": np.array([True], dtype=bool), + } + ) + record_tick += 1 + + # --- PAUSED: hold position --- + elif phase == DAggerPhase.PAUSED: + if last_action: + robot.send_action(last_action) + + # --- AUTONOMOUS: policy control (no recording) --- + else: + engine.notify_observation(obs_processed) + + if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): + continue + + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) + if action_dict is not None: + last_action = ctx.processors.robot_action_processor((action_dict, obs)) + + dt = time.perf_counter() - loop_start + if (sleep_t := control_interval - dt) > 0: + precise_sleep(sleep_t) + + finally: + engine.pause() + teleop.disable_torque() + with contextlib.suppress(Exception): + with self._episode_lock: + dataset.save_episode() + self._needs_push.set() # ------------------------------------------------------------------ # State-machine transition side-effects @@ -554,13 +613,41 @@ class DAggerStrategy(RolloutStrategy): k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features } _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) - interpolator.reset() elif new_phase == DAggerPhase.CORRECTING: - _teleop_disable_torque(teleop) - engine.reset() + teleop.disable_torque() elif new_phase == DAggerPhase.AUTONOMOUS: interpolator.reset() engine.reset() engine.resume() + + # ------------------------------------------------------------------ + # Background push (shared by both modes) + # ------------------------------------------------------------------ + + def _background_push(self, dataset, cfg) -> None: + """Queue a Hub push on the single-worker executor. + + The executor's max_workers=1 guarantees at most one push runs at + a time; submitted tasks are queued rather than dropped. + """ + if self._push_executor is None: + return + + if self._pending_push is not None and not self._pending_push.done(): + logger.info("Previous push still in progress; queueing next") + + def _push(): + try: + with self._episode_lock: + dataset.push_to_hub( + tags=cfg.dataset.tags if cfg.dataset else None, + private=cfg.dataset.private if cfg.dataset else False, + ) + self._needs_push.clear() + logger.info("Background push to hub complete") + except Exception as e: + logger.error("Background push failed: %s", e) + + self._pending_push = self._push_executor.submit(_push)