mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
694 lines
25 KiB
Python
694 lines
25 KiB
Python
#!/usr/bin/env python
|
|
|
|
"""
|
|
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies.
|
|
|
|
This script demonstrates:
|
|
1. Creating a robot/environment and policy (SmolVLA, Pi0, etc.) with RTC
|
|
2. Consuming actions from the policy while the robot/environment executes
|
|
3. Periodically requesting new action chunks in the background using threads
|
|
4. Managing action buffers and timing for real-time operation
|
|
|
|
Usage:
|
|
# With real robot
|
|
python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100
|
|
|
|
# With simulation environment
|
|
python rtc_demo.py --policy.path=lerobot/smolvla_base --env.type=pusht
|
|
|
|
# With config file
|
|
python rtc_demo.py --config_path=path/to/config.json
|
|
|
|
# With policy compilation for faster inference (recommended for production)
|
|
python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100 --compile_policy=true
|
|
|
|
# With aggressive compilation for maximum speed
|
|
python rtc_demo.py --policy.path=lerobot/smolvla_base --robot.type=so100 --compile_policy=true --compile_mode=max-autotune
|
|
|
|
Performance Notes:
|
|
- torch.compile() is NOT supported on MPS (Apple Silicon) due to attention operation limitations
|
|
- For MPS optimization, reduce num_steps in the policy config (biggest speedup)
|
|
- CUDA devices will see 2-5x speedup with compilation enabled
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from threading import Event, Lock, Thread
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.configs.types import RTCAttentionSchedule
|
|
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
|
from lerobot.envs.configs import EnvConfig # noqa: F401
|
|
from lerobot.envs.factory import make_env
|
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
|
from lerobot.policies.rtc.action_queue import ActionQueue
|
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
|
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
|
from lerobot.processor.factory import (
|
|
make_default_robot_action_processor,
|
|
make_default_robot_observation_processor,
|
|
)
|
|
from lerobot.rl.process import ProcessSignalHandler
|
|
from lerobot.robots import ( # noqa: F401
|
|
Robot,
|
|
RobotConfig,
|
|
koch_follower,
|
|
so100_follower,
|
|
so101_follower,
|
|
)
|
|
from lerobot.robots.utils import make_robot_from_config
|
|
from lerobot.utils.constants import OBS_IMAGES
|
|
from lerobot.utils.hub import HubMixin
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RobotWrapper:
|
|
def __init__(self, robot: Robot):
|
|
self.robot = robot
|
|
self.lock = Lock()
|
|
|
|
def get_observation(self) -> dict[str, Tensor]:
|
|
with self.lock:
|
|
return self.robot.get_observation()
|
|
|
|
def send_action(self, action: Tensor):
|
|
with self.lock:
|
|
self.robot.send_action(action)
|
|
|
|
def observation_features(self) -> list[str]:
|
|
with self.lock:
|
|
return self.robot.observation_features
|
|
|
|
def action_features(self) -> list[str]:
|
|
with self.lock:
|
|
return self.robot.action_features
|
|
|
|
|
|
class EnvWrapper:
|
|
"""Wrapper for gym environments to provide same interface as RobotWrapper."""
|
|
|
|
def __init__(self, env, env_cfg: EnvConfig):
|
|
self.env = env
|
|
self.env_cfg = env_cfg
|
|
self.lock = Lock()
|
|
self._last_obs = None
|
|
self._episode_count = 0
|
|
self._step_count = 0
|
|
|
|
# Initialize environment
|
|
obs, _ = self.env.reset()
|
|
self._last_obs = (
|
|
obs[0]
|
|
if isinstance(obs, tuple)
|
|
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
|
else obs
|
|
)
|
|
|
|
# Cache feature names
|
|
self._observation_features = None
|
|
self._action_features = None
|
|
|
|
def get_observation(self) -> dict[str, np.ndarray]:
|
|
"""Get current observation from environment.
|
|
|
|
Returns observations in the same format as robot.get_observation():
|
|
a dict mapping feature names to numpy arrays.
|
|
"""
|
|
with self.lock:
|
|
if self._last_obs is None:
|
|
# Reset environment on first observation
|
|
obs, _ = self.env.reset()
|
|
self._last_obs = (
|
|
obs[0]
|
|
if isinstance(obs, tuple)
|
|
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
|
else obs
|
|
)
|
|
|
|
# VectorEnv returns observations as numpy arrays in a batch
|
|
# Extract first element if it's a vectorized observation
|
|
obs = self._last_obs
|
|
if isinstance(obs, dict):
|
|
# Handle dict observations (extract first element from batch if needed)
|
|
result = {}
|
|
for key, value in obs.items():
|
|
if isinstance(value, np.ndarray) and len(value.shape) > 0 and value.shape[0] == 1:
|
|
# Remove batch dimension for single env
|
|
result[key] = value[0]
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
else:
|
|
# Handle array observations - shouldn't happen with our configs but handle it
|
|
return {"observation": obs[0] if len(obs.shape) > 1 else obs}
|
|
|
|
def send_action(self, action: dict):
|
|
"""Execute action in environment and update observation."""
|
|
with self.lock:
|
|
# Convert action dict to array based on action_features
|
|
action_list = []
|
|
for feature_name in self.action_features():
|
|
if feature_name in action:
|
|
action_list.append(action[feature_name])
|
|
|
|
action_array = np.array(action_list)
|
|
|
|
# VectorEnv expects actions with batch dimension
|
|
action_batch = action_array.reshape(1, -1)
|
|
|
|
# Step environment
|
|
obs, _reward, terminated, truncated, _info = self.env.step(action_batch)
|
|
|
|
# Extract from batch
|
|
self._last_obs = (
|
|
obs[0]
|
|
if isinstance(obs, tuple)
|
|
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
|
else obs
|
|
)
|
|
self._step_count += 1
|
|
|
|
# Check if episode is done (handle vectorized env format)
|
|
is_done = terminated[0] if isinstance(terminated, (np.ndarray, list)) else terminated
|
|
is_truncated = truncated[0] if isinstance(truncated, (np.ndarray, list)) else truncated
|
|
|
|
# Reset if episode is done
|
|
if is_done or is_truncated:
|
|
logger.info(f"Episode {self._episode_count} finished after {self._step_count} steps")
|
|
obs, _ = self.env.reset()
|
|
self._last_obs = (
|
|
obs[0]
|
|
if isinstance(obs, tuple)
|
|
or (hasattr(obs, "__getitem__") and len(obs) > 0 and not isinstance(obs, dict))
|
|
else obs
|
|
)
|
|
self._episode_count += 1
|
|
self._step_count = 0
|
|
|
|
def observation_features(self) -> list[str]:
|
|
"""Get observation feature names from environment config."""
|
|
if self._observation_features is not None:
|
|
return self._observation_features
|
|
|
|
with self.lock:
|
|
features = []
|
|
for feature_name in self.env_cfg.features:
|
|
if feature_name != "action":
|
|
# Use the mapped name from features_map
|
|
mapped_name = self.env_cfg.features_map.get(feature_name, feature_name)
|
|
features.append(mapped_name)
|
|
|
|
self._observation_features = features
|
|
return features
|
|
|
|
def action_features(self) -> list[str]:
|
|
"""Get action feature names from environment config."""
|
|
if self._action_features is not None:
|
|
return self._action_features
|
|
|
|
with self.lock:
|
|
# Return action dimension names
|
|
action_dim = self.env_cfg.features["action"].shape[0]
|
|
self._action_features = [f"action_{i}" for i in range(action_dim)]
|
|
return self._action_features
|
|
|
|
|
|
@dataclass
|
|
class RTCDemoConfig(HubMixin):
|
|
"""Configuration for RTC demo with action chunking policies."""
|
|
|
|
# Policy configuration
|
|
policy: PreTrainedConfig | None = None
|
|
|
|
# Robot configuration (mutually exclusive with env)
|
|
robot: RobotConfig | None = None
|
|
|
|
# Environment configuration (mutually exclusive with robot)
|
|
env: EnvConfig | None = None
|
|
|
|
# RTC configuration
|
|
rtc: RTCConfig = field(
|
|
default_factory=lambda: RTCConfig(
|
|
execution_horizon=10,
|
|
max_guidance_weight=1.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
)
|
|
)
|
|
|
|
# Demo parameters
|
|
duration: float = 30.0 # Duration to run the demo (seconds)
|
|
fps: float = 10.0 # Action execution frequency (Hz)
|
|
|
|
# Compute device
|
|
device: str | None = None # Device to run on (cuda, cpu, auto)
|
|
|
|
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
|
# It should be higher than inference delay + execution horizon.
|
|
action_queue_size_to_get_new_actions: int = 30
|
|
|
|
# Task to execute
|
|
task: str = field(default="", metadata={"help": "Task to execute"})
|
|
|
|
# Torch compile configuration
|
|
use_torch_compile: bool = field(
|
|
default=False,
|
|
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
|
)
|
|
|
|
torch_compile_backend: str = field(
|
|
default="inductor",
|
|
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
|
)
|
|
|
|
torch_compile_mode: str = field(
|
|
default="default",
|
|
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
|
)
|
|
|
|
torch_compile_disable_cudagraphs: bool = field(
|
|
default=True,
|
|
metadata={
|
|
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
|
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
else:
|
|
raise ValueError("Policy path is required")
|
|
|
|
# Validate that either robot or env is provided, but not both
|
|
if self.robot is None and self.env is None:
|
|
raise ValueError("Either robot or env configuration must be provided")
|
|
if self.robot is not None and self.env is not None:
|
|
raise ValueError("Cannot specify both robot and env configuration. Choose one.")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
|
|
def is_image_key(k: str) -> bool:
|
|
return k.startswith(OBS_IMAGES)
|
|
|
|
|
|
def get_actions(
|
|
policy,
|
|
robot: RobotWrapper,
|
|
robot_observation_processor,
|
|
action_queue: ActionQueue,
|
|
shutdown_event: Event,
|
|
cfg: RTCDemoConfig,
|
|
):
|
|
"""Thread function to request action chunks from the policy.
|
|
|
|
Args:
|
|
policy: The policy instance (SmolVLA, Pi0, etc.)
|
|
robot: The robot instance for getting observations
|
|
robot_observation_processor: Processor for raw robot observations
|
|
action_queue: Queue to put new action chunks
|
|
shutdown_event: Event to signal shutdown
|
|
cfg: Demo configuration
|
|
"""
|
|
try:
|
|
logger.info("[GET_ACTIONS] Starting get actions thread")
|
|
|
|
latency_tracker = LatencyTracker() # Track latency of action chunks
|
|
fps = cfg.fps
|
|
time_per_chunk = 1.0 / fps
|
|
|
|
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
|
policy_device = policy.config.device
|
|
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg.policy,
|
|
pretrained_path=cfg.policy.pretrained_path,
|
|
preprocessor_overrides={
|
|
"device_processor": {"device": cfg.policy.device},
|
|
},
|
|
)
|
|
|
|
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
|
|
|
if not cfg.rtc.enabled:
|
|
get_actions_threshold = 0
|
|
|
|
while not shutdown_event.is_set():
|
|
if action_queue.qsize() <= get_actions_threshold:
|
|
current_time = time.perf_counter()
|
|
action_index_before_inference = action_queue.get_action_index()
|
|
prev_actions = action_queue.get_left_over()
|
|
|
|
inference_latency = latency_tracker.max()
|
|
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
|
|
|
obs = robot.get_observation()
|
|
|
|
# Apply robot observation processor
|
|
obs_processed = robot_observation_processor(obs)
|
|
|
|
obs_with_policy_features = build_dataset_frame(
|
|
dataset_features, obs_processed, prefix="observation"
|
|
)
|
|
|
|
for name in obs_with_policy_features:
|
|
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
|
if "image" in name:
|
|
obs_with_policy_features[name] = (
|
|
obs_with_policy_features[name].type(torch.float32) / 255
|
|
)
|
|
obs_with_policy_features[name] = (
|
|
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
|
)
|
|
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
|
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
|
|
|
# for k, v in obs_with_policy_features.items():
|
|
# if isinstance(v, np.ndarray):
|
|
# obs_with_policy_features[k] = torch.from_numpy(v).to(policy_device)
|
|
|
|
# if is_image_key(k):
|
|
# obs_with_policy_features[k] = obs_with_policy_features[k].type(torch.float32) / 255
|
|
# obs_with_policy_features[k] = obs_with_policy_features[k].permute(2, 0, 1).unsqueeze(0)
|
|
# elif isinstance(obs_with_policy_features[k], torch.Tensor):
|
|
# obs_with_policy_features[k] = obs_with_policy_features[k].unsqueeze(0)
|
|
|
|
obs_with_policy_features["task"] = cfg.task
|
|
|
|
preproceseded_obs = preprocessor(obs_with_policy_features)
|
|
|
|
# Generate actions WITH RTC
|
|
actions = policy.predict_action_chunk(
|
|
preproceseded_obs,
|
|
inference_delay=inference_delay,
|
|
prev_chunk_left_over=prev_actions,
|
|
)
|
|
|
|
# Store original actions (before postprocessing) for RTC
|
|
original_actions = actions.squeeze(0).clone()
|
|
|
|
postprocessed_actions = postprocessor(actions)
|
|
|
|
postprocessed_actions = postprocessed_actions.squeeze(0)
|
|
|
|
new_latency = time.perf_counter() - current_time
|
|
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
latency_tracker.add(new_latency)
|
|
|
|
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
|
logger.warning(
|
|
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
|
)
|
|
|
|
action_queue.merge(
|
|
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
|
)
|
|
else:
|
|
# Small sleep to prevent busy waiting
|
|
time.sleep(0.1)
|
|
|
|
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
|
except Exception as e:
|
|
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
|
logger.error(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
|
|
def actor_control(
|
|
robot: RobotWrapper,
|
|
robot_action_processor,
|
|
action_queue: ActionQueue,
|
|
shutdown_event: Event,
|
|
cfg: RTCDemoConfig,
|
|
):
|
|
"""Thread function to execute actions on the robot.
|
|
|
|
Args:
|
|
robot: The robot instance
|
|
action_queue: Queue to get actions from
|
|
shutdown_event: Event to signal shutdown
|
|
cfg: Demo configuration
|
|
"""
|
|
try:
|
|
logger.info("[ACTOR] Starting actor thread")
|
|
|
|
action_count = 0
|
|
action_interval = 1.0 / cfg.fps
|
|
|
|
while not shutdown_event.is_set():
|
|
start_time = time.perf_counter()
|
|
|
|
# Try to get an action from the queue with timeout
|
|
action = action_queue.get()
|
|
|
|
if action is not None:
|
|
action = action.cpu()
|
|
action = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
|
action = robot_action_processor((action, None))
|
|
robot.send_action(action)
|
|
|
|
action_count += 1
|
|
|
|
dt_s = time.perf_counter() - start_time
|
|
time.sleep((action_interval - dt_s) - 0.001)
|
|
|
|
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
|
except Exception as e:
|
|
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
|
logger.error(traceback.format_exc())
|
|
sys.exit(1)
|
|
|
|
|
|
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
|
"""Apply torch.compile to the policy's predict_action_chunk method.
|
|
|
|
Args:
|
|
policy: Policy instance to compile
|
|
cfg: Configuration containing torch compile settings
|
|
|
|
Returns:
|
|
Policy with compiled predict_action_chunk method
|
|
"""
|
|
|
|
# PI models handle their own compilation
|
|
if policy.type == "pi05" or policy.type == "pi0":
|
|
return policy
|
|
|
|
try:
|
|
# Check if torch.compile is available (PyTorch 2.0+)
|
|
if not hasattr(torch, "compile"):
|
|
logger.warning(
|
|
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
|
f"Current version: {torch.__version__}. Skipping compilation."
|
|
)
|
|
return policy
|
|
|
|
logger.info("Applying torch.compile to predict_action_chunk...")
|
|
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
|
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
|
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
|
|
|
# Compile the predict_action_chunk method
|
|
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
|
compile_kwargs = {
|
|
"backend": cfg.torch_compile_backend,
|
|
"mode": cfg.torch_compile_mode,
|
|
}
|
|
|
|
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
|
if cfg.torch_compile_disable_cudagraphs:
|
|
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
|
|
|
original_method = policy.predict_action_chunk
|
|
compiled_method = torch.compile(original_method, **compile_kwargs)
|
|
policy.predict_action_chunk = compiled_method
|
|
logger.info("✓ Successfully compiled predict_action_chunk")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to apply torch.compile: {e}")
|
|
logger.warning("Continuing without torch.compile")
|
|
|
|
return policy
|
|
|
|
|
|
@parser.wrap()
|
|
def demo_cli(cfg: RTCDemoConfig):
|
|
"""Main entry point for RTC demo with draccus configuration."""
|
|
|
|
# Initialize logging
|
|
init_logging()
|
|
|
|
logger.info(f"Using device: {cfg.device}")
|
|
|
|
# Setup signal handler for graceful shutdown
|
|
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
|
shutdown_event = signal_handler.shutdown_event
|
|
|
|
policy = None
|
|
robot = None
|
|
vec_env = None
|
|
get_actions_thread = None
|
|
actor_thread = None
|
|
|
|
policy_class = get_policy_class(cfg.policy.type)
|
|
|
|
# Load config and set compile_model for pi0/pi05 models
|
|
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
|
|
|
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
|
config.compile_model = cfg.use_torch_compile
|
|
|
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
|
|
|
# Turn on RTC
|
|
policy.config.rtc_config = cfg.rtc
|
|
|
|
# Init RTC processort, as by default if RTC disabled in the config
|
|
# The processor won't be created
|
|
policy.init_rtc_processor()
|
|
|
|
assert policy.name in ["smolvla"], "Only smolvla are supported for RTC"
|
|
|
|
policy = policy.to(cfg.device)
|
|
policy.eval()
|
|
|
|
# Apply torch.compile to predict_action_chunk method if enabled
|
|
if cfg.use_torch_compile:
|
|
policy = _apply_torch_compile(policy, cfg)
|
|
|
|
# Create robot or environment
|
|
if cfg.robot is not None:
|
|
logger.info(f"Initializing robot: {cfg.robot.type}")
|
|
robot = make_robot_from_config(cfg.robot)
|
|
robot.connect()
|
|
agent_wrapper = RobotWrapper(robot)
|
|
else:
|
|
logger.info(f"Initializing environment: {cfg.env.type}")
|
|
# Create environment using make_env
|
|
env_dict = make_env(cfg.env, n_envs=1, use_async_envs=False)
|
|
|
|
# Validate environment structure: should have exactly one suite
|
|
if len(env_dict) != 1:
|
|
raise ValueError(
|
|
f"Expected exactly one environment suite, but got {len(env_dict)}. "
|
|
f"Suites: {list(env_dict.keys())}"
|
|
)
|
|
|
|
# Extract the actual env from the dict structure {suite: {task_id: vec_env}}
|
|
suite_name = list(env_dict.keys())[0]
|
|
task_dict = env_dict[suite_name]
|
|
|
|
# Validate task structure: should have exactly one task
|
|
if len(task_dict) != 1:
|
|
raise ValueError(
|
|
f"Expected exactly one task in suite '{suite_name}', but got {len(task_dict)}. "
|
|
f"Tasks: {list(task_dict.keys())}"
|
|
)
|
|
|
|
vec_env = task_dict[0]
|
|
logger.info(f"Created environment: suite='{suite_name}', task_id=0, num_envs={vec_env.num_envs}")
|
|
|
|
# Validate that we have exactly 1 parallel environment
|
|
if vec_env.num_envs != 1:
|
|
raise ValueError(
|
|
f"Expected exactly 1 parallel environment, but got {vec_env.num_envs}. "
|
|
f"The EnvWrapper is designed for single environment instances."
|
|
)
|
|
|
|
agent_wrapper = EnvWrapper(vec_env, cfg.env)
|
|
|
|
# Create robot observation processor
|
|
robot_observation_processor = make_default_robot_observation_processor()
|
|
robot_action_processor = make_default_robot_action_processor()
|
|
|
|
# Create action queue for communication between threads
|
|
action_queue = ActionQueue(cfg.rtc)
|
|
|
|
# Start chunk requester thread
|
|
get_actions_thread = Thread(
|
|
target=get_actions,
|
|
args=(policy, agent_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
|
daemon=True,
|
|
name="GetActions",
|
|
)
|
|
get_actions_thread.start()
|
|
logger.info("Started get actions thread")
|
|
|
|
# Start action executor thread
|
|
actor_thread = Thread(
|
|
target=actor_control,
|
|
args=(agent_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
|
daemon=True,
|
|
name="Actor",
|
|
)
|
|
actor_thread.start()
|
|
logger.info("Started actor thread")
|
|
|
|
logger.info("Started stop by duration thread")
|
|
|
|
# Main thread monitors for duration or shutdown
|
|
logger.info(f"Running demo for {cfg.duration} seconds...")
|
|
start_time = time.time()
|
|
|
|
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
|
time.sleep(10)
|
|
|
|
# Log queue status periodically
|
|
if int(time.time() - start_time) % 5 == 0:
|
|
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
|
|
|
if time.time() - start_time > cfg.duration:
|
|
break
|
|
|
|
logger.info("Demo duration reached or shutdown requested")
|
|
|
|
# Signal shutdown
|
|
shutdown_event.set()
|
|
|
|
# Wait for threads to finish
|
|
if get_actions_thread and get_actions_thread.is_alive():
|
|
logger.info("Waiting for chunk requester thread to finish...")
|
|
get_actions_thread.join()
|
|
|
|
if actor_thread and actor_thread.is_alive():
|
|
logger.info("Waiting for action executor thread to finish...")
|
|
actor_thread.join()
|
|
|
|
# Cleanup robot or environment
|
|
if cfg.robot is not None:
|
|
if robot:
|
|
robot.disconnect()
|
|
logger.info("Robot disconnected")
|
|
else:
|
|
# Close environment
|
|
if vec_env:
|
|
vec_env.close()
|
|
logger.info("Environment closed")
|
|
|
|
logger.info("Cleanup completed")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo_cli()
|
|
logging.info("RTC demo finished")
|