From 7ac05c838d4a0c64a0eb086776210799c1f38ea5 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 9 Jan 2026 12:56:43 +0100 Subject: [PATCH] add interpolation option --- examples/openarms/evaluate_with_rtc.py | 84 ++++++++++++++++++-------- 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/examples/openarms/evaluate_with_rtc.py b/examples/openarms/evaluate_with_rtc.py index dd19ac8c2..7926e5306 100644 --- a/examples/openarms/evaluate_with_rtc.py +++ b/examples/openarms/evaluate_with_rtc.py @@ -79,6 +79,7 @@ DEFAULT_TASK_DESCRIPTION = "three-folds-dataset" DEFAULT_NUM_EPISODES = 1 DEFAULT_FPS = 30 +DEFAULT_ROBOT_HZ = 50 DEFAULT_EPISODE_TIME_SEC = 300 DEFAULT_RESET_TIME_SEC = 60 @@ -167,6 +168,9 @@ class OpenArmsRTCEvalConfig(HubMixin): record_dataset: bool = True push_to_hub: bool = True + interpolation: bool = False + robot_hz: float = DEFAULT_ROBOT_HZ + use_torch_compile: bool = False torch_compile_backend: str = "inductor" torch_compile_mode: str = "default" @@ -323,54 +327,79 @@ def actor_thread( ): """Thread function to execute actions on the robot.""" try: - logger.info("[ACTOR] Starting actor thread") - - action_count = 0 - action_interval = 1.0 / cfg.fps action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] + action_count = 0 + + if cfg.interpolation: + robot_interval = 1.0 / cfg.robot_hz + interp_steps = int(cfg.robot_hz / cfg.fps) + logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.robot_hz}Hz ({interp_steps} steps)") + else: + robot_interval = 1.0 / cfg.fps + interp_steps = 1 + logger.info(f"[ACTOR] Interpolation OFF: policy={cfg.fps}Hz, robot={cfg.fps}Hz") + + prev_action: Tensor | None = None + current_action: Tensor | None = None + interp_step = 0 + last_dataset_time = 0.0 while not shutdown_event.is_set(): if not episode_active.is_set(): + prev_action = None + current_action = None + interp_step = 0 time.sleep(0.01) continue start_time = time.perf_counter() - action = action_queue.get() - if action is not None: - action = action.cpu() + if interp_step == 0 or current_action is None: + new_action = action_queue.get() + if new_action is not None: + prev_action = current_action if current_action is not None else new_action.cpu() + current_action = new_action.cpu() + interp_step = 0 + + if current_action is not None: + if cfg.interpolation and prev_action is not None and interp_steps > 1: + alpha = (interp_step + 1) / interp_steps + action_to_send = prev_action + alpha * (current_action - prev_action) + else: + action_to_send = current_action action_dict = {} for i, key in enumerate(action_keys): - if i < len(action): - action_dict[key] = action[i].item() + if i < len(action_to_send): + action_dict[key] = action_to_send[i].item() action_processed = robot_action_processor((action_dict, None)) robot.send_action(action_processed) + action_count += 1 + interp_step = (interp_step + 1) % interp_steps if cfg.record_dataset and dataset is not None: - with dataset_lock: - obs = robot.get_observation() - obs_processed = robot_observation_processor(obs) - action_for_dataset = teleop_action_processor((action_dict, None)) - - frame = {} - for key, value in obs_processed.items(): - frame[f"observation.{key}"] = value - for key, value in action_for_dataset.items(): - frame[f"action.{key}"] = value - frame["task"] = cfg.task - - dataset.add_frame(frame) - - action_count += 1 + now = time.perf_counter() + if now - last_dataset_time >= (1.0 / cfg.fps): + last_dataset_time = now + with dataset_lock: + obs = robot.get_observation() + obs_processed = robot_observation_processor(obs) + action_for_dataset = teleop_action_processor((action_dict, None)) + frame = {} + for key, value in obs_processed.items(): + frame[f"observation.{key}"] = value + for key, value in action_for_dataset.items(): + frame[f"action.{key}"] = value + frame["task"] = cfg.task + dataset.add_frame(frame) dt_s = time.perf_counter() - start_time - sleep_time = max(0, action_interval - dt_s - 0.001) + sleep_time = max(0, robot_interval - dt_s - 0.001) if sleep_time > 0: time.sleep(sleep_time) - logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}") + logger.info(f"[ACTOR] Shutting down. Total actions: {action_count}") except Exception as e: logger.error(f"[ACTOR] Fatal exception: {e}") logger.error(traceback.format_exc()) @@ -434,6 +463,9 @@ def main(cfg: OpenArmsRTCEvalConfig): print(f"RTC Enabled: {cfg.rtc.enabled}") print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}") print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}") + print(f"Policy Hz: {cfg.fps}") + print(f"Robot Hz: {cfg.robot_hz if cfg.interpolation else cfg.fps}") + print(f"Interpolation: {cfg.interpolation}") print(f"Device: {cfg.device}") print("=" * 60)