mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
734 lines
31 KiB
Python
734 lines
31 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""DAgger rollout strategy: Human-in-the-Loop data collection.
|
|
|
|
Implements the RaC paradigm (Recovery and Correction) for interactive
|
|
imitation learning. Alternates between autonomous policy execution and
|
|
human intervention via teleoperator.
|
|
|
|
Input is controlled via either a keyboard or foot pedal, selected by
|
|
the ``input_device`` config field. Each device exposes three actions:
|
|
|
|
1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED).
|
|
2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING).
|
|
3. **upload** — Push dataset to hub on demand (corrections-only mode).
|
|
ESC (keyboard only) — Stop session.
|
|
|
|
Recording Modes:
|
|
``record_autonomous=True``: Sentry-like continuous recording with
|
|
time-based episode rotation. Both autonomous and correction
|
|
frames are recorded; corrections tagged ``intervention=True``.
|
|
``record_autonomous=False``: Only correction windows are recorded.
|
|
Each correction (start to stop) becomes one episode.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import enum
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
from threading import Event, Lock
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from lerobot.common.control_utils import is_headless
|
|
from lerobot.datasets import VideoEncodingManager
|
|
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
|
from lerobot.teleoperators import Teleoperator
|
|
from lerobot.utils.constants import ACTION, OBS_STR
|
|
from lerobot.utils.feature_utils import build_dataset_frame
|
|
from lerobot.utils.import_utils import _pynput_available
|
|
from lerobot.utils.pedal import start_pedal_listener
|
|
from lerobot.utils.robot_utils import precise_sleep
|
|
from lerobot.utils.utils import log_say
|
|
|
|
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
|
from ..context import RolloutContext
|
|
from ..robot_wrapper import ThreadSafeRobot
|
|
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
|
|
|
PYNPUT_AVAILABLE = _pynput_available
|
|
keyboard = None
|
|
if PYNPUT_AVAILABLE:
|
|
try:
|
|
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
|
logging.info("No DISPLAY set. Skipping pynput import.")
|
|
PYNPUT_AVAILABLE = False
|
|
else:
|
|
from pynput import keyboard
|
|
except Exception as e:
|
|
PYNPUT_AVAILABLE = False
|
|
logging.info(f"Could not import pynput: {e}")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DAgger state machine
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class DAggerPhase(enum.Enum):
|
|
"""Observable phases of a DAgger episode."""
|
|
|
|
AUTONOMOUS = "autonomous" # Policy driving
|
|
PAUSED = "paused" # Engine paused, teleop aligned, awaiting input
|
|
CORRECTING = "correcting" # Human driving via teleop, recording interventions
|
|
|
|
|
|
# Valid (current_phase, event) -> next_phase
|
|
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
|
|
(DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED,
|
|
(DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS,
|
|
(DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING,
|
|
(DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED,
|
|
}
|
|
|
|
|
|
class DAggerEvents:
|
|
"""Thread-safe container for DAgger input device events.
|
|
|
|
The keyboard/pedal threads write transition requests; the main loop
|
|
consumes them.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = Lock()
|
|
self._phase = DAggerPhase.AUTONOMOUS
|
|
self._pending_transition: str | None = None
|
|
|
|
# Session-level flags
|
|
self.stop_recording = Event()
|
|
self.upload_requested = Event()
|
|
|
|
# -- Thread-safe phase access ------------------------------------------
|
|
|
|
@property
|
|
def phase(self) -> DAggerPhase:
|
|
"""Current phase of the DAgger state machine."""
|
|
with self._lock:
|
|
return self._phase
|
|
|
|
@phase.setter
|
|
def phase(self, value: DAggerPhase) -> None:
|
|
with self._lock:
|
|
self._phase = value
|
|
|
|
def request_transition(self, event: str) -> None:
|
|
"""Request a phase transition (called from keyboard/pedal threads).
|
|
|
|
Only enqueues the request if it corresponds to a valid transition
|
|
from the current phase, preventing impossible state changes.
|
|
"""
|
|
with self._lock:
|
|
if (self._phase, event) in _DAGGER_TRANSITIONS:
|
|
self._pending_transition = event
|
|
|
|
def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None:
|
|
"""Consume a pending transition (called from main loop)."""
|
|
with self._lock:
|
|
if self._pending_transition is None:
|
|
return None
|
|
key = (self._phase, self._pending_transition)
|
|
self._pending_transition = None
|
|
new_phase = _DAGGER_TRANSITIONS.get(key)
|
|
if new_phase is None:
|
|
return None
|
|
old_phase = self._phase
|
|
self._phase = new_phase
|
|
return old_phase, new_phase
|
|
|
|
def reset(self) -> None:
|
|
"""Reset all transient state for a fresh session."""
|
|
with self._lock:
|
|
self._phase = DAggerPhase.AUTONOMOUS
|
|
self._pending_transition = None
|
|
self.upload_requested.clear()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Teleoperator helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
|
def _teleop_smooth_move_to(
|
|
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
|
|
) -> None:
|
|
"""Smoothly move teleop to target position via linear interpolation.
|
|
|
|
Requires the teleoperator to support motor control methods
|
|
(``enable_torque``, ``write_goal_positions``, ``get_action``).
|
|
"""
|
|
teleop.enable_torque()
|
|
current = teleop.get_action()
|
|
steps = max(int(duration_s * fps), 1)
|
|
|
|
for step in range(steps + 1):
|
|
t = step / steps
|
|
interp = {}
|
|
for k in current:
|
|
if k in target_pos:
|
|
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
|
else:
|
|
interp[k] = current[k]
|
|
teleop.write_goal_positions(interp)
|
|
time.sleep(1 / fps)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Input device handlers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
|
"""Initialise keyboard listener with DAgger 3-key controls.
|
|
|
|
Returns the pynput Listener (or ``None`` in headless mode or when
|
|
pynput is unavailable).
|
|
"""
|
|
if not PYNPUT_AVAILABLE or is_headless():
|
|
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
|
|
return None
|
|
|
|
# Map config key names to pynput Key objects for special keys
|
|
special_keys = {
|
|
"space": keyboard.Key.space,
|
|
"tab": keyboard.Key.tab,
|
|
"enter": keyboard.Key.enter,
|
|
}
|
|
|
|
def _resolve_key(key) -> str | None:
|
|
"""Resolve a pynput key event to a config-comparable string."""
|
|
if key == keyboard.Key.esc:
|
|
return "esc"
|
|
for name, pynput_key in special_keys.items():
|
|
if key == pynput_key:
|
|
return name
|
|
if hasattr(key, "char") and key.char:
|
|
return key.char
|
|
return None
|
|
|
|
# Build mapping: resolved key string -> DAgger event name
|
|
key_to_event = {
|
|
cfg.pause_resume: "pause_resume",
|
|
cfg.correction: "correction",
|
|
}
|
|
|
|
def on_press(key):
|
|
try:
|
|
resolved = _resolve_key(key)
|
|
if resolved is None:
|
|
return
|
|
if resolved == "esc":
|
|
logger.info("Stop recording...")
|
|
events.stop_recording.set()
|
|
return
|
|
if resolved in key_to_event:
|
|
events.request_transition(key_to_event[resolved])
|
|
if resolved == cfg.upload:
|
|
events.upload_requested.set()
|
|
except Exception as e:
|
|
logger.debug("Key error: %s", e)
|
|
|
|
listener = keyboard.Listener(on_press=on_press)
|
|
listener.start()
|
|
logger.info(
|
|
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
|
cfg.pause_resume,
|
|
cfg.correction,
|
|
cfg.upload,
|
|
)
|
|
return listener
|
|
|
|
|
|
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
|
"""Initialise foot pedal listener with DAgger 3-pedal controls.
|
|
|
|
Returns the pedal listener thread (or ``None`` if evdev is unavailable).
|
|
"""
|
|
code_to_event = {
|
|
cfg.pause_resume: "pause_resume",
|
|
cfg.correction: "correction",
|
|
}
|
|
|
|
def on_press(code: str) -> None:
|
|
if code in code_to_event:
|
|
events.request_transition(code_to_event[code])
|
|
if code == cfg.upload:
|
|
events.upload_requested.set()
|
|
|
|
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
|
|
return start_pedal_listener(on_press, device_path=cfg.device_path)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DAgger Strategy
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class DAggerStrategy(RolloutStrategy):
|
|
"""Human-in-the-Loop data collection with intervention tagging.
|
|
|
|
State machine::
|
|
|
|
AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED
|
|
--(key1)--> AUTONOMOUS
|
|
|
|
Recording modes:
|
|
``record_autonomous=True``: Sentry-like continuous recording with
|
|
time-based episode rotation. Intervention frames tagged True.
|
|
``record_autonomous=False``: Only correction windows recorded.
|
|
Each correction = one episode. Upload on demand via key3.
|
|
"""
|
|
|
|
config: DAggerStrategyConfig
|
|
|
|
def __init__(self, config: DAggerStrategyConfig):
|
|
super().__init__(config)
|
|
self._listener = None
|
|
self._pedal_thread = None
|
|
self._events = DAggerEvents()
|
|
self._push_executor: ThreadPoolExecutor | None = None
|
|
self._pending_push: Future | None = None
|
|
self._needs_push = Event()
|
|
self._episode_lock = Lock()
|
|
|
|
def setup(self, ctx: RolloutContext) -> None:
|
|
"""Initialise the inference engine and input device listener."""
|
|
self._init_engine(ctx)
|
|
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
|
|
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
|
self._episode_duration_s = estimate_max_episode_seconds(
|
|
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
|
|
)
|
|
|
|
if self.config.input_device == "keyboard":
|
|
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
|
|
else:
|
|
self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal)
|
|
|
|
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
|
|
logger.info(
|
|
"DAgger strategy ready (input=%s, episodes=%d, record=%s, episode_duration=%.0fs)",
|
|
self.config.input_device,
|
|
self.config.num_episodes,
|
|
record_mode,
|
|
self._episode_duration_s,
|
|
)
|
|
|
|
def run(self, ctx: RolloutContext) -> None:
|
|
"""Run DAgger episodes with human-in-the-loop intervention."""
|
|
if self.config.record_autonomous:
|
|
self._run_continuous(ctx)
|
|
else:
|
|
self._run_corrections_only(ctx)
|
|
|
|
def teardown(self, ctx: RolloutContext) -> None:
|
|
"""Stop listeners, finalise the dataset, and disconnect hardware."""
|
|
play_sounds = ctx.runtime.cfg.play_sounds
|
|
logger.info("Stopping DAgger recording")
|
|
log_say("Stopping DAgger recording", play_sounds)
|
|
|
|
if self._listener is not None and not is_headless():
|
|
logger.info("Stopping keyboard listener")
|
|
self._listener.stop()
|
|
|
|
# Flush any queued/running push cleanly
|
|
if self._push_executor is not None:
|
|
logger.info("Shutting down push executor (waiting for pending pushes)...")
|
|
self._push_executor.shutdown(wait=True)
|
|
self._push_executor = None
|
|
|
|
if ctx.data.dataset is not None:
|
|
logger.info("Finalizing dataset...")
|
|
ctx.data.dataset.finalize()
|
|
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
|
logger.info("Pushing final dataset to hub...")
|
|
if safe_push_to_hub(
|
|
ctx.data.dataset,
|
|
tags=ctx.runtime.cfg.dataset.tags,
|
|
private=ctx.runtime.cfg.dataset.private,
|
|
):
|
|
logger.info("Dataset uploaded to hub")
|
|
log_say("Dataset uploaded to hub", play_sounds)
|
|
|
|
self._teardown_hardware(ctx.hardware)
|
|
logger.info("DAgger strategy teardown complete")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Continuous recording mode (record_autonomous=True)
|
|
# ------------------------------------------------------------------
|
|
|
|
def _run_continuous(self, ctx: RolloutContext) -> None:
|
|
"""Sentry-like continuous recording with intervention tagging.
|
|
|
|
Episodes are auto-rotated every ``episode_time_s`` seconds and
|
|
uploaded in the background every ``upload_every_n_episodes`` episodes.
|
|
Both autonomous and correction frames are recorded; corrections are
|
|
tagged with ``intervention=True``.
|
|
"""
|
|
engine = self._engine
|
|
cfg = ctx.runtime.cfg
|
|
robot = ctx.hardware.robot_wrapper
|
|
teleop = ctx.hardware.teleop
|
|
dataset = ctx.data.dataset
|
|
events = self._events
|
|
interpolator = self._interpolator
|
|
features = ctx.data.dataset_features
|
|
|
|
control_interval = interpolator.get_control_interval(cfg.fps)
|
|
record_stride = max(1, cfg.interpolation_multiplier)
|
|
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
|
play_sounds = cfg.play_sounds
|
|
|
|
engine.reset()
|
|
interpolator.reset()
|
|
events.reset()
|
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
|
# teleop.disable_torque()
|
|
engine.resume()
|
|
|
|
last_action: dict[str, Any] | None = None
|
|
record_tick = 0
|
|
start_time = time.perf_counter()
|
|
episode_start = time.perf_counter()
|
|
episodes_since_push = 0
|
|
episode_duration_s = self._episode_duration_s
|
|
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
|
|
|
|
with VideoEncodingManager(dataset):
|
|
try:
|
|
while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set():
|
|
loop_start = time.perf_counter()
|
|
|
|
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
|
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
|
break
|
|
|
|
# Process transitions
|
|
transition = events.consume_transition()
|
|
if transition is not None:
|
|
old_phase, new_phase = transition
|
|
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
|
last_action = None
|
|
|
|
phase = events.phase
|
|
obs = robot.get_observation()
|
|
obs_processed = ctx.processors.robot_observation_processor(obs)
|
|
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
|
|
|
# --- CORRECTING: human teleop control ---
|
|
if phase == DAggerPhase.CORRECTING:
|
|
teleop_action = teleop.get_action()
|
|
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
|
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
|
robot.send_action(robot_action_to_send)
|
|
last_action = robot_action_to_send
|
|
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
|
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {
|
|
**obs_frame,
|
|
**action_frame,
|
|
"task": task_str,
|
|
"intervention": np.array([True], dtype=bool),
|
|
}
|
|
dataset.add_frame(frame)
|
|
record_tick += 1
|
|
|
|
# --- PAUSED: hold position ---
|
|
elif phase == DAggerPhase.PAUSED:
|
|
if last_action:
|
|
robot.send_action(last_action)
|
|
|
|
# --- AUTONOMOUS: policy control ---
|
|
else:
|
|
engine.notify_observation(obs_processed)
|
|
|
|
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
|
continue
|
|
|
|
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
|
if action_dict is not None:
|
|
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
|
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
|
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
frame = {
|
|
**obs_frame,
|
|
**action_frame,
|
|
"task": task_str,
|
|
"intervention": np.array([False], dtype=bool),
|
|
}
|
|
dataset.add_frame(frame)
|
|
record_tick += 1
|
|
|
|
# Episode rotation derived from video file-size target.
|
|
# Do NOT save mid-correction — wait for the correction
|
|
# to finish so the episode boundary is clean.
|
|
elapsed = time.perf_counter() - episode_start
|
|
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
|
with self._episode_lock:
|
|
dataset.save_episode()
|
|
episodes_since_push += 1
|
|
self._needs_push.set()
|
|
logger.info(
|
|
"Episode saved (total: %d, elapsed: %.1fs)",
|
|
dataset.num_episodes,
|
|
elapsed,
|
|
)
|
|
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
|
|
|
|
if episodes_since_push >= self.config.upload_every_n_episodes:
|
|
self._background_push(dataset, cfg)
|
|
episodes_since_push = 0
|
|
|
|
episode_start = time.perf_counter()
|
|
|
|
dt = time.perf_counter() - loop_start
|
|
if (sleep_t := control_interval - dt) > 0:
|
|
precise_sleep(sleep_t)
|
|
|
|
finally:
|
|
logger.info("DAgger continuous control loop ended — pausing engine")
|
|
engine.pause()
|
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
|
# teleop.disable_torque()
|
|
with contextlib.suppress(Exception):
|
|
with self._episode_lock:
|
|
dataset.save_episode()
|
|
self._needs_push.set()
|
|
logger.info("Final in-progress episode saved")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Corrections-only mode (record_autonomous=False)
|
|
# ------------------------------------------------------------------
|
|
|
|
def _run_corrections_only(self, ctx: RolloutContext) -> None:
|
|
"""Record only human correction windows. Each correction = one episode.
|
|
|
|
The policy runs autonomously without recording. When the user
|
|
pauses and starts a correction, frames are recorded with
|
|
``intervention=True``. Stopping the correction saves the episode.
|
|
The dataset can be uploaded on demand via the upload key/pedal.
|
|
"""
|
|
engine = self._engine
|
|
cfg = ctx.runtime.cfg
|
|
robot = ctx.hardware.robot_wrapper
|
|
teleop = ctx.hardware.teleop
|
|
dataset = ctx.data.dataset
|
|
events = self._events
|
|
interpolator = self._interpolator
|
|
features = ctx.data.dataset_features
|
|
|
|
control_interval = interpolator.get_control_interval(cfg.fps)
|
|
record_stride = max(1, cfg.interpolation_multiplier)
|
|
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
|
play_sounds = cfg.play_sounds
|
|
|
|
engine.reset()
|
|
interpolator.reset()
|
|
events.reset()
|
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
|
# teleop.disable_torque()
|
|
engine.resume()
|
|
|
|
last_action: dict[str, Any] | None = None
|
|
record_tick = 0
|
|
recorded = 0
|
|
logger.info(
|
|
"DAgger corrections-only recording started (target: %d episodes)", self.config.num_episodes
|
|
)
|
|
|
|
with VideoEncodingManager(dataset):
|
|
try:
|
|
while (
|
|
recorded < self.config.num_episodes
|
|
and not events.stop_recording.is_set()
|
|
and not ctx.runtime.shutdown_event.is_set()
|
|
):
|
|
loop_start = time.perf_counter()
|
|
|
|
# Process transitions
|
|
transition = events.consume_transition()
|
|
if transition is not None:
|
|
old_phase, new_phase = transition
|
|
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
|
last_action = None
|
|
|
|
# Correction ended -> save episode (blocking if not streaming)
|
|
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
|
with self._episode_lock:
|
|
dataset.save_episode()
|
|
recorded += 1
|
|
self._needs_push.set()
|
|
logger.info(
|
|
"Correction %d/%d saved",
|
|
recorded,
|
|
self.config.num_episodes,
|
|
)
|
|
log_say(f"Correction {recorded} saved", play_sounds)
|
|
|
|
# On-demand upload
|
|
if events.upload_requested.is_set():
|
|
events.upload_requested.clear()
|
|
logger.info("Upload requested by user")
|
|
self._background_push(dataset, cfg)
|
|
|
|
phase = events.phase
|
|
obs = robot.get_observation()
|
|
obs_processed = ctx.processors.robot_observation_processor(obs)
|
|
|
|
# --- CORRECTING: human teleop control + recording ---
|
|
if phase == DAggerPhase.CORRECTING:
|
|
teleop_action = teleop.get_action()
|
|
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
|
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
|
robot.send_action(robot_action_to_send)
|
|
last_action = robot_action_to_send
|
|
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
|
|
|
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
|
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
|
if record_tick % record_stride == 0:
|
|
dataset.add_frame(
|
|
{
|
|
**obs_frame,
|
|
**action_frame,
|
|
"task": task_str,
|
|
"intervention": np.array([True], dtype=bool),
|
|
}
|
|
)
|
|
record_tick += 1
|
|
|
|
# --- PAUSED: hold position ---
|
|
elif phase == DAggerPhase.PAUSED:
|
|
if last_action:
|
|
robot.send_action(last_action)
|
|
|
|
# --- AUTONOMOUS: policy control (no recording) ---
|
|
else:
|
|
engine.notify_observation(obs_processed)
|
|
|
|
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
|
continue
|
|
|
|
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
|
if action_dict is not None:
|
|
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
|
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
|
|
|
dt = time.perf_counter() - loop_start
|
|
if (sleep_t := control_interval - dt) > 0:
|
|
precise_sleep(sleep_t)
|
|
|
|
finally:
|
|
logger.info("DAgger corrections-only loop ended — pausing engine")
|
|
engine.pause()
|
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
|
# teleop.disable_torque()
|
|
with contextlib.suppress(Exception):
|
|
with self._episode_lock:
|
|
dataset.save_episode()
|
|
self._needs_push.set()
|
|
logger.info("Final in-progress episode saved")
|
|
|
|
# ------------------------------------------------------------------
|
|
# State-machine transition side-effects
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _apply_transition(
|
|
old_phase: DAggerPhase,
|
|
new_phase: DAggerPhase,
|
|
engine,
|
|
interpolator,
|
|
robot: ThreadSafeRobot,
|
|
teleop: Teleoperator,
|
|
) -> None:
|
|
"""Execute side-effects for a validated phase transition."""
|
|
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
|
|
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
|
|
logger.info("Pausing engine — robot holds position")
|
|
engine.pause()
|
|
obs = robot.get_observation()
|
|
_robot_pos = {
|
|
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
|
}
|
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
|
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
|
|
|
elif new_phase == DAggerPhase.CORRECTING:
|
|
logger.info("Entering correction mode — human teleop control")
|
|
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
|
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
|
# teleop.disable_torque()
|
|
|
|
elif new_phase == DAggerPhase.AUTONOMOUS:
|
|
logger.info("Resuming autonomous mode — resetting engine and interpolator")
|
|
interpolator.reset()
|
|
engine.reset()
|
|
engine.resume()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Background push (shared by both modes)
|
|
# ------------------------------------------------------------------
|
|
|
|
def _background_push(self, dataset, cfg) -> None:
|
|
"""Queue a Hub push on the single-worker executor.
|
|
|
|
The executor's max_workers=1 guarantees at most one push runs at
|
|
a time; submitted tasks are queued rather than dropped. Pushes
|
|
are blocked while the operator is mid-correction to avoid
|
|
uploading a partially-recorded episode.
|
|
"""
|
|
if self._push_executor is None:
|
|
return
|
|
|
|
if self._events.phase == DAggerPhase.CORRECTING:
|
|
logger.info("Skipping push — correction in progress")
|
|
return
|
|
|
|
if self._pending_push is not None and not self._pending_push.done():
|
|
logger.info("Previous push still in progress; queueing next")
|
|
|
|
def _push():
|
|
try:
|
|
with self._episode_lock:
|
|
if safe_push_to_hub(
|
|
dataset,
|
|
tags=cfg.dataset.tags if cfg.dataset else None,
|
|
private=cfg.dataset.private if cfg.dataset else False,
|
|
):
|
|
self._needs_push.clear()
|
|
logger.info("Background push to hub complete")
|
|
except Exception as e:
|
|
logger.error("Background push failed: %s", e)
|
|
|
|
self._pending_push = self._push_executor.submit(_push)
|
|
logger.info("Background push task submitted")
|