filesize default change + more logs + filesize mb based episode + go back to init pos + rerun log + date end of repo_id

This commit is contained in:
Steven Palma
2026-04-19 16:50:19 +02:00
parent 8cee56e2d6
commit 32a27cae8a
15 changed files with 406 additions and 54 deletions
+6
View File
@@ -15,6 +15,7 @@
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@@ -69,3 +70,8 @@ class DatasetRecordConfig:
encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
if self.repo_id:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.repo_id = f"{self.repo_id}_{timestamp}"
+2 -2
View File
@@ -71,8 +71,8 @@ class ForwardCompatibilityError(CompatibilityError):
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
DEFAULT_DATA_FILE_SIZE_IN_MB = 50 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 100 # Max size per file
INFO_PATH = "meta/info.json"
STATS_PATH = "meta/stats.json"
+24 -6
View File
@@ -62,12 +62,19 @@ class BaseStrategyConfig(RolloutStrategyConfig):
class SentryStrategyConfig(RolloutStrategyConfig):
"""Continuous autonomous rollout with always-on recording.
Episodes are auto-rotated every ``episode_duration_s`` seconds and
uploaded in the background every ``upload_every_n_episodes`` episodes.
Episode duration is derived from camera resolution, FPS, and
``target_video_file_size_mb`` so that each saved episode produces a
video file that has crossed the target size. This aligns episode
boundaries with the dataset's video file chunking, so each
``push_to_hub`` call uploads complete video files rather than
re-uploading a growing file that hasn't crossed the chunk boundary.
"""
episode_duration_s: float = 20.0
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation. Episodes are
# saved once the estimated video duration would exceed this limit.
# Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None.
target_video_file_size_mb: float | None = None
@RolloutStrategyConfig.register_subclass("highlight")
@@ -129,15 +136,18 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
3. **upload** push dataset to hub on demand (corrections-only mode).
When ``record_autonomous=True`` (default) both autonomous and correction
frames are recorded with sentry-like time-based episode rotation and
background uploading. Set to ``False`` to record only the human-correction
frames are recorded with size-based episode rotation (same as Sentry)
and background uploading. ``push_to_hub`` is blocked while a correction
is in progress. Set to ``False`` to record only the human-correction
windows, where each correction becomes its own episode.
"""
episode_time_s: float = 20.0
num_episodes: int = 10
record_autonomous: bool = False
upload_every_n_episodes: int = 5
# Target video file size in MB for episode rotation (record_autonomous
# mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None.
target_video_file_size_mb: float | None = None
input_device: str = "keyboard"
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
@@ -184,6 +194,14 @@ class RolloutConfig:
device: str | None = None
task: str = ""
display_data: bool = False
# Display data on a remote Rerun server
display_ip: str | None = None
# Port of the remote Rerun server
display_port: int | None = None
# Whether to display compressed images in Rerun
display_compressed_images: bool = False
# Use vocal synthesis to read events
play_sounds: bool = True
resume: bool = False
# Torch compile
+29 -1
View File
@@ -98,10 +98,15 @@ class HardwareContext:
The raw robot is available via ``robot_wrapper.inner`` when needed
(e.g. for disconnect); strategies should otherwise go through the
thread-safe wrapper.
``initial_position`` stores the robot's joint positions at connect
time. Strategies use it to return the robot to a safe pose before
shutting down.
"""
robot_wrapper: ThreadSafeRobot
teleop: Teleoperator | None
initial_position: dict | None = None
@dataclass
@@ -167,6 +172,7 @@ def build_rollout_context(
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
policy_config = cfg.policy
policy_class = get_policy_class(policy_config.type)
@@ -199,6 +205,7 @@ def build_rollout_context(
policy = policy.to(cfg.device)
policy.eval()
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
try:
@@ -225,14 +232,24 @@ def build_rollout_context(
robot_observation_processor = robot_observation_processor or _o
# --- 3. Hardware (heaviest side-effect, deferred) -----------------
logger.info("Connecting robot (%s)...", cfg.robot.type if cfg.robot else "?")
robot = make_robot_from_config(cfg.robot)
robot.connect()
logger.info("Robot connected: %s", robot.name)
# Store the initial joint positions so we can return to a safe pose on shutdown.
initial_obs = robot.get_observation()
initial_position = {k: v for k, v in initial_obs.items() if k.endswith(".pos")}
logger.info("Captured initial robot position (%d keys)", len(initial_position))
robot_wrapper = ThreadSafeRobot(robot)
teleop = None
if cfg.teleop is not None:
logger.info("Connecting teleoperator (%s)...", cfg.teleop.type if cfg.teleop else "?")
teleop = make_teleoperator_from_config(cfg.teleop)
teleop.connect()
logger.info("Teleoperator connected")
# DAgger requires teleop with motor control capabilities (enable_torque,
# disable_torque, write_goal_positions).
@@ -280,6 +297,7 @@ def build_rollout_context(
# --- 5. Dataset -------------
dataset = None
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
logger.info("Setting up dataset (repo_id=%s)...", cfg.dataset.repo_id)
if cfg.resume:
dataset = LeRobotDataset.resume(
cfg.dataset.repo_id,
@@ -318,6 +336,9 @@ def build_rollout_context(
encoder_threads=cfg.dataset.encoder_threads,
)
if dataset is not None:
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
dataset_stats = None
if dataset is not None:
@@ -337,6 +358,10 @@ def build_rollout_context(
)
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
logger.info(
"Creating inference engine (type=%s)...",
cfg.inference.type if hasattr(cfg.inference, "type") else "sync",
)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
inference_strategy = create_inference_engine(
cfg.inference,
@@ -356,9 +381,12 @@ def build_rollout_context(
)
# --- 8. Assemble ---------------------------------------------------
logger.info("Rollout context assembled successfully")
return RolloutContext(
runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event),
hardware=HardwareContext(robot_wrapper=robot_wrapper, teleop=teleop),
hardware=HardwareContext(
robot_wrapper=robot_wrapper, teleop=teleop, initial_position=initial_position
),
policy=PolicyContext(
policy=policy,
preprocessor=preprocessor,
+1
View File
@@ -98,6 +98,7 @@ def create_inference_engine(
shutdown_event: Event | None = None,
) -> InferenceEngine:
"""Instantiate the appropriate inference engine from a config object."""
logger.info("Creating inference engine: %s", config.type)
if isinstance(config, SyncInferenceConfig):
return SyncInferenceEngine(
policy=policy,
+14
View File
@@ -158,6 +158,12 @@ class RTCInferenceEngine(InferenceEngine):
if not self._use_torch_compile:
self._compile_warmup_done.set()
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
else:
logger.info(
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
compile_warmup_inferences,
)
# Processor introspection for relative-action re-anchoring.
self._relative_step = next(
@@ -216,22 +222,30 @@ class RTCInferenceEngine(InferenceEngine):
def stop(self) -> None:
"""Signal the RTC thread to stop and wait for it."""
logger.info("Stopping RTC inference thread...")
self._shutdown_event.set()
self._policy_active.clear()
if self._rtc_thread is not None and self._rtc_thread.is_alive():
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
if self._rtc_thread.is_alive():
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
else:
logger.info("RTC inference thread stopped")
self._rtc_thread = None
def pause(self) -> None:
"""Pause the RTC background thread."""
logger.info("Pausing RTC inference thread")
self._policy_active.clear()
def resume(self) -> None:
"""Resume the RTC background thread."""
logger.info("Resuming RTC inference thread")
self._policy_active.set()
def reset(self) -> None:
"""Reset the policy, processors, and action queue."""
logger.info("Resetting RTC inference state (policy + processors + queue)")
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
+8
View File
@@ -58,15 +58,23 @@ class SyncInferenceEngine(InferenceEngine):
self._task = task
self._device = torch.device(device or "cpu")
self._robot_type = robot_type
logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device,
len(ordered_action_keys),
)
def start(self) -> None:
"""No background resources to start."""
logger.info("SyncInferenceEngine started (inline mode — no background thread)")
def stop(self) -> None:
"""No background resources to stop."""
logger.info("SyncInferenceEngine stopped")
def reset(self) -> None:
"""Reset the policy and pre/post-processors."""
logger.info("Resetting sync inference state (policy + processors)")
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
+3 -1
View File
@@ -14,11 +14,13 @@
"""Rollout strategies — public API re-exports."""
from .core import RolloutStrategy, send_next_action
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .factory import create_strategy
__all__ = [
"RolloutStrategy",
"create_strategy",
"estimate_max_episode_seconds",
"safe_push_to_hub",
"send_next_action",
]
+4 -1
View File
@@ -50,11 +50,13 @@ class BaseStrategy(RolloutStrategy):
start_time = time.perf_counter()
engine.resume()
logger.info("Base strategy control loop started")
while not ctx.runtime.shutdown_event.is_set():
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
@@ -64,7 +66,8 @@ class BaseStrategy(RolloutStrategy):
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
send_next_action(obs_processed, obs, ctx, interpolator)
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
+127 -2
View File
@@ -17,19 +17,24 @@
from __future__ import annotations
import abc
import logging
import time
from typing import TYPE_CHECKING
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.action_interpolator import ActionInterpolator
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import log_rerun_data
from ..inference import InferenceEngine
if TYPE_CHECKING:
from ..configs import RolloutStrategyConfig
from ..context import HardwareContext, RolloutContext
from ..context import HardwareContext, RolloutContext, RuntimeContext
logger = logging.getLogger(__name__)
class RolloutStrategy(abc.ABC):
@@ -54,8 +59,10 @@ class RolloutStrategy(abc.ABC):
"""
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
self._engine = ctx.policy.inference
logger.info("Starting inference engine...")
self._engine.start()
self._warmup_flushed = False
logger.info("Inference engine started")
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
"""Handle torch.compile warmup phase.
@@ -74,6 +81,7 @@ class RolloutStrategy(abc.ABC):
precise_sleep(sleep_t)
return True
if not self._warmup_flushed:
logger.info("Warmup complete — flushing stale state and resuming engine")
engine.reset()
interpolator.reset()
self._warmup_flushed = True
@@ -81,16 +89,57 @@ class RolloutStrategy(abc.ABC):
return False
def _teardown_hardware(self, hw: HardwareContext) -> None:
"""Stop the inference engine and disconnect hardware."""
"""Stop the inference engine, return robot to initial position, and disconnect hardware."""
if self._engine is not None:
logger.info("Stopping inference engine...")
self._engine.stop()
robot = hw.robot_wrapper.inner
if robot.is_connected:
if hw.initial_position:
logger.info("Returning robot to initial position before shutdown...")
self._return_to_initial_position(hw)
logger.info("Disconnecting robot...")
robot.disconnect()
teleop = hw.teleop
if teleop is not None and teleop.is_connected:
logger.info("Disconnecting teleoperator...")
teleop.disconnect()
@staticmethod
def _return_to_initial_position(hw: HardwareContext, duration_s: float = 3.0, fps: int = 50) -> None:
"""Smoothly interpolate the robot back to its initial position."""
robot = hw.robot_wrapper
target = hw.initial_position
try:
current_obs = robot.get_observation()
current_pos = {k: v for k, v in current_obs.items() if k in target}
steps = max(int(duration_s * fps), 1)
for step in range(1, steps + 1):
t = step / steps
interp = {}
for k in current_pos:
interp[k] = current_pos[k] * (1 - t) + target[k] * t
robot.send_action(interp)
precise_sleep(1 / fps)
except Exception as e:
logger.warning("Could not return to initial position: %s", e)
@staticmethod
def _log_telemetry(
obs_processed: dict | None,
action_dict: dict | None,
runtime_ctx: RuntimeContext,
) -> None:
"""Log observation/action telemetry to Rerun if display_data is enabled."""
cfg = runtime_ctx.cfg
if not cfg.display_data:
return
log_rerun_data(
observation=obs_processed,
action=action_dict,
compress_images=cfg.display_compressed_images,
)
@abc.abstractmethod
def setup(self, ctx: RolloutContext) -> None:
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
@@ -104,6 +153,82 @@ class RolloutStrategy(abc.ABC):
"""Cleanup: save dataset, stop threads, disconnect hardware."""
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def safe_push_to_hub(dataset, tags=None, private=False) -> bool:
"""Push dataset to hub, skipping if no episodes have been saved.
Returns ``True`` if the push was attempted, ``False`` if skipped.
"""
if dataset.num_episodes == 0:
logger.warning("No episodes saved — skipping push to hub")
return False
dataset.push_to_hub(tags=tags, private=private)
return True
def estimate_max_episode_seconds(
dataset_features: dict,
fps: float,
target_size_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
) -> float:
"""Conservatively estimate how many seconds of video will exceed *target_size_mb*.
Each camera produces its own video file, so the episode duration is
driven by the **slowest** camera to fill ``target_size_mb`` i.e.
the one with the fewest pixels per frame (lowest bitrate).
Uses a deliberately **low** bits-per-pixel estimate so the computed
duration is *longer* than reality. By the time the timer fires the
actual video file is guaranteed to have crossed the target size,
which aligns episode boundaries with the dataset's video-file
chunking each ``push_to_hub`` uploads complete files rather than
re-uploading a still-growing one.
The estimate ignores codec-specific settings (CRF, preset) on purpose:
we only need a rough lower bound on bitrate, not a precise prediction.
Falls back to 600 s (10 min) when no video features are present.
"""
# 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of
# robot footage (real-world is typically 0.1 0.3 bpp). Under-
# estimating the bitrate over-estimates the time → the episode will be
# *larger* than target_size_mb when we save, which is what we want.
conservative_bpp = 0.1
# Collect per-camera pixel counts — each camera has its own video file.
camera_pixels = []
for feat in dataset_features.values():
if feat.get("dtype") == "video":
shape = feat.get("shape", ())
# Assuming shape could be (C, H, W) or (T, C, H, W)
# We want to extract the spatial dimensions.
if len(shape) >= 3:
h, w = shape[-2], shape[-1]
pixels = h * w
if pixels > 0:
camera_pixels.append(pixels)
if not camera_pixels:
return 600.0
# Use the smallest camera: it produces the lowest bitrate and therefore
# takes the longest to reach the target — the conservative choice.
min_pixels = min(camera_pixels)
bits_per_frame = min_pixels * conservative_bpp
bytes_per_second = (bits_per_frame * fps) / 8
# Guard against division by zero just in case
if bytes_per_second <= 0:
return 600.0
return (target_size_mb * 1024 * 1024) / bytes_per_second
# ---------------------------------------------------------------------------
# Shared action-dispatch helper
# ---------------------------------------------------------------------------
+80 -15
View File
@@ -50,17 +50,19 @@ import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
from lerobot.utils.pedal import start_pedal_listener
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, send_next_action
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
@@ -249,6 +251,12 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
listener = keyboard.Listener(on_press=on_press)
listener.start()
logger.info(
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
cfg.pause_resume,
cfg.correction,
cfg.upload,
)
return listener
@@ -268,6 +276,7 @@ def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
if code == cfg.upload:
events.upload_requested.set()
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
return start_pedal_listener(on_press, device_path=cfg.device_path)
@@ -307,6 +316,10 @@ class DAggerStrategy(RolloutStrategy):
"""Initialise the inference engine and input device listener."""
self._init_engine(ctx)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
self._episode_duration_s = estimate_max_episode_seconds(
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
)
if self.config.input_device == "keyboard":
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
@@ -315,10 +328,11 @@ class DAggerStrategy(RolloutStrategy):
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
logger.info(
"DAgger strategy ready (input=%s, episodes=%d, record=%s)",
"DAgger strategy ready (input=%s, episodes=%d, record=%s, episode_duration=%.0fs)",
self.config.input_device,
self.config.num_episodes,
record_mode,
self._episode_duration_s,
)
def run(self, ctx: RolloutContext) -> None:
@@ -330,21 +344,32 @@ class DAggerStrategy(RolloutStrategy):
def teardown(self, ctx: RolloutContext) -> None:
"""Stop listeners, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping DAgger recording")
log_say("Stopping DAgger recording", play_sounds)
if self._listener is not None and not is_headless():
logger.info("Stopping keyboard listener")
self._listener.stop()
# Flush any queued/running push cleanly
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
ctx.data.dataset.push_to_hub(
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
)
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("DAgger strategy teardown complete")
@@ -373,6 +398,7 @@ class DAggerStrategy(RolloutStrategy):
control_interval = interpolator.get_control_interval(cfg.fps)
record_stride = max(1, cfg.interpolation_multiplier)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
play_sounds = cfg.play_sounds
engine.reset()
interpolator.reset()
@@ -384,9 +410,11 @@ class DAggerStrategy(RolloutStrategy):
last_action: dict[str, Any] | None = None
record_tick = 0
episode_start = time.perf_counter()
start_time = time.perf_counter()
episode_start = time.perf_counter()
episodes_since_push = 0
episode_duration_s = self._episode_duration_s
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
with VideoEncodingManager(dataset):
try:
@@ -394,6 +422,7 @@ class DAggerStrategy(RolloutStrategy):
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
# Process transitions
@@ -415,6 +444,7 @@ class DAggerStrategy(RolloutStrategy):
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
@@ -440,6 +470,7 @@ class DAggerStrategy(RolloutStrategy):
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
last_action = ctx.processors.robot_action_processor((action_dict, obs))
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
if record_tick % record_stride == 0:
@@ -452,14 +483,21 @@ class DAggerStrategy(RolloutStrategy):
dataset.add_frame(frame)
record_tick += 1
# Sentry-like episode rotation
# Episode rotation derived from video file-size target.
# Do NOT save mid-correction — wait for the correction
# to finish so the episode boundary is clean.
elapsed = time.perf_counter() - episode_start
if elapsed >= self.config.episode_time_s:
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
with self._episode_lock:
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
if episodes_since_push >= self.config.upload_every_n_episodes:
self._background_push(dataset, cfg)
@@ -472,6 +510,7 @@ class DAggerStrategy(RolloutStrategy):
precise_sleep(sleep_t)
finally:
logger.info("DAgger continuous control loop ended — pausing engine")
engine.pause()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
@@ -480,6 +519,7 @@ class DAggerStrategy(RolloutStrategy):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# Corrections-only mode (record_autonomous=False)
@@ -505,6 +545,7 @@ class DAggerStrategy(RolloutStrategy):
control_interval = interpolator.get_control_interval(cfg.fps)
record_stride = max(1, cfg.interpolation_multiplier)
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
play_sounds = cfg.play_sounds
engine.reset()
interpolator.reset()
@@ -517,6 +558,9 @@ class DAggerStrategy(RolloutStrategy):
last_action: dict[str, Any] | None = None
record_tick = 0
recorded = 0
logger.info(
"DAgger corrections-only recording started (target: %d episodes)", self.config.num_episodes
)
with VideoEncodingManager(dataset):
try:
@@ -540,11 +584,17 @@ class DAggerStrategy(RolloutStrategy):
dataset.save_episode()
recorded += 1
self._needs_push.set()
logger.info("Episode %d saved", recorded)
logger.info(
"Correction %d/%d saved",
recorded,
self.config.num_episodes,
)
log_say(f"Correction {recorded} saved", play_sounds)
# On-demand upload
if events.upload_requested.is_set():
events.upload_requested.clear()
logger.info("Upload requested by user")
self._background_push(dataset, cfg)
phase = events.phase
@@ -558,6 +608,7 @@ class DAggerStrategy(RolloutStrategy):
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
last_action = robot_action_to_send
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
@@ -586,6 +637,7 @@ class DAggerStrategy(RolloutStrategy):
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
last_action = ctx.processors.robot_action_processor((action_dict, obs))
dt = time.perf_counter() - loop_start
@@ -593,6 +645,7 @@ class DAggerStrategy(RolloutStrategy):
precise_sleep(sleep_t)
finally:
logger.info("DAgger corrections-only loop ended — pausing engine")
engine.pause()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
@@ -601,6 +654,7 @@ class DAggerStrategy(RolloutStrategy):
with self._episode_lock:
dataset.save_episode()
self._needs_push.set()
logger.info("Final in-progress episode saved")
# ------------------------------------------------------------------
# State-machine transition side-effects
@@ -616,7 +670,9 @@ class DAggerStrategy(RolloutStrategy):
teleop: Teleoperator,
) -> None:
"""Execute side-effects for a validated phase transition."""
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
logger.info("Pausing engine — robot holds position")
engine.pause()
obs = robot.get_observation()
_robot_pos = {
@@ -627,12 +683,13 @@ class DAggerStrategy(RolloutStrategy):
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
elif new_phase == DAggerPhase.CORRECTING:
logger.info("Entering correction mode — human teleop control")
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
pass
elif new_phase == DAggerPhase.AUTONOMOUS:
logger.info("Resuming autonomous mode — resetting engine and interpolator")
interpolator.reset()
engine.reset()
engine.resume()
@@ -645,24 +702,32 @@ class DAggerStrategy(RolloutStrategy):
"""Queue a Hub push on the single-worker executor.
The executor's max_workers=1 guarantees at most one push runs at
a time; submitted tasks are queued rather than dropped.
a time; submitted tasks are queued rather than dropped. Pushes
are blocked while the operator is mid-correction to avoid
uploading a partially-recorded episode.
"""
if self._push_executor is None:
return
if self._events.phase == DAggerPhase.CORRECTING:
logger.info("Skipping push — correction in progress")
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
with self._episode_lock:
dataset.push_to_hub(
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
self._needs_push.clear()
logger.info("Background push to hub complete")
):
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
+38 -7
View File
@@ -30,11 +30,12 @@ from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available, require_package
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import HighlightStrategyConfig
from ..context import RolloutContext
from ..ring_buffer import RolloutRingBuffer
from .core import RolloutStrategy, send_next_action
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
keyboard = None
@@ -91,6 +92,11 @@ class HighlightStrategy(RolloutStrategy):
)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push")
logger.info(
"Ring buffer initialized (max_seconds=%.0f, max_memory=%.0fMB)",
self.config.ring_buffer_seconds,
self.config.ring_buffer_max_memory_mb,
)
self._setup_keyboard(ctx.runtime.shutdown_event)
logger.info(
"Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')",
@@ -112,9 +118,11 @@ class HighlightStrategy(RolloutStrategy):
control_interval = interpolator.get_control_interval(cfg.fps)
engine.resume()
play_sounds = cfg.play_sounds
start_time = time.perf_counter()
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
logger.info("Highlight strategy recording started (press '%s' to save)", self.config.save_key)
with VideoEncodingManager(dataset):
try:
@@ -122,6 +130,7 @@ class HighlightStrategy(RolloutStrategy):
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
@@ -134,6 +143,7 @@ class HighlightStrategy(RolloutStrategy):
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
@@ -159,11 +169,16 @@ class HighlightStrategy(RolloutStrategy):
else:
dataset.add_frame(frame)
dataset.save_episode()
logger.info("Episode saved")
logger.info("Episode saved (total: %d)", dataset.num_episodes)
log_say(
f"Episode {dataset.num_episodes} saved",
play_sounds,
)
self._recording_live.clear()
if self._push_requested.is_set():
self._push_requested.clear()
logger.info("Push requested by user")
self._background_push(dataset, cfg)
if self._recording_live.is_set():
@@ -176,26 +191,39 @@ class HighlightStrategy(RolloutStrategy):
precise_sleep(sleep_t)
finally:
logger.info("Highlight control loop ended")
if self._recording_live.is_set():
logger.info("Saving in-progress live episode")
with contextlib.suppress(Exception):
dataset.save_episode()
def teardown(self, ctx: RolloutContext) -> None:
"""Stop listeners, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping highlight recording")
log_say("Stopping highlight recording", play_sounds)
if self._listener is not None:
logger.info("Stopping keyboard listener")
self._listener.stop()
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
ctx.data.dataset.push_to_hub(
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
)
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("Highlight strategy teardown complete")
@@ -222,6 +250,7 @@ class HighlightStrategy(RolloutStrategy):
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
except ImportError:
logger.warning("pynput not available — keyboard listener disabled")
@@ -235,12 +264,14 @@ class HighlightStrategy(RolloutStrategy):
def _push():
try:
dataset.push_to_hub(
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
logger.info("Background push to hub complete")
):
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
+52 -13
View File
@@ -23,13 +23,15 @@ from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event, Lock
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import SentryStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, send_next_action
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
logger = logging.getLogger(__name__)
@@ -37,9 +39,15 @@ logger = logging.getLogger(__name__)
class SentryStrategy(RolloutStrategy):
"""Continuous autonomous rollout with always-on recording.
Episodes are auto-rotated every ``episode_duration_s`` seconds. The
dataset is pushed to the Hub via a bounded single-worker executor so
no push is ever silently dropped and exactly one push runs at a time.
Episode duration is derived from camera resolution, FPS, and
``DEFAULT_VIDEO_FILE_SIZE_IN_MB`` so that each saved episode
produces a video file that has crossed the chunk-size boundary.
This keeps ``push_to_hub`` efficient it uploads complete video
files rather than re-uploading a still-growing one.
The dataset is pushed to the Hub via a bounded single-worker executor
so no push is ever silently dropped and exactly one push runs at a
time.
Policy state (hidden state, RTC queue) intentionally persists across
episode boundaries Sentry slices one continuous rollout, the robot
@@ -62,9 +70,13 @@ class SentryStrategy(RolloutStrategy):
"""Initialise the inference engine and background push executor."""
self._init_engine(ctx)
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="sentry-push")
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
self._episode_duration_s = estimate_max_episode_seconds(
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
)
logger.info(
"Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)",
self.config.episode_duration_s,
self._episode_duration_s,
self.config.upload_every_n_episodes,
)
@@ -80,11 +92,14 @@ class SentryStrategy(RolloutStrategy):
control_interval = interpolator.get_control_interval(cfg.fps)
engine.resume()
play_sounds = cfg.play_sounds
episode_duration_s = self._episode_duration_s
start_time = time.perf_counter()
episode_start = time.perf_counter()
episodes_since_push = 0
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
logger.info("Sentry recording started (episode_duration=%.0fs)", episode_duration_s)
with VideoEncodingManager(dataset):
try:
@@ -92,6 +107,7 @@ class SentryStrategy(RolloutStrategy):
loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
obs = robot.get_observation()
@@ -104,6 +120,7 @@ class SentryStrategy(RolloutStrategy):
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
@@ -113,8 +130,12 @@ class SentryStrategy(RolloutStrategy):
# ``add_frame`` does not need ``_episode_lock``.
dataset.add_frame(frame)
# Episode rotation derived from video file-size target.
# The duration is a conservative estimate so the actual
# video has crossed DEFAULT_VIDEO_FILE_SIZE_IN_MB by now,
# keeping push_to_hub efficient (uploads complete files).
elapsed = time.perf_counter() - episode_start
if elapsed >= self.config.episode_duration_s:
if elapsed >= episode_duration_s:
# ``save_episode`` finalises the in-progress episode and
# flushes it to disk; ``_episode_lock`` serialises this with
# ``push_to_hub`` (run in the background executor) so the
@@ -123,7 +144,12 @@ class SentryStrategy(RolloutStrategy):
dataset.save_episode()
episodes_since_push += 1
self._needs_push.set()
logger.info("Episode saved (total: %d)", dataset.num_episodes)
logger.info(
"Episode saved (total: %d, elapsed: %.1fs)",
dataset.num_episodes,
elapsed,
)
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
if episodes_since_push >= self.config.upload_every_n_episodes:
self._background_push(dataset, cfg)
@@ -136,6 +162,7 @@ class SentryStrategy(RolloutStrategy):
precise_sleep(sleep_t)
finally:
logger.info("Sentry control loop ended — saving final episode")
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
@@ -143,18 +170,28 @@ class SentryStrategy(RolloutStrategy):
def teardown(self, ctx: RolloutContext) -> None:
"""Flush pending pushes, finalise the dataset, and disconnect hardware."""
play_sounds = ctx.runtime.cfg.play_sounds
logger.info("Stopping sentry recording")
log_say("Stopping sentry recording", play_sounds)
# Flush any queued/running push cleanly.
if self._push_executor is not None:
logger.info("Shutting down push executor (waiting for pending pushes)...")
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
ctx.data.dataset.push_to_hub(
logger.info("Pushing final dataset to hub...")
if safe_push_to_hub(
ctx.data.dataset,
tags=ctx.runtime.cfg.dataset.tags,
private=ctx.runtime.cfg.dataset.private,
)
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(ctx.hardware)
logger.info("Sentry strategy teardown complete")
@@ -174,13 +211,15 @@ class SentryStrategy(RolloutStrategy):
def _push():
try:
with self._episode_lock:
dataset.push_to_hub(
if safe_push_to_hub(
dataset,
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
self._needs_push.clear()
logger.info("Background push to hub complete")
):
self._needs_push.clear()
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
logger.info("Background push task submitted")
+4 -1
View File
@@ -484,7 +484,10 @@ def record(
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
if dataset and dataset.num_episodes > 0:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
else:
logging.warning("No episodes saved — skipping push to hub")
log_say("Exiting", cfg.play_sounds)
return dataset
+14 -5
View File
@@ -44,7 +44,6 @@ Usage examples::
# Sentry mode (continuous recording)
lerobot-rollout \\
--strategy.type=sentry \\
--strategy.episode_duration_s=120 \\
--strategy.upload_every_n_episodes=5 \\
--policy.path=lerobot/pi0_base \\
--inference.type=rtc \\
@@ -82,9 +81,7 @@ from lerobot.robots import ( # noqa: F401
so_follower,
unitree_g1 as unitree_g1_robot,
)
from lerobot.rollout.configs import RolloutConfig
from lerobot.rollout.context import build_rollout_context
from lerobot.rollout.strategies import create_strategy
from lerobot.rollout import RolloutConfig, build_rollout_context, create_strategy
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
@@ -102,6 +99,7 @@ from lerobot.teleoperators import ( # noqa: F401
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.utils import init_logging
from lerobot.utils.visualization_utils import init_rerun
logger = logging.getLogger(__name__)
@@ -111,6 +109,10 @@ def rollout(cfg: RolloutConfig):
"""Main entry point for policy deployment."""
init_logging()
if cfg.display_data:
logger.info("Initializing Rerun visualization (ip=%s, port=%s)", cfg.display_ip, cfg.display_port)
init_rerun(session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
@@ -118,10 +120,17 @@ def rollout(cfg: RolloutConfig):
ctx = build_rollout_context(cfg, shutdown_event)
strategy = create_strategy(cfg.strategy)
logger.info("Strategy: %s", cfg.strategy.type)
logger.info("Rollout strategy: %s", cfg.strategy.type)
logger.info(
"Robot: %s | FPS: %.0f | Duration: %s",
cfg.robot.type if cfg.robot else "?",
cfg.fps,
f"{cfg.duration}s" if cfg.duration > 0 else "infinite",
)
try:
strategy.setup(ctx)
logger.info("Rollout setup complete, starting rollout...")
strategy.run(ctx)
except KeyboardInterrupt:
logger.info("Interrupted by user")