mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
some more iterations
This commit is contained in:
@@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies
|
||||
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
"""
|
||||
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .types import (
|
||||
@@ -39,6 +40,7 @@ __all__ = [
|
||||
"PolicyFeature",
|
||||
"RTCAttentionSchedule",
|
||||
# Config classes
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"PeftConfig",
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str = ""
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str = ""
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
@@ -21,7 +23,6 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .rtc import ActionInterpolator as ActionInterpolator
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
|
||||
@@ -1,116 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
__all__ = ["ActionInterpolator"]
|
||||
|
||||
@@ -17,21 +17,21 @@
|
||||
from .configs import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
DatasetRecordConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutDatasetConfig,
|
||||
RolloutStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
from .context import RolloutContext, build_rollout_context
|
||||
from .inference import (
|
||||
InferenceStrategy,
|
||||
InferenceStrategyConfig,
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
RTCInferenceConfig,
|
||||
RTCInferenceStrategy,
|
||||
RTCInferenceEngine,
|
||||
SyncInferenceConfig,
|
||||
SyncInferenceStrategy,
|
||||
create_inference_strategy,
|
||||
SyncInferenceEngine,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .ring_buffer import RolloutRingBuffer
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
@@ -41,21 +41,21 @@ __all__ = [
|
||||
"BaseStrategyConfig",
|
||||
"DAggerStrategyConfig",
|
||||
"HighlightStrategyConfig",
|
||||
"InferenceStrategy",
|
||||
"InferenceStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceStrategy",
|
||||
"RTCInferenceEngine",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutDatasetConfig",
|
||||
"DatasetRecordConfig",
|
||||
"RolloutRingBuffer",
|
||||
"RolloutStrategy",
|
||||
"RolloutStrategyConfig",
|
||||
"SentryStrategyConfig",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceStrategy",
|
||||
"SyncInferenceEngine",
|
||||
"ThreadSafeRobot",
|
||||
"build_rollout_context",
|
||||
"create_inference_strategy",
|
||||
"create_inference_engine",
|
||||
"create_strategy",
|
||||
]
|
||||
|
||||
@@ -19,15 +19,15 @@ from __future__ import annotations
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import PreTrainedConfig, parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
|
||||
from .inference import InferenceStrategyConfig, SyncInferenceConfig
|
||||
from .inference import InferenceEngineConfig, SyncInferenceConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,6 +84,7 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
|
||||
ring_buffer_seconds: float = 30.0
|
||||
ring_buffer_max_memory_mb: float = 2048.0
|
||||
save_key: str = "s"
|
||||
push_key: str = "h"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@@ -109,33 +110,6 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
record_autonomous: bool = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset recording config (shared across recording strategies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutDatasetConfig:
|
||||
"""Dataset configuration for rollout strategies that record data."""
|
||||
|
||||
repo_id: str = ""
|
||||
single_task: str = ""
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
vcodec: str = "auto"
|
||||
streaming_encoding: bool = True
|
||||
encoder_queue_maxsize: int = 30
|
||||
encoder_threads: int | None = None
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level rollout config
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -161,10 +135,10 @@ class RolloutConfig:
|
||||
strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig)
|
||||
|
||||
# Inference backend (polymorphic: --inference.type=sync|rtc)
|
||||
inference: InferenceStrategyConfig = field(default_factory=SyncInferenceConfig)
|
||||
inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig)
|
||||
|
||||
# Dataset (required for sentry, highlight, dagger; None for base)
|
||||
dataset: RolloutDatasetConfig | None = None
|
||||
dataset: DatasetRecordConfig | None = None
|
||||
|
||||
# Runtime
|
||||
fps: float = 30.0
|
||||
|
||||
@@ -49,9 +49,9 @@ from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_fea
|
||||
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceStrategy,
|
||||
InferenceEngine,
|
||||
RTCInferenceConfig,
|
||||
create_inference_strategy,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
|
||||
@@ -106,12 +106,12 @@ class HardwareContext:
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""Loaded policy and its inference strategy."""
|
||||
"""Loaded policy and its inference engine."""
|
||||
|
||||
policy: PreTrainedPolicy
|
||||
preprocessor: PolicyProcessorPipeline
|
||||
postprocessor: PolicyProcessorPipeline
|
||||
inference: InferenceStrategy
|
||||
inference: InferenceEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -159,7 +159,7 @@ def build_rollout_context(
|
||||
robot_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None,
|
||||
) -> RolloutContext:
|
||||
"""Wire up policy, processors, hardware, dataset, and inference strategy.
|
||||
"""Wire up policy, processors, hardware, dataset, and inference engine.
|
||||
|
||||
The order is policy-first / hardware-last so a bad ``--policy.path``
|
||||
fails fast without touching the robot.
|
||||
@@ -329,7 +329,7 @@ def build_rollout_context(
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
inference_strategy = create_inference_strategy(
|
||||
inference_strategy = create_inference_engine(
|
||||
cfg.inference,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -12,28 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference strategy package — backend-agnostic action production.
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete strategies (sync, RTC, …) expose the same small interface so
|
||||
rollout strategies never branch on the inference backend.
|
||||
"""
|
||||
|
||||
from .base import InferenceStrategy
|
||||
from .base import InferenceEngine
|
||||
from .factory import (
|
||||
InferenceStrategyConfig,
|
||||
InferenceEngineConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_strategy,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .rtc import RTCInferenceStrategy
|
||||
from .sync import SyncInferenceStrategy
|
||||
from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
"InferenceStrategy",
|
||||
"InferenceStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceStrategy",
|
||||
"RTCInferenceEngine",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceStrategy",
|
||||
"create_inference_strategy",
|
||||
"SyncInferenceEngine",
|
||||
"create_inference_engine",
|
||||
]
|
||||
|
||||
@@ -12,11 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference strategy ABC.
|
||||
"""Inference engine ABC.
|
||||
|
||||
Rollout strategies consume actions through this small interface so they
|
||||
do not need to know whether inference is synchronous, runs in a
|
||||
background thread (RTC), or comes from an external source.
|
||||
do not need to know whether the inference engine is synchronous, runs in
|
||||
a background thread (RTC), or comes from an external source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -26,7 +26,7 @@ import abc
|
||||
import torch
|
||||
|
||||
|
||||
class InferenceStrategy(abc.ABC):
|
||||
class InferenceEngine(abc.ABC):
|
||||
"""Abstract backend for producing actions during rollout.
|
||||
|
||||
Subclasses decide whether inference happens inline, in a background
|
||||
|
||||
@@ -12,11 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference strategy configs and factory.
|
||||
"""Inference engine configs and factory.
|
||||
|
||||
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
||||
backend requires registering its config subclass and dispatching it in
|
||||
:func:`create_inference_strategy`.
|
||||
:func:`create_inference_engine`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -33,9 +33,9 @@ from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .base import InferenceStrategy
|
||||
from .rtc import RTCInferenceStrategy
|
||||
from .sync import SyncInferenceStrategy
|
||||
from .base import InferenceEngine
|
||||
from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
"""Abstract base for inference backend configuration.
|
||||
|
||||
Use ``--inference.type=<name>`` on the CLI to select a backend.
|
||||
@@ -57,15 +57,15 @@ class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@InferenceStrategyConfig.register_subclass("sync")
|
||||
@InferenceEngineConfig.register_subclass("sync")
|
||||
@dataclass
|
||||
class SyncInferenceConfig(InferenceStrategyConfig):
|
||||
class SyncInferenceConfig(InferenceEngineConfig):
|
||||
"""Inline synchronous inference (one policy call per control tick)."""
|
||||
|
||||
|
||||
@InferenceStrategyConfig.register_subclass("rtc")
|
||||
@InferenceEngineConfig.register_subclass("rtc")
|
||||
@dataclass
|
||||
class RTCInferenceConfig(InferenceStrategyConfig):
|
||||
class RTCInferenceConfig(InferenceEngineConfig):
|
||||
"""Real-Time Chunking: async policy inference in a background thread."""
|
||||
|
||||
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
|
||||
@@ -80,8 +80,8 @@ class RTCInferenceConfig(InferenceStrategyConfig):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_inference_strategy(
|
||||
config: InferenceStrategyConfig,
|
||||
def create_inference_engine(
|
||||
config: InferenceEngineConfig,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
@@ -96,10 +96,10 @@ def create_inference_strategy(
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> InferenceStrategy:
|
||||
"""Instantiate the appropriate inference strategy from a config object."""
|
||||
) -> InferenceEngine:
|
||||
"""Instantiate the appropriate inference engine from a config object."""
|
||||
if isinstance(config, SyncInferenceConfig):
|
||||
return SyncInferenceStrategy(
|
||||
return SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
@@ -110,7 +110,7 @@ def create_inference_strategy(
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
)
|
||||
if isinstance(config, RTCInferenceConfig):
|
||||
return RTCInferenceStrategy(
|
||||
return RTCInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
@@ -125,4 +125,4 @@ def create_inference_strategy(
|
||||
rtc_queue_threshold=config.queue_threshold,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
raise ValueError(f"Unknown inference strategy type: {type(config).__name__}")
|
||||
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Real-Time Chunking inference strategy.
|
||||
"""Real-Time Chunking inference engine.
|
||||
|
||||
A background thread produces action chunks asynchronously via
|
||||
:meth:`policy.predict_action_chunk`. The main control loop polls
|
||||
@@ -47,7 +47,7 @@ from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .base import InferenceStrategy
|
||||
from .base import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -104,11 +104,11 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTCInferenceStrategy
|
||||
# RTCInferenceEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RTCInferenceStrategy(InferenceStrategy):
|
||||
class RTCInferenceEngine(InferenceEngine):
|
||||
"""Async RTC inference: a background thread produces action chunks.
|
||||
|
||||
``get_action`` pops the next action from the shared queue (or
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Synchronous inference strategy: inline policy call per control tick."""
|
||||
"""Synchronous inference engine: inline policy call per control tick."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -26,12 +26,12 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
from .base import InferenceStrategy
|
||||
from .base import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncInferenceStrategy(InferenceStrategy):
|
||||
class SyncInferenceEngine(InferenceEngine):
|
||||
"""Inline synchronous inference: compute one action per call.
|
||||
|
||||
``get_action`` runs the full policy pipeline (pre/post-processor +
|
||||
|
||||
@@ -45,7 +45,6 @@ class BaseStrategy(RolloutStrategy):
|
||||
interpolator = self._interpolator
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
ordered_keys = ctx.data.ordered_action_keys
|
||||
|
||||
start_time = time.perf_counter()
|
||||
engine.resume()
|
||||
@@ -63,14 +62,12 @@ class BaseStrategy(RolloutStrategy):
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
send_next_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.data.dataset_features
|
||||
)
|
||||
send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
self._teardown_hardware(ctx)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Base strategy teardown complete")
|
||||
|
||||
@@ -20,16 +20,16 @@ import abc
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..inference import InferenceStrategy
|
||||
from ..inference import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..configs import RolloutStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..context import HardwareContext, RolloutContext
|
||||
|
||||
|
||||
class RolloutStrategy(abc.ABC):
|
||||
@@ -42,12 +42,12 @@ class RolloutStrategy(abc.ABC):
|
||||
|
||||
def __init__(self, config: RolloutStrategyConfig) -> None:
|
||||
self.config = config
|
||||
self._engine: InferenceStrategy | None = None
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Attach the inference strategy + interpolator and start the backend.
|
||||
"""Attach the inference engine + interpolator and start the backend.
|
||||
|
||||
Call this from ``setup()`` so strategies share identical setup
|
||||
without duplicating code.
|
||||
@@ -80,14 +80,14 @@ class RolloutStrategy(abc.ABC):
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _teardown_hardware(self, ctx: RolloutContext) -> None:
|
||||
def _teardown_hardware(self, hw: HardwareContext) -> None:
|
||||
"""Stop the inference engine and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
robot = ctx.hardware.robot_wrapper.inner
|
||||
robot = hw.robot_wrapper.inner
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
teleop = ctx.hardware.teleop
|
||||
teleop = hw.teleop
|
||||
if teleop is not None and teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
@@ -110,24 +110,25 @@ class RolloutStrategy(abc.ABC):
|
||||
|
||||
|
||||
def send_next_action(
|
||||
engine: InferenceStrategy,
|
||||
obs_processed: dict,
|
||||
obs_raw: dict,
|
||||
ctx: RolloutContext,
|
||||
interpolator: ActionInterpolator,
|
||||
ordered_keys: list[str],
|
||||
features: dict,
|
||||
) -> dict | None:
|
||||
"""Dispatch the next action to the robot.
|
||||
|
||||
Pulls the next action tensor from the inference strategy, feeds the
|
||||
Pulls the next action tensor from the inference engine, feeds the
|
||||
interpolator, and sends the interpolated action through the
|
||||
``robot_action_processor`` to the robot. Works identically for
|
||||
sync and async backends — the strategy never needs to branch.
|
||||
sync and async backends — the rollout strategy never needs to branch.
|
||||
|
||||
Returns the action dict that was sent, or ``None`` if no action was
|
||||
ready (e.g. empty async queue, interpolator not yet primed).
|
||||
"""
|
||||
engine = ctx.policy.inference
|
||||
features = ctx.data.dataset_features
|
||||
ordered_keys = ctx.data.ordered_action_keys
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action(obs_frame)
|
||||
|
||||
@@ -407,7 +407,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
)
|
||||
|
||||
self._teardown_hardware(ctx)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -429,8 +429,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
record_autonomous = self.config.record_autonomous
|
||||
|
||||
ordered_keys = ctx.data.ordered_action_keys
|
||||
features = dataset.features
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
@@ -499,9 +498,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
timestamp = time.perf_counter() - start_t
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
|
||||
@@ -21,6 +21,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event as ThreadingEvent
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
@@ -75,7 +76,9 @@ class HighlightStrategy(RolloutStrategy):
|
||||
self._listener = None
|
||||
self._save_requested = ThreadingEvent()
|
||||
self._recording_live = ThreadingEvent()
|
||||
self._shutdown_event: ThreadingEvent | None = None
|
||||
self._push_requested = ThreadingEvent()
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
self._init_engine(ctx)
|
||||
@@ -86,12 +89,13 @@ class HighlightStrategy(RolloutStrategy):
|
||||
fps=ctx.runtime.cfg.fps,
|
||||
)
|
||||
|
||||
self._shutdown_event = ctx.runtime.shutdown_event
|
||||
self._setup_keyboard()
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push")
|
||||
self._setup_keyboard(ctx.runtime.shutdown_event)
|
||||
logger.info(
|
||||
"Highlight strategy ready (buffer=%.0fs, key='%s')",
|
||||
"Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')",
|
||||
self.config.ring_buffer_seconds,
|
||||
self.config.save_key,
|
||||
self.config.push_key,
|
||||
)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
@@ -101,10 +105,9 @@ class HighlightStrategy(RolloutStrategy):
|
||||
dataset = ctx.data.dataset
|
||||
ring = self._ring
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
ordered_keys = ctx.data.ordered_action_keys
|
||||
features = dataset.features
|
||||
|
||||
engine.resume()
|
||||
|
||||
@@ -126,9 +129,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
@@ -158,9 +159,10 @@ class HighlightStrategy(RolloutStrategy):
|
||||
dataset.save_episode()
|
||||
logger.info("Episode saved")
|
||||
self._recording_live.clear()
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
engine.resume()
|
||||
|
||||
if self._push_requested.is_set():
|
||||
self._push_requested.clear()
|
||||
self._background_push(dataset, cfg)
|
||||
|
||||
if self._recording_live.is_set():
|
||||
dataset.add_frame(frame)
|
||||
@@ -180,6 +182,10 @@ class HighlightStrategy(RolloutStrategy):
|
||||
if self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if self._push_executor is not None:
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
ctx.data.dataset.finalize()
|
||||
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
@@ -188,28 +194,50 @@ class HighlightStrategy(RolloutStrategy):
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
)
|
||||
|
||||
self._teardown_hardware(ctx)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Highlight strategy teardown complete")
|
||||
|
||||
def _setup_keyboard(self) -> None:
|
||||
"""Set up keyboard listener for the save key."""
|
||||
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
|
||||
"""Set up keyboard listener for save and push keys."""
|
||||
if is_headless():
|
||||
logger.warning("Headless environment — highlight save key unavailable")
|
||||
logger.warning("Headless environment — highlight keys unavailable")
|
||||
return
|
||||
|
||||
try:
|
||||
save_key = self.config.save_key
|
||||
push_key = self.config.push_key
|
||||
|
||||
def on_press(key):
|
||||
with contextlib.suppress(Exception):
|
||||
if hasattr(key, "char") and key.char == save_key:
|
||||
self._save_requested.set()
|
||||
elif hasattr(key, "char") and key.char == push_key:
|
||||
self._push_requested.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
self._save_requested.clear()
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
shutdown_event.set()
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
except ImportError:
|
||||
logger.warning("pynput not available — keyboard listener disabled")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor."""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
dataset.push_to_hub(
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
)
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
|
||||
@@ -73,10 +73,9 @@ class SentryStrategy(RolloutStrategy):
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
dataset = ctx.data.dataset
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
ordered_keys = ctx.data.ordered_action_keys
|
||||
features = dataset.features
|
||||
|
||||
engine.resume()
|
||||
|
||||
@@ -100,9 +99,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
@@ -156,7 +153,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
)
|
||||
|
||||
self._teardown_hardware(ctx)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Sentry strategy teardown complete")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
|
||||
@@ -67,8 +67,7 @@ lerobot-record \\
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.cameras import CameraConfig # noqa: F401
|
||||
@@ -82,6 +81,7 @@ from lerobot.common.control_utils import (
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
@@ -137,63 +137,6 @@ from lerobot.utils.utils import (
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecordConfig:
|
||||
robot: RobotConfig
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
@@ -17,9 +17,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@@ -21,8 +21,9 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
|
||||
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
|
||||
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
|
||||
from lerobot.scripts.lerobot_record import RecordConfig, record
|
||||
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
|
||||
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
Reference in New Issue
Block a user