diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index af37636f9..a8d0a1d15 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -39,8 +39,9 @@ Usage: uv run python examples/rtc/eval_dataset.py \ --policy.path=lerobot/pi05_libero_finetuned \ --dataset.repo_id=HuggingFaceVLA/libero \ - --rtc.execution_horizon=8 \ + --rtc.execution_horizon=10 \ --device=mps + --seed=10 # Basic usage with pi0.5 policy with cuda device uv run python examples/rtc/eval_dataset.py \ @@ -795,16 +796,34 @@ class RTCEvaluator: ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks ax.set_xlim(-0.5, max_len - 0.5) - # Add legend only to first subplot - if dim_idx == 0: - ax.legend(loc="best", fontsize=9) - axes[-1].set_xlabel("Step", fontsize=10) + # Collect legend handles and labels from first subplot + handles, labels = axes[0].get_legend_handles_labels() + # Remove duplicates while preserving order + seen = set() + unique_handles = [] + unique_labels = [] + for handle, label in zip(handles, labels, strict=True): + if label not in seen: + seen.add(label) + unique_handles.append(handle) + unique_labels.append(label) + + # Add legend outside the plot area (to the right) + fig.legend( + unique_handles, + unique_labels, + loc="center right", + fontsize=9, + bbox_to_anchor=(1.0, 0.5), + framealpha=0.9, + ) + # Save figure output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png") - fig.tight_layout() - fig.savefig(output_path, dpi=150) + fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right + fig.savefig(output_path, dpi=150, bbox_inches="tight") logging.info(f"Saved final actions comparison to {output_path}") plt.close(fig) @@ -825,6 +844,7 @@ class RTCEvaluator: axs_corr[:, 1], # Right column for correction axs_x1t[:, 1], # Right column for x1_t num_steps, + add_labels=True, # Add labels for RTC (right column) ) self._plot_denoising_steps_from_tracker( @@ -834,6 +854,7 @@ class RTCEvaluator: axs_corr[:, 0], # Left column for correction axs_x1t[:, 0], # Left column for x1_t num_steps, + add_labels=False, # No labels for No RTC (left column) ) # Plot no-RTC x_t data on right chart as orange dashed line for comparison @@ -849,15 +870,21 @@ class RTCEvaluator: axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) - # Plot ground truth on x_t axes + # Plot ground truth on x_t axes (no labels for left column) RTCDebugVisualizer.plot_waypoints( - axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None ) RTCDebugVisualizer.plot_waypoints( - axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None ) + # Add legends outside the plot area for each figure + self._add_figure_legend(fig_xt, axs_xt) + self._add_figure_legend(fig_vt, axs_vt) + self._add_figure_legend(fig_corr, axs_corr) + self._add_figure_legend(fig_x1t, axs_x1t) + # Save denoising plots self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")) self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")) @@ -875,13 +902,47 @@ class RTCEvaluator: return fig, axs + def _add_figure_legend(self, fig, axs): + """Add a legend outside the plot area on the right side. + + Args: + fig: Matplotlib figure to add legend to + axs: Array of axes to collect legend handles from + """ + # Collect all handles and labels from the first row of axes (right column) + handles, labels = axs[0, 1].get_legend_handles_labels() + + # Remove duplicates while preserving order + seen = set() + unique_handles = [] + unique_labels = [] + for handle, label in zip(handles, labels, strict=True): + if label not in seen: + seen.add(label) + unique_handles.append(handle) + unique_labels.append(label) + + # Add legend outside the plot area (to the right, close to charts) + if unique_handles: + fig.legend( + unique_handles, + unique_labels, + loc="center left", + fontsize=8, + bbox_to_anchor=(0.87, 0.5), + framealpha=0.9, + ncol=1, + ) + def _save_figure(self, fig, path): - fig.tight_layout() - fig.savefig(path, dpi=150) + fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right + fig.savefig(path, dpi=150, bbox_inches="tight") logging.info(f"Saved figure to {path}") plt.close(fig) - def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps): + def _plot_denoising_steps_from_tracker( + self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True + ): """Plot denoising steps from tracker data. Args: @@ -891,6 +952,7 @@ class RTCEvaluator: corr_axs: Matplotlib axes for correction plots (array of 6 axes) x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes) num_steps: Total number of denoising steps for colormap + add_labels: Whether to add legend labels for the plots """ logging.info("=" * 80) @@ -905,17 +967,18 @@ class RTCEvaluator: for step_idx, debug_step in enumerate(debug_steps): color = colors[step_idx % len(colors)] + label = f"Step {step_idx}" if add_labels else None # Plot x_t if debug_step.x_t is not None: RTCDebugVisualizer.plot_waypoints( - xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}" + xt_axs, debug_step.x_t, start_from=0, color=color, label=label ) # Plot v_t if debug_step.v_t is not None: RTCDebugVisualizer.plot_waypoints( - vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}" + vt_axs, debug_step.v_t, start_from=0, color=color, label=label ) # Plot correction on separate axes @@ -925,17 +988,18 @@ class RTCEvaluator: debug_step.correction, start_from=0, color=color, - label=f"Step {step_idx}", + label=label, ) # Plot x1_t (predicted state) if x1t_axs is not None and debug_step.x1_t is not None: + x1t_label = f"x1_t Step {step_idx}" if add_labels else None RTCDebugVisualizer.plot_waypoints( x1t_axs, debug_step.x1_t, start_from=0, color=color, - label=f"x1_t Step {step_idx}", + label=x1t_label, ) # Plot error in orange dashed @@ -947,6 +1011,7 @@ class RTCEvaluator: ) num_dims = min(error_chunk.shape[-1], 6) + error_label = f"error Step {step_idx}" if add_labels else None for j in range(num_dims): x1t_axs[j].plot( np.arange(0, error_chunk.shape[0]), @@ -954,7 +1019,7 @@ class RTCEvaluator: color="orange", linestyle="--", alpha=0.7, - label=f"error Step {step_idx}", + label=error_label, ) # Recalculate axis limits after plotting to ensure proper scaling diff --git a/src/lerobot/policies/rtc/debug_visualizer.py b/src/lerobot/policies/rtc/debug_visualizer.py index 8b831dfd9..589c86c95 100644 --- a/src/lerobot/policies/rtc/debug_visualizer.py +++ b/src/lerobot/policies/rtc/debug_visualizer.py @@ -111,7 +111,3 @@ class RTCDebugVisualizer: if not ax.yaxis.get_label().get_text(): ax.set_ylabel(f"Dim {dim_idx}", fontsize=10) ax.grid(True, alpha=0.3) - - # Add legend if label provided and this is the first dimension - if label and dim_idx == 0: - ax.legend(loc="best", fontsize=8) diff --git a/src/lerobot/policies/rtc/flow_matching.png b/src/lerobot/policies/rtc/flow_matching.png index 173ae7001..3ef86edfd 100644 Binary files a/src/lerobot/policies/rtc/flow_matching.png and b/src/lerobot/policies/rtc/flow_matching.png differ