Files
lerobot/src/lerobot/rollout/strategies/core.py
T
2026-04-19 23:53:53 +02:00

273 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategy ABC and shared action-dispatch helper."""
from __future__ import annotations
import abc
import logging
import time
from typing import TYPE_CHECKING
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.utils.action_interpolator import ActionInterpolator
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.visualization_utils import log_rerun_data
from ..inference import InferenceEngine
if TYPE_CHECKING:
from ..configs import RolloutStrategyConfig
from ..context import HardwareContext, RolloutContext, RuntimeContext
logger = logging.getLogger(__name__)
class RolloutStrategy(abc.ABC):
"""Abstract base for rollout execution strategies.
Each concrete strategy implements a self-contained control loop with
its own recording/interaction semantics. Strategies are mutually
exclusive — only one runs per session.
"""
def __init__(self, config: RolloutStrategyConfig) -> None:
self.config = config
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
def _init_engine(self, ctx: RolloutContext) -> None:
"""Attach the inference engine and action interpolator, then start the backend.
Creates an :class:`ActionInterpolator` from the config's
``interpolation_multiplier`` and starts the inference engine.
Call this from ``setup()`` so strategies share identical
initialisation without duplicating code.
"""
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
self._engine = ctx.policy.inference
logger.info("Starting inference engine...")
self._engine.start()
self._warmup_flushed = False
logger.info("Inference engine started")
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
"""Handle torch.compile warmup phase.
Returns ``True`` if the caller should ``continue`` (still warming
up). On the first post-warmup iteration the engine and
interpolator are reset so stale warmup state is discarded.
"""
engine = self._engine
interpolator = self._interpolator
if not use_torch_compile:
return False
if not engine.ready:
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
return True
if not self._warmup_flushed:
logger.info("Warmup complete — flushing stale state and resuming engine")
engine.reset()
interpolator.reset()
self._warmup_flushed = True
engine.resume()
return False
def _teardown_hardware(self, hw: HardwareContext) -> None:
"""Stop the inference engine, return robot to initial position, and disconnect hardware."""
if self._engine is not None:
logger.info("Stopping inference engine...")
self._engine.stop()
robot = hw.robot_wrapper.inner
if robot.is_connected:
if hw.initial_position:
logger.info("Returning robot to initial position before shutdown...")
self._return_to_initial_position(hw)
logger.info("Disconnecting robot...")
robot.disconnect()
teleop = hw.teleop
if teleop is not None and teleop.is_connected:
logger.info("Disconnecting teleoperator...")
teleop.disconnect()
@staticmethod
def _return_to_initial_position(hw: HardwareContext, duration_s: float = 3.0, fps: int = 50) -> None:
"""Smoothly interpolate the robot back to its initial position."""
robot = hw.robot_wrapper
target = hw.initial_position
try:
current_obs = robot.get_observation()
current_pos = {k: v for k, v in current_obs.items() if k in target}
steps = max(int(duration_s * fps), 1)
for step in range(1, steps + 1):
t = step / steps
interp = {}
for k in current_pos:
interp[k] = current_pos[k] * (1 - t) + target[k] * t
robot.send_action(interp)
precise_sleep(1 / fps)
except Exception as e:
logger.warning("Could not return to initial position: %s", e)
@staticmethod
def _log_telemetry(
obs_processed: dict | None,
action_dict: dict | None,
runtime_ctx: RuntimeContext,
) -> None:
"""Log observation/action telemetry to Rerun if display_data is enabled."""
cfg = runtime_ctx.cfg
if not cfg.display_data:
return
log_rerun_data(
observation=obs_processed,
action=action_dict,
compress_images=cfg.display_compressed_images,
)
@abc.abstractmethod
def setup(self, ctx: RolloutContext) -> None:
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
@abc.abstractmethod
def run(self, ctx: RolloutContext) -> None:
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
@abc.abstractmethod
def teardown(self, ctx: RolloutContext) -> None:
"""Cleanup: save dataset, stop threads, disconnect hardware."""
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def safe_push_to_hub(dataset, tags=None, private=False) -> bool:
"""Push dataset to hub, skipping if no episodes have been saved.
Returns ``True`` if the push was attempted, ``False`` if skipped.
"""
if dataset.num_episodes == 0:
logger.warning("No episodes saved — skipping push to hub")
return False
dataset.push_to_hub(tags=tags, private=private)
return True
def estimate_max_episode_seconds(
dataset_features: dict,
fps: float,
target_size_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
) -> float:
"""Conservatively estimate how many seconds of video will exceed *target_size_mb*.
Each camera produces its own video file, so the episode duration is
driven by the **slowest** camera to fill ``target_size_mb`` — i.e.
the one with the fewest pixels per frame (lowest bitrate).
Uses a deliberately **low** bits-per-pixel estimate so the computed
duration is *longer* than reality. By the time the timer fires the
actual video file is guaranteed to have crossed the target size,
which aligns episode boundaries with the dataset's video-file
chunking — each ``push_to_hub`` uploads complete files rather than
re-uploading a still-growing one.
The estimate ignores codec-specific settings (CRF, preset) on purpose:
we only need a rough lower bound on bitrate, not a precise prediction.
Falls back to 600 s (10 min) when no video features are present.
"""
# 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of
# robot footage (real-world is typically 0.1 0.3 bpp). Under-
# estimating the bitrate over-estimates the time → the episode will be
# *larger* than target_size_mb when we save, which is what we want.
conservative_bpp = 0.1
# Collect per-camera pixel counts — each camera has its own video file.
camera_pixels = []
for feat in dataset_features.values():
if feat.get("dtype") == "video":
shape = feat.get("shape", ())
# Assuming shape could be (C, H, W) or (T, C, H, W)
# We want to extract the spatial dimensions.
if len(shape) >= 3:
h, w = shape[-2], shape[-1]
pixels = h * w
if pixels > 0:
camera_pixels.append(pixels)
if not camera_pixels:
return 600.0
# Use the smallest camera: it produces the lowest bitrate and therefore
# takes the longest to reach the target — the conservative choice.
min_pixels = min(camera_pixels)
bits_per_frame = min_pixels * conservative_bpp
bytes_per_second = (bits_per_frame * fps) / 8
# Guard against division by zero just in case
if bytes_per_second <= 0:
return 600.0
return (target_size_mb * 1024 * 1024) / bytes_per_second
# ---------------------------------------------------------------------------
# Shared action-dispatch helper
# ---------------------------------------------------------------------------
def send_next_action(
obs_processed: dict,
obs_raw: dict,
ctx: RolloutContext,
interpolator: ActionInterpolator,
) -> dict | None:
"""Dispatch the next action to the robot.
Pulls the next action tensor from the inference engine, feeds the
interpolator, and sends the interpolated action through the
``robot_action_processor`` to the robot. Works identically for
sync and async backends — the rollout strategy never needs to branch.
Returns the action dict that was sent, or ``None`` if no action was
ready (e.g. empty async queue, interpolator not yet primed).
"""
engine = ctx.policy.inference
features = ctx.data.dataset_features
ordered_keys = ctx.data.ordered_action_keys
if interpolator.needs_new_action():
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action(obs_frame)
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
interp = interpolator.get()
if interp is None:
return None
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
ctx.hardware.robot_wrapper.send_action(processed)
return action_dict