Move plotting logic from modeling_smolvla to eval_dataset script

Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Eugene Mironov
2025-11-03 19:24:35 +07:00
parent 2204a45020
commit 26db4b64d8
3 changed files with 105 additions and 159 deletions
+101 -7
View File
@@ -152,6 +152,7 @@ class RTCEvaluator:
# Configure RTC
cfg.rtc.enabled = True
cfg.rtc.debug = True # Enable debug tracking for visualization
self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor()
@@ -210,18 +211,19 @@ class RTCEvaluator:
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
# Generate actions WITHOUT RTC (plot on left column)
# Generate actions WITHOUT RTC
logger.info("Generating actions WITHOUT RTC")
self.policy.config.rtc_config.enabled = False
with torch.no_grad():
no_rtc_actions = self.policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise,
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
)
# Generate actions WITH RTC (plot on right column)
# Plot denoising steps from tracker (no RTC - left column)
# Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists
# Generate actions WITH RTC
logger.info("Generating actions WITH RTC")
self.policy.config.rtc_config.enabled = True
with torch.no_grad():
@@ -231,9 +233,27 @@ class RTCEvaluator:
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon,
viz_xt_axs=axs_xt[:, 1], # Right column for x_t
viz_vt_axs=axs_vt[:, 1], # Right column for v_t
viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t
)
# Plot denoising steps from tracker (RTC - right column)
if self.policy.rtc_processor is not None:
num_steps = self.policy.config.num_steps
self._plot_denoising_steps_from_tracker(
self.policy.rtc_processor.tracker,
axs_xt[:, 1], # Right column for x_t
axs_vt[:, 1], # Right column for v_t
axs_x1t[:, 1], # Right column for x1_t
num_steps,
)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x1_t axes
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Set titles for denoising plots
@@ -390,6 +410,80 @@ class RTCEvaluator:
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps):
"""Plot denoising steps from tracker data.
Args:
tracker: Tracker object containing debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
num_steps: Total number of denoising steps for colormap
"""
if tracker is None:
return
debug_steps = tracker.get_all_steps()
if not debug_steps:
return
# Define colors for different denoise steps (using a colormap)
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
for step_idx, debug_step in enumerate(debug_steps):
color = colors[step_idx % len(colors)]
# Plot x_t
if debug_step.x_t is not None:
RTCDebugVisualizer.plot_waypoints(
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
)
# Plot v_t
if debug_step.v_t is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
)
# Plot correction in red
if debug_step.correction is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs,
debug_step.correction,
start_from=0,
color="red",
label=f"Step corr {step_idx}",
)
# Plot x1_t (predicted state)
if x1t_axs is not None and debug_step.x1_t is not None:
RTCDebugVisualizer.plot_waypoints(
x1t_axs,
debug_step.x1_t,
start_from=0,
color=color,
label=f"x1_t Step {step_idx}",
)
# Plot error in orange dashed
if x1t_axs is not None and debug_step.err is not None:
error_chunk = (
debug_step.err[0].cpu().numpy()
if len(debug_step.err.shape) == 3
else debug_step.err.cpu().numpy()
)
num_dims = min(error_chunk.shape[-1], 6)
for j in range(num_dims):
x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]),
error_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
label=f"error Step {step_idx}",
)
@parser.wrap()
def main(cfg: RTCEvalConfig):