some more iterations

This commit is contained in:
Steven Palma
2026-04-16 15:52:23 +02:00
parent 783ec6e232
commit 3eda5712d3
21 changed files with 329 additions and 313 deletions
+2
View File
@@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
"""
from .dataset import DatasetRecordConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .types import (
@@ -39,6 +40,7 @@ __all__ = [
"PolicyFeature",
"RTCAttentionSchedule",
# Config classes
"DatasetRecordConfig",
"DatasetConfig",
"EvalConfig",
"PeftConfig",
+71
View File
@@ -0,0 +1,71 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str = ""
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str = ""
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
# Maximum number of frames to buffer per camera when using streaming encoding.
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
encoder_queue_maxsize: int = 30
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
+2 -1
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
@@ -21,7 +23,6 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .rtc import ActionInterpolator as ActionInterpolator
from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
+3 -115
View File
@@ -1,116 +1,4 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
from lerobot.utils.action_interpolator import ActionInterpolator
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
__all__ = ["ActionInterpolator"]
+12 -12
View File
@@ -17,21 +17,21 @@
from .configs import (
BaseStrategyConfig,
DAggerStrategyConfig,
DatasetRecordConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutDatasetConfig,
RolloutStrategyConfig,
SentryStrategyConfig,
)
from .context import RolloutContext, build_rollout_context
from .inference import (
InferenceStrategy,
InferenceStrategyConfig,
InferenceEngine,
InferenceEngineConfig,
RTCInferenceConfig,
RTCInferenceStrategy,
RTCInferenceEngine,
SyncInferenceConfig,
SyncInferenceStrategy,
create_inference_strategy,
SyncInferenceEngine,
create_inference_engine,
)
from .ring_buffer import RolloutRingBuffer
from .robot_wrapper import ThreadSafeRobot
@@ -41,21 +41,21 @@ __all__ = [
"BaseStrategyConfig",
"DAggerStrategyConfig",
"HighlightStrategyConfig",
"InferenceStrategy",
"InferenceStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"RTCInferenceConfig",
"RTCInferenceStrategy",
"RTCInferenceEngine",
"RolloutConfig",
"RolloutContext",
"RolloutDatasetConfig",
"DatasetRecordConfig",
"RolloutRingBuffer",
"RolloutStrategy",
"RolloutStrategyConfig",
"SentryStrategyConfig",
"SyncInferenceConfig",
"SyncInferenceStrategy",
"SyncInferenceEngine",
"ThreadSafeRobot",
"build_rollout_context",
"create_inference_strategy",
"create_inference_engine",
"create_strategy",
]
+5 -31
View File
@@ -19,15 +19,15 @@ from __future__ import annotations
import abc
import logging
from dataclasses import dataclass, field
from pathlib import Path
import draccus
from lerobot.configs import PreTrainedConfig, parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.robots.config import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from .inference import InferenceStrategyConfig, SyncInferenceConfig
from .inference import InferenceEngineConfig, SyncInferenceConfig
logger = logging.getLogger(__name__)
@@ -84,6 +84,7 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
ring_buffer_seconds: float = 30.0
ring_buffer_max_memory_mb: float = 2048.0
save_key: str = "s"
push_key: str = "h"
@RolloutStrategyConfig.register_subclass("dagger")
@@ -109,33 +110,6 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
record_autonomous: bool = True
# ---------------------------------------------------------------------------
# Dataset recording config (shared across recording strategies)
# ---------------------------------------------------------------------------
@dataclass
class RolloutDatasetConfig:
"""Dataset configuration for rollout strategies that record data."""
repo_id: str = ""
single_task: str = ""
root: str | Path | None = None
fps: int = 30
video: bool = True
push_to_hub: bool = True
private: bool = False
tags: list[str] | None = None
num_image_writer_processes: int = 0
num_image_writer_threads_per_camera: int = 4
video_encoding_batch_size: int = 1
vcodec: str = "auto"
streaming_encoding: bool = True
encoder_queue_maxsize: int = 30
encoder_threads: int | None = None
rename_map: dict[str, str] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# Top-level rollout config
# ---------------------------------------------------------------------------
@@ -161,10 +135,10 @@ class RolloutConfig:
strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig)
# Inference backend (polymorphic: --inference.type=sync|rtc)
inference: InferenceStrategyConfig = field(default_factory=SyncInferenceConfig)
inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig)
# Dataset (required for sentry, highlight, dagger; None for base)
dataset: RolloutDatasetConfig | None = None
dataset: DatasetRecordConfig | None = None
# Runtime
fps: float = 30.0
+6 -6
View File
@@ -49,9 +49,9 @@ from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_fea
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
from .inference import (
InferenceStrategy,
InferenceEngine,
RTCInferenceConfig,
create_inference_strategy,
create_inference_engine,
)
from .robot_wrapper import ThreadSafeRobot
@@ -106,12 +106,12 @@ class HardwareContext:
@dataclass
class PolicyContext:
"""Loaded policy and its inference strategy."""
"""Loaded policy and its inference engine."""
policy: PreTrainedPolicy
preprocessor: PolicyProcessorPipeline
postprocessor: PolicyProcessorPipeline
inference: InferenceStrategy
inference: InferenceEngine
@dataclass
@@ -159,7 +159,7 @@ def build_rollout_context(
robot_action_processor: RobotProcessorPipeline | None = None,
robot_observation_processor: RobotProcessorPipeline | None = None,
) -> RolloutContext:
"""Wire up policy, processors, hardware, dataset, and inference strategy.
"""Wire up policy, processors, hardware, dataset, and inference engine.
The order is policy-first / hardware-last so a bad ``--policy.path``
fails fast without touching the robot.
@@ -329,7 +329,7 @@ def build_rollout_context(
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
inference_strategy = create_inference_strategy(
inference_strategy = create_inference_engine(
cfg.inference,
policy=policy,
preprocessor=preprocessor,
+11 -11
View File
@@ -12,28 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference strategy package — backend-agnostic action production.
"""Inference engine package — backend-agnostic action production.
Concrete strategies (sync, RTC, …) expose the same small interface so
rollout strategies never branch on the inference backend.
"""
from .base import InferenceStrategy
from .base import InferenceEngine
from .factory import (
InferenceStrategyConfig,
InferenceEngineConfig,
RTCInferenceConfig,
SyncInferenceConfig,
create_inference_strategy,
create_inference_engine,
)
from .rtc import RTCInferenceStrategy
from .sync import SyncInferenceStrategy
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
__all__ = [
"InferenceStrategy",
"InferenceStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"RTCInferenceConfig",
"RTCInferenceStrategy",
"RTCInferenceEngine",
"SyncInferenceConfig",
"SyncInferenceStrategy",
"create_inference_strategy",
"SyncInferenceEngine",
"create_inference_engine",
]
+4 -4
View File
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference strategy ABC.
"""Inference engine ABC.
Rollout strategies consume actions through this small interface so they
do not need to know whether inference is synchronous, runs in a
background thread (RTC), or comes from an external source.
do not need to know whether the inference engine is synchronous, runs in
a background thread (RTC), or comes from an external source.
"""
from __future__ import annotations
@@ -26,7 +26,7 @@ import abc
import torch
class InferenceStrategy(abc.ABC):
class InferenceEngine(abc.ABC):
"""Abstract backend for producing actions during rollout.
Subclasses decide whether inference happens inline, in a background
+17 -17
View File
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference strategy configs and factory.
"""Inference engine configs and factory.
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
backend requires registering its config subclass and dispatching it in
:func:`create_inference_strategy`.
:func:`create_inference_engine`.
"""
from __future__ import annotations
@@ -33,9 +33,9 @@ from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.processor import PolicyProcessorPipeline
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceStrategy
from .rtc import RTCInferenceStrategy
from .sync import SyncInferenceStrategy
from .base import InferenceEngine
from .rtc import RTCInferenceEngine
from .sync import SyncInferenceEngine
logger = logging.getLogger(__name__)
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
@dataclass
class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC):
"""Abstract base for inference backend configuration.
Use ``--inference.type=<name>`` on the CLI to select a backend.
@@ -57,15 +57,15 @@ class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
return self.get_choice_name(self.__class__)
@InferenceStrategyConfig.register_subclass("sync")
@InferenceEngineConfig.register_subclass("sync")
@dataclass
class SyncInferenceConfig(InferenceStrategyConfig):
class SyncInferenceConfig(InferenceEngineConfig):
"""Inline synchronous inference (one policy call per control tick)."""
@InferenceStrategyConfig.register_subclass("rtc")
@InferenceEngineConfig.register_subclass("rtc")
@dataclass
class RTCInferenceConfig(InferenceStrategyConfig):
class RTCInferenceConfig(InferenceEngineConfig):
"""Real-Time Chunking: async policy inference in a background thread."""
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
@@ -80,8 +80,8 @@ class RTCInferenceConfig(InferenceStrategyConfig):
# ---------------------------------------------------------------------------
def create_inference_strategy(
config: InferenceStrategyConfig,
def create_inference_engine(
config: InferenceEngineConfig,
*,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
@@ -96,10 +96,10 @@ def create_inference_strategy(
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
shutdown_event: Event | None = None,
) -> InferenceStrategy:
"""Instantiate the appropriate inference strategy from a config object."""
) -> InferenceEngine:
"""Instantiate the appropriate inference engine from a config object."""
if isinstance(config, SyncInferenceConfig):
return SyncInferenceStrategy(
return SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
@@ -110,7 +110,7 @@ def create_inference_strategy(
robot_type=robot_wrapper.robot_type,
)
if isinstance(config, RTCInferenceConfig):
return RTCInferenceStrategy(
return RTCInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
@@ -125,4 +125,4 @@ def create_inference_strategy(
rtc_queue_threshold=config.queue_threshold,
shutdown_event=shutdown_event,
)
raise ValueError(f"Unknown inference strategy type: {type(config).__name__}")
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
+4 -4
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Real-Time Chunking inference strategy.
"""Real-Time Chunking inference engine.
A background thread produces action chunks asynchronously via
:meth:`policy.predict_action_chunk`. The main control loop polls
@@ -47,7 +47,7 @@ from lerobot.utils.constants import OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame
from ..robot_wrapper import ThreadSafeRobot
from .base import InferenceStrategy
from .base import InferenceEngine
logger = logging.getLogger(__name__)
@@ -104,11 +104,11 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
# ---------------------------------------------------------------------------
# RTCInferenceStrategy
# RTCInferenceEngine
# ---------------------------------------------------------------------------
class RTCInferenceStrategy(InferenceStrategy):
class RTCInferenceEngine(InferenceEngine):
"""Async RTC inference: a background thread produces action chunks.
``get_action`` pops the next action from the shared queue (or
+3 -3
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Synchronous inference strategy: inline policy call per control tick."""
"""Synchronous inference engine: inline policy call per control tick."""
from __future__ import annotations
@@ -26,12 +26,12 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from .base import InferenceStrategy
from .base import InferenceEngine
logger = logging.getLogger(__name__)
class SyncInferenceStrategy(InferenceStrategy):
class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor +
+2 -5
View File
@@ -45,7 +45,6 @@ class BaseStrategy(RolloutStrategy):
interpolator = self._interpolator
control_interval = interpolator.get_control_interval(cfg.fps)
ordered_keys = ctx.data.ordered_action_keys
start_time = time.perf_counter()
engine.resume()
@@ -63,14 +62,12 @@ class BaseStrategy(RolloutStrategy):
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
send_next_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.data.dataset_features
)
send_next_action(obs_processed, obs, ctx, interpolator)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
def teardown(self, ctx: RolloutContext) -> None:
self._teardown_hardware(ctx)
self._teardown_hardware(ctx.hardware)
logger.info("Base strategy teardown complete")
+14 -13
View File
@@ -20,16 +20,16 @@ import abc
import time
from typing import TYPE_CHECKING
from lerobot.policies.rtc import ActionInterpolator
from lerobot.utils.action_interpolator import ActionInterpolator
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from ..inference import InferenceStrategy
from ..inference import InferenceEngine
if TYPE_CHECKING:
from ..configs import RolloutStrategyConfig
from ..context import RolloutContext
from ..context import HardwareContext, RolloutContext
class RolloutStrategy(abc.ABC):
@@ -42,12 +42,12 @@ class RolloutStrategy(abc.ABC):
def __init__(self, config: RolloutStrategyConfig) -> None:
self.config = config
self._engine: InferenceStrategy | None = None
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._warmup_flushed: bool = False
def _init_engine(self, ctx: RolloutContext) -> None:
"""Attach the inference strategy + interpolator and start the backend.
"""Attach the inference engine + interpolator and start the backend.
Call this from ``setup()`` so strategies share identical setup
without duplicating code.
@@ -80,14 +80,14 @@ class RolloutStrategy(abc.ABC):
engine.resume()
return False
def _teardown_hardware(self, ctx: RolloutContext) -> None:
def _teardown_hardware(self, hw: HardwareContext) -> None:
"""Stop the inference engine and disconnect hardware."""
if self._engine is not None:
self._engine.stop()
robot = ctx.hardware.robot_wrapper.inner
robot = hw.robot_wrapper.inner
if robot.is_connected:
robot.disconnect()
teleop = ctx.hardware.teleop
teleop = hw.teleop
if teleop is not None and teleop.is_connected:
teleop.disconnect()
@@ -110,24 +110,25 @@ class RolloutStrategy(abc.ABC):
def send_next_action(
engine: InferenceStrategy,
obs_processed: dict,
obs_raw: dict,
ctx: RolloutContext,
interpolator: ActionInterpolator,
ordered_keys: list[str],
features: dict,
) -> dict | None:
"""Dispatch the next action to the robot.
Pulls the next action tensor from the inference strategy, feeds the
Pulls the next action tensor from the inference engine, feeds the
interpolator, and sends the interpolated action through the
``robot_action_processor`` to the robot. Works identically for
sync and async backends — the strategy never needs to branch.
sync and async backends — the rollout strategy never needs to branch.
Returns the action dict that was sent, or ``None`` if no action was
ready (e.g. empty async queue, interpolator not yet primed).
"""
engine = ctx.policy.inference
features = ctx.data.dataset_features
ordered_keys = ctx.data.ordered_action_keys
if interpolator.needs_new_action():
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action(obs_frame)
+3 -6
View File
@@ -407,7 +407,7 @@ class DAggerStrategy(RolloutStrategy):
private=ctx.runtime.cfg.dataset.private,
)
self._teardown_hardware(ctx)
self._teardown_hardware(ctx.hardware)
logger.info("DAgger strategy teardown complete")
# ------------------------------------------------------------------
@@ -429,8 +429,7 @@ class DAggerStrategy(RolloutStrategy):
record_stride = max(1, cfg.interpolation_multiplier)
record_autonomous = self.config.record_autonomous
ordered_keys = ctx.data.ordered_action_keys
features = dataset.features
features = ctx.data.dataset_features
engine.reset()
interpolator.reset()
@@ -499,9 +498,7 @@ class DAggerStrategy(RolloutStrategy):
timestamp = time.perf_counter() - start_t
continue
action_dict = send_next_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
last_action = ctx.processors.robot_action_processor((action_dict, obs))
+46 -18
View File
@@ -21,6 +21,7 @@ import logging
import os
import sys
import time
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Event as ThreadingEvent
from lerobot.common.control_utils import is_headless
@@ -75,7 +76,9 @@ class HighlightStrategy(RolloutStrategy):
self._listener = None
self._save_requested = ThreadingEvent()
self._recording_live = ThreadingEvent()
self._shutdown_event: ThreadingEvent | None = None
self._push_requested = ThreadingEvent()
self._push_executor: ThreadPoolExecutor | None = None
self._pending_push: Future | None = None
def setup(self, ctx: RolloutContext) -> None:
self._init_engine(ctx)
@@ -86,12 +89,13 @@ class HighlightStrategy(RolloutStrategy):
fps=ctx.runtime.cfg.fps,
)
self._shutdown_event = ctx.runtime.shutdown_event
self._setup_keyboard()
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push")
self._setup_keyboard(ctx.runtime.shutdown_event)
logger.info(
"Highlight strategy ready (buffer=%.0fs, key='%s')",
"Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')",
self.config.ring_buffer_seconds,
self.config.save_key,
self.config.push_key,
)
def run(self, ctx: RolloutContext) -> None:
@@ -101,10 +105,9 @@ class HighlightStrategy(RolloutStrategy):
dataset = ctx.data.dataset
ring = self._ring
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
ordered_keys = ctx.data.ordered_action_keys
features = dataset.features
engine.resume()
@@ -126,9 +129,7 @@ class HighlightStrategy(RolloutStrategy):
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
@@ -158,9 +159,10 @@ class HighlightStrategy(RolloutStrategy):
dataset.save_episode()
logger.info("Episode saved")
self._recording_live.clear()
engine.reset()
interpolator.reset()
engine.resume()
if self._push_requested.is_set():
self._push_requested.clear()
self._background_push(dataset, cfg)
if self._recording_live.is_set():
dataset.add_frame(frame)
@@ -180,6 +182,10 @@ class HighlightStrategy(RolloutStrategy):
if self._listener is not None:
self._listener.stop()
if self._push_executor is not None:
self._push_executor.shutdown(wait=True)
self._push_executor = None
if ctx.data.dataset is not None:
ctx.data.dataset.finalize()
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
@@ -188,28 +194,50 @@ class HighlightStrategy(RolloutStrategy):
private=ctx.runtime.cfg.dataset.private,
)
self._teardown_hardware(ctx)
self._teardown_hardware(ctx.hardware)
logger.info("Highlight strategy teardown complete")
def _setup_keyboard(self) -> None:
"""Set up keyboard listener for the save key."""
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
"""Set up keyboard listener for save and push keys."""
if is_headless():
logger.warning("Headless environment — highlight save key unavailable")
logger.warning("Headless environment — highlight keys unavailable")
return
try:
save_key = self.config.save_key
push_key = self.config.push_key
def on_press(key):
with contextlib.suppress(Exception):
if hasattr(key, "char") and key.char == save_key:
self._save_requested.set()
elif hasattr(key, "char") and key.char == push_key:
self._push_requested.set()
elif key == keyboard.Key.esc:
self._save_requested.clear()
if self._shutdown_event is not None:
self._shutdown_event.set()
shutdown_event.set()
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
except ImportError:
logger.warning("pynput not available — keyboard listener disabled")
def _background_push(self, dataset, cfg) -> None:
"""Queue a Hub push on the single-worker executor."""
if self._push_executor is None:
return
if self._pending_push is not None and not self._pending_push.done():
logger.info("Previous push still in progress; queueing next")
def _push():
try:
dataset.push_to_hub(
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)
self._pending_push = self._push_executor.submit(_push)
+3 -6
View File
@@ -73,10 +73,9 @@ class SentryStrategy(RolloutStrategy):
robot = ctx.hardware.robot_wrapper
dataset = ctx.data.dataset
interpolator = self._interpolator
features = ctx.data.dataset_features
control_interval = interpolator.get_control_interval(cfg.fps)
ordered_keys = ctx.data.ordered_action_keys
features = dataset.features
engine.resume()
@@ -100,9 +99,7 @@ class SentryStrategy(RolloutStrategy):
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
@@ -156,7 +153,7 @@ class SentryStrategy(RolloutStrategy):
private=ctx.runtime.cfg.dataset.private,
)
self._teardown_hardware(ctx)
self._teardown_hardware(ctx.hardware)
logger.info("Sentry strategy teardown complete")
def _background_push(self, dataset, cfg) -> None:
+2 -59
View File
@@ -67,8 +67,7 @@ lerobot-record \\
import logging
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from dataclasses import asdict, dataclass
from pprint import pformat
from lerobot.cameras import CameraConfig # noqa: F401
@@ -82,6 +81,7 @@ from lerobot.common.control_utils import (
sanity_check_dataset_robot_compatibility,
)
from lerobot.configs import parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.datasets import (
LeRobotDataset,
VideoEncodingManager,
@@ -137,63 +137,6 @@ from lerobot.utils.utils import (
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
root: str | Path | None = None
# Limit the frames per second.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
# Use 'auto' to auto-detect the best available hardware encoder.
vcodec: str = "libsvtav1"
# Enable streaming video encoding: encode frames in real-time during capture instead
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
streaming_encoding: bool = False
# Maximum number of frames to buffer per camera when using streaming encoding.
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
encoder_queue_maxsize: int = 30
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
if self.single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
@dataclass
class RecordConfig:
robot: RobotConfig
+116
View File
@@ -0,0 +1,116 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action interpolation for smoother robot control.
Provides configurable Nx control rate by interpolating between consecutive actions.
Useful with RTC and action-chunking policies to reduce jerkiness.
"""
from torch import Tensor
class ActionInterpolator:
"""Interpolates between consecutive actions for smoother control.
When enabled with multiplier N, produces N actions per policy action
by linearly interpolating between the previous and current action.
Example with multiplier=3:
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
This effectively multiplies the control rate for smoother motion.
Usage:
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
# In control loop:
if interpolator.needs_new_action():
new_action = queue.get()
if new_action:
interpolator.add(new_action.cpu())
action = interpolator.get()
if action:
robot.send_action(action)
"""
def __init__(self, multiplier: int = 1):
"""Initialize the interpolator.
Args:
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
"""
if multiplier < 1:
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
self.multiplier = multiplier
self._prev: Tensor | None = None
self._buffer: list[Tensor] = []
self._idx = 0
@property
def enabled(self) -> bool:
"""Whether interpolation is active (multiplier > 1)."""
return self.multiplier > 1
def reset(self):
"""Reset interpolation state (call between episodes)."""
self._prev = None
self._buffer = []
self._idx = 0
def needs_new_action(self) -> bool:
"""Check if a new action is needed from the queue."""
return self._idx >= len(self._buffer)
def add(self, action: Tensor) -> None:
"""Add a new action and compute interpolated sequence.
Args:
action: New action tensor from policy/queue (already on CPU).
"""
if self.multiplier > 1 and self._prev is not None:
self._buffer = []
for i in range(1, self.multiplier + 1):
t = i / self.multiplier
interp = self._prev + t * (action - self._prev)
self._buffer.append(interp)
else:
# First step: no previous action yet, so run at base FPS without interpolation.
self._buffer = [action.clone()]
self._prev = action.clone()
self._idx = 0
def get(self) -> Tensor | None:
"""Get the next interpolated action.
Returns:
Next action tensor, or None if buffer is exhausted.
"""
if self._idx >= len(self._buffer):
return None
action = self._buffer[self._idx]
self._idx += 1
return action
def get_control_interval(self, fps: float) -> float:
"""Get the control interval based on interpolation multiplier.
Args:
fps: Base frames per second.
Returns:
Control interval in seconds (divided by multiplier).
"""
return 1.0 / (fps * self.multiplier)
@@ -17,9 +17,9 @@
import pytest
import torch
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.action_interpolator import ActionInterpolator
# ====================== Fixtures ======================
+2 -1
View File
@@ -21,8 +21,9 @@ import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
from lerobot.scripts.lerobot_record import RecordConfig, record
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
from tests.fixtures.constants import DUMMY_REPO_ID