Update images

This commit is contained in:
Eugene Mironov
2025-11-19 03:02:26 +07:00
parent 045f9c02f7
commit 8008dbb02c
3 changed files with 83 additions and 22 deletions
+83 -18
View File
@@ -39,8 +39,9 @@ Usage:
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \ --policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \ --dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \ --rtc.execution_horizon=10 \
--device=mps --device=mps
--seed=10
# Basic usage with pi0.5 policy with cuda device # Basic usage with pi0.5 policy with cuda device
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
@@ -795,16 +796,34 @@ class RTCEvaluator:
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5) ax.set_xlim(-0.5, max_len - 0.5)
# Add legend only to first subplot
if dim_idx == 0:
ax.legend(loc="best", fontsize=9)
axes[-1].set_xlabel("Step", fontsize=10) axes[-1].set_xlabel("Step", fontsize=10)
# Collect legend handles and labels from first subplot
handles, labels = axes[0].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right)
fig.legend(
unique_handles,
unique_labels,
loc="center right",
fontsize=9,
bbox_to_anchor=(1.0, 0.5),
framealpha=0.9,
)
# Save figure # Save figure
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png") output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
fig.tight_layout() fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
fig.savefig(output_path, dpi=150) fig.savefig(output_path, dpi=150, bbox_inches="tight")
logging.info(f"Saved final actions comparison to {output_path}") logging.info(f"Saved final actions comparison to {output_path}")
plt.close(fig) plt.close(fig)
@@ -825,6 +844,7 @@ class RTCEvaluator:
axs_corr[:, 1], # Right column for correction axs_corr[:, 1], # Right column for correction
axs_x1t[:, 1], # Right column for x1_t axs_x1t[:, 1], # Right column for x1_t
num_steps, num_steps,
add_labels=True, # Add labels for RTC (right column)
) )
self._plot_denoising_steps_from_tracker( self._plot_denoising_steps_from_tracker(
@@ -834,6 +854,7 @@ class RTCEvaluator:
axs_corr[:, 0], # Left column for correction axs_corr[:, 0], # Left column for correction
axs_x1t[:, 0], # Left column for x1_t axs_x1t[:, 0], # Left column for x1_t
num_steps, num_steps,
add_labels=False, # No labels for No RTC (left column)
) )
# Plot no-RTC x_t data on right chart as orange dashed line for comparison # Plot no-RTC x_t data on right chart as orange dashed line for comparison
@@ -849,15 +870,21 @@ class RTCEvaluator:
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
) )
# Plot ground truth on x_t axes # Plot ground truth on x_t axes (no labels for left column)
RTCDebugVisualizer.plot_waypoints( RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
) )
RTCDebugVisualizer.plot_waypoints( RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
) )
# Add legends outside the plot area for each figure
self._add_figure_legend(fig_xt, axs_xt)
self._add_figure_legend(fig_vt, axs_vt)
self._add_figure_legend(fig_corr, axs_corr)
self._add_figure_legend(fig_x1t, axs_x1t)
# Save denoising plots # Save denoising plots
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")) self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")) self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
@@ -875,13 +902,47 @@ class RTCEvaluator:
return fig, axs return fig, axs
def _add_figure_legend(self, fig, axs):
"""Add a legend outside the plot area on the right side.
Args:
fig: Matplotlib figure to add legend to
axs: Array of axes to collect legend handles from
"""
# Collect all handles and labels from the first row of axes (right column)
handles, labels = axs[0, 1].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right, close to charts)
if unique_handles:
fig.legend(
unique_handles,
unique_labels,
loc="center left",
fontsize=8,
bbox_to_anchor=(0.87, 0.5),
framealpha=0.9,
ncol=1,
)
def _save_figure(self, fig, path): def _save_figure(self, fig, path):
fig.tight_layout() fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
fig.savefig(path, dpi=150) fig.savefig(path, dpi=150, bbox_inches="tight")
logging.info(f"Saved figure to {path}") logging.info(f"Saved figure to {path}")
plt.close(fig) plt.close(fig)
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps): def _plot_denoising_steps_from_tracker(
self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
):
"""Plot denoising steps from tracker data. """Plot denoising steps from tracker data.
Args: Args:
@@ -891,6 +952,7 @@ class RTCEvaluator:
corr_axs: Matplotlib axes for correction plots (array of 6 axes) corr_axs: Matplotlib axes for correction plots (array of 6 axes)
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes) x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
num_steps: Total number of denoising steps for colormap num_steps: Total number of denoising steps for colormap
add_labels: Whether to add legend labels for the plots
""" """
logging.info("=" * 80) logging.info("=" * 80)
@@ -905,17 +967,18 @@ class RTCEvaluator:
for step_idx, debug_step in enumerate(debug_steps): for step_idx, debug_step in enumerate(debug_steps):
color = colors[step_idx % len(colors)] color = colors[step_idx % len(colors)]
label = f"Step {step_idx}" if add_labels else None
# Plot x_t # Plot x_t
if debug_step.x_t is not None: if debug_step.x_t is not None:
RTCDebugVisualizer.plot_waypoints( RTCDebugVisualizer.plot_waypoints(
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}" xt_axs, debug_step.x_t, start_from=0, color=color, label=label
) )
# Plot v_t # Plot v_t
if debug_step.v_t is not None: if debug_step.v_t is not None:
RTCDebugVisualizer.plot_waypoints( RTCDebugVisualizer.plot_waypoints(
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}" vt_axs, debug_step.v_t, start_from=0, color=color, label=label
) )
# Plot correction on separate axes # Plot correction on separate axes
@@ -925,17 +988,18 @@ class RTCEvaluator:
debug_step.correction, debug_step.correction,
start_from=0, start_from=0,
color=color, color=color,
label=f"Step {step_idx}", label=label,
) )
# Plot x1_t (predicted state) # Plot x1_t (predicted state)
if x1t_axs is not None and debug_step.x1_t is not None: if x1t_axs is not None and debug_step.x1_t is not None:
x1t_label = f"x1_t Step {step_idx}" if add_labels else None
RTCDebugVisualizer.plot_waypoints( RTCDebugVisualizer.plot_waypoints(
x1t_axs, x1t_axs,
debug_step.x1_t, debug_step.x1_t,
start_from=0, start_from=0,
color=color, color=color,
label=f"x1_t Step {step_idx}", label=x1t_label,
) )
# Plot error in orange dashed # Plot error in orange dashed
@@ -947,6 +1011,7 @@ class RTCEvaluator:
) )
num_dims = min(error_chunk.shape[-1], 6) num_dims = min(error_chunk.shape[-1], 6)
error_label = f"error Step {step_idx}" if add_labels else None
for j in range(num_dims): for j in range(num_dims):
x1t_axs[j].plot( x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]), np.arange(0, error_chunk.shape[0]),
@@ -954,7 +1019,7 @@ class RTCEvaluator:
color="orange", color="orange",
linestyle="--", linestyle="--",
alpha=0.7, alpha=0.7,
label=f"error Step {step_idx}", label=error_label,
) )
# Recalculate axis limits after plotting to ensure proper scaling # Recalculate axis limits after plotting to ensure proper scaling
@@ -111,7 +111,3 @@ class RTCDebugVisualizer:
if not ax.yaxis.get_label().get_text(): if not ax.yaxis.get_label().get_text():
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10) ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
# Add legend if label provided and this is the first dimension
if label and dim_idx == 0:
ax.legend(loc="best", fontsize=8)
Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 MiB

After

Width:  |  Height:  |  Size: 1.3 MiB