diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py index bfae221cf..185dd41d0 100644 --- a/examples/rac/rac_data_collection_openarms_rtc.py +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -330,120 +330,110 @@ def make_identity_processors(): def rtc_inference_thread( policy, - obs_holder: dict, # {"obs": filtered_obs, "features": observation_features} - set by main loop + obs_holder: dict, hw_features: dict, preprocessor, postprocessor, - queue_holder: dict, # {"queue": ActionQueue} - mutable so we can update per episode + queue_holder: dict, shutdown_event: Event, policy_active: Event, cfg: RaCRTCConfig, ): - """Background thread that generates action chunks using RTC. - - IMPORTANT: This thread does NOT access the robot directly! - It reads observations from obs_holder which is updated by the main loop. - This avoids race conditions on the CAN bus. - """ - logger.info("[RTC] Inference thread started (reads obs from main loop, no direct robot access)") - - latency_tracker = LatencyTracker() - time_per_chunk = 1.0 / cfg.dataset.fps - policy_device = policy.config.device - - get_actions_threshold = cfg.action_queue_size_to_get_new_actions - if not cfg.rtc.enabled: - get_actions_threshold = 0 - - inference_count = 0 - wait_logged = False - - while not shutdown_event.is_set(): - if not policy_active.is_set(): - if not wait_logged: - logger.info("[RTC] Waiting for policy_active...") - wait_logged = True - time.sleep(0.01) - continue + """Background thread that generates action chunks using RTC.""" + try: + logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========") + logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys") + latency_tracker = LatencyTracker() + time_per_chunk = 1.0 / cfg.dataset.fps + policy_device = policy.config.device + + get_actions_threshold = cfg.action_queue_size_to_get_new_actions + if not cfg.rtc.enabled: + get_actions_threshold = 0 + + inference_count = 0 wait_logged = False - action_queue = queue_holder["queue"] - if action_queue is None: - logger.warning("[RTC] queue_holder['queue'] is None!") - time.sleep(0.01) - continue + while not shutdown_event.is_set(): + if not policy_active.is_set(): + if not wait_logged: + logger.info("[RTC] Waiting for policy_active...") + wait_logged = True + time.sleep(0.01) + continue + + wait_logged = False + + action_queue = queue_holder["queue"] + if action_queue is None: + logger.warning("[RTC] queue_holder['queue'] is None!") + time.sleep(0.01) + continue + + obs_filtered = obs_holder.get("obs") + if obs_filtered is None: + logger.warning("[RTC] obs_holder['obs'] is None!") + time.sleep(0.01) + continue + + qsize = action_queue.qsize() + if qsize <= get_actions_threshold: + try: + if inference_count == 0: + logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}") + + current_time = time.perf_counter() + action_index_before_inference = action_queue.get_action_index() + prev_actions = action_queue.get_left_over() + + inference_latency = latency_tracker.max() + inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 + + obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation") + + for name in obs_with_policy_features: + obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) + if "image" in name: + obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255 + obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous() + obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device) + + obs_with_policy_features["task"] = [cfg.dataset.single_task] + obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower") + + preprocessed_obs = preprocessor(obs_with_policy_features) + + actions = policy.predict_action_chunk( + preprocessed_obs, + inference_delay=inference_delay, + prev_chunk_left_over=prev_actions, + ) + + original_actions = actions.squeeze(0).clone() + postprocessed_actions = postprocessor(actions).squeeze(0) + + new_latency = time.perf_counter() - current_time + new_delay = math.ceil(new_latency / time_per_chunk) + latency_tracker.add(new_latency) + + action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference) + + inference_count += 1 + logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}") + except Exception as e: + logger.error(f"[RTC] Inference error: {e}") + import traceback + traceback.print_exc() + time.sleep(1.0) + else: + time.sleep(0.01) - # Get observation from shared holder (set by main loop) - obs_filtered = obs_holder.get("obs") - if obs_filtered is None: - logger.warning("[RTC] obs_holder['obs'] is None - main loop not setting it?") - time.sleep(0.01) - continue - - qsize = action_queue.qsize() - if qsize <= get_actions_threshold: - try: - if inference_count == 0: - logger.info(f"[RTC] Starting first inference, obs has {len(obs_filtered)} keys, qsize={qsize}, threshold={get_actions_threshold}") - current_time = time.perf_counter() - action_index_before_inference = action_queue.get_action_index() - prev_actions = action_queue.get_left_over() - - inference_latency = latency_tracker.max() - inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 - - # Build observation for policy (using obs from main loop) - obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation") - - # Convert to tensors (like evaluate_with_rtc.py) - for name in obs_with_policy_features: - obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) - if "image" in name: - obs_with_policy_features[name] = ( - obs_with_policy_features[name].type(torch.float32) / 255 - ) - obs_with_policy_features[name] = ( - obs_with_policy_features[name].permute(2, 0, 1).contiguous() - ) - obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0) - obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device) - - obs_with_policy_features["task"] = [cfg.dataset.single_task] - obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower") - - # Preprocess and run inference - preprocessed_obs = preprocessor(obs_with_policy_features) - - actions = policy.predict_action_chunk( - preprocessed_obs, - inference_delay=inference_delay, - prev_chunk_left_over=prev_actions, - ) - - original_actions = actions.squeeze(0).clone() - postprocessed_actions = postprocessor(actions).squeeze(0) - - new_latency = time.perf_counter() - current_time - new_delay = math.ceil(new_latency / time_per_chunk) - latency_tracker.add(new_latency) - - # Put actions in queue! - action_queue.merge( - original_actions, postprocessed_actions, new_delay, action_index_before_inference - ) - - inference_count += 1 - logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}, shape={postprocessed_actions.shape}") - except Exception as e: - logger.error(f"[RTC] Inference failed: {e}") - import traceback - traceback.print_exc() - time.sleep(1.0) # Don't spam errors - else: - time.sleep(0.01) - - logger.info("[RTC] Inference thread shutting down") + logger.info("[RTC] Inference thread shutting down") + except Exception as e: + logger.error(f"[RTC] THREAD CRASHED: {e}") + import traceback + traceback.print_exc() # ============================================================================