diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index c861e128c..c3d540339 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -452,12 +452,13 @@ class RTCEvaluator: noise_clone = noise.clone() policy_no_rtc_policy.rtc_processor.reset_tracker() with torch.no_grad(): - _ = policy_no_rtc_policy.predict_action_chunk( + no_rtc_actions = policy_no_rtc_policy.predict_action_chunk( preprocessed_second_sample, noise=noise, ) no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps() logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC") + logging.info(f" Generated no_rtc_actions shape: {no_rtc_actions.shape}") # Destroy policy_no_rtc to free memory before loading policy_rtc self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc") @@ -477,7 +478,7 @@ class RTCEvaluator: ) policy_rtc_policy.rtc_processor.reset_tracker() with torch.no_grad(): - _ = policy_rtc_policy.predict_action_chunk( + rtc_actions = policy_rtc_policy.predict_action_chunk( preprocessed_second_sample, noise=noise_clone, inference_delay=self.cfg.inference_delay, @@ -486,6 +487,7 @@ class RTCEvaluator: ) rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps() logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC") + logging.info(f" Generated rtc_actions shape: {rtc_actions.shape}") # Save num_steps before destroying policy (needed for plotting) try: @@ -502,9 +504,163 @@ class RTCEvaluator: logging.info("=" * 80) logging.info("Plotting results...") self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps) + + # Validate RTC behavior + logging.info("=" * 80) + logging.info("Validating RTC behavior...") + self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over) + logging.info("=" * 80) logging.info("Evaluation completed successfully") + def validate_rtc_behavior(self, rtc_actions, no_rtc_actions, prev_chunk_left_over): + """Validate RTC behavior by comparing final action predictions with expected values. + + Validation rules: + 1. During delay [0:inference_delay]: RTC should equal prev_chunk + 2. After delay, within execution horizon [inference_delay:execution_horizon]: + RTC should be between prev_chunk and no_rtc + 3. After execution horizon [execution_horizon:]: RTC should equal no_rtc + + Args: + rtc_actions: Final actions from RTC policy (batch, time, action_dim) + no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim) + prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim) + """ + if rtc_actions is None or no_rtc_actions is None: + logging.warning(" ⚠ Cannot validate: missing action predictions") + return + + # Convert to numpy for comparison (remove batch dimension if present) + rtc_actions_np = ( + rtc_actions.squeeze(0).cpu().numpy() if len(rtc_actions.shape) == 3 else rtc_actions.cpu().numpy() + ) + no_rtc_actions_np = ( + no_rtc_actions.squeeze(0).cpu().numpy() + if len(no_rtc_actions.shape) == 3 + else no_rtc_actions.cpu().numpy() + ) + prev_chunk = prev_chunk_left_over.cpu().numpy() + + # Determine chunk length for comparison + chunk_len = min(rtc_actions_np.shape[0], no_rtc_actions_np.shape[0], prev_chunk.shape[0]) + inference_delay = self.cfg.inference_delay + execution_horizon = self.cfg.rtc.execution_horizon + + # Tolerance for floating point comparison + rtol = 1e-3 # Relative tolerance + atol = 1e-3 # Absolute tolerance + + validation_passed = True + warnings = [] + + logging.info(" Validating RTC behavior:") + logging.info(f" Chunk length: {chunk_len}") + logging.info(f" Inference delay: {inference_delay}") + logging.info(f" Execution horizon: {execution_horizon}") + logging.info(f" Tolerance: rtol={rtol}, atol={atol}") + + # ============================================================================ + # Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk + # ============================================================================ + if inference_delay > 0: + delay_end = min(inference_delay, chunk_len) + rtc_delay = rtc_actions_np[:delay_end] + prev_delay = prev_chunk[:delay_end] + + if not np.allclose(rtc_delay, prev_delay, rtol=rtol, atol=atol): + max_diff = np.max(np.abs(rtc_delay - prev_delay)) + mean_diff = np.mean(np.abs(rtc_delay - prev_delay)) + warnings.append( + f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], " + f"RTC does NOT equal prev_chunk!\n" + f" Max difference: {max_diff:.6f}\n" + f" Mean difference: {mean_diff:.6f}" + ) + validation_passed = False + else: + logging.info(f" ✓ During delay [0:{delay_end}]: RTC equals prev_chunk") + + # ============================================================================ + # Rule 2: After delay, within execution horizon [inference_delay:execution_horizon] + # RTC should be between prev_chunk and no_rtc + # ============================================================================ + blend_start = inference_delay + blend_end = min(execution_horizon, chunk_len) + + if blend_end > blend_start: + rtc_blend = rtc_actions_np[blend_start:blend_end] + prev_blend = prev_chunk[blend_start:blend_end] + no_rtc_blend = no_rtc_actions_np[blend_start:blend_end] + + # Check if RTC is between prev_chunk and no_rtc (element-wise) + # For each element, check if it's between the min and max of prev_chunk and no_rtc + min_bound = np.minimum(prev_blend, no_rtc_blend) - atol + max_bound = np.maximum(prev_blend, no_rtc_blend) + atol + + within_bounds = np.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) + + if not np.all(within_bounds): + violations = np.sum(~within_bounds) + total_elements = within_bounds.size + violation_pct = 100.0 * violations / total_elements + + # Find max violation + lower_violations = np.maximum(0, min_bound - rtc_blend) + upper_violations = np.maximum(0, rtc_blend - max_bound) + max_violation = np.max(np.maximum(lower_violations, upper_violations)) + + warnings.append( + f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], " + f"RTC is NOT always between prev_chunk and no_rtc!\n" + f" Violations: {violations}/{total_elements} elements ({violation_pct:.1f}%)\n" + f" Max violation distance: {max_violation:.6f}" + ) + validation_passed = False + else: + logging.info( + f" ✓ Blend region [{blend_start}:{blend_end}]: RTC is between prev_chunk and no_rtc" + ) + + # ============================================================================ + # Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc + # ============================================================================ + if execution_horizon < chunk_len: + rtc_after = rtc_actions_np[execution_horizon:chunk_len] + no_rtc_after = no_rtc_actions_np[execution_horizon:chunk_len] + + if not np.allclose(rtc_after, no_rtc_after, rtol=rtol, atol=atol): + max_diff = np.max(np.abs(rtc_after - no_rtc_after)) + mean_diff = np.mean(np.abs(rtc_after - no_rtc_after)) + warnings.append( + f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], " + f"RTC does NOT equal no_rtc!\n" + f" Max difference: {max_diff:.6f}\n" + f" Mean difference: {mean_diff:.6f}" + ) + validation_passed = False + else: + logging.info( + f" ✓ After execution horizon [{execution_horizon}:{chunk_len}]: RTC equals no_rtc" + ) + + # ============================================================================ + # Report results + # ============================================================================ + logging.info("=" * 80) + if validation_passed: + logging.info(" ✅ VALIDATION PASSED: All RTC behavior checks passed!") + logging.info(" • During delay: RTC = prev_chunk ✓") + logging.info(" • Blend region: prev_chunk ≤ RTC ≤ no_rtc ✓") + logging.info(" • After execution horizon: RTC = no_rtc ✓") + else: + logging.error(" ❌ VALIDATION FAILED: RTC behavior does not match expected!") + logging.error("") + for warning in warnings: + logging.error(warning) + logging.error("") + logging.error(" Please check the implementation of RTC guidance.") + def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps): # Create side-by-side figures for denoising visualization fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") diff --git a/examples/rtc/run_dataset_evaluation.sh b/examples/rtc/run_dataset_evaluation.sh deleted file mode 100755 index 81370682f..000000000 --- a/examples/rtc/run_dataset_evaluation.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash - -# Example script to run RTC evaluation on dataset -# This shows different usage scenarios - -set -e # Exit on error - -POLICY_PATH="lerobot/smolvla_base" -DATASET="lerobot/pusht" -DEVICE="cuda" # Change to "cpu" or "mps" if needed - -echo "========================================" -echo "RTC Dataset Evaluation Examples" -echo "========================================" - -# Example 1: Quick evaluation (100 samples, every step) -echo -e "\n[Example 1] Quick evaluation - 100 samples, every step" -python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path="${POLICY_PATH}" \ - --dataset.repo_id="${DATASET}" \ - --num_iterations=100 \ - --skip_steps=1 \ - --device="${DEVICE}" \ - --output_path="results/rtc_eval_quick.json" - -# Example 2: Simulating realistic inference delay (every 3rd step) -echo -e "\n[Example 2] Realistic inference delay - 200 samples, every 3rd step" -python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path="${POLICY_PATH}" \ - --dataset.repo_id="${DATASET}" \ - --num_iterations=200 \ - --skip_steps=3 \ - --rtc.execution_horizon=10 \ - --device="${DEVICE}" \ - --output_path="results/rtc_eval_delay3.json" - -# Example 3: Higher inference delay (every 5th step) -echo -e "\n[Example 3] High inference delay - 200 samples, every 5th step" -python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path="${POLICY_PATH}" \ - --dataset.repo_id="${DATASET}" \ - --num_iterations=200 \ - --skip_steps=5 \ - --rtc.execution_horizon=12 \ - --device="${DEVICE}" \ - --output_path="results/rtc_eval_delay5.json" - -# Example 4: Testing different RTC configurations -echo -e "\n[Example 4] Different RTC config - LINEAR schedule" -python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path="${POLICY_PATH}" \ - --dataset.repo_id="${DATASET}" \ - --num_iterations=100 \ - --skip_steps=3 \ - --rtc.execution_horizon=8 \ - --rtc.prefix_attention_schedule=LINEAR \ - --rtc.max_guidance_weight=5.0 \ - --device="${DEVICE}" \ - --output_path="results/rtc_eval_linear.json" - -# Example 5: Verbose mode for debugging -echo -e "\n[Example 5] Verbose mode - 20 samples with detailed output" -python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path="${POLICY_PATH}" \ - --dataset.repo_id="${DATASET}" \ - --num_iterations=20 \ - --skip_steps=3 \ - --device="${DEVICE}" \ - --verbose=true \ - --output_path="results/rtc_eval_verbose.json" - -echo -e "\n========================================" -echo "All evaluations completed!" -echo "Results saved in results/ directory" -echo "========================================"