mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
debug rtc
This commit is contained in:
@@ -330,120 +330,110 @@ def make_identity_processors():
|
||||
|
||||
def rtc_inference_thread(
|
||||
policy,
|
||||
obs_holder: dict, # {"obs": filtered_obs, "features": observation_features} - set by main loop
|
||||
obs_holder: dict,
|
||||
hw_features: dict,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder: dict, # {"queue": ActionQueue} - mutable so we can update per episode
|
||||
queue_holder: dict,
|
||||
shutdown_event: Event,
|
||||
policy_active: Event,
|
||||
cfg: RaCRTCConfig,
|
||||
):
|
||||
"""Background thread that generates action chunks using RTC.
|
||||
|
||||
IMPORTANT: This thread does NOT access the robot directly!
|
||||
It reads observations from obs_holder which is updated by the main loop.
|
||||
This avoids race conditions on the CAN bus.
|
||||
"""
|
||||
logger.info("[RTC] Inference thread started (reads obs from main loop, no direct robot access)")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
policy_device = policy.config.device
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
wait_logged = False
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
if not wait_logged:
|
||||
logger.info("[RTC] Waiting for policy_active...")
|
||||
wait_logged = True
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
"""Background thread that generates action chunks using RTC."""
|
||||
try:
|
||||
logger.info("[RTC] ========== INFERENCE THREAD STARTED ==========")
|
||||
logger.info(f"[RTC] policy={policy.name}, hw_features has {len(hw_features)} keys")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
policy_device = policy.config.device
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
inference_count = 0
|
||||
wait_logged = False
|
||||
|
||||
action_queue = queue_holder["queue"]
|
||||
if action_queue is None:
|
||||
logger.warning("[RTC] queue_holder['queue'] is None!")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
if not wait_logged:
|
||||
logger.info("[RTC] Waiting for policy_active...")
|
||||
wait_logged = True
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
wait_logged = False
|
||||
|
||||
action_queue = queue_holder["queue"]
|
||||
if action_queue is None:
|
||||
logger.warning("[RTC] queue_holder['queue'] is None!")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
obs_filtered = obs_holder.get("obs")
|
||||
if obs_filtered is None:
|
||||
logger.warning("[RTC] obs_holder['obs'] is None!")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
qsize = action_queue.qsize()
|
||||
if qsize <= get_actions_threshold:
|
||||
try:
|
||||
if inference_count == 0:
|
||||
logger.info(f"[RTC] Starting first inference, obs keys={len(obs_filtered)}, qsize={qsize}")
|
||||
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].float() / 255
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0).to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.dataset.single_task]
|
||||
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
action_queue.merge(original_actions, postprocessed_actions, new_delay, action_index_before_inference)
|
||||
|
||||
inference_count += 1
|
||||
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] Inference error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
# Get observation from shared holder (set by main loop)
|
||||
obs_filtered = obs_holder.get("obs")
|
||||
if obs_filtered is None:
|
||||
logger.warning("[RTC] obs_holder['obs'] is None - main loop not setting it?")
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
qsize = action_queue.qsize()
|
||||
if qsize <= get_actions_threshold:
|
||||
try:
|
||||
if inference_count == 0:
|
||||
logger.info(f"[RTC] Starting first inference, obs has {len(obs_filtered)} keys, qsize={qsize}, threshold={get_actions_threshold}")
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
# Build observation for policy (using obs from main loop)
|
||||
obs_with_policy_features = build_dataset_frame(hw_features, obs_filtered, prefix="observation")
|
||||
|
||||
# Convert to tensors (like evaluate_with_rtc.py)
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.dataset.single_task]
|
||||
obs_with_policy_features["robot_type"] = obs_holder.get("robot_type", "openarms_follower")
|
||||
|
||||
# Preprocess and run inference
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
# Put actions in queue!
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
|
||||
inference_count += 1
|
||||
logger.info(f"[RTC] Inference #{inference_count}, latency={new_latency:.2f}s, queue={action_queue.qsize()}, shape={postprocessed_actions.shape}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] Inference failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(1.0) # Don't spam errors
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[RTC] Inference thread shutting down")
|
||||
logger.info("[RTC] Inference thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] THREAD CRASHED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
Reference in New Issue
Block a user