Move plotting logic from modeling_smolvla to eval_dataset script

Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Eugene Mironov
2025-11-03 19:24:35 +07:00
parent 2204a45020
commit 26db4b64d8
3 changed files with 105 additions and 159 deletions
+101 -7
View File
@@ -152,6 +152,7 @@ class RTCEvaluator:
# Configure RTC
cfg.rtc.enabled = True
cfg.rtc.debug = True # Enable debug tracking for visualization
self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor()
@@ -210,18 +211,19 @@ class RTCEvaluator:
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
# Generate actions WITHOUT RTC (plot on left column)
# Generate actions WITHOUT RTC
logger.info("Generating actions WITHOUT RTC")
self.policy.config.rtc_config.enabled = False
with torch.no_grad():
no_rtc_actions = self.policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise,
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
)
# Generate actions WITH RTC (plot on right column)
# Plot denoising steps from tracker (no RTC - left column)
# Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists
# Generate actions WITH RTC
logger.info("Generating actions WITH RTC")
self.policy.config.rtc_config.enabled = True
with torch.no_grad():
@@ -231,9 +233,27 @@ class RTCEvaluator:
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon,
viz_xt_axs=axs_xt[:, 1], # Right column for x_t
viz_vt_axs=axs_vt[:, 1], # Right column for v_t
viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t
)
# Plot denoising steps from tracker (RTC - right column)
if self.policy.rtc_processor is not None:
num_steps = self.policy.config.num_steps
self._plot_denoising_steps_from_tracker(
self.policy.rtc_processor.tracker,
axs_xt[:, 1], # Right column for x_t
axs_vt[:, 1], # Right column for v_t
axs_x1t[:, 1], # Right column for x1_t
num_steps,
)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x1_t axes
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Set titles for denoising plots
@@ -390,6 +410,80 @@ class RTCEvaluator:
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps):
"""Plot denoising steps from tracker data.
Args:
tracker: Tracker object containing debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
num_steps: Total number of denoising steps for colormap
"""
if tracker is None:
return
debug_steps = tracker.get_all_steps()
if not debug_steps:
return
# Define colors for different denoise steps (using a colormap)
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
for step_idx, debug_step in enumerate(debug_steps):
color = colors[step_idx % len(colors)]
# Plot x_t
if debug_step.x_t is not None:
RTCDebugVisualizer.plot_waypoints(
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
)
# Plot v_t
if debug_step.v_t is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
)
# Plot correction in red
if debug_step.correction is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs,
debug_step.correction,
start_from=0,
color="red",
label=f"Step corr {step_idx}",
)
# Plot x1_t (predicted state)
if x1t_axs is not None and debug_step.x1_t is not None:
RTCDebugVisualizer.plot_waypoints(
x1t_axs,
debug_step.x1_t,
start_from=0,
color=color,
label=f"x1_t Step {step_idx}",
)
# Plot error in orange dashed
if x1t_axs is not None and debug_step.err is not None:
error_chunk = (
debug_step.err[0].cpu().numpy()
if len(debug_step.err.shape) == 3
else debug_step.err.cpu().numpy()
)
num_dims = min(error_chunk.shape[-1], 6)
for j in range(num_dims):
x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]),
error_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
label=f"error Step {step_idx}",
)
@parser.wrap()
def main(cfg: RTCEvalConfig):
-1
View File
@@ -263,7 +263,6 @@ class RTCProcessor:
# Record debug information (all params except x_t which is recorded externally)
self.track(
time=time,
v_t=v_t,
x1_t=x1_t,
correction=correction,
err=err,
@@ -55,14 +55,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
@@ -72,9 +69,6 @@ from lerobot.policies.utils import (
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.utils import get_safe_dtype
# Make plot_waypoints easily accessible
plot_waypoints = RTCDebugVisualizer.plot_waypoints
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
@@ -544,11 +538,6 @@ class VLAFlowMatching(nn.Module):
self.prefix_length = self.config.prefix_length
self.rtc_processor = rtc_processor
# For visualization of x_t during denoising
self.denoise_step_counter = 0
self.viz_fig = None
self.viz_axs = None
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
@@ -750,22 +739,10 @@ class VLAFlowMatching(nn.Module):
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)
Args:
viz_xt_axs: Optional matplotlib axes for plotting x_t trajectories (array of 6 axes)
viz_vt_axs: Optional matplotlib axes for plotting v_t trajectories (array of 6 axes)
viz_x1t_axs: Optional matplotlib axes for plotting x1_t predicted state and error (array of 6 axes)
When RTC is enabled, plots both x1_t (solid line) and error (orange dashed line)
"""
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
# Extract visualization axes from kwargs
viz_xt_axs = kwargs.pop("viz_xt_axs", None)
viz_vt_axs = kwargs.pop("viz_vt_axs", None)
viz_x1t_axs = kwargs.pop("viz_x1t_axs", None)
if noise is None:
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
@@ -789,7 +766,6 @@ class VLAFlowMatching(nn.Module):
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
use_provided_axes = viz_xt_axs is not None and viz_vt_axs is not None
while time >= -dt / 2:
expanded_time = time.expand(bsize)
@@ -824,132 +800,9 @@ class VLAFlowMatching(nn.Module):
x_t += dt * v_t
time += dt
# Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self._rtc_enabled() and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t)
# Retrieve data from tracker for plotting
correction = None
x1_t = None
error = None
if self._rtc_enabled() and self.rtc_processor.is_debug_enabled():
recent_steps = self.rtc_processor.get_recent_debug_steps(n=1)
if recent_steps:
debug_step = recent_steps[0]
correction = debug_step.correction
x1_t = debug_step.x1_t
error = debug_step.err
# Visualize x_t using plot_waypoints - accumulate all denoise steps
# Use provided axes or create new ones
if not use_provided_axes:
if self.viz_fig is None:
# Create figure once on first denoise step
self.viz_fig, self.viz_axs = plt.subplots(6, 1, figsize=(12, 12))
self.viz_v_fig, self.viz_v_axs = plt.subplots(6, 1, figsize=(12, 12))
xt_axs = self.viz_axs
vt_axs = self.viz_v_axs
else:
xt_axs = viz_xt_axs
vt_axs = viz_vt_axs
# Define colors for different denoise steps (using a colormap)
colors = plt.cm.viridis(np.linspace(0, 1, self.config.num_steps))
color = colors[self.denoise_step_counter % len(colors)]
# Plot this denoise step
plot_waypoints(xt_axs, x_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}")
# Plot this denoise step
plot_waypoints(vt_axs, v_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}")
if correction is not None:
plot_waypoints(
vt_axs,
correction,
start_from=0,
color="red",
label=f"Step corr {self.denoise_step_counter}",
)
# Plot x1_t if axes provided and RTC is enabled
if viz_x1t_axs is not None and x1_t is not None:
plot_waypoints(
viz_x1t_axs,
x1_t,
start_from=0,
color=color,
label=f"x1_t Step {self.denoise_step_counter}",
)
# Plot error on the same axes with different color
if error is not None:
# Use orange color for error
# Handle batch dimension if present
error_chunk = error[0].cpu().numpy() if len(error.shape) == 3 else error.cpu().numpy()
num_dims = min(error_chunk.shape[-1], 6)
for j in range(num_dims):
viz_x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]),
error_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
label=f"error Step {self.denoise_step_counter}",
)
self.denoise_step_counter += 1
# Save visualization of x_t denoise steps (only if using internal figures)
if not use_provided_axes and self.viz_fig is not None:
plt.figure(self.viz_fig.number)
xt_name = "smolvla_x_t_denoise_steps.png"
v_name = "smolvla_v_denoise_steps.png"
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
xt_name = "smolvla_x_t_with_rtc_denoise_steps.png"
v_name = "smolvla_v_with_rtc_denoise_steps.png"
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
if prev_chunk_left_over is not None:
plot_waypoints(
self.viz_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
plt.savefig(xt_name)
plt.close(self.viz_fig)
# Reset for next inference
self.viz_fig = None
self.viz_axs = None
self.denoise_step_counter = 0
plt.figure(self.viz_v_fig.number)
plt.savefig(v_name)
plt.close(self.viz_v_fig)
self.viz_v_fig = None
self.viz_v_axs = None
# Plot ground truth on provided axes if available
if use_provided_axes:
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
if prev_chunk_left_over is not None and self._rtc_enabled():
plot_waypoints(
viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Also plot ground truth on x1_t axes if provided
if viz_x1t_axs is not None:
plot_waypoints(
viz_x1t_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Reset counter when using provided axes (for next call)
if use_provided_axes:
self.denoise_step_counter = 0
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
return x_t