mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
e670ac5daf
* Add basic support for PEFT adapter methods This changes adds support for training policies with much less parameters by applying adapter methods such as LoRA on specific parts of the policies and therefore possibly higher learning rates / batch sizes. To make this as accessible as possible I thought it useful to provide defaults for `target_modules` and `modules_to_save`. Currently only SmolVLA has such defaults but when we agree that this change is useful I will set out to generate more such defaults. While the user can override these settings, they are expected to only change the peft_method, rank and init_type parameters. * Implement loading of PEFT adapters Loading a PEFT adapter is currently done by initializing a policy with default config and then applying the adapter on the resulting model. This has the obvious drawback that any configurations done during training are not applied in the adapted model. Currently the `use_peft` attribute of `PreTrainedConfig` is only set during loading to signal the following code that it has to deal with a PEFT adapter. However we could imagine a scenario where this is already set at training time and stored alongside the adapter. * Store policy config alongside PEFT checkpoint Before this change the PEFT-wrapped policy did not save the policy's config alongside the adapter config / weights which prevented us from changing the policy config. Now the policy config is saved both in full training and PEFT training. This change makes loading the PEFT policy adapter much easier as well. * Add default config for ACT * Support targets like `all-linear` * Formatting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix failing tests * Remove PEFT compatibility changes in config We'll wait for the PEFT release that fixes this for good. * Remove `use_peft` parameter from training script Instead we make the PEFT config optional which has the same effect. * Log adapter config to WandB * Better documentation for CLI arguments * Don't unload & merge the PEFT model This can make things hard when using quantized layers (user expects quantized base layers with unquantized adapters for example, merging defaults to upcast the layers leading to higher memory). * Correct way of identifying when to save config * Add CLI end-to-end tests Currently there don't seem to be any way to test the CLI commands. Since this change mostly happens in those I thought it best to add a way to test these commands end-to-end. More integrated commands like `lerobot-record` need patching but standalone commands like training seem to work fine. * Update default targets Removed ACT since it doesn't make sense to fine-tune ACT without having it pretrained beforehand. SmolVLA and Pi0/0.5 are much more senseful targets. * Clean up loading code - Centralized instantiation of the PEFT wrapper in `make_policy` for inference (e.g. in `lerobot-record`) - Training a PEFT policy also sets `cfg.use_peft` so that all inference code loading the policy can rely on that attribute to identify if PEFT loading is needed - Modified RTC example to also include PEFT policies. Mostly because this is an example I'm currently exploring. * Make sure push_to_hub works Since PEFT only wraps `push_to_hub` and not `push_model_to_hub`, the reference to `self` in `policy.push_model_to_hub` is the unwrapped policy which, of course, doesn't know anything about PEFT. To make the upload process aware of PEFT, we pass the unwrapped policy down to `push_model_to_hub` as a kwarg. This is not ideal but I think it is the best way for now. * formatting * Warn when encountering from-scratch-training * Revamp pretrained model loading There were quite a few factors that convinced me that the status quo is able to load pretrained models from the PEFT adapter config but in fact that didn't work. This commit fixes the following things: - policies wrapped in PEFT will now have a `name_or_path` attribute containing the name or path of the pretrained model we're fine-tuning - we further assume that SmolVLA without `pretrained_path` and `load_vlm_weights==False` must be an user-side error - we assume that using PEFT on from-scratch-policies must be an user-side-error * Make it possible to unset policy features This is necessary to train pre-trained policies on new datasets so that the features are inferred from the new dataset and not from the pretrained policy. * Use correct loading for PEFT in RTC example * Make it possible to use PeftModels in eval * Add test checking that PEFT actually reduces params * Adapt state/action projections instead of full-finetuning There doesn't seem to be a benefit to fully fine-tune these layers over just adapting them, so we do that instead. * Disallow PEFT training on non-pretrained policies At first I thought it would make sense to have this feature in case you want to fine-tune a pre-trained section but in the end it makes more trouble than it's worth. It's still possible to allow this in the future when a concrete need arises. * Add basic documentation * Formatting * Add peft as extra dependency, mark tests Fast tests currently fail because of the missing dependency. * Fix pre-commit issues * Add walx <> peft conflict for uv * Exclude peft from pi install for now --------- Co-authored-by: nemo <git@ningu.net> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
561 lines
20 KiB
Python
561 lines
20 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
|
|
|
This script demonstrates:
|
|
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
|
2. Consuming actions from the policy while the robot executes
|
|
3. Periodically requesting new action chunks in the background using threads
|
|
4. Managing action buffers and timing for real-time operation
|
|
|
|
For simulation environments, see eval_with_simulation.py
|
|
|
|
Usage:
|
|
# Run RTC with Real robot with RTC
|
|
uv run examples/rtc/eval_with_real_robot.py \
|
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
|
--policy.device=mps \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--robot.type=so100_follower \
|
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
--robot.id=so100_follower \
|
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
--task="Move green small object into the purple platform" \
|
|
--duration=120
|
|
|
|
# Run RTC with Real robot without RTC
|
|
uv run examples/rtc/eval_with_real_robot.py \
|
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
|
--policy.device=mps \
|
|
--rtc.enabled=false \
|
|
--robot.type=so100_follower \
|
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
--robot.id=so100_follower \
|
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
--task="Move green small object into the purple platform" \
|
|
--duration=120
|
|
|
|
# Run RTC with Real robot with pi0.5 policy
|
|
uv run examples/rtc/eval_with_real_robot.py \
|
|
--policy.path=helper2424/pi05_check_rtc \
|
|
--policy.device=mps \
|
|
--rtc.enabled=true \
|
|
--rtc.execution_horizon=20 \
|
|
--robot.type=so100_follower \
|
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
--robot.id=so100_follower \
|
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
|
--task="Move green small object into the purple platform" \
|
|
--duration=120
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from threading import Event, Lock, Thread
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.configs.types import RTCAttentionSchedule
|
|
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
|
from lerobot.policies.rtc.action_queue import ActionQueue
|
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
|
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
|
from lerobot.processor.factory import (
|
|
make_default_robot_action_processor,
|
|
make_default_robot_observation_processor,
|
|
)
|
|
from lerobot.rl.process import ProcessSignalHandler
|
|
from lerobot.robots import ( # noqa: F401
|
|
Robot,
|
|
RobotConfig,
|
|
koch_follower,
|
|
so100_follower,
|
|
so101_follower,
|
|
)
|
|
from lerobot.robots.utils import make_robot_from_config
|
|
from lerobot.utils.constants import OBS_IMAGES
|
|
from lerobot.utils.hub import HubMixin
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RobotWrapper:
|
|
def __init__(self, robot: Robot):
|
|
self.robot = robot
|
|
self.lock = Lock()
|
|
|
|
def get_observation(self) -> dict[str, Tensor]:
|
|
with self.lock:
|
|
return self.robot.get_observation()
|
|
|
|
def send_action(self, action: Tensor):
|
|
with self.lock:
|
|
self.robot.send_action(action)
|
|
|
|
def observation_features(self) -> list[str]:
|
|
with self.lock:
|
|
return self.robot.observation_features
|
|
|
|
def action_features(self) -> list[str]:
|
|
with self.lock:
|
|
return self.robot.action_features
|
|
|
|
|
|
@dataclass
|
|
class RTCDemoConfig(HubMixin):
|
|
"""Configuration for RTC demo with action chunking policies and real robots."""
|
|
|
|
# Policy configuration
|
|
policy: PreTrainedConfig | None = None
|
|
|
|
# Robot configuration
|
|
robot: RobotConfig | None = None
|
|
|
|
# RTC configuration
|
|
rtc: RTCConfig = field(
|
|
default_factory=lambda: RTCConfig(
|
|
execution_horizon=10,
|
|
max_guidance_weight=1.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
)
|
|
)
|
|
|
|
# Demo parameters
|
|
duration: float = 30.0 # Duration to run the demo (seconds)
|
|
fps: float = 10.0 # Action execution frequency (Hz)
|
|
|
|
# Compute device
|
|
device: str | None = None # Device to run on (cuda, cpu, auto)
|
|
|
|
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
|
# It should be higher than inference delay + execution horizon.
|
|
action_queue_size_to_get_new_actions: int = 30
|
|
|
|
# Task to execute
|
|
task: str = field(default="", metadata={"help": "Task to execute"})
|
|
|
|
# Torch compile configuration
|
|
use_torch_compile: bool = field(
|
|
default=False,
|
|
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
|
)
|
|
|
|
torch_compile_backend: str = field(
|
|
default="inductor",
|
|
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
|
)
|
|
|
|
torch_compile_mode: str = field(
|
|
default="default",
|
|
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
|
)
|
|
|
|
torch_compile_disable_cudagraphs: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
|
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
else:
|
|
raise ValueError("Policy path is required")
|
|
|
|
# Validate that robot configuration is provided
|
|
if self.robot is None:
|
|
raise ValueError("Robot configuration must be provided")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
|
|
def is_image_key(k: str) -> bool:
|
|
return k.startswith(OBS_IMAGES)
|
|
|
|
|
|
def get_actions(
|
|
policy,
|
|
robot: RobotWrapper,
|
|
robot_observation_processor,
|
|
action_queue: ActionQueue,
|
|
shutdown_event: Event,
|
|
cfg: RTCDemoConfig,
|
|
):
|
|
"""Thread function to request action chunks from the policy.
|
|
|
|
Args:
|
|
policy: The policy instance (SmolVLA, Pi0, etc.)
|
|
robot: The robot instance for getting observations
|
|
robot_observation_processor: Processor for raw robot observations
|
|
action_queue: Queue to put new action chunks
|
|
shutdown_event: Event to signal shutdown
|
|
cfg: Demo configuration
|
|
"""
|
|
try:
|
|
logger.info("[GET_ACTIONS] Starting get actions thread")
|
|
|
|
latency_tracker = LatencyTracker() # Track latency of action chunks
|
|
fps = cfg.fps
|
|
time_per_chunk = 1.0 / fps
|
|
|
|
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
|
policy_device = policy.config.device
|
|
|
|
# Load preprocessor and postprocessor from pretrained files
|
|
# The stats are embedded in the processor .safetensors files
|
|
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
|
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg.policy,
|
|
pretrained_path=cfg.policy.pretrained_path,
|
|
dataset_stats=None, # Will load from pretrained processor files
|
|
preprocessor_overrides={
|
|
"device_processor": {"device": cfg.policy.device},
|
|
},
|
|
)
|
|
|
|
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
|
|
|
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
|
|
|
if not cfg.rtc.enabled:
|
|
get_actions_threshold = 0
|
|
|
|
while not shutdown_event.is_set():
|
|
if action_queue.qsize() <= get_actions_threshold:
|
|
current_time = time.perf_counter()
|
|
action_index_before_inference = action_queue.get_action_index()
|
|
prev_actions = action_queue.get_left_over()
|
|
|
|
inference_latency = latency_tracker.max()
|
|
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
|
|
|
obs = robot.get_observation()
|
|
|
|
# Apply robot observation processor
|
|
obs_processed = robot_observation_processor(obs)
|
|
|
|
obs_with_policy_features = build_dataset_frame(
|
|
dataset_features, obs_processed, prefix="observation"
|
|
)
|
|
|
|
for name in obs_with_policy_features:
|
|
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
|
if "image" in name:
|
|
obs_with_policy_features[name] = (
|
|
obs_with_policy_features[name].type(torch.float32) / 255
|
|
)
|
|
obs_with_policy_features[name] = (
|
|
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
|
)
|
|
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
|
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
|
|
|
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
|
obs_with_policy_features["robot_type"] = (
|
|
robot.robot.name if hasattr(robot.robot, "name") else ""
|
|
)
|
|
|
|
preproceseded_obs = preprocessor(obs_with_policy_features)
|
|
|
|
# Generate actions WITH RTC
|
|
actions = policy.predict_action_chunk(
|
|
preproceseded_obs,
|
|
inference_delay=inference_delay,
|
|
prev_chunk_left_over=prev_actions,
|
|
)
|
|
|
|
# Store original actions (before postprocessing) for RTC
|
|
original_actions = actions.squeeze(0).clone()
|
|
|
|
postprocessed_actions = postprocessor(actions)
|
|
|
|
postprocessed_actions = postprocessed_actions.squeeze(0)
|
|
|
|
new_latency = time.perf_counter() - current_time
|
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
latency_tracker.add(new_latency)
|
|
|
|
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
|
logger.warning(
|
|
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
|
)
|
|
|
|
action_queue.merge(
|
|
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
|
)
|
|
else:
|
|
# Small sleep to prevent busy waiting
|
|
time.sleep(0.1)
|
|
|
|
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
|
except Exception as e:
|
|
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
|
logger.error(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
|
|
def actor_control(
|
|
robot: RobotWrapper,
|
|
robot_action_processor,
|
|
action_queue: ActionQueue,
|
|
shutdown_event: Event,
|
|
cfg: RTCDemoConfig,
|
|
):
|
|
"""Thread function to execute actions on the robot.
|
|
|
|
Args:
|
|
robot: The robot instance
|
|
action_queue: Queue to get actions from
|
|
shutdown_event: Event to signal shutdown
|
|
cfg: Demo configuration
|
|
"""
|
|
try:
|
|
logger.info("[ACTOR] Starting actor thread")
|
|
|
|
action_count = 0
|
|
action_interval = 1.0 / cfg.fps
|
|
|
|
while not shutdown_event.is_set():
|
|
start_time = time.perf_counter()
|
|
|
|
# Try to get an action from the queue with timeout
|
|
action = action_queue.get()
|
|
|
|
if action is not None:
|
|
action = action.cpu()
|
|
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
|
action_processed = robot_action_processor((action_dict, None))
|
|
robot.send_action(action_processed)
|
|
|
|
action_count += 1
|
|
|
|
dt_s = time.perf_counter() - start_time
|
|
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
|
|
|
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
|
except Exception as e:
|
|
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
|
logger.error(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
|
|
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
|
"""Apply torch.compile to the policy's predict_action_chunk method.
|
|
|
|
Args:
|
|
policy: Policy instance to compile
|
|
cfg: Configuration containing torch compile settings
|
|
|
|
Returns:
|
|
Policy with compiled predict_action_chunk method
|
|
"""
|
|
|
|
# PI models handle their own compilation
|
|
if policy.type == "pi05" or policy.type == "pi0":
|
|
return policy
|
|
|
|
try:
|
|
# Check if torch.compile is available (PyTorch 2.0+)
|
|
if not hasattr(torch, "compile"):
|
|
logger.warning(
|
|
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
|
f"Current version: {torch.__version__}. Skipping compilation."
|
|
)
|
|
return policy
|
|
|
|
logger.info("Applying torch.compile to predict_action_chunk...")
|
|
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
|
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
|
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
|
|
|
# Compile the predict_action_chunk method
|
|
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
|
compile_kwargs = {
|
|
"backend": cfg.torch_compile_backend,
|
|
"mode": cfg.torch_compile_mode,
|
|
}
|
|
|
|
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
|
if cfg.torch_compile_disable_cudagraphs:
|
|
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
|
|
|
original_method = policy.predict_action_chunk
|
|
compiled_method = torch.compile(original_method, **compile_kwargs)
|
|
policy.predict_action_chunk = compiled_method
|
|
logger.info("✓ Successfully compiled predict_action_chunk")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to apply torch.compile: {e}")
|
|
logger.warning("Continuing without torch.compile")
|
|
|
|
return policy
|
|
|
|
|
|
@parser.wrap()
|
|
def demo_cli(cfg: RTCDemoConfig):
|
|
"""Main entry point for RTC demo with draccus configuration."""
|
|
|
|
# Initialize logging
|
|
init_logging()
|
|
|
|
logger.info(f"Using device: {cfg.device}")
|
|
|
|
# Setup signal handler for graceful shutdown
|
|
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
|
shutdown_event = signal_handler.shutdown_event
|
|
|
|
policy = None
|
|
robot = None
|
|
get_actions_thread = None
|
|
actor_thread = None
|
|
|
|
policy_class = get_policy_class(cfg.policy.type)
|
|
|
|
# Load config and set compile_model for pi0/pi05 models
|
|
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
|
|
|
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
|
config.compile_model = cfg.use_torch_compile
|
|
|
|
if config.use_peft:
|
|
from peft import PeftConfig, PeftModel
|
|
|
|
peft_pretrained_path = cfg.policy.pretrained_path
|
|
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
|
|
|
policy = policy_class.from_pretrained(
|
|
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
|
)
|
|
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
|
else:
|
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
|
|
|
# Turn on RTC
|
|
policy.config.rtc_config = cfg.rtc
|
|
|
|
# Init RTC processort, as by default if RTC disabled in the config
|
|
# The processor won't be created
|
|
policy.init_rtc_processor()
|
|
|
|
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
|
|
|
policy = policy.to(cfg.device)
|
|
policy.eval()
|
|
|
|
# Apply torch.compile to predict_action_chunk method if enabled
|
|
if cfg.use_torch_compile:
|
|
policy = _apply_torch_compile(policy, cfg)
|
|
|
|
# Create robot
|
|
logger.info(f"Initializing robot: {cfg.robot.type}")
|
|
robot = make_robot_from_config(cfg.robot)
|
|
robot.connect()
|
|
robot_wrapper = RobotWrapper(robot)
|
|
|
|
# Create robot observation processor
|
|
robot_observation_processor = make_default_robot_observation_processor()
|
|
robot_action_processor = make_default_robot_action_processor()
|
|
|
|
# Create action queue for communication between threads
|
|
action_queue = ActionQueue(cfg.rtc)
|
|
|
|
# Start chunk requester thread
|
|
get_actions_thread = Thread(
|
|
target=get_actions,
|
|
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
|
daemon=True,
|
|
name="GetActions",
|
|
)
|
|
get_actions_thread.start()
|
|
logger.info("Started get actions thread")
|
|
|
|
# Start action executor thread
|
|
actor_thread = Thread(
|
|
target=actor_control,
|
|
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
|
daemon=True,
|
|
name="Actor",
|
|
)
|
|
actor_thread.start()
|
|
logger.info("Started actor thread")
|
|
|
|
logger.info("Started stop by duration thread")
|
|
|
|
# Main thread monitors for duration or shutdown
|
|
logger.info(f"Running demo for {cfg.duration} seconds...")
|
|
start_time = time.time()
|
|
|
|
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
|
time.sleep(10)
|
|
|
|
# Log queue status periodically
|
|
if int(time.time() - start_time) % 5 == 0:
|
|
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
|
|
|
if time.time() - start_time > cfg.duration:
|
|
break
|
|
|
|
logger.info("Demo duration reached or shutdown requested")
|
|
|
|
# Signal shutdown
|
|
shutdown_event.set()
|
|
|
|
# Wait for threads to finish
|
|
if get_actions_thread and get_actions_thread.is_alive():
|
|
logger.info("Waiting for chunk requester thread to finish...")
|
|
get_actions_thread.join()
|
|
|
|
if actor_thread and actor_thread.is_alive():
|
|
logger.info("Waiting for action executor thread to finish...")
|
|
actor_thread.join()
|
|
|
|
# Cleanup robot
|
|
if robot:
|
|
robot.disconnect()
|
|
logger.info("Robot disconnected")
|
|
|
|
logger.info("Cleanup completed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo_cli()
|
|
logging.info("RTC demo finished")
|