diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 4729856aa..5bffd111c 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -201,33 +201,23 @@ class RTCEvaluator: noise = self.policy.model.sample_noise(noise_size, self.device) noise_clone = noise.clone() - # Create side-by-side figures for denoising visualization - fig_xt, axs_xt = plt.subplots(6, 2, figsize=(24, 12)) - fig_xt.suptitle("x_t Denoising: No RTC (left) vs RTC (right)", fontsize=16) - - fig_vt, axs_vt = plt.subplots(6, 2, figsize=(24, 12)) - fig_vt.suptitle("v_t Denoising: No RTC (left) vs RTC (right)", fontsize=16) - - fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12)) - fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16) - # Generate actions WITHOUT RTC logger.info("Generating actions WITHOUT RTC") self.policy.config.rtc_config.enabled = False with torch.no_grad(): - no_rtc_actions = self.policy.predict_action_chunk( + _ = self.policy.predict_action_chunk( preprocessed_second_sample, noise=noise, ) - # Plot denoising steps from tracker (no RTC - left column) - # Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists + no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps() + self.policy.rtc_processor.reset_tracker() # Generate actions WITH RTC logger.info("Generating actions WITH RTC") self.policy.config.rtc_config.enabled = True with torch.no_grad(): - rtc_actions = self.policy.predict_action_chunk( + _ = self.policy.predict_action_chunk( preprocessed_second_sample, noise=noise_clone, inference_delay=self.cfg.inference_delay, @@ -235,195 +225,86 @@ class RTCEvaluator: execution_horizon=self.cfg.rtc.execution_horizon, ) - # Plot denoising steps from tracker (RTC - right column) - if self.policy.rtc_processor is not None: - num_steps = self.policy.config.num_steps - self._plot_denoising_steps_from_tracker( - self.policy.rtc_processor.tracker, - axs_xt[:, 1], # Right column for x_t - axs_vt[:, 1], # Right column for v_t - axs_x1t[:, 1], # Right column for x1_t - num_steps, - ) + # ================================================================ - # Plot ground truth on x_t axes - RTCDebugVisualizer.plot_waypoints( - axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" - ) - - # Plot ground truth on x1_t axes - RTCDebugVisualizer.plot_waypoints( - axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" - ) - - # Set titles for denoising plots - for ax in axs_xt[:, 0]: - ax.set_title("No RTC" if ax == axs_xt[0, 0] else "", fontsize=12) - for ax in axs_xt[:, 1]: - ax.set_title("RTC" if ax == axs_xt[0, 1] else "", fontsize=12) - - for ax in axs_vt[:, 0]: - ax.set_title("No RTC" if ax == axs_vt[0, 0] else "", fontsize=12) - for ax in axs_vt[:, 1]: - ax.set_title("RTC" if ax == axs_vt[0, 1] else "", fontsize=12) - - for ax in axs_x1t[:, 0]: - ax.set_title("No RTC (N/A)" if ax == axs_x1t[0, 0] else "", fontsize=12) - for ax in axs_x1t[:, 1]: - ax.set_title("RTC" if ax == axs_x1t[0, 1] else "", fontsize=12) - - # Save denoising plots - fig_xt.tight_layout() - xt_path = os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png") - fig_xt.savefig(xt_path, dpi=150) - logger.info(f"Saved x_t denoising comparison to {xt_path}") - plt.close(fig_xt) - - fig_vt.tight_layout() - vt_path = os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png") - fig_vt.savefig(vt_path, dpi=150) - logger.info(f"Saved v_t denoising comparison to {vt_path}") - plt.close(fig_vt) - - fig_x1t.tight_layout() - x1t_path = os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png") - fig_x1t.savefig(x1t_path, dpi=150) - logger.info(f"Saved x1_t predicted state & error comparison to {x1t_path}") - plt.close(fig_x1t) - - # Create side-by-side comparison: No RTC (left) vs RTC (right) - fig, axs = plt.subplots(6, 2, figsize=(24, 12)) - fig.suptitle("Final Action Comparison: No RTC (left) vs RTC (right)", fontsize=16) - - # Plot on left column (No RTC) - self._plot_actions( - axs[:, 0], - prev_chunk_left_over[0].cpu().numpy(), - no_rtc_actions[0].cpu().numpy(), - "No RTC", - ) - - # Plot on right column (RTC) - self._plot_actions( - axs[:, 1], - prev_chunk_left_over[0].cpu().numpy(), - rtc_actions[0].detach().cpu().numpy(), - "RTC", - ) - - plt.tight_layout() - final_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png") - plt.savefig(final_path, dpi=150) - logger.info(f"Saved final actions comparison to {final_path}") - plt.close(fig) - - # Visualize debug information if enabled - self._visualize_debug_info() + rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps() + self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over) logger.info("Evaluation completed successfully") - def _plot_actions(self, axs, prev_chunk, predicted_actions, title): - """Plot actions comparison on given axes.""" - # Ensure arrays are 2D - if prev_chunk.ndim == 1: - prev_chunk = prev_chunk.reshape(1, -1) - if predicted_actions.ndim == 1: - predicted_actions = predicted_actions.reshape(1, -1) - - for j in range(min(prev_chunk.shape[-1], 6)): # Limit to 6 dimensions - axs[j].plot( - np.arange(prev_chunk.shape[0]), - prev_chunk[:, j], - color="green", - label="Previous Chunk", - ) - axs[j].plot( - np.arange(predicted_actions.shape[0]), - predicted_actions[:, j], - color="red" if "RTC" in title else "blue", - label=title, - ) - axs[j].set_ylabel("Joint angle", fontsize=14) - axs[j].grid() - axs[j].legend(loc="upper right", fontsize=14) - axs[j].set_title(title if j == 0 else "", fontsize=12) - if j == 2: - axs[j].set_xlabel("Step #", fontsize=16) - - def _visualize_debug_info(self): - """Visualize debug information from the RTC processor.""" - # Use proxy method to check if debug is enabled - if not self.policy.rtc_processor.is_debug_enabled(): - logger.warning("Debug tracking is disabled. Skipping debug visualization.") - return - - # Get tracker length using proxy method - if self.policy.rtc_processor.get_tracker_length() == 0: - logger.warning("No debug steps recorded. Skipping debug visualization.") - return - - # Create output directory - os.makedirs(self.cfg.output_dir, exist_ok=True) - logger.info(f"Saving debug visualizations to {self.cfg.output_dir}") - - # Still need direct access to tracker for visualization functions - # This is acceptable since RTCDebugVisualizer is part of the RTC package - tracker = self.policy.rtc_processor.tracker - - # Print statistics - RTCDebugVisualizer.print_debug_statistics(tracker) - - # Plot debug summary - summary_path = os.path.join(self.cfg.output_dir, "debug_summary.png") - RTCDebugVisualizer.plot_debug_summary( - tracker, - save_path=summary_path, - show=False, + def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over): + # Create side-by-side figures for denoising visualization + fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") + fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)") + fig_x1t, axs_x1t = self._create_figure( + "x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)" ) - # Plot correction heatmap - heatmap_path = os.path.join(self.cfg.output_dir, "correction_heatmap.png") - RTCDebugVisualizer.plot_correction_heatmap( - tracker, - save_path=heatmap_path, - show=False, + num_steps = self.policy.config.num_steps + self._plot_denoising_steps_from_tracker( + rtc_tracked_steps, + axs_xt[:, 1], # Right column for x_t + axs_vt[:, 1], # Right column for v_t + axs_x1t[:, 1], # Right column for x1_t + num_steps, ) - # Plot step-by-step comparison (last step) - step_path = os.path.join(self.cfg.output_dir, "step_comparison_last.png") - RTCDebugVisualizer.plot_step_by_step_comparison( - tracker, - step_idx=-1, - save_path=step_path, - show=False, + self._plot_denoising_steps_from_tracker( + no_rtc_tracked_steps, + axs_xt[:, 0], # Left column for x_t + axs_vt[:, 0], # Left column for v_t + axs_x1t[:, 0], # Left column for x1_t + num_steps, ) - # Plot step-by-step comparison (first step) - step_path_first = os.path.join(self.cfg.output_dir, "step_comparison_first.png") - if self.policy.rtc_processor.get_tracker_length() > 0: - RTCDebugVisualizer.plot_step_by_step_comparison( - tracker, - step_idx=0, - save_path=step_path_first, - show=False, - ) + # Plot ground truth on x_t axes + RTCDebugVisualizer.plot_waypoints( + axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) - logger.info(f"Debug visualizations saved to {self.cfg.output_dir}") + RTCDebugVisualizer.plot_waypoints( + axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) - def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps): + # Plot ground truth on x1_t axes + RTCDebugVisualizer.plot_waypoints( + axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) + + # Save denoising plots + self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")) + self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")) + self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png")) + + def _create_figure(self, title): + fig, axs = plt.subplots(6, 2, figsize=(24, 12)) + fig.suptitle(title, fontsize=16) + + for ax in axs[:, 0]: + ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12) + for ax in axs[:, 1]: + ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12) + + return fig, axs + + def _save_figure(self, fig, path): + fig.tight_layout() + fig.savefig(path, dpi=150) + logger.info(f"Saved figure to {path}") + plt.close(fig) + + def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps): """Plot denoising steps from tracker data. Args: - tracker: Tracker object containing debug steps + tracked_steps: List of DebugStep objects containing debug steps xt_axs: Matplotlib axes for x_t plots (array of 6 axes) vt_axs: Matplotlib axes for v_t plots (array of 6 axes) x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes) num_steps: Total number of denoising steps for colormap """ - if tracker is None: - return - debug_steps = tracker.get_all_steps() + debug_steps = tracked_steps if not debug_steps: return