From 33f84fe0ecb178e62e8bfd78740f375486da91d5 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 9 Jan 2026 09:56:14 +0100 Subject: [PATCH] with rtc --- examples/openarms/evaluate_interpolation.py | 593 +++++++++++++------- 1 file changed, 382 insertions(+), 211 deletions(-) diff --git a/examples/openarms/evaluate_interpolation.py b/examples/openarms/evaluate_interpolation.py index c7a99a9fa..fd0276019 100644 --- a/examples/openarms/evaluate_interpolation.py +++ b/examples/openarms/evaluate_interpolation.py @@ -15,11 +15,11 @@ # limitations under the License. """ -OpenArms Policy Evaluation with Interpolation +OpenArms Policy Evaluation with RTC and Interpolation -Evaluates a trained policy with smooth action interpolation: -- Decoupled camera capture (CAMERA_FPS) from robot control (ROBOT_FPS) -- Speed multiplier to execute actions faster than training +Evaluates a trained policy with: +- RTC (Real-Time Chunking) for async inference - decouples policy from robot loop +- Smooth action interpolation for high-frequency robot control - Velocity feedforward for smoother tracking - Adjustable PID gains @@ -27,27 +27,41 @@ Example usage: python examples/openarms/evaluate_interpolation.py """ +import logging +import math +import sys import time +import traceback from collections import deque from pathlib import Path +from threading import Event, Lock, Thread import numpy as np +import torch +from torch import Tensor from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import RTCAttentionSchedule from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.utils import combine_feature_dicts -from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies.rtc.action_queue import ActionQueue +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.latency_tracker import LatencyTracker from lerobot.processor import make_default_processors from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig from lerobot.robots.openarms.openarms_follower import OpenArmsFollower from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader -from lerobot.utils.control_utils import init_keyboard_listener, predict_action -from lerobot.utils.utils import log_say, get_safe_torch_device +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + # ======================== MODEL & TASK CONFIG ======================== HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model @@ -57,81 +71,107 @@ TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task # ======================== TIMING CONFIG ======================== CAMERA_FPS = 30 # Camera hardware limit (fixed) POLICY_FPS = 30 # What the policy was trained with -SPEED_MULTIPLIER = 1.2 # Execute actions faster (1.0 = normal, 1.2 = 20% faster) ROBOT_FPS = 50 # Robot command rate (higher = smoother interpolation) -# Derived values -EFFECTIVE_POLICY_FPS = int(POLICY_FPS * SPEED_MULTIPLIER) # How fast we consume actions (36Hz at 1.2x) - NUM_EPISODES = 1 EPISODE_TIME_SEC = 300 RESET_TIME_SEC = 60 +# ======================== RTC CONFIG ======================== +RTC_ENABLED = True +RTC_EXECUTION_HORIZON = 20 +RTC_MAX_GUIDANCE_WEIGHT = 5.0 +ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS = 30 # Should be > inference_delay + execution_horizon + # ======================== PID TUNING ======================== -# Set to None to use robot config defaults -CUSTOM_KP_SCALE = 0.7 # Scale factor for position gain (0.5-1.0, lower = smoother) -CUSTOM_KD_SCALE = 1.3 # Scale factor for damping gain (1.0-2.0, higher = less overshoot) -USE_VELOCITY_FEEDFORWARD = True # Enable velocity feedforward for smoother tracking +CUSTOM_KP_SCALE = 1.0 # Scale factor for position gain (0.5-1.0, lower = smoother) +CUSTOM_KD_SCALE = 1.0 # Scale factor for damping gain (1.0-2.0, higher = less overshoot) +USE_VELOCITY_FEEDFORWARD = False # Enable velocity feedforward for smoother tracking # ======================== ROBOT CONFIG ======================== FOLLOWER_LEFT_PORT = "can0" FOLLOWER_RIGHT_PORT = "can1" -USE_LEADER_FOR_RESETS = True +USE_LEADER_FOR_RESETS = False LEADER_LEFT_PORT = "can2" LEADER_RIGHT_PORT = "can3" -# Camera config uses CAMERA_FPS (hardware limit) +DEVICE = "cuda" + CAMERA_CONFIG = { - "left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=CAMERA_FPS), - "right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=CAMERA_FPS), + "left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=1280, height=720, fps=CAMERA_FPS), + "right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=1280, height=720, fps=CAMERA_FPS), "base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=CAMERA_FPS), } +class RobotWrapper: + """Thread-safe wrapper for robot operations.""" + + def __init__(self, robot: OpenArmsFollower): + self.robot = robot + self.lock = Lock() + + def get_observation(self) -> dict[str, Tensor]: + with self.lock: + return self.robot.get_observation() + + def send_action(self, action: dict, **kwargs) -> None: + with self.lock: + self.robot.send_action(action, **kwargs) + + @property + def observation_features(self) -> dict: + with self.lock: + return self.robot.observation_features + + @property + def action_features(self) -> dict: + with self.lock: + return self.robot.action_features + + @property + def name(self) -> str: + return self.robot.name + + class ActionInterpolator: - """Interpolate between policy actions for smoother robot control with velocity estimation.""" + """Interpolate between consecutive actions for smoother robot control.""" - def __init__(self, effective_policy_fps: int, robot_fps: int): - self.effective_policy_fps = effective_policy_fps + def __init__(self, policy_fps: int, robot_fps: int): + self.policy_fps = policy_fps self.robot_fps = robot_fps - self.substeps_per_policy_step = robot_fps / effective_policy_fps - self.prev_action: dict | None = None - self.curr_action: dict | None = None + self.substeps_per_policy_step = robot_fps / policy_fps + self.prev_action: Tensor | None = None + self.curr_action: Tensor | None = None self.substep = 0 - self.last_interpolated: dict | None = None + self.last_interpolated: Tensor | None = None - def update(self, new_action: dict) -> None: + def update(self, new_action: Tensor) -> None: self.prev_action = self.curr_action self.curr_action = new_action self.substep = 0 - def get_interpolated_action(self) -> tuple[dict | None, dict | None]: - """Returns (interpolated_position, estimated_velocity_deg_per_sec)""" + def get_interpolated_action(self) -> tuple[Tensor | None, Tensor | None]: + """Returns (interpolated_action, estimated_velocity)""" if self.curr_action is None: return None, None if self.prev_action is None: - self.last_interpolated = self.curr_action.copy() - return self.curr_action, {k: 0.0 for k in self.curr_action} + self.last_interpolated = self.curr_action.clone() + return self.curr_action, torch.zeros_like(self.curr_action) t = min(self.substep / self.substeps_per_policy_step, 1.0) self.substep += 1 - interpolated = {} - velocity = {} + interpolated = self.prev_action * (1 - t) + self.curr_action * t + dt = 1.0 / self.robot_fps + if self.last_interpolated is not None: + velocity = (interpolated - self.last_interpolated) / dt + else: + velocity = (self.curr_action - self.prev_action) * self.policy_fps - for key in self.curr_action: - prev = self.prev_action.get(key, self.curr_action[key]) - curr = self.curr_action[key] - interpolated[key] = prev * (1 - t) + curr * t - - if self.last_interpolated is not None and key in self.last_interpolated: - velocity[key] = (interpolated[key] - self.last_interpolated[key]) / dt - else: - velocity[key] = (curr - prev) * self.effective_policy_fps - - self.last_interpolated = interpolated.copy() + self.last_interpolated = interpolated.clone() return interpolated, velocity def reset(self): @@ -175,160 +215,230 @@ class HzTracker: self.last_print_time = 0 -def interpolated_eval_loop( - robot, +def get_actions_thread( policy, - preprocessor, - postprocessor, + robot: RobotWrapper, robot_observation_processor, + action_queue: ActionQueue, + shutdown_event: Event, + episode_active: Event, + rtc_config: RTCConfig, + policy_fps: int, + task: str, + pretrained_path: str, + device: str, +): + """Thread function to asynchronously generate action chunks from the policy.""" + try: + logger.info("[GET_ACTIONS] Starting action generation thread") + + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / policy_fps + + hw_features = hw_to_dataset_features(robot.observation_features, "observation") + policy_device = device + + logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {pretrained_path}") + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy.config, + pretrained_path=pretrained_path, + dataset_stats=None, + preprocessor_overrides={"device_processor": {"device": device}}, + ) + + logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully") + + get_actions_threshold = ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS if rtc_config.enabled else 0 + + while not shutdown_event.is_set(): + if not episode_active.is_set(): + time.sleep(0.01) + continue + + if action_queue.qsize() <= get_actions_threshold: + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() + + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 + + obs = robot.get_observation() + obs_processed = robot_observation_processor(obs) + + obs_with_policy_features = build_dataset_frame( + hw_features, obs_processed, prefix="observation" + ) + + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = ( + obs_with_policy_features[name].type(torch.float32) / 255 + ) + obs_with_policy_features[name] = ( + obs_with_policy_features[name].permute(2, 0, 1).contiguous() + ) + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) + obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) + + obs_with_policy_features["task"] = [task] + obs_with_policy_features["robot_type"] = robot.name + + preprocessed_obs = preprocessor(obs_with_policy_features) + + actions = policy.predict_action_chunk( + preprocessed_obs, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + + original_actions = actions.squeeze(0).clone() + postprocessed_actions = postprocessor(actions).squeeze(0) + + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + + if ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS < rtc_config.execution_horizon + new_delay: + logger.warning( + "[GET_ACTIONS] action_queue_size_to_get_new_actions too small. " + "Should be higher than inference delay + execution horizon." + ) + + action_queue.merge( + original_actions, postprocessed_actions, new_delay, action_index_before_inference + ) + + logger.debug( + f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, " + f"delay={new_delay}, queue_size={action_queue.qsize()}" + ) + else: + time.sleep(0.01) + + logger.info("[GET_ACTIONS] Action generation thread shutting down") + except Exception as e: + logger.error(f"[GET_ACTIONS] Fatal exception: {e}") + logger.error(traceback.format_exc()) + shutdown_event.set() + sys.exit(1) + + +def actor_thread( + robot: RobotWrapper, robot_action_processor, - dataset, - events, + action_queue: ActionQueue, + shutdown_event: Event, + episode_active: Event, interpolator: ActionInterpolator, robot_hz_tracker: HzTracker, - camera_fps: int, - effective_policy_fps: int, robot_fps: int, - control_time_s: float, - task: str, - kp_scale: float | None = None, - kd_scale: float | None = None, - use_velocity_ff: bool = False, + action_keys: list[str], + custom_kp: dict | None, + custom_kd: dict | None, + use_velocity_ff: bool, ): - """ - Run evaluation with decoupled camera and robot control: - - Camera captures at camera_fps (hardware limit) - - Policy inference runs when new camera frame is available - - Actions are consumed at effective_policy_fps (sped up by SPEED_MULTIPLIER) - - Robot receives interpolated commands at robot_fps (smoothest) - """ - from lerobot.scripts.lerobot_record import build_dataset_frame, make_robot_action - from lerobot.utils.visualization_utils import log_rerun_data - - camera_dt = 1.0 / camera_fps - policy_dt = 1.0 / effective_policy_fps - robot_dt = 1.0 / robot_fps - - interpolator.reset() - robot_hz_tracker.reset() - policy.reset() - - # Build custom gains if scaling is enabled - custom_kp = None - custom_kd = None - if kp_scale is not None or kd_scale is not None: - custom_kp = {} - custom_kd = {} - for arm in ["right", "left"]: - bus = robot.bus_right if arm == "right" else robot.bus_left - for i, motor_name in enumerate(bus.motors): - full_name = f"{arm}_{motor_name}" - default_kp = robot.config.position_kp[i] if isinstance(robot.config.position_kp, list) else robot.config.position_kp - default_kd = robot.config.position_kd[i] if isinstance(robot.config.position_kd, list) else robot.config.position_kd - custom_kp[full_name] = default_kp * (kp_scale or 1.0) - custom_kd[full_name] = default_kd * (kd_scale or 1.0) - print(f"Custom gains: kp_scale={kp_scale}, kd_scale={kd_scale}") - - if use_velocity_ff: - print("Velocity feedforward: enabled") - - last_camera_time = -camera_dt - last_policy_action_time = -policy_dt - cached_observation = None - cached_robot_action = None - - start_time = time.perf_counter() - - print(f"\nStarting interpolated eval loop:") - print(f" Camera: {camera_fps}Hz | Policy actions consumed: {effective_policy_fps}Hz | Robot: {robot_fps}Hz") - - while time.perf_counter() - start_time < control_time_s: - if events["exit_early"] or events["stop_recording"]: - break + """Thread function to execute interpolated actions on the robot at high frequency.""" + try: + logger.info("[ACTOR] Starting actor thread") + + action_count = 0 + action_interval = 1.0 / robot_fps + + while not shutdown_event.is_set(): + if not episode_active.is_set(): + time.sleep(0.01) + continue + + start_time = time.perf_counter() - loop_start = time.perf_counter() - elapsed = loop_start - start_time - - # === CAMERA CAPTURE (at camera_fps, decoupled from robot) === - if elapsed - last_camera_time >= camera_dt: - obs = robot.get_observation() - obs_processed = robot_observation_processor(obs) - observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation") + # Get new action from queue and update interpolator + action = action_queue.get() + if action is not None: + interpolator.update(action.cpu()) - # Run policy inference with fresh observation - action_values = predict_action( - observation=observation_frame, - policy=policy, - device=get_safe_torch_device(policy.config.device), - preprocessor=preprocessor, - postprocessor=postprocessor, - use_amp=policy.config.use_amp, - task=task, - robot_type=robot.robot_type, - ) + # Get interpolated action for smooth control + smooth_action, velocity = interpolator.get_interpolated_action() - act_processed = make_robot_action(action_values, dataset.features) - cached_robot_action = robot_action_processor((act_processed, obs)) - cached_observation = (obs_processed, observation_frame, act_processed) - - last_camera_time = elapsed - - # === ACTION UPDATE (at effective_policy_fps, faster than camera if speed > 1) === - if elapsed - last_policy_action_time >= policy_dt and cached_robot_action is not None: - interpolator.update(cached_robot_action) - last_policy_action_time = elapsed - - # Save to dataset at effective policy rate - if dataset is not None and cached_observation is not None: - obs_processed, observation_frame, act_processed = cached_observation - action_frame = build_dataset_frame(dataset.features, act_processed, prefix="action") - frame = {**observation_frame, **action_frame, "task": task} - dataset.add_frame(frame) - log_rerun_data(observation=obs_processed, action=act_processed) - - # === ROBOT COMMAND (at robot_fps, highest rate for smoothness) === - smooth_action, velocity = interpolator.get_interpolated_action() - if smooth_action is not None: - vel_ff = velocity if use_velocity_ff else None - robot.send_action(smooth_action, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff) - - robot_hz_tracker.tick() - - # Maintain robot control rate - sleep_time = robot_dt - (time.perf_counter() - loop_start) - if sleep_time > 0: - time.sleep(sleep_time) + if smooth_action is not None: + action_dict = {} + for i, key in enumerate(action_keys): + if i < len(smooth_action): + action_dict[key] = smooth_action[i].item() + + action_processed = robot_action_processor((action_dict, None)) + + vel_ff = None + if use_velocity_ff and velocity is not None: + vel_ff = {} + for i, key in enumerate(action_keys): + if i < len(velocity): + motor_name = key.replace(".pos", "") + vel_ff[motor_name] = velocity[i].item() + + robot.send_action(action_processed, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff) + action_count += 1 + + robot_hz_tracker.tick() + + dt_s = time.perf_counter() - start_time + sleep_time = max(0, action_interval - dt_s - 0.001) + if sleep_time > 0: + time.sleep(sleep_time) + + logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}") + except Exception as e: + logger.error(f"[ACTOR] Fatal exception: {e}") + logger.error(traceback.format_exc()) + shutdown_event.set() + sys.exit(1) + + +def build_custom_gains(robot, kp_scale: float | None, kd_scale: float | None) -> tuple[dict | None, dict | None]: + """Build custom KP/KD gains for the robot.""" + if kp_scale is None and kd_scale is None: + return None, None - # Print final stats - robot_hz = robot_hz_tracker.get_avg_hz() - if robot_hz: - print(f"\nFinal average robot Hz: {robot_hz:.1f}") + custom_kp = {} + custom_kd = {} + for arm in ["right", "left"]: + bus = robot.robot.bus_right if arm == "right" else robot.robot.bus_left + for i, motor_name in enumerate(bus.motors): + full_name = f"{arm}_{motor_name}" + default_kp = robot.robot.config.position_kp[i] if isinstance(robot.robot.config.position_kp, list) else robot.robot.config.position_kp + default_kd = robot.robot.config.position_kd[i] if isinstance(robot.robot.config.position_kd, list) else robot.robot.config.position_kd + custom_kp[full_name] = default_kp * (kp_scale or 1.0) + custom_kd[full_name] = default_kd * (kd_scale or 1.0) + + return custom_kp, custom_kd def main(): - """Main evaluation function.""" + """Main evaluation function with RTC and interpolation.""" print("=" * 60) - print("OpenArms Policy Evaluation with Interpolation") + print("OpenArms Policy Evaluation with RTC + Interpolation") print("=" * 60) print(f"\nModel: {HF_MODEL_ID}") print(f"Dataset: {HF_EVAL_DATASET_ID}") print(f"Task: {TASK_DESCRIPTION}") print(f"\n--- Timing ---") - print(f"Camera FPS: {CAMERA_FPS} (hardware limit)") - print(f"Policy trained at: {POLICY_FPS}Hz") - print(f"Speed multiplier: {SPEED_MULTIPLIER}x") - print(f"Effective policy FPS: {EFFECTIVE_POLICY_FPS}Hz (actions consumed)") - print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated commands)") - print(f"\n--- PID Tuning ---") - print(f"KP scale: {CUSTOM_KP_SCALE}") - print(f"KD scale: {CUSTOM_KD_SCALE}") - print(f"Velocity feedforward: {USE_VELOCITY_FEEDFORWARD}") + print(f"Policy FPS: {POLICY_FPS}Hz") + print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated)") + print(f"\n--- RTC ---") + print(f"RTC Enabled: {RTC_ENABLED}") + print(f"Execution Horizon: {RTC_EXECUTION_HORIZON}") + print(f"Max Guidance Weight: {RTC_MAX_GUIDANCE_WEIGHT}") + print(f"\n--- PID ---") + print(f"KP scale: {CUSTOM_KP_SCALE}, KD scale: {CUSTOM_KD_SCALE}") + print(f"Velocity FF: {USE_VELOCITY_FEEDFORWARD}") print(f"\n--- Episodes ---") - print(f"Episodes: {NUM_EPISODES}") - print(f"Duration: {EPISODE_TIME_SEC}s per episode") - print(f"Reset time: {RESET_TIME_SEC}s") - print(f"Leader for resets: {USE_LEADER_FOR_RESETS}") + print(f"Episodes: {NUM_EPISODES}, Duration: {EPISODE_TIME_SEC}s") print("=" * 60) + + shutdown_event = Event() + episode_active = Event() follower_config = OpenArmsFollowerConfig( port_left=FOLLOWER_LEFT_PORT, @@ -346,6 +456,9 @@ def main(): if not follower.is_connected: raise RuntimeError("Follower robot failed to connect!") + robot = RobotWrapper(follower) + logger.info("Follower robot connected") + leader = None if USE_LEADER_FOR_RESETS: leader_config = OpenArmsLeaderConfig( @@ -366,9 +479,9 @@ def main(): leader.bus_right.enable_torque() leader.bus_left.enable_torque() time.sleep(0.1) - print(f"Leader connected with gravity compensation") + print("Leader connected with gravity compensation") else: - print(f"Leader connected (no gravity compensation)") + print("Leader connected (no gravity compensation)") teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() @@ -401,10 +514,9 @@ def main(): leader.disconnect() return - # Dataset uses effective policy FPS (sped up rate) dataset = LeRobotDataset.create( repo_id=HF_EVAL_DATASET_ID, - fps=EFFECTIVE_POLICY_FPS, + fps=POLICY_FPS, features=dataset_features, robot_type=follower.name, use_videos=True, @@ -412,53 +524,102 @@ def main(): image_writer_threads=12, ) + # Load policy with RTC support + logger.info(f"Loading policy from: {HF_MODEL_ID}") policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) policy_config.pretrained_path = HF_MODEL_ID - policy = make_policy(policy_config, ds_meta=dataset.meta) - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=policy.config, - pretrained_path=HF_MODEL_ID, - dataset_stats=dataset.meta.stats, - preprocessor_overrides={ - "device_processor": {"device": str(policy.config.device)} - }, + policy_class = get_policy_class(policy_config.type) + policy = policy_class.from_pretrained(HF_MODEL_ID, config=policy_config) + + rtc_config = RTCConfig( + enabled=RTC_ENABLED, + execution_horizon=RTC_EXECUTION_HORIZON, + max_guidance_weight=RTC_MAX_GUIDANCE_WEIGHT, + prefix_attention_schedule=RTCAttentionSchedule.EXP, ) + policy.config.rtc_config = rtc_config + policy.init_rtc_processor() + + assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 support RTC" + + policy = policy.to(DEVICE) + policy.eval() + + logger.info(f"Policy loaded: {policy.name}") print(f"\nRunning evaluation...") listener, events = init_keyboard_listener() - init_rerun(session_name="openarms_evaluation_interp") + init_rerun(session_name="openarms_eval_rtc_interp") - interpolator = ActionInterpolator(effective_policy_fps=EFFECTIVE_POLICY_FPS, robot_fps=ROBOT_FPS) - robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0) + action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] + custom_kp, custom_kd = build_custom_gains(robot, CUSTOM_KP_SCALE, CUSTOM_KD_SCALE) + if custom_kp: + print(f"Custom gains applied") + if USE_VELOCITY_FEEDFORWARD: + print("Velocity feedforward: enabled") + episode_idx = 0 + get_actions_t = None + actor_t = None try: while episode_idx < NUM_EPISODES and not events["stop_recording"]: log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}") print(f"\n--- Episode {episode_idx + 1}/{NUM_EPISODES} ---") - interpolated_eval_loop( - robot=follower, - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - robot_observation_processor=robot_observation_processor, - robot_action_processor=robot_action_processor, - dataset=dataset, - events=events, - interpolator=interpolator, - robot_hz_tracker=robot_hz_tracker, - camera_fps=CAMERA_FPS, - effective_policy_fps=EFFECTIVE_POLICY_FPS, - robot_fps=ROBOT_FPS, - control_time_s=EPISODE_TIME_SEC, - task=TASK_DESCRIPTION, - kp_scale=CUSTOM_KP_SCALE, - kd_scale=CUSTOM_KD_SCALE, - use_velocity_ff=USE_VELOCITY_FEEDFORWARD, + action_queue = ActionQueue(rtc_config) + interpolator = ActionInterpolator(policy_fps=POLICY_FPS, robot_fps=ROBOT_FPS) + robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0) + + get_actions_t = Thread( + target=get_actions_thread, + args=( + policy, robot, robot_observation_processor, action_queue, + shutdown_event, episode_active, rtc_config, POLICY_FPS, + TASK_DESCRIPTION, HF_MODEL_ID, DEVICE, + ), + daemon=True, + name="GetActions", ) + get_actions_t.start() + + actor_t = Thread( + target=actor_thread, + args=( + robot, robot_action_processor, action_queue, + shutdown_event, episode_active, interpolator, robot_hz_tracker, + ROBOT_FPS, action_keys, custom_kp, custom_kd, USE_VELOCITY_FEEDFORWARD, + ), + daemon=True, + name="Actor", + ) + actor_t.start() + + logger.info("Started inference and actor threads") + + episode_active.set() + episode_start_time = time.time() + + while (time.time() - episode_start_time) < EPISODE_TIME_SEC: + if events["exit_early"] or events["stop_recording"] or shutdown_event.is_set(): + break + + elapsed = time.time() - episode_start_time + if int(elapsed) % 10 == 0 and int(elapsed) > 0: + robot_hz = robot_hz_tracker.get_avg_hz() + logger.info( + f"Progress: {elapsed:.0f}/{EPISODE_TIME_SEC}s, " + f"queue={action_queue.qsize()}, hz={robot_hz:.1f if robot_hz else 0}" + ) + + time.sleep(0.5) + + episode_active.clear() + + robot_hz = robot_hz_tracker.get_avg_hz() + logger.info(f"Episode {episode_idx + 1} done. Avg Hz: {robot_hz:.1f if robot_hz else 0}") if events["rerecord_episode"]: log_say("Re-recording episode") @@ -566,6 +727,15 @@ def main(): print("\n\nInterrupted by user") finally: + shutdown_event.set() + episode_active.clear() + + if get_actions_t is not None and get_actions_t.is_alive(): + get_actions_t.join(timeout=2.0) + + if actor_t is not None and actor_t.is_alive(): + actor_t.join(timeout=2.0) + if leader: leader.bus_right.disable_torque() leader.bus_left.disable_torque() @@ -573,6 +743,7 @@ def main(): leader.disconnect() follower.disconnect() + logger.info("Follower disconnected") if listener is not None: listener.stop()