diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 98158db67..f30141acc 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -789,9 +789,6 @@ class VLAFlowMatching(nn.Module): x_t = noise time = torch.tensor(1.0, dtype=torch.float32, device=device) - correction = None - x1_t = None - error = None use_provided_axes = viz_xt_axs is not None and viz_vt_axs is not None while time >= -dt / 2: @@ -828,9 +825,21 @@ class VLAFlowMatching(nn.Module): time += dt # Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step) - if self._rtc_enabled() and correction is not None: + if self._rtc_enabled() and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t) + # Retrieve data from tracker for plotting + correction = None + x1_t = None + error = None + if self._rtc_enabled() and self.rtc_processor.is_debug_enabled(): + recent_steps = self.rtc_processor.get_recent_debug_steps(n=1) + if recent_steps: + debug_step = recent_steps[0] + correction = debug_step.correction + x1_t = debug_step.x1_t + error = debug_step.err + # Visualize x_t using plot_waypoints - accumulate all denoise steps # Use provided axes or create new ones if not use_provided_axes: