diff --git a/examples/openarms/evaluate_with_rtc.py b/examples/openarms/evaluate_with_rtc.py index 7926e5306..860c919f8 100644 --- a/examples/openarms/evaluate_with_rtc.py +++ b/examples/openarms/evaluate_with_rtc.py @@ -79,7 +79,6 @@ DEFAULT_TASK_DESCRIPTION = "three-folds-dataset" DEFAULT_NUM_EPISODES = 1 DEFAULT_FPS = 30 -DEFAULT_ROBOT_HZ = 50 DEFAULT_EPISODE_TIME_SEC = 300 DEFAULT_RESET_TIME_SEC = 60 @@ -169,7 +168,6 @@ class OpenArmsRTCEvalConfig(HubMixin): push_to_hub: bool = True interpolation: bool = False - robot_hz: float = DEFAULT_ROBOT_HZ use_torch_compile: bool = False torch_compile_backend: str = "inductor" @@ -328,45 +326,56 @@ def actor_thread( """Thread function to execute actions on the robot.""" try: action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] - action_count = 0 if cfg.interpolation: - robot_interval = 1.0 / cfg.robot_hz - interp_steps = int(cfg.robot_hz / cfg.fps) - logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.robot_hz}Hz ({interp_steps} steps)") + interp_factor = 2 + robot_interval = 1.0 / (cfg.fps * interp_factor) + logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.fps * interp_factor}Hz (2x)") else: + interp_factor = 1 robot_interval = 1.0 / cfg.fps - interp_steps = 1 logger.info(f"[ACTOR] Interpolation OFF: policy={cfg.fps}Hz, robot={cfg.fps}Hz") prev_action: Tensor | None = None - current_action: Tensor | None = None - interp_step = 0 + interpolated_actions: list[Tensor] = [] + interp_idx = 0 + + robot_send_count = 0 + policy_consume_count = 0 + last_hz_print = time.perf_counter() last_dataset_time = 0.0 while not shutdown_event.is_set(): if not episode_active.is_set(): prev_action = None - current_action = None - interp_step = 0 + interpolated_actions = [] + interp_idx = 0 + robot_send_count = 0 + policy_consume_count = 0 + last_hz_print = time.perf_counter() time.sleep(0.01) continue start_time = time.perf_counter() - if interp_step == 0 or current_action is None: + if interp_idx >= len(interpolated_actions): new_action = action_queue.get() if new_action is not None: - prev_action = current_action if current_action is not None else new_action.cpu() current_action = new_action.cpu() - interp_step = 0 + policy_consume_count += 1 - if current_action is not None: - if cfg.interpolation and prev_action is not None and interp_steps > 1: - alpha = (interp_step + 1) / interp_steps - action_to_send = prev_action + alpha * (current_action - prev_action) - else: - action_to_send = current_action + if cfg.interpolation and prev_action is not None: + mid = prev_action + 0.5 * (current_action - prev_action) + interpolated_actions = [mid, current_action] + else: + interpolated_actions = [current_action] + + prev_action = current_action + interp_idx = 0 + + if interp_idx < len(interpolated_actions): + action_to_send = interpolated_actions[interp_idx] + interp_idx += 1 action_dict = {} for i, key in enumerate(action_keys): @@ -375,8 +384,7 @@ def actor_thread( action_processed = robot_action_processor((action_dict, None)) robot.send_action(action_processed) - action_count += 1 - interp_step = (interp_step + 1) % interp_steps + robot_send_count += 1 if cfg.record_dataset and dataset is not None: now = time.perf_counter() @@ -394,12 +402,22 @@ def actor_thread( frame["task"] = cfg.task dataset.add_frame(frame) + now = time.perf_counter() + if now - last_hz_print >= 5.0: + elapsed = now - last_hz_print + actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0 + actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0 + logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}") + robot_send_count = 0 + policy_consume_count = 0 + last_hz_print = now + dt_s = time.perf_counter() - start_time sleep_time = max(0, robot_interval - dt_s - 0.001) if sleep_time > 0: time.sleep(sleep_time) - logger.info(f"[ACTOR] Shutting down. Total actions: {action_count}") + logger.info("[ACTOR] Shutting down") except Exception as e: logger.error(f"[ACTOR] Fatal exception: {e}") logger.error(traceback.format_exc()) @@ -464,7 +482,7 @@ def main(cfg: OpenArmsRTCEvalConfig): print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}") print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}") print(f"Policy Hz: {cfg.fps}") - print(f"Robot Hz: {cfg.robot_hz if cfg.interpolation else cfg.fps}") + print(f"Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}") print(f"Interpolation: {cfg.interpolation}") print(f"Device: {cfg.device}") print("=" * 60)