diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index a8d0a1d15..7f281dc33 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -544,11 +544,6 @@ class RTCEvaluator: logging.info("Plotting results...") self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps) - # Validate RTC behavior - # logging.info("=" * 80) - # logging.info("Validating RTC behavior...") - # self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over) - # Plot final actions comparison logging.info("=" * 80) logging.info("Plotting final actions comparison...") @@ -557,159 +552,6 @@ class RTCEvaluator: logging.info("=" * 80) logging.info("Evaluation completed successfully") - def validate_rtc_behavior(self, rtc_actions, no_rtc_actions, prev_chunk_left_over): - """Validate RTC behavior by comparing final action predictions with expected values. - - Validation rules: - 1. During delay [0:inference_delay]: RTC should equal prev_chunk - 2. After delay, within execution horizon [inference_delay:execution_horizon]: - RTC should be between prev_chunk and no_rtc - 3. After execution horizon [execution_horizon:]: RTC should equal no_rtc - - Args: - rtc_actions: Final actions from RTC policy (batch, time, action_dim) - no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim) - prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim) - """ - # Remove batch dimension if present and move to CPU - rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() - no_rtc_actions_t = ( - no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() - ) - prev_chunk = prev_chunk_left_over.cpu() - - logging.info(f" rtc_actions shape: {rtc_actions_t.shape}") - logging.info(f" no_rtc_actions shape: {no_rtc_actions_t.shape}") - logging.info(f" prev_chunk shape: {prev_chunk.shape}") - - # Determine chunk length for comparison - chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk.shape[0]) - inference_delay = self.cfg.inference_delay - execution_horizon = self.cfg.rtc.execution_horizon - - # Tolerance for floating point comparison - rtol = 1e-2 # Relative tolerance - - validation_passed = True - warnings = [] - - logging.info(" Validating RTC behavior:") - logging.info(f" Chunk length: {chunk_len}") - logging.info(f" Inference delay: {inference_delay}") - logging.info(f" Execution horizon: {execution_horizon}") - logging.info(f" Tolerance: rtol={rtol}") - - # ============================================================================ - # Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk - # ============================================================================ - if inference_delay > 0: - delay_end = min(inference_delay, chunk_len) - rtc_delay = rtc_actions_t[:delay_end] - prev_delay = prev_chunk[:delay_end] - - logging.info(f" rtc_delay: {rtc_delay.shape}") - logging.info(f" prev_delay: {prev_delay.shape}") - - if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): - max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() - mean_diff = torch.mean(torch.abs(rtc_delay - prev_delay)).item() - logging.info(f" rtc_delay: {rtc_delay}") - logging.info(f" prev_delay: {prev_delay}") - logging.info(f" max_diff: {max_diff}") - logging.info(f" mean_diff: {mean_diff}") - warnings.append( - f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], " - f"RTC does NOT equal prev_chunk!\n" - f" Max difference: {max_diff:.6f}\n" - f" Mean difference: {mean_diff:.6f}" - ) - validation_passed = False - else: - logging.info(f" ✓ During delay [0:{delay_end}]: RTC equals prev_chunk") - - # ============================================================================ - # Rule 2: After delay, within execution horizon [inference_delay:execution_horizon] - # RTC should be between prev_chunk and no_rtc - # ============================================================================ - blend_start = inference_delay - blend_end = min(execution_horizon, chunk_len) - - if blend_end > blend_start: - rtc_blend = rtc_actions_t[blend_start:blend_end] - prev_blend = prev_chunk[blend_start:blend_end] - no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] - - # Check if RTC is between prev_chunk and no_rtc (element-wise) - # For each element, check if it's between the min and max of prev_chunk and no_rtc - min_bound = torch.minimum(prev_blend, no_rtc_blend) - max_bound = torch.maximum(prev_blend, no_rtc_blend) - - within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) - - if not torch.all(within_bounds): - violations = torch.sum(~within_bounds).item() - total_elements = within_bounds.numel() - violation_pct = 100.0 * violations / total_elements - - # Find max violation - lower_violations = torch.maximum(torch.tensor(0.0), min_bound - rtc_blend) - upper_violations = torch.maximum(torch.tensor(0.0), rtc_blend - max_bound) - max_violation = torch.max(torch.maximum(lower_violations, upper_violations)).item() - - warnings.append( - f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], " - f"RTC is NOT always between prev_chunk and no_rtc!\n" - f" Violations: {violations}/{total_elements} elements ({violation_pct:.1f}%)\n" - f" Max violation distance: {max_violation:.6f}" - ) - validation_passed = False - else: - logging.info( - f" ✓ Blend region [{blend_start}:{blend_end}]: RTC is between prev_chunk and no_rtc" - ) - - # ============================================================================ - # Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc - # ============================================================================ - if execution_horizon < chunk_len: - rtc_after = rtc_actions_t[execution_horizon:chunk_len] - no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] - - logging.info(f" rtc_after: {rtc_after}") - logging.info(f" no_rtc_after: {no_rtc_after}") - - if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): - max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() - mean_diff = torch.mean(torch.abs(rtc_after - no_rtc_after)).item() - warnings.append( - f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], " - f"RTC does NOT equal no_rtc!\n" - f" Max difference: {max_diff:.6f}\n" - f" Mean difference: {mean_diff:.6f}" - ) - validation_passed = False - else: - logging.info( - f" ✓ After execution horizon [{execution_horizon}:{chunk_len}]: RTC equals no_rtc" - ) - - # ============================================================================ - # Report results - # ============================================================================ - logging.info("=" * 80) - if validation_passed: - logging.info(" ✅ VALIDATION PASSED: All RTC behavior checks passed!") - logging.info(" • During delay: RTC = prev_chunk ✓") - logging.info(" • Blend region: prev_chunk ≤ RTC ≤ no_rtc ✓") - logging.info(" • After execution horizon: RTC = no_rtc ✓") - else: - logging.error(" ❌ VALIDATION FAILED: RTC behavior does not match expected!") - logging.error("") - for warning in warnings: - logging.error(warning) - logging.error("") - logging.error(" Please check the implementation of RTC guidance.") - def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over): """Plot final action predictions comparison on a single chart.