diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 77a283b70..f3dbb8d37 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -222,6 +222,7 @@ class RTCEvaluator: # Create side-by-side figures for denoising visualization fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)") + fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)") fig_x1t, axs_x1t = self._create_figure( "x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)" ) @@ -231,6 +232,7 @@ class RTCEvaluator: rtc_tracked_steps, axs_xt[:, 1], # Right column for x_t axs_vt[:, 1], # Right column for v_t + axs_corr[:, 1], # Right column for correction axs_x1t[:, 1], # Right column for x1_t num_steps, ) @@ -239,6 +241,7 @@ class RTCEvaluator: no_rtc_tracked_steps, axs_xt[:, 0], # Left column for x_t axs_vt[:, 0], # Left column for v_t + axs_corr[:, 0], # Left column for correction axs_x1t[:, 0], # Left column for x1_t num_steps, ) @@ -265,6 +268,7 @@ class RTCEvaluator: # Save denoising plots self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")) self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")) + self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png")) self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png")) def _create_figure(self, title): @@ -284,13 +288,14 @@ class RTCEvaluator: logging.info(f"Saved figure to {path}") plt.close(fig) - def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps): + def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps): """Plot denoising steps from tracker data. Args: tracked_steps: List of DebugStep objects containing debug steps xt_axs: Matplotlib axes for x_t plots (array of 6 axes) vt_axs: Matplotlib axes for v_t plots (array of 6 axes) + corr_axs: Matplotlib axes for correction plots (array of 6 axes) x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes) num_steps: Total number of denoising steps for colormap """ @@ -317,14 +322,14 @@ class RTCEvaluator: vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}" ) - # Plot correction in red + # Plot correction on separate axes if debug_step.correction is not None: RTCDebugVisualizer.plot_waypoints( - vt_axs, + corr_axs, debug_step.correction, start_from=0, - color="red", - label=f"Step corr {step_idx}", + color=color, + label=f"Step {step_idx}", ) # Plot x1_t (predicted state) @@ -356,6 +361,29 @@ class RTCEvaluator: label=f"error Step {step_idx}", ) + # Recalculate axis limits after plotting to ensure proper scaling + self._rescale_axes(xt_axs) + self._rescale_axes(vt_axs) + self._rescale_axes(corr_axs) + self._rescale_axes(x1t_axs) + + def _rescale_axes(self, axes): + """Rescale axes to show all data with proper margins. + + Args: + axes: Array of matplotlib axes to rescale + """ + for ax in axes: + ax.relim() + ax.autoscale_view() + + # Add 10% margin to y-axis for better visualization + ylim = ax.get_ylim() + y_range = ylim[1] - ylim[0] + if y_range > 0: # Avoid division by zero + margin = y_range * 0.1 + ax.set_ylim(ylim[0] - margin, ylim[1] + margin) + @parser.wrap() def main(cfg: RTCEvalConfig):