mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 07:39:53 +00:00
818892a38b
* feat: HIL data collection, RTC interpolator, and action queue improvements - Add Human-in-the-Loop (HIL) data collection examples (sync + RTC) - Add HIL data collection documentation - Add ActionInterpolator for smoother policy control at higher rates - Integrate interpolator into lerobot-record and eval_with_real_robot - Add action queue clear() and get_processed_left_over() methods - Add rtc/__init__.py for cleaner imports * docs: expand Related Work section with paper summaries * fix: only record dataset frames at original fps, not at interpolated rate The interpolator speeds up robot control (e.g. 2x) but dataset frames should still be recorded at the original fps. Interpolated-only iterations now only send actions to the robot without writing to the dataset. * refactor: merge HIL sync and RTC scripts into single file with --rtc.enabled toggle Combines hil_data_collection.py and hil_data_collection_rtc.py into one script. RTC is toggled via --rtc.enabled=true (defaults to off for sync inference). Deletes the separate hil_data_collection_rtc.py and updates docs to reflect the single-script usage. * test: add ActionInterpolator test suite (29 tests) Covers constructor validation, passthrough (multiplier=1), 2x and 3x interpolation with exact value checks, reset/episode boundaries, control interval calculation, multi-dim actions, and simulated control loop integration. * test: add ActionQueue + ActionInterpolator integration tests Verifies the interpolator doesn't interfere with RTC's leftover chunk tracking: queue consumption rate matches base fps regardless of multiplier, get_left_over/get_processed_left_over only change on queue.get(), merge preserves smooth interpolation across chunks, and interpolator reset is independent of queue state. * feat: register SO follower/leader configs in HIL script Adds SOFollowerRobotConfig and SOLeaderTeleopConfig imports so SO100/SO101 robots can be used via --robot.type=so_follower and --teleop.type=so_leader. Updates docs accordingly. Made-with: Cursor * docs: remove em dashes from HIL documentation Made-with: Cursor * refactor: rename examples/rac to examples/hil Updates directory name and all references in docs and script docstrings. Made-with: Cursor * fix: encorperate pr feedback comments * refactor(tests): enhance ActionInterpolator test structure and add detailed docstrings * feedback pr and test fix * fix(test): pass correct real_delay in interpolator delay test The test was passing real_delay=0 and relying on _check_delays to silently override it with the index-based diff. Now passes real_delay=3 to match the 3 actions consumed during the simulated inference period. * fix pr feedback * ordering * update hil script * fix * default name * fix(bi_openarm): use kw_only=True to fix dataclass field ordering BiOpenArmFollowerConfig overrides `id` with a default, making it positional in the child — non-default `left_arm_config` then follows a default field, which Python dataclasses forbid. Adding kw_only=True (matching the parent RobotConfig) removes positional constraints. Made-with: Cursor * style: format long line in hil_data_collection.py Made-with: Cursor * pr feedback --------- Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
678 lines
26 KiB
Python
678 lines
26 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
|
|
|
This script demonstrates:
|
|
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
|
2. Consuming actions from the policy while the robot executes
|
|
3. Periodically requesting new action chunks in the background using threads
|
|
4. Managing action buffers and timing for real-time operation
|
|
|
|
For simulation environments, see eval_with_simulation.py
|
|
|
|
Usage:
|
|
# Run RTC with Real robot with RTC
|
|
uv run examples/rtc/eval_with_real_robot.py \
|
|
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
|
--policy.device=mps \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--robot.type=so100_follower \
|
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
--robot.id=so100_follower \
|
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
--task="Move green small object into the purple platform" \
|
|
--duration=120
|
|
|
|
# Run RTC with Real robot without RTC
|
|
uv run examples/rtc/eval_with_real_robot.py \
|
|
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
|
--policy.device=mps \
|
|
--rtc.enabled=false \
|
|
--robot.type=so100_follower \
|
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
--robot.id=so100_follower \
|
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
--task="Move green small object into the purple platform" \
|
|
--duration=120
|
|
|
|
# Run RTC with Real robot with pi0.5 policy
|
|
uv run examples/rtc/eval_with_real_robot.py \
|
|
--policy.path=<USER>/pi05_check_rtc \
|
|
--policy.device=mps \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--robot.type=so100_follower \
|
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
--robot.id=so100_follower \
|
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
|
--task="Move green small object into the purple platform" \
|
|
--duration=120
|
|
|
|
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
|
python examples/rtc/eval_with_real_robot.py \
|
|
--policy.path=lerobot-data-collection/folding_final \
|
|
--robot.type=bi_openarm_follower \
|
|
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
|
--robot.left_arm_config.port=can0 \
|
|
--robot.left_arm_config.side=left \
|
|
--robot.left_arm_config.can_interface=socketcan \
|
|
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
|
--robot.left_arm_config.max_relative_target=8.0 \
|
|
--robot.right_arm_config.port=can1 \
|
|
--robot.right_arm_config.side=right \
|
|
--robot.right_arm_config.can_interface=socketcan \
|
|
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
|
--robot.right_arm_config.max_relative_target=8.0 \
|
|
--task="Fold the T-shirt properly" \
|
|
--fps=30 \
|
|
--duration=2000 \
|
|
--interpolation_multiplier=3 \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--rtc.max_guidance_weight=5.0 \
|
|
--rtc.prefix_attention_schedule=LINEAR \
|
|
--device=cuda
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from threading import Event, Lock, Thread
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
|
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.configs.types import RTCAttentionSchedule
|
|
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
|
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
|
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
|
from lerobot.processor import (
|
|
NormalizerProcessorStep,
|
|
RelativeActionsProcessorStep,
|
|
TransitionKey,
|
|
create_transition,
|
|
)
|
|
from lerobot.processor.factory import (
|
|
make_default_robot_action_processor,
|
|
make_default_robot_observation_processor,
|
|
)
|
|
from lerobot.processor.relative_action_processor import to_relative_actions
|
|
from lerobot.rl.process import ProcessSignalHandler
|
|
from lerobot.robots import ( # noqa: F401
|
|
Robot,
|
|
RobotConfig,
|
|
bi_openarm_follower,
|
|
bi_so_follower,
|
|
koch_follower,
|
|
so_follower,
|
|
unitree_g1,
|
|
)
|
|
from lerobot.robots.utils import make_robot_from_config
|
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
|
from lerobot.utils.hub import HubMixin
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RobotWrapper:
|
|
def __init__(self, robot: Robot):
|
|
self.robot = robot
|
|
self.lock = Lock()
|
|
|
|
def get_observation(self) -> dict[str, Tensor]:
|
|
with self.lock:
|
|
return self.robot.get_observation()
|
|
|
|
def send_action(self, action: Tensor):
|
|
with self.lock:
|
|
self.robot.send_action(action)
|
|
|
|
def observation_features(self) -> list[str]:
|
|
with self.lock:
|
|
return self.robot.observation_features
|
|
|
|
def action_features(self) -> list[str]:
|
|
with self.lock:
|
|
return self.robot.action_features
|
|
|
|
|
|
@dataclass
|
|
class RTCDemoConfig(HubMixin):
|
|
"""Configuration for RTC demo with action chunking policies and real robots."""
|
|
|
|
# Policy configuration
|
|
policy: PreTrainedConfig | None = None
|
|
|
|
# Robot configuration
|
|
robot: RobotConfig | None = None
|
|
|
|
# RTC configuration
|
|
rtc: RTCConfig = field(
|
|
default_factory=lambda: RTCConfig(
|
|
execution_horizon=10,
|
|
max_guidance_weight=1.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
)
|
|
)
|
|
|
|
# Demo parameters
|
|
duration: float = 30.0 # Duration to run the demo (seconds)
|
|
fps: float = 10.0 # Action execution frequency (Hz)
|
|
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
|
|
|
# Compute device
|
|
device: str | None = None # Device to run on (cuda, cpu, auto)
|
|
|
|
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
|
# It should be higher than inference delay + execution horizon.
|
|
action_queue_size_to_get_new_actions: int = 30
|
|
|
|
# Task to execute
|
|
task: str = field(default="", metadata={"help": "Task to execute"})
|
|
|
|
# Torch compile configuration
|
|
use_torch_compile: bool = field(
|
|
default=False,
|
|
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
|
)
|
|
|
|
torch_compile_backend: str = field(
|
|
default="inductor",
|
|
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
|
)
|
|
|
|
torch_compile_mode: str = field(
|
|
default="default",
|
|
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
|
)
|
|
|
|
torch_compile_disable_cudagraphs: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
|
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
else:
|
|
raise ValueError("Policy path is required")
|
|
|
|
# Validate that robot configuration is provided
|
|
if self.robot is None:
|
|
raise ValueError("Robot configuration must be provided")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
|
|
def is_image_key(k: str) -> bool:
|
|
return k.startswith(OBS_IMAGES)
|
|
|
|
|
|
def _reanchor_relative_rtc_prefix(
|
|
prev_actions_absolute: Tensor,
|
|
current_state: Tensor,
|
|
relative_step: RelativeActionsProcessorStep,
|
|
normalizer_step: NormalizerProcessorStep | None,
|
|
policy_device: torch.device | str,
|
|
) -> Tensor:
|
|
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
|
|
|
When a policy uses relative actions, the RTC prefix (leftover actions from
|
|
the previous chunk) is stored in absolute space. Before feeding it back to
|
|
the policy we need to re-express it relative to the *current* robot state
|
|
and then re-normalize.
|
|
"""
|
|
state = current_state.detach().cpu()
|
|
if state.dim() == 1:
|
|
state = state.unsqueeze(0)
|
|
|
|
action_cpu = prev_actions_absolute.detach().cpu()
|
|
mask = relative_step._build_mask(action_cpu.shape[-1])
|
|
relative_actions = to_relative_actions(action_cpu, state, mask)
|
|
|
|
transition = create_transition(action=relative_actions)
|
|
if normalizer_step is not None:
|
|
transition = normalizer_step(transition)
|
|
|
|
return transition[TransitionKey.ACTION].to(policy_device)
|
|
|
|
|
|
def get_actions(
|
|
policy,
|
|
robot: RobotWrapper,
|
|
robot_observation_processor,
|
|
action_queue: ActionQueue,
|
|
shutdown_event: Event,
|
|
cfg: RTCDemoConfig,
|
|
):
|
|
"""Thread function to request action chunks from the policy.
|
|
|
|
Args:
|
|
policy: The policy instance (SmolVLA, Pi0, etc.)
|
|
robot: The robot instance for getting observations
|
|
robot_observation_processor: Processor for raw robot observations
|
|
action_queue: Queue to put new action chunks
|
|
shutdown_event: Event to signal shutdown
|
|
cfg: Demo configuration
|
|
"""
|
|
try:
|
|
logger.info("[GET_ACTIONS] Starting get actions thread")
|
|
|
|
latency_tracker = LatencyTracker() # Track latency of action chunks
|
|
fps = cfg.fps
|
|
time_per_chunk = 1.0 / fps
|
|
|
|
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
|
# not the full pos/vel/torque state the robot exposes.
|
|
observation_features_hw = {
|
|
key: value
|
|
for key, value in robot.observation_features().items()
|
|
if key.endswith(".pos") or isinstance(value, tuple)
|
|
}
|
|
|
|
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
|
policy_device = policy.config.device
|
|
|
|
# Load preprocessor and postprocessor from pretrained files
|
|
# The stats are embedded in the processor .safetensors files
|
|
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
|
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg.policy,
|
|
pretrained_path=cfg.policy.pretrained_path,
|
|
dataset_stats=None, # Will load from pretrained processor files
|
|
preprocessor_overrides={
|
|
"device_processor": {"device": cfg.policy.device},
|
|
},
|
|
)
|
|
|
|
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
|
|
|
relative_step = next(
|
|
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
|
None,
|
|
)
|
|
normalizer_step = next(
|
|
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
|
None,
|
|
)
|
|
if relative_step is not None:
|
|
if relative_step.action_names is None:
|
|
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
|
if cfg_names:
|
|
relative_step.action_names = list(cfg_names)
|
|
else:
|
|
relative_step.action_names = [
|
|
k for k in robot.robot.action_features if k.endswith(".pos")
|
|
]
|
|
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
|
|
|
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
|
|
|
if not cfg.rtc.enabled:
|
|
get_actions_threshold = 0
|
|
|
|
while not shutdown_event.is_set():
|
|
if action_queue.qsize() <= get_actions_threshold:
|
|
current_time = time.perf_counter()
|
|
action_index_before_inference = action_queue.get_action_index()
|
|
prev_actions = action_queue.get_left_over()
|
|
|
|
inference_latency = latency_tracker.max()
|
|
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
|
|
|
obs = robot.get_observation()
|
|
|
|
# Apply robot observation processor
|
|
obs_processed = robot_observation_processor(obs)
|
|
|
|
obs_with_policy_features = build_dataset_frame(
|
|
dataset_features, obs_processed, prefix="observation"
|
|
)
|
|
|
|
for name in obs_with_policy_features:
|
|
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
|
if "image" in name:
|
|
obs_with_policy_features[name] = (
|
|
obs_with_policy_features[name].type(torch.float32) / 255
|
|
)
|
|
obs_with_policy_features[name] = (
|
|
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
|
)
|
|
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
|
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
|
|
|
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
|
obs_with_policy_features["robot_type"] = (
|
|
robot.robot.name if hasattr(robot.robot, "name") else ""
|
|
)
|
|
|
|
preproceseded_obs = preprocessor(obs_with_policy_features)
|
|
|
|
# Re-anchor leftover actions for relative-action policies.
|
|
# We need the *postprocessed* (absolute) leftover, not the original
|
|
# (normalized/relative) one that get_left_over() returns.
|
|
if (
|
|
prev_actions is not None
|
|
and relative_step is not None
|
|
and OBS_STATE in obs_with_policy_features
|
|
):
|
|
with action_queue.lock:
|
|
if action_queue.queue is not None:
|
|
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
|
else:
|
|
prev_actions_abs = None
|
|
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
|
prev_actions = _reanchor_relative_rtc_prefix(
|
|
prev_actions_absolute=prev_actions_abs,
|
|
current_state=obs_with_policy_features[OBS_STATE],
|
|
relative_step=relative_step,
|
|
normalizer_step=normalizer_step,
|
|
policy_device=policy_device,
|
|
)
|
|
|
|
# Generate actions WITH RTC
|
|
actions = policy.predict_action_chunk(
|
|
preproceseded_obs,
|
|
inference_delay=inference_delay,
|
|
prev_chunk_left_over=prev_actions,
|
|
)
|
|
|
|
# Store original actions (before postprocessing) for RTC
|
|
original_actions = actions.squeeze(0).clone()
|
|
|
|
postprocessed_actions = postprocessor(actions)
|
|
|
|
postprocessed_actions = postprocessed_actions.squeeze(0)
|
|
|
|
new_latency = time.perf_counter() - current_time
|
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
latency_tracker.add(new_latency)
|
|
|
|
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
|
logger.warning(
|
|
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
|
)
|
|
|
|
action_queue.merge(
|
|
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
|
)
|
|
else:
|
|
# Small sleep to prevent busy waiting
|
|
time.sleep(0.1)
|
|
|
|
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
|
except Exception as e:
|
|
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
|
logger.error(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
|
|
def actor_control(
|
|
robot: RobotWrapper,
|
|
robot_action_processor,
|
|
action_queue: ActionQueue,
|
|
shutdown_event: Event,
|
|
cfg: RTCDemoConfig,
|
|
):
|
|
"""Thread function to execute actions on the robot.
|
|
|
|
Args:
|
|
robot: The robot instance
|
|
action_queue: Queue to get actions from
|
|
shutdown_event: Event to signal shutdown
|
|
cfg: Demo configuration
|
|
"""
|
|
try:
|
|
logger.info("[ACTOR] Starting actor thread")
|
|
|
|
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
|
|
|
action_count = 0
|
|
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
|
action_interval = interpolator.get_control_interval(cfg.fps)
|
|
|
|
while not shutdown_event.is_set():
|
|
start_time = time.perf_counter()
|
|
|
|
if interpolator.needs_new_action():
|
|
new_action = action_queue.get()
|
|
if new_action is not None:
|
|
interpolator.add(new_action.cpu())
|
|
|
|
action = interpolator.get()
|
|
if action is not None:
|
|
action = action.cpu()
|
|
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
|
action_processed = robot_action_processor((action_dict, None))
|
|
robot.send_action(action_processed)
|
|
action_count += 1
|
|
|
|
dt_s = time.perf_counter() - start_time
|
|
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
|
|
|
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
|
except Exception as e:
|
|
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
|
logger.error(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
|
|
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
|
"""Apply torch.compile to the policy's predict_action_chunk method.
|
|
|
|
Args:
|
|
policy: Policy instance to compile
|
|
cfg: Configuration containing torch compile settings
|
|
|
|
Returns:
|
|
Policy with compiled predict_action_chunk method
|
|
"""
|
|
|
|
# PI models handle their own compilation
|
|
if policy.type == "pi05" or policy.type == "pi0":
|
|
return policy
|
|
|
|
try:
|
|
# Check if torch.compile is available (PyTorch 2.0+)
|
|
if not hasattr(torch, "compile"):
|
|
logger.warning(
|
|
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
|
f"Current version: {torch.__version__}. Skipping compilation."
|
|
)
|
|
return policy
|
|
|
|
logger.info("Applying torch.compile to predict_action_chunk...")
|
|
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
|
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
|
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
|
|
|
# Compile the predict_action_chunk method
|
|
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
|
compile_kwargs = {
|
|
"backend": cfg.torch_compile_backend,
|
|
"mode": cfg.torch_compile_mode,
|
|
}
|
|
|
|
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
|
if cfg.torch_compile_disable_cudagraphs:
|
|
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
|
|
|
original_method = policy.predict_action_chunk
|
|
compiled_method = torch.compile(original_method, **compile_kwargs)
|
|
policy.predict_action_chunk = compiled_method
|
|
logger.info("✓ Successfully compiled predict_action_chunk")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to apply torch.compile: {e}")
|
|
logger.warning("Continuing without torch.compile")
|
|
|
|
return policy
|
|
|
|
|
|
@parser.wrap()
|
|
def demo_cli(cfg: RTCDemoConfig):
|
|
"""Main entry point for RTC demo with draccus configuration."""
|
|
|
|
# Initialize logging
|
|
init_logging()
|
|
|
|
logger.info(f"Using device: {cfg.device}")
|
|
|
|
# Setup signal handler for graceful shutdown
|
|
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
|
shutdown_event = signal_handler.shutdown_event
|
|
|
|
policy = None
|
|
robot = None
|
|
get_actions_thread = None
|
|
actor_thread = None
|
|
|
|
policy_class = get_policy_class(cfg.policy.type)
|
|
|
|
# Load config and set compile_model for pi0/pi05 models
|
|
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
|
|
|
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
|
config.compile_model = cfg.use_torch_compile
|
|
|
|
if config.use_peft:
|
|
from peft import PeftConfig, PeftModel
|
|
|
|
peft_pretrained_path = cfg.policy.pretrained_path
|
|
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
|
|
|
policy = policy_class.from_pretrained(
|
|
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
|
)
|
|
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
|
else:
|
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
|
|
|
# Turn on RTC
|
|
policy.config.rtc_config = cfg.rtc
|
|
|
|
# Init RTC processort, as by default if RTC disabled in the config
|
|
# The processor won't be created
|
|
policy.init_rtc_processor()
|
|
|
|
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
|
|
|
policy = policy.to(cfg.device)
|
|
policy.eval()
|
|
|
|
# Apply torch.compile to predict_action_chunk method if enabled
|
|
if cfg.use_torch_compile:
|
|
policy = _apply_torch_compile(policy, cfg)
|
|
|
|
# Create robot
|
|
logger.info(f"Initializing robot: {cfg.robot.type}")
|
|
robot = make_robot_from_config(cfg.robot)
|
|
robot.connect()
|
|
robot_wrapper = RobotWrapper(robot)
|
|
|
|
# Create robot observation processor
|
|
robot_observation_processor = make_default_robot_observation_processor()
|
|
robot_action_processor = make_default_robot_action_processor()
|
|
|
|
# Create action queue for communication between threads
|
|
action_queue = ActionQueue(cfg.rtc)
|
|
|
|
# Start chunk requester thread
|
|
get_actions_thread = Thread(
|
|
target=get_actions,
|
|
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
|
daemon=True,
|
|
name="GetActions",
|
|
)
|
|
get_actions_thread.start()
|
|
logger.info("Started get actions thread")
|
|
|
|
# Start action executor thread
|
|
actor_thread = Thread(
|
|
target=actor_control,
|
|
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
|
daemon=True,
|
|
name="Actor",
|
|
)
|
|
actor_thread.start()
|
|
logger.info("Started actor thread")
|
|
|
|
logger.info("Started stop by duration thread")
|
|
|
|
# Main thread monitors for duration or shutdown
|
|
logger.info(f"Running demo for {cfg.duration} seconds...")
|
|
start_time = time.time()
|
|
|
|
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
|
time.sleep(10)
|
|
|
|
# Log queue status periodically
|
|
if int(time.time() - start_time) % 5 == 0:
|
|
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
|
|
|
if time.time() - start_time > cfg.duration:
|
|
break
|
|
|
|
logger.info("Demo duration reached or shutdown requested")
|
|
|
|
# Signal shutdown
|
|
shutdown_event.set()
|
|
|
|
# Wait for threads to finish
|
|
if get_actions_thread and get_actions_thread.is_alive():
|
|
logger.info("Waiting for chunk requester thread to finish...")
|
|
get_actions_thread.join()
|
|
|
|
if actor_thread and actor_thread.is_alive():
|
|
logger.info("Waiting for action executor thread to finish...")
|
|
actor_thread.join()
|
|
|
|
# Cleanup robot
|
|
if robot:
|
|
robot.disconnect()
|
|
logger.info("Robot disconnected")
|
|
|
|
logger.info("Cleanup completed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo_cli()
|
|
logging.info("RTC demo finished")
|