mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Refactor plotting loging
This commit is contained in:
+63
-182
@@ -201,33 +201,23 @@ class RTCEvaluator:
|
||||
noise = self.policy.model.sample_noise(noise_size, self.device)
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_xt.suptitle("x_t Denoising: No RTC (left) vs RTC (right)", fontsize=16)
|
||||
|
||||
fig_vt, axs_vt = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_vt.suptitle("v_t Denoising: No RTC (left) vs RTC (right)", fontsize=16)
|
||||
|
||||
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
|
||||
|
||||
# Generate actions WITHOUT RTC
|
||||
logger.info("Generating actions WITHOUT RTC")
|
||||
self.policy.config.rtc_config.enabled = False
|
||||
with torch.no_grad():
|
||||
no_rtc_actions = self.policy.predict_action_chunk(
|
||||
_ = self.policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
# Plot denoising steps from tracker (no RTC - left column)
|
||||
# Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists
|
||||
no_rtc_tracked_steps = self.policy.rtc_processor.tracker.get_all_steps()
|
||||
self.policy.rtc_processor.reset_tracker()
|
||||
|
||||
# Generate actions WITH RTC
|
||||
logger.info("Generating actions WITH RTC")
|
||||
self.policy.config.rtc_config.enabled = True
|
||||
with torch.no_grad():
|
||||
rtc_actions = self.policy.predict_action_chunk(
|
||||
_ = self.policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise_clone,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
@@ -235,195 +225,86 @@ class RTCEvaluator:
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
)
|
||||
|
||||
# Plot denoising steps from tracker (RTC - right column)
|
||||
if self.policy.rtc_processor is not None:
|
||||
num_steps = self.policy.config.num_steps
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
self.policy.rtc_processor.tracker,
|
||||
axs_xt[:, 1], # Right column for x_t
|
||||
axs_vt[:, 1], # Right column for v_t
|
||||
axs_x1t[:, 1], # Right column for x1_t
|
||||
num_steps,
|
||||
)
|
||||
# ================================================================
|
||||
|
||||
# Plot ground truth on x_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Plot ground truth on x1_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Set titles for denoising plots
|
||||
for ax in axs_xt[:, 0]:
|
||||
ax.set_title("No RTC" if ax == axs_xt[0, 0] else "", fontsize=12)
|
||||
for ax in axs_xt[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs_xt[0, 1] else "", fontsize=12)
|
||||
|
||||
for ax in axs_vt[:, 0]:
|
||||
ax.set_title("No RTC" if ax == axs_vt[0, 0] else "", fontsize=12)
|
||||
for ax in axs_vt[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs_vt[0, 1] else "", fontsize=12)
|
||||
|
||||
for ax in axs_x1t[:, 0]:
|
||||
ax.set_title("No RTC (N/A)" if ax == axs_x1t[0, 0] else "", fontsize=12)
|
||||
for ax in axs_x1t[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs_x1t[0, 1] else "", fontsize=12)
|
||||
|
||||
# Save denoising plots
|
||||
fig_xt.tight_layout()
|
||||
xt_path = os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")
|
||||
fig_xt.savefig(xt_path, dpi=150)
|
||||
logger.info(f"Saved x_t denoising comparison to {xt_path}")
|
||||
plt.close(fig_xt)
|
||||
|
||||
fig_vt.tight_layout()
|
||||
vt_path = os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")
|
||||
fig_vt.savefig(vt_path, dpi=150)
|
||||
logger.info(f"Saved v_t denoising comparison to {vt_path}")
|
||||
plt.close(fig_vt)
|
||||
|
||||
fig_x1t.tight_layout()
|
||||
x1t_path = os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png")
|
||||
fig_x1t.savefig(x1t_path, dpi=150)
|
||||
logger.info(f"Saved x1_t predicted state & error comparison to {x1t_path}")
|
||||
plt.close(fig_x1t)
|
||||
|
||||
# Create side-by-side comparison: No RTC (left) vs RTC (right)
|
||||
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig.suptitle("Final Action Comparison: No RTC (left) vs RTC (right)", fontsize=16)
|
||||
|
||||
# Plot on left column (No RTC)
|
||||
self._plot_actions(
|
||||
axs[:, 0],
|
||||
prev_chunk_left_over[0].cpu().numpy(),
|
||||
no_rtc_actions[0].cpu().numpy(),
|
||||
"No RTC",
|
||||
)
|
||||
|
||||
# Plot on right column (RTC)
|
||||
self._plot_actions(
|
||||
axs[:, 1],
|
||||
prev_chunk_left_over[0].cpu().numpy(),
|
||||
rtc_actions[0].detach().cpu().numpy(),
|
||||
"RTC",
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
final_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
||||
plt.savefig(final_path, dpi=150)
|
||||
logger.info(f"Saved final actions comparison to {final_path}")
|
||||
plt.close(fig)
|
||||
|
||||
# Visualize debug information if enabled
|
||||
self._visualize_debug_info()
|
||||
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
|
||||
|
||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
|
||||
logger.info("Evaluation completed successfully")
|
||||
|
||||
def _plot_actions(self, axs, prev_chunk, predicted_actions, title):
|
||||
"""Plot actions comparison on given axes."""
|
||||
# Ensure arrays are 2D
|
||||
if prev_chunk.ndim == 1:
|
||||
prev_chunk = prev_chunk.reshape(1, -1)
|
||||
if predicted_actions.ndim == 1:
|
||||
predicted_actions = predicted_actions.reshape(1, -1)
|
||||
|
||||
for j in range(min(prev_chunk.shape[-1], 6)): # Limit to 6 dimensions
|
||||
axs[j].plot(
|
||||
np.arange(prev_chunk.shape[0]),
|
||||
prev_chunk[:, j],
|
||||
color="green",
|
||||
label="Previous Chunk",
|
||||
)
|
||||
axs[j].plot(
|
||||
np.arange(predicted_actions.shape[0]),
|
||||
predicted_actions[:, j],
|
||||
color="red" if "RTC" in title else "blue",
|
||||
label=title,
|
||||
)
|
||||
axs[j].set_ylabel("Joint angle", fontsize=14)
|
||||
axs[j].grid()
|
||||
axs[j].legend(loc="upper right", fontsize=14)
|
||||
axs[j].set_title(title if j == 0 else "", fontsize=12)
|
||||
if j == 2:
|
||||
axs[j].set_xlabel("Step #", fontsize=16)
|
||||
|
||||
def _visualize_debug_info(self):
|
||||
"""Visualize debug information from the RTC processor."""
|
||||
# Use proxy method to check if debug is enabled
|
||||
if not self.policy.rtc_processor.is_debug_enabled():
|
||||
logger.warning("Debug tracking is disabled. Skipping debug visualization.")
|
||||
return
|
||||
|
||||
# Get tracker length using proxy method
|
||||
if self.policy.rtc_processor.get_tracker_length() == 0:
|
||||
logger.warning("No debug steps recorded. Skipping debug visualization.")
|
||||
return
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||
logger.info(f"Saving debug visualizations to {self.cfg.output_dir}")
|
||||
|
||||
# Still need direct access to tracker for visualization functions
|
||||
# This is acceptable since RTCDebugVisualizer is part of the RTC package
|
||||
tracker = self.policy.rtc_processor.tracker
|
||||
|
||||
# Print statistics
|
||||
RTCDebugVisualizer.print_debug_statistics(tracker)
|
||||
|
||||
# Plot debug summary
|
||||
summary_path = os.path.join(self.cfg.output_dir, "debug_summary.png")
|
||||
RTCDebugVisualizer.plot_debug_summary(
|
||||
tracker,
|
||||
save_path=summary_path,
|
||||
show=False,
|
||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
||||
fig_x1t, axs_x1t = self._create_figure(
|
||||
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
|
||||
)
|
||||
|
||||
# Plot correction heatmap
|
||||
heatmap_path = os.path.join(self.cfg.output_dir, "correction_heatmap.png")
|
||||
RTCDebugVisualizer.plot_correction_heatmap(
|
||||
tracker,
|
||||
save_path=heatmap_path,
|
||||
show=False,
|
||||
num_steps = self.policy.config.num_steps
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
rtc_tracked_steps,
|
||||
axs_xt[:, 1], # Right column for x_t
|
||||
axs_vt[:, 1], # Right column for v_t
|
||||
axs_x1t[:, 1], # Right column for x1_t
|
||||
num_steps,
|
||||
)
|
||||
|
||||
# Plot step-by-step comparison (last step)
|
||||
step_path = os.path.join(self.cfg.output_dir, "step_comparison_last.png")
|
||||
RTCDebugVisualizer.plot_step_by_step_comparison(
|
||||
tracker,
|
||||
step_idx=-1,
|
||||
save_path=step_path,
|
||||
show=False,
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
no_rtc_tracked_steps,
|
||||
axs_xt[:, 0], # Left column for x_t
|
||||
axs_vt[:, 0], # Left column for v_t
|
||||
axs_x1t[:, 0], # Left column for x1_t
|
||||
num_steps,
|
||||
)
|
||||
|
||||
# Plot step-by-step comparison (first step)
|
||||
step_path_first = os.path.join(self.cfg.output_dir, "step_comparison_first.png")
|
||||
if self.policy.rtc_processor.get_tracker_length() > 0:
|
||||
RTCDebugVisualizer.plot_step_by_step_comparison(
|
||||
tracker,
|
||||
step_idx=0,
|
||||
save_path=step_path_first,
|
||||
show=False,
|
||||
)
|
||||
# Plot ground truth on x_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||
# Plot ground truth on x1_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Save denoising plots
|
||||
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
||||
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
||||
self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
|
||||
|
||||
def _create_figure(self, title):
|
||||
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig.suptitle(title, fontsize=16)
|
||||
|
||||
for ax in axs[:, 0]:
|
||||
ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12)
|
||||
for ax in axs[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12)
|
||||
|
||||
return fig, axs
|
||||
|
||||
def _save_figure(self, fig, path):
|
||||
fig.tight_layout()
|
||||
fig.savefig(path, dpi=150)
|
||||
logger.info(f"Saved figure to {path}")
|
||||
plt.close(fig)
|
||||
|
||||
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||
"""Plot denoising steps from tracker data.
|
||||
|
||||
Args:
|
||||
tracker: Tracker object containing debug steps
|
||||
tracked_steps: List of DebugStep objects containing debug steps
|
||||
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
|
||||
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
|
||||
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
||||
num_steps: Total number of denoising steps for colormap
|
||||
"""
|
||||
if tracker is None:
|
||||
return
|
||||
|
||||
debug_steps = tracker.get_all_steps()
|
||||
debug_steps = tracked_steps
|
||||
if not debug_steps:
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user