This commit is contained in:
Pepijn
2026-01-09 16:54:11 +01:00
parent 480ee3299f
commit feedababd2
@@ -356,77 +356,90 @@ def rtc_inference_thread(
get_actions_threshold = 0
inference_count = 0
wait_logged = False
while not shutdown_event.is_set():
if not policy_active.is_set():
if not wait_logged:
logger.info("[RTC] Waiting for policy_active...")
wait_logged = True
time.sleep(0.01)
continue
wait_logged = False
action_queue = queue_holder["queue"]
if action_queue is None:
logger.warning("[RTC] queue_holder['queue'] is None!")
time.sleep(0.01)
continue
# Get observation from shared holder (set by main loop)
obs_filtered = obs_holder.get("obs")
if obs_filtered is None:
if inference_count == 0:
logger.warning("[RTC] Waiting for observation from main loop...")
logger.warning("[RTC] obs_holder['obs'] is None - main loop not setting it?")
time.sleep(0.01)
continue
if action_queue.qsize() <= get_actions_threshold:
if inference_count == 0:
logger.info(f"[RTC] Starting first inference, obs has {len(obs_filtered)} keys")
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
# Build observation for policy (using obs from main loop)
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
# Convert to tensors (like evaluate_with_rtc.py)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.dataset.single_task]
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
# Preprocess and run inference
preprocessed_obs = preprocessor(obs_with_policy_features)
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
# Put actions in queue!
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
inference_count += 1
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}, shape={postprocessed_actions.shape}")
qsize = action_queue.qsize()
if qsize <= get_actions_threshold:
try:
if inference_count == 0:
logger.info(f"[RTC] Starting first inference, obs has {len(obs_filtered)} keys, qsize={qsize}, threshold={get_actions_threshold}")
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
# Build observation for policy (using obs from main loop)
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
# Convert to tensors (like evaluate_with_rtc.py)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.dataset.single_task]
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
# Preprocess and run inference
preprocessed_obs = preprocessor(obs_with_policy_features)
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
# Put actions in queue!
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
inference_count += 1
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}, shape={postprocessed_actions.shape}")
except Exception as e:
logger.error(f"[RTC] Inference failed: {e}")
import traceback
traceback.print_exc()
time.sleep(1.0) # Don't spam errors
else:
time.sleep(0.01)