From 2204a45020eeb97aa13d89c95d5c27fc133944fd Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 3 Nov 2025 19:17:11 +0700 Subject: [PATCH] Refactor SmolVLA plotting to use tracker data instead of local variables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove local tracking variables (correction, x1_t, error) from the denoising loop and instead retrieve plotting data from the RTC tracker after each denoise step. This makes the code cleaner and uses the tracker as the single source of truth for debug/visualization data. Changes: - Remove initialization of correction, x1_t, error before denoising loop - After each Euler step, retrieve most recent debug step from tracker - Extract correction, x1_t, err from debug step for plotting - Update tracking condition to use is_debug_enabled() method 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Co-Authored-By: Alexander Soare --- .../policies/smolvla/modeling_smolvla.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 98158db67..f30141acc 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -789,9 +789,6 @@ class VLAFlowMatching(nn.Module): x_t = noise time = torch.tensor(1.0, dtype=torch.float32, device=device) - correction = None - x1_t = None - error = None use_provided_axes = viz_xt_axs is not None and viz_vt_axs is not None while time >= -dt / 2: @@ -828,9 +825,21 @@ class VLAFlowMatching(nn.Module): time += dt # Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step) - if self._rtc_enabled() and correction is not None: + if self._rtc_enabled() and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t) + # Retrieve data from tracker for plotting + correction = None + x1_t = None + error = None + if self._rtc_enabled() and self.rtc_processor.is_debug_enabled(): + recent_steps = self.rtc_processor.get_recent_debug_steps(n=1) + if recent_steps: + debug_step = recent_steps[0] + correction = debug_step.correction + x1_t = debug_step.x1_t + error = debug_step.err + # Visualize x_t using plot_waypoints - accumulate all denoise steps # Use provided axes or create new ones if not use_provided_axes: