From 26db4b64d82d7e97c0aaf55a4f10953d33916f55 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 3 Nov 2025 19:24:35 +0700 Subject: [PATCH] Move plotting logic from modeling_smolvla to eval_dataset script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor to improve separation of concerns: modeling_smolvla.py changes: - Remove all plotting logic from sample_actions method - Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters - Remove matplotlib and RTCDebugVisualizer imports - Remove viz_fig, viz_axs, denoise_step_counter instance variables - Simplify denoising loop to only track data in rtc_processor eval_dataset.py changes: - Add _plot_denoising_steps_from_tracker helper method - Retrieve debug steps from tracker after inference - Plot x_t, v_t, x1_t, correction, and error from tracker data - Enable debug tracking (cfg.rtc.debug = True) for visualization - Remove viz axes parameters from predict_action_chunk calls modeling_rtc.py changes: - Remove v_t from track() call (handled by user change) Benefits: - Cleaner modeling code focused on inference - Evaluation script owns all visualization logic - Better separation of concerns - Tracker is single source of truth for debug data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Co-Authored-By: Alexander Soare --- examples/rtc/eval_dataset.py | 108 +++++++++++- src/lerobot/policies/rtc/modeling_rtc.py | 1 - .../policies/smolvla/modeling_smolvla.py | 155 +----------------- 3 files changed, 105 insertions(+), 159 deletions(-) diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 7dd5710cb..4729856aa 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -152,6 +152,7 @@ class RTCEvaluator: # Configure RTC cfg.rtc.enabled = True + cfg.rtc.debug = True # Enable debug tracking for visualization self.policy.config.rtc_config = cfg.rtc self.policy.init_rtc_processor() @@ -210,18 +211,19 @@ class RTCEvaluator: fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12)) fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16) - # Generate actions WITHOUT RTC (plot on left column) + # Generate actions WITHOUT RTC logger.info("Generating actions WITHOUT RTC") self.policy.config.rtc_config.enabled = False with torch.no_grad(): no_rtc_actions = self.policy.predict_action_chunk( preprocessed_second_sample, noise=noise, - viz_xt_axs=axs_xt[:, 0], # Left column for x_t - viz_vt_axs=axs_vt[:, 0], # Left column for v_t ) - # Generate actions WITH RTC (plot on right column) + # Plot denoising steps from tracker (no RTC - left column) + # Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists + + # Generate actions WITH RTC logger.info("Generating actions WITH RTC") self.policy.config.rtc_config.enabled = True with torch.no_grad(): @@ -231,9 +233,27 @@ class RTCEvaluator: inference_delay=self.cfg.inference_delay, prev_chunk_left_over=prev_chunk_left_over, execution_horizon=self.cfg.rtc.execution_horizon, - viz_xt_axs=axs_xt[:, 1], # Right column for x_t - viz_vt_axs=axs_vt[:, 1], # Right column for v_t - viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t + ) + + # Plot denoising steps from tracker (RTC - right column) + if self.policy.rtc_processor is not None: + num_steps = self.policy.config.num_steps + self._plot_denoising_steps_from_tracker( + self.policy.rtc_processor.tracker, + axs_xt[:, 1], # Right column for x_t + axs_vt[:, 1], # Right column for v_t + axs_x1t[:, 1], # Right column for x1_t + num_steps, + ) + + # Plot ground truth on x_t axes + RTCDebugVisualizer.plot_waypoints( + axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) + + # Plot ground truth on x1_t axes + RTCDebugVisualizer.plot_waypoints( + axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) # Set titles for denoising plots @@ -390,6 +410,80 @@ class RTCEvaluator: logger.info(f"Debug visualizations saved to {self.cfg.output_dir}") + def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps): + """Plot denoising steps from tracker data. + + Args: + tracker: Tracker object containing debug steps + xt_axs: Matplotlib axes for x_t plots (array of 6 axes) + vt_axs: Matplotlib axes for v_t plots (array of 6 axes) + x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes) + num_steps: Total number of denoising steps for colormap + """ + if tracker is None: + return + + debug_steps = tracker.get_all_steps() + if not debug_steps: + return + + # Define colors for different denoise steps (using a colormap) + colors = plt.cm.viridis(np.linspace(0, 1, num_steps)) + + for step_idx, debug_step in enumerate(debug_steps): + color = colors[step_idx % len(colors)] + + # Plot x_t + if debug_step.x_t is not None: + RTCDebugVisualizer.plot_waypoints( + xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}" + ) + + # Plot v_t + if debug_step.v_t is not None: + RTCDebugVisualizer.plot_waypoints( + vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}" + ) + + # Plot correction in red + if debug_step.correction is not None: + RTCDebugVisualizer.plot_waypoints( + vt_axs, + debug_step.correction, + start_from=0, + color="red", + label=f"Step corr {step_idx}", + ) + + # Plot x1_t (predicted state) + if x1t_axs is not None and debug_step.x1_t is not None: + RTCDebugVisualizer.plot_waypoints( + x1t_axs, + debug_step.x1_t, + start_from=0, + color=color, + label=f"x1_t Step {step_idx}", + ) + + # Plot error in orange dashed + if x1t_axs is not None and debug_step.err is not None: + error_chunk = ( + debug_step.err[0].cpu().numpy() + if len(debug_step.err.shape) == 3 + else debug_step.err.cpu().numpy() + ) + + num_dims = min(error_chunk.shape[-1], 6) + for j in range(num_dims): + x1t_axs[j].plot( + np.arange(0, error_chunk.shape[0]), + error_chunk[:, j], + color="orange", + linestyle="--", + alpha=0.7, + label=f"error Step {step_idx}", + ) + @parser.wrap() def main(cfg: RTCEvalConfig): diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 17db3762d..b6feed47a 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -263,7 +263,6 @@ class RTCProcessor: # Record debug information (all params except x_t which is recorded externally) self.track( time=time, - v_t=v_t, x1_t=x1_t, correction=correction, err=err, diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index f30141acc..3cebe38c2 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -55,14 +55,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") import math from collections import deque -import matplotlib.pyplot as plt -import numpy as np import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel @@ -72,9 +69,6 @@ from lerobot.policies.utils import ( from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.utils.utils import get_safe_dtype -# Make plot_waypoints easily accessible -plot_waypoints = RTCDebugVisualizer.plot_waypoints - def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" @@ -544,11 +538,6 @@ class VLAFlowMatching(nn.Module): self.prefix_length = self.config.prefix_length self.rtc_processor = rtc_processor - # For visualization of x_t during denoising - self.denoise_step_counter = 0 - self.viz_fig = None - self.viz_axs = None - def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled @@ -750,22 +739,10 @@ class VLAFlowMatching(nn.Module): def sample_actions( self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs ) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors) - - Args: - viz_xt_axs: Optional matplotlib axes for plotting x_t trajectories (array of 6 axes) - viz_vt_axs: Optional matplotlib axes for plotting v_t trajectories (array of 6 axes) - viz_x1t_axs: Optional matplotlib axes for plotting x1_t predicted state and error (array of 6 axes) - When RTC is enabled, plots both x1_t (solid line) and error (orange dashed line) - """ + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = state.shape[0] device = state.device - # Extract visualization axes from kwargs - viz_xt_axs = kwargs.pop("viz_xt_axs", None) - viz_vt_axs = kwargs.pop("viz_vt_axs", None) - viz_x1t_axs = kwargs.pop("viz_x1t_axs", None) - if noise is None: actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) noise = self.sample_noise(actions_shape, device) @@ -789,7 +766,6 @@ class VLAFlowMatching(nn.Module): x_t = noise time = torch.tensor(1.0, dtype=torch.float32, device=device) - use_provided_axes = viz_xt_axs is not None and viz_vt_axs is not None while time >= -dt / 2: expanded_time = time.expand(bsize) @@ -824,132 +800,9 @@ class VLAFlowMatching(nn.Module): x_t += dt * v_t time += dt - # Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step) - if self._rtc_enabled() and self.rtc_processor.is_debug_enabled(): - self.rtc_processor.track(time=time, x_t=x_t) - - # Retrieve data from tracker for plotting - correction = None - x1_t = None - error = None - if self._rtc_enabled() and self.rtc_processor.is_debug_enabled(): - recent_steps = self.rtc_processor.get_recent_debug_steps(n=1) - if recent_steps: - debug_step = recent_steps[0] - correction = debug_step.correction - x1_t = debug_step.x1_t - error = debug_step.err - - # Visualize x_t using plot_waypoints - accumulate all denoise steps - # Use provided axes or create new ones - if not use_provided_axes: - if self.viz_fig is None: - # Create figure once on first denoise step - self.viz_fig, self.viz_axs = plt.subplots(6, 1, figsize=(12, 12)) - self.viz_v_fig, self.viz_v_axs = plt.subplots(6, 1, figsize=(12, 12)) - xt_axs = self.viz_axs - vt_axs = self.viz_v_axs - else: - xt_axs = viz_xt_axs - vt_axs = viz_vt_axs - - # Define colors for different denoise steps (using a colormap) - colors = plt.cm.viridis(np.linspace(0, 1, self.config.num_steps)) - color = colors[self.denoise_step_counter % len(colors)] - - # Plot this denoise step - plot_waypoints(xt_axs, x_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}") - - # Plot this denoise step - plot_waypoints(vt_axs, v_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}") - - if correction is not None: - plot_waypoints( - vt_axs, - correction, - start_from=0, - color="red", - label=f"Step corr {self.denoise_step_counter}", - ) - - # Plot x1_t if axes provided and RTC is enabled - if viz_x1t_axs is not None and x1_t is not None: - plot_waypoints( - viz_x1t_axs, - x1_t, - start_from=0, - color=color, - label=f"x1_t Step {self.denoise_step_counter}", - ) - - # Plot error on the same axes with different color - if error is not None: - # Use orange color for error - # Handle batch dimension if present - error_chunk = error[0].cpu().numpy() if len(error.shape) == 3 else error.cpu().numpy() - - num_dims = min(error_chunk.shape[-1], 6) - for j in range(num_dims): - viz_x1t_axs[j].plot( - np.arange(0, error_chunk.shape[0]), - error_chunk[:, j], - color="orange", - linestyle="--", - alpha=0.7, - label=f"error Step {self.denoise_step_counter}", - ) - - self.denoise_step_counter += 1 - - # Save visualization of x_t denoise steps (only if using internal figures) - if not use_provided_axes and self.viz_fig is not None: - plt.figure(self.viz_fig.number) - - xt_name = "smolvla_x_t_denoise_steps.png" - v_name = "smolvla_v_denoise_steps.png" - - if self.config.rtc_config is not None and self.config.rtc_config.enabled: - xt_name = "smolvla_x_t_with_rtc_denoise_steps.png" - v_name = "smolvla_v_with_rtc_denoise_steps.png" - - prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - - if prev_chunk_left_over is not None: - plot_waypoints( - self.viz_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" - ) - - plt.savefig(xt_name) - plt.close(self.viz_fig) - - # Reset for next inference - self.viz_fig = None - self.viz_axs = None - self.denoise_step_counter = 0 - - plt.figure(self.viz_v_fig.number) - plt.savefig(v_name) - plt.close(self.viz_v_fig) - - self.viz_v_fig = None - self.viz_v_axs = None - - # Plot ground truth on provided axes if available - if use_provided_axes: - prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - if prev_chunk_left_over is not None and self._rtc_enabled(): - plot_waypoints( - viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" - ) - # Also plot ground truth on x1_t axes if provided - if viz_x1t_axs is not None: - plot_waypoints( - viz_x1t_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" - ) - - # Reset counter when using provided axes (for next call) - if use_provided_axes: - self.denoise_step_counter = 0 + # Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step) + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) return x_t