From 3eda5712d35a8262375784b28731964263d239bf Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 16 Apr 2026 15:52:23 +0200 Subject: [PATCH] some more iterations --- src/lerobot/configs/__init__.py | 2 + src/lerobot/configs/dataset.py | 71 +++++++++++ src/lerobot/policies/__init__.py | 3 +- .../policies/rtc/action_interpolator.py | 118 +----------------- src/lerobot/rollout/__init__.py | 24 ++-- src/lerobot/rollout/configs.py | 36 +----- src/lerobot/rollout/context.py | 12 +- src/lerobot/rollout/inference/__init__.py | 22 ++-- src/lerobot/rollout/inference/base.py | 8 +- src/lerobot/rollout/inference/factory.py | 34 ++--- src/lerobot/rollout/inference/rtc.py | 8 +- src/lerobot/rollout/inference/sync.py | 6 +- src/lerobot/rollout/strategies/base.py | 7 +- src/lerobot/rollout/strategies/core.py | 27 ++-- src/lerobot/rollout/strategies/dagger.py | 9 +- src/lerobot/rollout/strategies/highlight.py | 64 +++++++--- src/lerobot/rollout/strategies/sentry.py | 9 +- src/lerobot/scripts/lerobot_record.py | 61 +-------- src/lerobot/utils/action_interpolator.py | 116 +++++++++++++++++ .../policies/rtc/test_action_interpolator.py | 2 +- tests/test_control_robot.py | 3 +- 21 files changed, 329 insertions(+), 313 deletions(-) create mode 100644 src/lerobot/configs/dataset.py create mode 100644 src/lerobot/utils/action_interpolator.py diff --git a/src/lerobot/configs/__init__.py b/src/lerobot/configs/__init__.py index 3ddaec1af..ab74c3cd3 100644 --- a/src/lerobot/configs/__init__.py +++ b/src/lerobot/configs/__init__.py @@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies Import them directly: ``from lerobot.configs.train import TrainPipelineConfig`` """ +from .dataset import DatasetRecordConfig from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from .policies import PreTrainedConfig from .types import ( @@ -39,6 +40,7 @@ __all__ = [ "PolicyFeature", "RTCAttentionSchedule", # Config classes + "DatasetRecordConfig", "DatasetConfig", "EvalConfig", "PeftConfig", diff --git a/src/lerobot/configs/dataset.py b/src/lerobot/configs/dataset.py new file mode 100644 index 000000000..a56de0ccf --- /dev/null +++ b/src/lerobot/configs/dataset.py @@ -0,0 +1,71 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``.""" + +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class DatasetRecordConfig: + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str = "" + # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") + single_task: str = "" + # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. + root: str | Path | None = None + # Limit the frames per second. + fps: int = 30 + # Number of seconds for data recording for each episode. + episode_time_s: int | float = 60 + # Number of seconds for resetting the environment after each episode. + reset_time_s: int | float = 60 + # Number of episodes to record. + num_episodes: int = 50 + # Encode frames in the dataset into video + video: bool = True + # Upload dataset to Hugging Face hub. + push_to_hub: bool = True + # Upload on private repository on the Hugging Face hub. + private: bool = False + # Add tags to your dataset on the hub. + tags: list[str] | None = None + # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; + # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes + # and threads depends on your system. We recommend 4 threads per camera with 0 processes. + # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. + num_image_writer_processes: int = 0 + # Number of threads writing the frames as png images on disk, per camera. + # Too many threads might cause unstable teleoperation fps due to main thread being blocked. + # Not enough threads might cause low camera fps. + num_image_writer_threads_per_camera: int = 4 + # Number of episodes to record before batch encoding videos + # Set to 1 for immediate encoding (default behavior), or higher for batched encoding + video_encoding_batch_size: int = 1 + # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto', + # or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'. + # Use 'auto' to auto-detect the best available hardware encoder. + vcodec: str = "libsvtav1" + # Enable streaming video encoding: encode frames in real-time during capture instead + # of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding + streaming_encoding: bool = False + # Maximum number of frames to buffer per camera when using streaming encoding. + # ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up. + encoder_queue_maxsize: int = 30 + # Number of threads per encoder instance. None = auto (codec default). + # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. + encoder_threads: int | None = None + # Rename map for the observation to override the image and state keys + rename_map: dict[str, str] = field(default_factory=dict) diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index e138a84d9..905276642 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator + from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors @@ -21,7 +23,6 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config from .pretrained import PreTrainedPolicy as PreTrainedPolicy -from .rtc import ActionInterpolator as ActionInterpolator from .sac.configuration_sac import SACConfig as SACConfig from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig from .sarm.configuration_sarm import SARMConfig as SARMConfig diff --git a/src/lerobot/policies/rtc/action_interpolator.py b/src/lerobot/policies/rtc/action_interpolator.py index 222dc33b5..c30481d3b 100644 --- a/src/lerobot/policies/rtc/action_interpolator.py +++ b/src/lerobot/policies/rtc/action_interpolator.py @@ -1,116 +1,4 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility. +from lerobot.utils.action_interpolator import ActionInterpolator -"""Action interpolation for smoother robot control. - -Provides configurable Nx control rate by interpolating between consecutive actions. -Useful with RTC and action-chunking policies to reduce jerkiness. -""" - -from torch import Tensor - - -class ActionInterpolator: - """Interpolates between consecutive actions for smoother control. - - When enabled with multiplier N, produces N actions per policy action - by linearly interpolating between the previous and current action. - - Example with multiplier=3: - prev_action -> [1/3 interpolated, 2/3 interpolated, current_action] - - This effectively multiplies the control rate for smoother motion. - - Usage: - interpolator = ActionInterpolator(multiplier=2) # 2x control rate - - # In control loop: - if interpolator.needs_new_action(): - new_action = queue.get() - if new_action: - interpolator.add(new_action.cpu()) - - action = interpolator.get() - if action: - robot.send_action(action) - """ - - def __init__(self, multiplier: int = 1): - """Initialize the interpolator. - - Args: - multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.) - """ - if multiplier < 1: - raise ValueError(f"multiplier must be >= 1, got {multiplier}") - self.multiplier = multiplier - self._prev: Tensor | None = None - self._buffer: list[Tensor] = [] - self._idx = 0 - - @property - def enabled(self) -> bool: - """Whether interpolation is active (multiplier > 1).""" - return self.multiplier > 1 - - def reset(self): - """Reset interpolation state (call between episodes).""" - self._prev = None - self._buffer = [] - self._idx = 0 - - def needs_new_action(self) -> bool: - """Check if a new action is needed from the queue.""" - return self._idx >= len(self._buffer) - - def add(self, action: Tensor) -> None: - """Add a new action and compute interpolated sequence. - - Args: - action: New action tensor from policy/queue (already on CPU). - """ - if self.multiplier > 1 and self._prev is not None: - self._buffer = [] - for i in range(1, self.multiplier + 1): - t = i / self.multiplier - interp = self._prev + t * (action - self._prev) - self._buffer.append(interp) - else: - # First step: no previous action yet, so run at base FPS without interpolation. - self._buffer = [action.clone()] - self._prev = action.clone() - self._idx = 0 - - def get(self) -> Tensor | None: - """Get the next interpolated action. - - Returns: - Next action tensor, or None if buffer is exhausted. - """ - if self._idx >= len(self._buffer): - return None - action = self._buffer[self._idx] - self._idx += 1 - return action - - def get_control_interval(self, fps: float) -> float: - """Get the control interval based on interpolation multiplier. - - Args: - fps: Base frames per second. - - Returns: - Control interval in seconds (divided by multiplier). - """ - return 1.0 / (fps * self.multiplier) +__all__ = ["ActionInterpolator"] diff --git a/src/lerobot/rollout/__init__.py b/src/lerobot/rollout/__init__.py index f0f5bf140..896d6f91a 100644 --- a/src/lerobot/rollout/__init__.py +++ b/src/lerobot/rollout/__init__.py @@ -17,21 +17,21 @@ from .configs import ( BaseStrategyConfig, DAggerStrategyConfig, + DatasetRecordConfig, HighlightStrategyConfig, RolloutConfig, - RolloutDatasetConfig, RolloutStrategyConfig, SentryStrategyConfig, ) from .context import RolloutContext, build_rollout_context from .inference import ( - InferenceStrategy, - InferenceStrategyConfig, + InferenceEngine, + InferenceEngineConfig, RTCInferenceConfig, - RTCInferenceStrategy, + RTCInferenceEngine, SyncInferenceConfig, - SyncInferenceStrategy, - create_inference_strategy, + SyncInferenceEngine, + create_inference_engine, ) from .ring_buffer import RolloutRingBuffer from .robot_wrapper import ThreadSafeRobot @@ -41,21 +41,21 @@ __all__ = [ "BaseStrategyConfig", "DAggerStrategyConfig", "HighlightStrategyConfig", - "InferenceStrategy", - "InferenceStrategyConfig", + "InferenceEngine", + "InferenceEngineConfig", "RTCInferenceConfig", - "RTCInferenceStrategy", + "RTCInferenceEngine", "RolloutConfig", "RolloutContext", - "RolloutDatasetConfig", + "DatasetRecordConfig", "RolloutRingBuffer", "RolloutStrategy", "RolloutStrategyConfig", "SentryStrategyConfig", "SyncInferenceConfig", - "SyncInferenceStrategy", + "SyncInferenceEngine", "ThreadSafeRobot", "build_rollout_context", - "create_inference_strategy", + "create_inference_engine", "create_strategy", ] diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 8d9ac776c..f30abbacb 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -19,15 +19,15 @@ from __future__ import annotations import abc import logging from dataclasses import dataclass, field -from pathlib import Path import draccus from lerobot.configs import PreTrainedConfig, parser +from lerobot.configs.dataset import DatasetRecordConfig from lerobot.robots.config import RobotConfig from lerobot.teleoperators.config import TeleoperatorConfig -from .inference import InferenceStrategyConfig, SyncInferenceConfig +from .inference import InferenceEngineConfig, SyncInferenceConfig logger = logging.getLogger(__name__) @@ -84,6 +84,7 @@ class HighlightStrategyConfig(RolloutStrategyConfig): ring_buffer_seconds: float = 30.0 ring_buffer_max_memory_mb: float = 2048.0 save_key: str = "s" + push_key: str = "h" @RolloutStrategyConfig.register_subclass("dagger") @@ -109,33 +110,6 @@ class DAggerStrategyConfig(RolloutStrategyConfig): record_autonomous: bool = True -# --------------------------------------------------------------------------- -# Dataset recording config (shared across recording strategies) -# --------------------------------------------------------------------------- - - -@dataclass -class RolloutDatasetConfig: - """Dataset configuration for rollout strategies that record data.""" - - repo_id: str = "" - single_task: str = "" - root: str | Path | None = None - fps: int = 30 - video: bool = True - push_to_hub: bool = True - private: bool = False - tags: list[str] | None = None - num_image_writer_processes: int = 0 - num_image_writer_threads_per_camera: int = 4 - video_encoding_batch_size: int = 1 - vcodec: str = "auto" - streaming_encoding: bool = True - encoder_queue_maxsize: int = 30 - encoder_threads: int | None = None - rename_map: dict[str, str] = field(default_factory=dict) - - # --------------------------------------------------------------------------- # Top-level rollout config # --------------------------------------------------------------------------- @@ -161,10 +135,10 @@ class RolloutConfig: strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig) # Inference backend (polymorphic: --inference.type=sync|rtc) - inference: InferenceStrategyConfig = field(default_factory=SyncInferenceConfig) + inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig) # Dataset (required for sentry, highlight, dagger; None for base) - dataset: RolloutDatasetConfig | None = None + dataset: DatasetRecordConfig | None = None # Runtime fps: float = 30.0 diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index 81cab5713..96d8682b2 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -49,9 +49,9 @@ from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_fea from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig from .inference import ( - InferenceStrategy, + InferenceEngine, RTCInferenceConfig, - create_inference_strategy, + create_inference_engine, ) from .robot_wrapper import ThreadSafeRobot @@ -106,12 +106,12 @@ class HardwareContext: @dataclass class PolicyContext: - """Loaded policy and its inference strategy.""" + """Loaded policy and its inference engine.""" policy: PreTrainedPolicy preprocessor: PolicyProcessorPipeline postprocessor: PolicyProcessorPipeline - inference: InferenceStrategy + inference: InferenceEngine @dataclass @@ -159,7 +159,7 @@ def build_rollout_context( robot_action_processor: RobotProcessorPipeline | None = None, robot_observation_processor: RobotProcessorPipeline | None = None, ) -> RolloutContext: - """Wire up policy, processors, hardware, dataset, and inference strategy. + """Wire up policy, processors, hardware, dataset, and inference engine. The order is policy-first / hardware-last so a bad ``--policy.path`` fails fast without touching the robot. @@ -329,7 +329,7 @@ def build_rollout_context( # --- 7. Inference strategy (needs policy + pre/post + hardware) -- task_str = cfg.dataset.single_task if cfg.dataset else cfg.task - inference_strategy = create_inference_strategy( + inference_strategy = create_inference_engine( cfg.inference, policy=policy, preprocessor=preprocessor, diff --git a/src/lerobot/rollout/inference/__init__.py b/src/lerobot/rollout/inference/__init__.py index b85801de9..1d0cbaafe 100644 --- a/src/lerobot/rollout/inference/__init__.py +++ b/src/lerobot/rollout/inference/__init__.py @@ -12,28 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Inference strategy package — backend-agnostic action production. +"""Inference engine package — backend-agnostic action production. Concrete strategies (sync, RTC, …) expose the same small interface so rollout strategies never branch on the inference backend. """ -from .base import InferenceStrategy +from .base import InferenceEngine from .factory import ( - InferenceStrategyConfig, + InferenceEngineConfig, RTCInferenceConfig, SyncInferenceConfig, - create_inference_strategy, + create_inference_engine, ) -from .rtc import RTCInferenceStrategy -from .sync import SyncInferenceStrategy +from .rtc import RTCInferenceEngine +from .sync import SyncInferenceEngine __all__ = [ - "InferenceStrategy", - "InferenceStrategyConfig", + "InferenceEngine", + "InferenceEngineConfig", "RTCInferenceConfig", - "RTCInferenceStrategy", + "RTCInferenceEngine", "SyncInferenceConfig", - "SyncInferenceStrategy", - "create_inference_strategy", + "SyncInferenceEngine", + "create_inference_engine", ] diff --git a/src/lerobot/rollout/inference/base.py b/src/lerobot/rollout/inference/base.py index 9ef51845e..439efb9da 100644 --- a/src/lerobot/rollout/inference/base.py +++ b/src/lerobot/rollout/inference/base.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Inference strategy ABC. +"""Inference engine ABC. Rollout strategies consume actions through this small interface so they -do not need to know whether inference is synchronous, runs in a -background thread (RTC), or comes from an external source. +do not need to know whether the inference engine is synchronous, runs in +a background thread (RTC), or comes from an external source. """ from __future__ import annotations @@ -26,7 +26,7 @@ import abc import torch -class InferenceStrategy(abc.ABC): +class InferenceEngine(abc.ABC): """Abstract backend for producing actions during rollout. Subclasses decide whether inference happens inline, in a background diff --git a/src/lerobot/rollout/inference/factory.py b/src/lerobot/rollout/inference/factory.py index 6e762f710..88482f845 100644 --- a/src/lerobot/rollout/inference/factory.py +++ b/src/lerobot/rollout/inference/factory.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Inference strategy configs and factory. +"""Inference engine configs and factory. Selection is explicit via ``--inference.type=sync|rtc``. Adding a new backend requires registering its config subclass and dispatching it in -:func:`create_inference_strategy`. +:func:`create_inference_engine`. """ from __future__ import annotations @@ -33,9 +33,9 @@ from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.processor import PolicyProcessorPipeline from ..robot_wrapper import ThreadSafeRobot -from .base import InferenceStrategy -from .rtc import RTCInferenceStrategy -from .sync import SyncInferenceStrategy +from .base import InferenceEngine +from .rtc import RTCInferenceEngine +from .sync import SyncInferenceEngine logger = logging.getLogger(__name__) @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) @dataclass -class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC): +class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC): """Abstract base for inference backend configuration. Use ``--inference.type=`` on the CLI to select a backend. @@ -57,15 +57,15 @@ class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC): return self.get_choice_name(self.__class__) -@InferenceStrategyConfig.register_subclass("sync") +@InferenceEngineConfig.register_subclass("sync") @dataclass -class SyncInferenceConfig(InferenceStrategyConfig): +class SyncInferenceConfig(InferenceEngineConfig): """Inline synchronous inference (one policy call per control tick).""" -@InferenceStrategyConfig.register_subclass("rtc") +@InferenceEngineConfig.register_subclass("rtc") @dataclass -class RTCInferenceConfig(InferenceStrategyConfig): +class RTCInferenceConfig(InferenceEngineConfig): """Real-Time Chunking: async policy inference in a background thread.""" # ``RTCConfig`` is a small dataclass with default-only fields, so eagerly @@ -80,8 +80,8 @@ class RTCInferenceConfig(InferenceStrategyConfig): # --------------------------------------------------------------------------- -def create_inference_strategy( - config: InferenceStrategyConfig, +def create_inference_engine( + config: InferenceEngineConfig, *, policy: PreTrainedPolicy, preprocessor: PolicyProcessorPipeline, @@ -96,10 +96,10 @@ def create_inference_strategy( use_torch_compile: bool = False, compile_warmup_inferences: int = 2, shutdown_event: Event | None = None, -) -> InferenceStrategy: - """Instantiate the appropriate inference strategy from a config object.""" +) -> InferenceEngine: + """Instantiate the appropriate inference engine from a config object.""" if isinstance(config, SyncInferenceConfig): - return SyncInferenceStrategy( + return SyncInferenceEngine( policy=policy, preprocessor=preprocessor, postprocessor=postprocessor, @@ -110,7 +110,7 @@ def create_inference_strategy( robot_type=robot_wrapper.robot_type, ) if isinstance(config, RTCInferenceConfig): - return RTCInferenceStrategy( + return RTCInferenceEngine( policy=policy, preprocessor=preprocessor, postprocessor=postprocessor, @@ -125,4 +125,4 @@ def create_inference_strategy( rtc_queue_threshold=config.queue_threshold, shutdown_event=shutdown_event, ) - raise ValueError(f"Unknown inference strategy type: {type(config).__name__}") + raise ValueError(f"Unknown inference engine type: {type(config).__name__}") diff --git a/src/lerobot/rollout/inference/rtc.py b/src/lerobot/rollout/inference/rtc.py index a7f507d3f..d7b2511bb 100644 --- a/src/lerobot/rollout/inference/rtc.py +++ b/src/lerobot/rollout/inference/rtc.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Real-Time Chunking inference strategy. +"""Real-Time Chunking inference engine. A background thread produces action chunks asynchronously via :meth:`policy.predict_action_chunk`. The main control loop polls @@ -47,7 +47,7 @@ from lerobot.utils.constants import OBS_STATE from lerobot.utils.feature_utils import build_dataset_frame from ..robot_wrapper import ThreadSafeRobot -from .base import InferenceStrategy +from .base import InferenceEngine logger = logging.getLogger(__name__) @@ -104,11 +104,11 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int # --------------------------------------------------------------------------- -# RTCInferenceStrategy +# RTCInferenceEngine # --------------------------------------------------------------------------- -class RTCInferenceStrategy(InferenceStrategy): +class RTCInferenceEngine(InferenceEngine): """Async RTC inference: a background thread produces action chunks. ``get_action`` pops the next action from the shared queue (or diff --git a/src/lerobot/rollout/inference/sync.py b/src/lerobot/rollout/inference/sync.py index d1c4b0e53..d1d055075 100644 --- a/src/lerobot/rollout/inference/sync.py +++ b/src/lerobot/rollout/inference/sync.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Synchronous inference strategy: inline policy call per control tick.""" +"""Synchronous inference engine: inline policy call per control tick.""" from __future__ import annotations @@ -26,12 +26,12 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference from lerobot.processor import PolicyProcessorPipeline -from .base import InferenceStrategy +from .base import InferenceEngine logger = logging.getLogger(__name__) -class SyncInferenceStrategy(InferenceStrategy): +class SyncInferenceEngine(InferenceEngine): """Inline synchronous inference: compute one action per call. ``get_action`` runs the full policy pipeline (pre/post-processor + diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index b0714a297..26d6b081e 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -45,7 +45,6 @@ class BaseStrategy(RolloutStrategy): interpolator = self._interpolator control_interval = interpolator.get_control_interval(cfg.fps) - ordered_keys = ctx.data.ordered_action_keys start_time = time.perf_counter() engine.resume() @@ -63,14 +62,12 @@ class BaseStrategy(RolloutStrategy): if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - send_next_action( - engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.data.dataset_features - ) + send_next_action(obs_processed, obs, ctx, interpolator) dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) def teardown(self, ctx: RolloutContext) -> None: - self._teardown_hardware(ctx) + self._teardown_hardware(ctx.hardware) logger.info("Base strategy teardown complete") diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index 4ae8e0196..991cb053d 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -20,16 +20,16 @@ import abc import time from typing import TYPE_CHECKING -from lerobot.policies.rtc import ActionInterpolator +from lerobot.utils.action_interpolator import ActionInterpolator from lerobot.utils.constants import OBS_STR from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep -from ..inference import InferenceStrategy +from ..inference import InferenceEngine if TYPE_CHECKING: from ..configs import RolloutStrategyConfig - from ..context import RolloutContext + from ..context import HardwareContext, RolloutContext class RolloutStrategy(abc.ABC): @@ -42,12 +42,12 @@ class RolloutStrategy(abc.ABC): def __init__(self, config: RolloutStrategyConfig) -> None: self.config = config - self._engine: InferenceStrategy | None = None + self._engine: InferenceEngine | None = None self._interpolator: ActionInterpolator | None = None self._warmup_flushed: bool = False def _init_engine(self, ctx: RolloutContext) -> None: - """Attach the inference strategy + interpolator and start the backend. + """Attach the inference engine + interpolator and start the backend. Call this from ``setup()`` so strategies share identical setup without duplicating code. @@ -80,14 +80,14 @@ class RolloutStrategy(abc.ABC): engine.resume() return False - def _teardown_hardware(self, ctx: RolloutContext) -> None: + def _teardown_hardware(self, hw: HardwareContext) -> None: """Stop the inference engine and disconnect hardware.""" if self._engine is not None: self._engine.stop() - robot = ctx.hardware.robot_wrapper.inner + robot = hw.robot_wrapper.inner if robot.is_connected: robot.disconnect() - teleop = ctx.hardware.teleop + teleop = hw.teleop if teleop is not None and teleop.is_connected: teleop.disconnect() @@ -110,24 +110,25 @@ class RolloutStrategy(abc.ABC): def send_next_action( - engine: InferenceStrategy, obs_processed: dict, obs_raw: dict, ctx: RolloutContext, interpolator: ActionInterpolator, - ordered_keys: list[str], - features: dict, ) -> dict | None: """Dispatch the next action to the robot. - Pulls the next action tensor from the inference strategy, feeds the + Pulls the next action tensor from the inference engine, feeds the interpolator, and sends the interpolated action through the ``robot_action_processor`` to the robot. Works identically for - sync and async backends — the strategy never needs to branch. + sync and async backends — the rollout strategy never needs to branch. Returns the action dict that was sent, or ``None`` if no action was ready (e.g. empty async queue, interpolator not yet primed). """ + engine = ctx.policy.inference + features = ctx.data.dataset_features + ordered_keys = ctx.data.ordered_action_keys + if interpolator.needs_new_action(): obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) action_tensor = engine.get_action(obs_frame) diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 5639d10ba..499b10d15 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -407,7 +407,7 @@ class DAggerStrategy(RolloutStrategy): private=ctx.runtime.cfg.dataset.private, ) - self._teardown_hardware(ctx) + self._teardown_hardware(ctx.hardware) logger.info("DAgger strategy teardown complete") # ------------------------------------------------------------------ @@ -429,8 +429,7 @@ class DAggerStrategy(RolloutStrategy): record_stride = max(1, cfg.interpolation_multiplier) record_autonomous = self.config.record_autonomous - ordered_keys = ctx.data.ordered_action_keys - features = dataset.features + features = ctx.data.dataset_features engine.reset() interpolator.reset() @@ -499,9 +498,7 @@ class DAggerStrategy(RolloutStrategy): timestamp = time.perf_counter() - start_t continue - action_dict = send_next_action( - engine, obs_processed, obs, ctx, interpolator, ordered_keys, features - ) + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) if action_dict is not None: last_action = ctx.processors.robot_action_processor((action_dict, obs)) diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index 766929f45..455edc45c 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -21,6 +21,7 @@ import logging import os import sys import time +from concurrent.futures import Future, ThreadPoolExecutor from threading import Event as ThreadingEvent from lerobot.common.control_utils import is_headless @@ -75,7 +76,9 @@ class HighlightStrategy(RolloutStrategy): self._listener = None self._save_requested = ThreadingEvent() self._recording_live = ThreadingEvent() - self._shutdown_event: ThreadingEvent | None = None + self._push_requested = ThreadingEvent() + self._push_executor: ThreadPoolExecutor | None = None + self._pending_push: Future | None = None def setup(self, ctx: RolloutContext) -> None: self._init_engine(ctx) @@ -86,12 +89,13 @@ class HighlightStrategy(RolloutStrategy): fps=ctx.runtime.cfg.fps, ) - self._shutdown_event = ctx.runtime.shutdown_event - self._setup_keyboard() + self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push") + self._setup_keyboard(ctx.runtime.shutdown_event) logger.info( - "Highlight strategy ready (buffer=%.0fs, key='%s')", + "Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')", self.config.ring_buffer_seconds, self.config.save_key, + self.config.push_key, ) def run(self, ctx: RolloutContext) -> None: @@ -101,10 +105,9 @@ class HighlightStrategy(RolloutStrategy): dataset = ctx.data.dataset ring = self._ring interpolator = self._interpolator + features = ctx.data.dataset_features control_interval = interpolator.get_control_interval(cfg.fps) - ordered_keys = ctx.data.ordered_action_keys - features = dataset.features engine.resume() @@ -126,9 +129,7 @@ class HighlightStrategy(RolloutStrategy): if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - action_dict = send_next_action( - engine, obs_processed, obs, ctx, interpolator, ordered_keys, features - ) + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) if action_dict is not None: obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) @@ -158,9 +159,10 @@ class HighlightStrategy(RolloutStrategy): dataset.save_episode() logger.info("Episode saved") self._recording_live.clear() - engine.reset() - interpolator.reset() - engine.resume() + + if self._push_requested.is_set(): + self._push_requested.clear() + self._background_push(dataset, cfg) if self._recording_live.is_set(): dataset.add_frame(frame) @@ -180,6 +182,10 @@ class HighlightStrategy(RolloutStrategy): if self._listener is not None: self._listener.stop() + if self._push_executor is not None: + self._push_executor.shutdown(wait=True) + self._push_executor = None + if ctx.data.dataset is not None: ctx.data.dataset.finalize() if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: @@ -188,28 +194,50 @@ class HighlightStrategy(RolloutStrategy): private=ctx.runtime.cfg.dataset.private, ) - self._teardown_hardware(ctx) + self._teardown_hardware(ctx.hardware) logger.info("Highlight strategy teardown complete") - def _setup_keyboard(self) -> None: - """Set up keyboard listener for the save key.""" + def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None: + """Set up keyboard listener for save and push keys.""" if is_headless(): - logger.warning("Headless environment — highlight save key unavailable") + logger.warning("Headless environment — highlight keys unavailable") return try: save_key = self.config.save_key + push_key = self.config.push_key def on_press(key): with contextlib.suppress(Exception): if hasattr(key, "char") and key.char == save_key: self._save_requested.set() + elif hasattr(key, "char") and key.char == push_key: + self._push_requested.set() elif key == keyboard.Key.esc: self._save_requested.clear() - if self._shutdown_event is not None: - self._shutdown_event.set() + shutdown_event.set() self._listener = keyboard.Listener(on_press=on_press) self._listener.start() except ImportError: logger.warning("pynput not available — keyboard listener disabled") + + def _background_push(self, dataset, cfg) -> None: + """Queue a Hub push on the single-worker executor.""" + if self._push_executor is None: + return + + if self._pending_push is not None and not self._pending_push.done(): + logger.info("Previous push still in progress; queueing next") + + def _push(): + try: + dataset.push_to_hub( + tags=cfg.dataset.tags if cfg.dataset else None, + private=cfg.dataset.private if cfg.dataset else False, + ) + logger.info("Background push to hub complete") + except Exception as e: + logger.error("Background push failed: %s", e) + + self._pending_push = self._push_executor.submit(_push) diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index 515fb380b..1b580cc10 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -73,10 +73,9 @@ class SentryStrategy(RolloutStrategy): robot = ctx.hardware.robot_wrapper dataset = ctx.data.dataset interpolator = self._interpolator + features = ctx.data.dataset_features control_interval = interpolator.get_control_interval(cfg.fps) - ordered_keys = ctx.data.ordered_action_keys - features = dataset.features engine.resume() @@ -100,9 +99,7 @@ class SentryStrategy(RolloutStrategy): if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - action_dict = send_next_action( - engine, obs_processed, obs, ctx, interpolator, ordered_keys, features - ) + action_dict = send_next_action(obs_processed, obs, ctx, interpolator) if action_dict is not None: obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) @@ -156,7 +153,7 @@ class SentryStrategy(RolloutStrategy): private=ctx.runtime.cfg.dataset.private, ) - self._teardown_hardware(ctx) + self._teardown_hardware(ctx.hardware) logger.info("Sentry strategy teardown complete") def _background_push(self, dataset, cfg) -> None: diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index fc4b5779c..32e14d705 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -67,8 +67,7 @@ lerobot-record \\ import logging import time -from dataclasses import asdict, dataclass, field -from pathlib import Path +from dataclasses import asdict, dataclass from pprint import pformat from lerobot.cameras import CameraConfig # noqa: F401 @@ -82,6 +81,7 @@ from lerobot.common.control_utils import ( sanity_check_dataset_robot_compatibility, ) from lerobot.configs import parser +from lerobot.configs.dataset import DatasetRecordConfig from lerobot.datasets import ( LeRobotDataset, VideoEncodingManager, @@ -137,63 +137,6 @@ from lerobot.utils.utils import ( from lerobot.utils.visualization_utils import init_rerun, log_rerun_data -@dataclass -class DatasetRecordConfig: - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") - single_task: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id. - root: str | Path | None = None - # Limit the frames per second. - fps: int = 30 - # Number of seconds for data recording for each episode. - episode_time_s: int | float = 60 - # Number of seconds for resetting the environment after each episode. - reset_time_s: int | float = 60 - # Number of episodes to record. - num_episodes: int = 50 - # Encode frames in the dataset into video - video: bool = True - # Upload dataset to Hugging Face hub. - push_to_hub: bool = True - # Upload on private repository on the Hugging Face hub. - private: bool = False - # Add tags to your dataset on the hub. - tags: list[str] | None = None - # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; - # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes - # and threads depends on your system. We recommend 4 threads per camera with 0 processes. - # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. - num_image_writer_processes: int = 0 - # Number of threads writing the frames as png images on disk, per camera. - # Too many threads might cause unstable teleoperation fps due to main thread being blocked. - # Not enough threads might cause low camera fps. - num_image_writer_threads_per_camera: int = 4 - # Number of episodes to record before batch encoding videos - # Set to 1 for immediate encoding (default behavior), or higher for batched encoding - video_encoding_batch_size: int = 1 - # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto', - # or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'. - # Use 'auto' to auto-detect the best available hardware encoder. - vcodec: str = "libsvtav1" - # Enable streaming video encoding: encode frames in real-time during capture instead - # of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding - streaming_encoding: bool = False - # Maximum number of frames to buffer per camera when using streaming encoding. - # ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up. - encoder_queue_maxsize: int = 30 - # Number of threads per encoder instance. None = auto (codec default). - # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. - encoder_threads: int | None = None - # Rename map for the observation to override the image and state keys - rename_map: dict[str, str] = field(default_factory=dict) - - def __post_init__(self): - if self.single_task is None: - raise ValueError("You need to provide a task as argument in `single_task`.") - - @dataclass class RecordConfig: robot: RobotConfig diff --git a/src/lerobot/utils/action_interpolator.py b/src/lerobot/utils/action_interpolator.py new file mode 100644 index 000000000..222dc33b5 --- /dev/null +++ b/src/lerobot/utils/action_interpolator.py @@ -0,0 +1,116 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Action interpolation for smoother robot control. + +Provides configurable Nx control rate by interpolating between consecutive actions. +Useful with RTC and action-chunking policies to reduce jerkiness. +""" + +from torch import Tensor + + +class ActionInterpolator: + """Interpolates between consecutive actions for smoother control. + + When enabled with multiplier N, produces N actions per policy action + by linearly interpolating between the previous and current action. + + Example with multiplier=3: + prev_action -> [1/3 interpolated, 2/3 interpolated, current_action] + + This effectively multiplies the control rate for smoother motion. + + Usage: + interpolator = ActionInterpolator(multiplier=2) # 2x control rate + + # In control loop: + if interpolator.needs_new_action(): + new_action = queue.get() + if new_action: + interpolator.add(new_action.cpu()) + + action = interpolator.get() + if action: + robot.send_action(action) + """ + + def __init__(self, multiplier: int = 1): + """Initialize the interpolator. + + Args: + multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.) + """ + if multiplier < 1: + raise ValueError(f"multiplier must be >= 1, got {multiplier}") + self.multiplier = multiplier + self._prev: Tensor | None = None + self._buffer: list[Tensor] = [] + self._idx = 0 + + @property + def enabled(self) -> bool: + """Whether interpolation is active (multiplier > 1).""" + return self.multiplier > 1 + + def reset(self): + """Reset interpolation state (call between episodes).""" + self._prev = None + self._buffer = [] + self._idx = 0 + + def needs_new_action(self) -> bool: + """Check if a new action is needed from the queue.""" + return self._idx >= len(self._buffer) + + def add(self, action: Tensor) -> None: + """Add a new action and compute interpolated sequence. + + Args: + action: New action tensor from policy/queue (already on CPU). + """ + if self.multiplier > 1 and self._prev is not None: + self._buffer = [] + for i in range(1, self.multiplier + 1): + t = i / self.multiplier + interp = self._prev + t * (action - self._prev) + self._buffer.append(interp) + else: + # First step: no previous action yet, so run at base FPS without interpolation. + self._buffer = [action.clone()] + self._prev = action.clone() + self._idx = 0 + + def get(self) -> Tensor | None: + """Get the next interpolated action. + + Returns: + Next action tensor, or None if buffer is exhausted. + """ + if self._idx >= len(self._buffer): + return None + action = self._buffer[self._idx] + self._idx += 1 + return action + + def get_control_interval(self, fps: float) -> float: + """Get the control interval based on interpolation multiplier. + + Args: + fps: Base frames per second. + + Returns: + Control interval in seconds (divided by multiplier). + """ + return 1.0 / (fps * self.multiplier) diff --git a/tests/policies/rtc/test_action_interpolator.py b/tests/policies/rtc/test_action_interpolator.py index 9a4276df1..3eb239d7e 100644 --- a/tests/policies/rtc/test_action_interpolator.py +++ b/tests/policies/rtc/test_action_interpolator.py @@ -17,9 +17,9 @@ import pytest import torch -from lerobot.policies.rtc.action_interpolator import ActionInterpolator from lerobot.policies.rtc.action_queue import ActionQueue from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.action_interpolator import ActionInterpolator # ====================== Fixtures ====================== diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 28e91a149..dd10c0c1c 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -21,8 +21,9 @@ import pytest pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])") +from lerobot.configs.dataset import DatasetRecordConfig from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate -from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record +from lerobot.scripts.lerobot_record import RecordConfig, record from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate from tests.fixtures.constants import DUMMY_REPO_ID