debug rtc

This commit is contained in:
Pepijn
2026-01-09 16:58:57 +01:00
parent feedababd2
commit 3316301693
@@ -330,22 +330,19 @@ def make_identity_processors():
def rtc_inference_thread( def rtc_inference_thread(
policy, policy,
obs_holder: dict, # {"obs": filtered_obs, "features": observation_features} - set by main loop obs_holder: dict,
hw_features: dict, hw_features: dict,
preprocessor, preprocessor,
postprocessor, postprocessor,
queue_holder: dict, # {"queue": ActionQueue} - mutable so we can update per episode queue_holder: dict,
shutdown_event: Event, shutdown_event: Event,
policy_active: Event, policy_active: Event,
cfg: RaCRTCConfig, cfg: RaCRTCConfig,
): ):
"""Background thread that generates action chunks using RTC. """Background thread that generates action chunks using RTC."""
try:
IMPORTANT: This thread does NOT access the robot directly! logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========")
It reads observations from obs_holder which is updated by the main loop. logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys")
This avoids race conditions on the CAN bus.
"""
logger.info("[RTC] Inference thread started (reads obs from main loop, no direct robot access)")
latency_tracker = LatencyTracker() latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / cfg.dataset.fps time_per_chunk = 1.0 / cfg.dataset.fps
@@ -374,10 +371,9 @@ def rtc_inference_thread(
time.sleep(0.01) time.sleep(0.01)
continue continue
# Get observation from shared holder (set by main loop)
obs_filtered = obs_holder.get("obs") obs_filtered = obs_holder.get("obs")
if obs_filtered is None: if obs_filtered is None:
logger.warning("[RTC] obs_holder['obs'] is None - main loop not setting it?") logger.warning("[RTC] obs_holder['obs'] is None!")
time.sleep(0.01) time.sleep(0.01)
continue continue
@@ -385,7 +381,8 @@ def rtc_inference_thread(
if qsize <= get_actions_threshold: if qsize <= get_actions_threshold:
try: try:
if inference_count == 0: if inference_count == 0:
logger.info(f"[RTC] Starting first inference, obs has {len(obs_filtered)} keys, qsize={qsize}, threshold={get_actions_threshold}") logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}")
current_time = time.perf_counter() current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index() action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over() prev_actions = action_queue.get_left_over()
@@ -393,26 +390,18 @@ def rtc_inference_thread(
inference_latency = latency_tracker.max() inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0 inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
# Build observation for policy (using obs from main loop)
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation") obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
# Convert to tensors (like evaluate_with_rtc.py)
for name in obs_with_policy_features: for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name]) obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name: if "image" in name:
obs_with_policy_features[name] = ( obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255
obs_with_policy_features[name].type(torch.float32) / 255 obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous()
) obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.dataset.single_task] obs_with_policy_features["task"] = [cfg.dataset.single_task]
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower") obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
# Preprocess and run inference
preprocessed_obs = preprocessor(obs_with_policy_features) preprocessed_obs = preprocessor(obs_with_policy_features)
actions = policy.predict_action_chunk( actions = policy.predict_action_chunk(
@@ -428,22 +417,23 @@ def rtc_inference_thread(
new_delay = math.ceil(new_latency / time_per_chunk) new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency) latency_tracker.add(new_latency)
# Put actions in queue! action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
inference_count += 1 inference_count += 1
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}, shape={postprocessed_actions.shape}") logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}")
except Exception as e: except Exception as e:
logger.error(f"[RTC] Inference failed: {e}") logger.error(f"[RTC] Inference error: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(1.0) # Don't spam errors time.sleep(1.0)
else: else:
time.sleep(0.01) time.sleep(0.01)
logger.info("[RTC] Inference thread shutting down") logger.info("[RTC] Inference thread shutting down")
except Exception as e:
logger.error(f"[RTC] THREAD CRASHED: {e}")
import traceback
traceback.print_exc()
# ============================================================================ # ============================================================================