mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Improve visualization: separate correction plot and fix axis scaling
Changes: - Create separate figure for correction data instead of overlaying on v_t - Add _rescale_axes helper method to properly scale all axes - Add 10% margin to y-axis for better visualization - Fix v_t chart vertical compression issue Benefits: - Clearer v_t plot without correction overlay - Better axis scaling with proper margins - Separate correction figure for focused analysis - Improved readability of all denoising visualizations Output files: - denoising_xt_comparison.png (x_t trajectories) - denoising_vt_comparison.png (v_t velocity - now cleaner) - denoising_correction_comparison.png (NEW - separate corrections) - denoising_x1t_comparison.png (x1_t state with error) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -222,6 +222,7 @@ class RTCEvaluator:
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
||||
fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)")
|
||||
fig_x1t, axs_x1t = self._create_figure(
|
||||
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
|
||||
)
|
||||
@@ -231,6 +232,7 @@ class RTCEvaluator:
|
||||
rtc_tracked_steps,
|
||||
axs_xt[:, 1], # Right column for x_t
|
||||
axs_vt[:, 1], # Right column for v_t
|
||||
axs_corr[:, 1], # Right column for correction
|
||||
axs_x1t[:, 1], # Right column for x1_t
|
||||
num_steps,
|
||||
)
|
||||
@@ -239,6 +241,7 @@ class RTCEvaluator:
|
||||
no_rtc_tracked_steps,
|
||||
axs_xt[:, 0], # Left column for x_t
|
||||
axs_vt[:, 0], # Left column for v_t
|
||||
axs_corr[:, 0], # Left column for correction
|
||||
axs_x1t[:, 0], # Left column for x1_t
|
||||
num_steps,
|
||||
)
|
||||
@@ -265,6 +268,7 @@ class RTCEvaluator:
|
||||
# Save denoising plots
|
||||
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
||||
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
||||
self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png"))
|
||||
self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
|
||||
|
||||
def _create_figure(self, title):
|
||||
@@ -284,13 +288,14 @@ class RTCEvaluator:
|
||||
logging.info(f"Saved figure to {path}")
|
||||
plt.close(fig)
|
||||
|
||||
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps):
|
||||
"""Plot denoising steps from tracker data.
|
||||
|
||||
Args:
|
||||
tracked_steps: List of DebugStep objects containing debug steps
|
||||
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
|
||||
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
|
||||
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
|
||||
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
||||
num_steps: Total number of denoising steps for colormap
|
||||
"""
|
||||
@@ -317,14 +322,14 @@ class RTCEvaluator:
|
||||
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
|
||||
)
|
||||
|
||||
# Plot correction in red
|
||||
# Plot correction on separate axes
|
||||
if debug_step.correction is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
vt_axs,
|
||||
corr_axs,
|
||||
debug_step.correction,
|
||||
start_from=0,
|
||||
color="red",
|
||||
label=f"Step corr {step_idx}",
|
||||
color=color,
|
||||
label=f"Step {step_idx}",
|
||||
)
|
||||
|
||||
# Plot x1_t (predicted state)
|
||||
@@ -356,6 +361,29 @@ class RTCEvaluator:
|
||||
label=f"error Step {step_idx}",
|
||||
)
|
||||
|
||||
# Recalculate axis limits after plotting to ensure proper scaling
|
||||
self._rescale_axes(xt_axs)
|
||||
self._rescale_axes(vt_axs)
|
||||
self._rescale_axes(corr_axs)
|
||||
self._rescale_axes(x1t_axs)
|
||||
|
||||
def _rescale_axes(self, axes):
|
||||
"""Rescale axes to show all data with proper margins.
|
||||
|
||||
Args:
|
||||
axes: Array of matplotlib axes to rescale
|
||||
"""
|
||||
for ax in axes:
|
||||
ax.relim()
|
||||
ax.autoscale_view()
|
||||
|
||||
# Add 10% margin to y-axis for better visualization
|
||||
ylim = ax.get_ylim()
|
||||
y_range = ylim[1] - ylim[0]
|
||||
if y_range > 0: # Avoid division by zero
|
||||
margin = y_range * 0.1
|
||||
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: RTCEvalConfig):
|
||||
|
||||
Reference in New Issue
Block a user