mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Move plotting logic from modeling_smolvla to eval_dataset script
Refactor to improve separation of concerns: modeling_smolvla.py changes: - Remove all plotting logic from sample_actions method - Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters - Remove matplotlib and RTCDebugVisualizer imports - Remove viz_fig, viz_axs, denoise_step_counter instance variables - Simplify denoising loop to only track data in rtc_processor eval_dataset.py changes: - Add _plot_denoising_steps_from_tracker helper method - Retrieve debug steps from tracker after inference - Plot x_t, v_t, x1_t, correction, and error from tracker data - Enable debug tracking (cfg.rtc.debug = True) for visualization - Remove viz axes parameters from predict_action_chunk calls modeling_rtc.py changes: - Remove v_t from track() call (handled by user change) Benefits: - Cleaner modeling code focused on inference - Evaluation script owns all visualization logic - Better separation of concerns - Tracker is single source of truth for debug data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -152,6 +152,7 @@ class RTCEvaluator:
|
||||
|
||||
# Configure RTC
|
||||
cfg.rtc.enabled = True
|
||||
cfg.rtc.debug = True # Enable debug tracking for visualization
|
||||
self.policy.config.rtc_config = cfg.rtc
|
||||
self.policy.init_rtc_processor()
|
||||
|
||||
@@ -210,18 +211,19 @@ class RTCEvaluator:
|
||||
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
|
||||
|
||||
# Generate actions WITHOUT RTC (plot on left column)
|
||||
# Generate actions WITHOUT RTC
|
||||
logger.info("Generating actions WITHOUT RTC")
|
||||
self.policy.config.rtc_config.enabled = False
|
||||
with torch.no_grad():
|
||||
no_rtc_actions = self.policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
|
||||
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC (plot on right column)
|
||||
# Plot denoising steps from tracker (no RTC - left column)
|
||||
# Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists
|
||||
|
||||
# Generate actions WITH RTC
|
||||
logger.info("Generating actions WITH RTC")
|
||||
self.policy.config.rtc_config.enabled = True
|
||||
with torch.no_grad():
|
||||
@@ -231,9 +233,27 @@ class RTCEvaluator:
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
viz_xt_axs=axs_xt[:, 1], # Right column for x_t
|
||||
viz_vt_axs=axs_vt[:, 1], # Right column for v_t
|
||||
viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t
|
||||
)
|
||||
|
||||
# Plot denoising steps from tracker (RTC - right column)
|
||||
if self.policy.rtc_processor is not None:
|
||||
num_steps = self.policy.config.num_steps
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
self.policy.rtc_processor.tracker,
|
||||
axs_xt[:, 1], # Right column for x_t
|
||||
axs_vt[:, 1], # Right column for v_t
|
||||
axs_x1t[:, 1], # Right column for x1_t
|
||||
num_steps,
|
||||
)
|
||||
|
||||
# Plot ground truth on x_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Plot ground truth on x1_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Set titles for denoising plots
|
||||
@@ -390,6 +410,80 @@ class RTCEvaluator:
|
||||
|
||||
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
|
||||
|
||||
def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||
"""Plot denoising steps from tracker data.
|
||||
|
||||
Args:
|
||||
tracker: Tracker object containing debug steps
|
||||
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
|
||||
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
|
||||
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
||||
num_steps: Total number of denoising steps for colormap
|
||||
"""
|
||||
if tracker is None:
|
||||
return
|
||||
|
||||
debug_steps = tracker.get_all_steps()
|
||||
if not debug_steps:
|
||||
return
|
||||
|
||||
# Define colors for different denoise steps (using a colormap)
|
||||
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
|
||||
|
||||
for step_idx, debug_step in enumerate(debug_steps):
|
||||
color = colors[step_idx % len(colors)]
|
||||
|
||||
# Plot x_t
|
||||
if debug_step.x_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
|
||||
)
|
||||
|
||||
# Plot v_t
|
||||
if debug_step.v_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
|
||||
)
|
||||
|
||||
# Plot correction in red
|
||||
if debug_step.correction is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
vt_axs,
|
||||
debug_step.correction,
|
||||
start_from=0,
|
||||
color="red",
|
||||
label=f"Step corr {step_idx}",
|
||||
)
|
||||
|
||||
# Plot x1_t (predicted state)
|
||||
if x1t_axs is not None and debug_step.x1_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
x1t_axs,
|
||||
debug_step.x1_t,
|
||||
start_from=0,
|
||||
color=color,
|
||||
label=f"x1_t Step {step_idx}",
|
||||
)
|
||||
|
||||
# Plot error in orange dashed
|
||||
if x1t_axs is not None and debug_step.err is not None:
|
||||
error_chunk = (
|
||||
debug_step.err[0].cpu().numpy()
|
||||
if len(debug_step.err.shape) == 3
|
||||
else debug_step.err.cpu().numpy()
|
||||
)
|
||||
|
||||
num_dims = min(error_chunk.shape[-1], 6)
|
||||
for j in range(num_dims):
|
||||
x1t_axs[j].plot(
|
||||
np.arange(0, error_chunk.shape[0]),
|
||||
error_chunk[:, j],
|
||||
color="orange",
|
||||
linestyle="--",
|
||||
alpha=0.7,
|
||||
label=f"error Step {step_idx}",
|
||||
)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: RTCEvalConfig):
|
||||
|
||||
@@ -263,7 +263,6 @@ class RTCProcessor:
|
||||
# Record debug information (all params except x_t which is recorded externally)
|
||||
self.track(
|
||||
time=time,
|
||||
v_t=v_t,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
|
||||
@@ -55,14 +55,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
@@ -72,9 +69,6 @@ from lerobot.policies.utils import (
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# Make plot_waypoints easily accessible
|
||||
plot_waypoints = RTCDebugVisualizer.plot_waypoints
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
@@ -544,11 +538,6 @@ class VLAFlowMatching(nn.Module):
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
# For visualization of x_t during denoising
|
||||
self.denoise_step_counter = 0
|
||||
self.viz_fig = None
|
||||
self.viz_axs = None
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
@@ -750,22 +739,10 @@ class VLAFlowMatching(nn.Module):
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)
|
||||
|
||||
Args:
|
||||
viz_xt_axs: Optional matplotlib axes for plotting x_t trajectories (array of 6 axes)
|
||||
viz_vt_axs: Optional matplotlib axes for plotting v_t trajectories (array of 6 axes)
|
||||
viz_x1t_axs: Optional matplotlib axes for plotting x1_t predicted state and error (array of 6 axes)
|
||||
When RTC is enabled, plots both x1_t (solid line) and error (orange dashed line)
|
||||
"""
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
# Extract visualization axes from kwargs
|
||||
viz_xt_axs = kwargs.pop("viz_xt_axs", None)
|
||||
viz_vt_axs = kwargs.pop("viz_vt_axs", None)
|
||||
viz_x1t_axs = kwargs.pop("viz_x1t_axs", None)
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
@@ -789,7 +766,6 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
use_provided_axes = viz_xt_axs is not None and viz_vt_axs is not None
|
||||
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
@@ -824,132 +800,9 @@ class VLAFlowMatching(nn.Module):
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
|
||||
# Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self._rtc_enabled() and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t)
|
||||
|
||||
# Retrieve data from tracker for plotting
|
||||
correction = None
|
||||
x1_t = None
|
||||
error = None
|
||||
if self._rtc_enabled() and self.rtc_processor.is_debug_enabled():
|
||||
recent_steps = self.rtc_processor.get_recent_debug_steps(n=1)
|
||||
if recent_steps:
|
||||
debug_step = recent_steps[0]
|
||||
correction = debug_step.correction
|
||||
x1_t = debug_step.x1_t
|
||||
error = debug_step.err
|
||||
|
||||
# Visualize x_t using plot_waypoints - accumulate all denoise steps
|
||||
# Use provided axes or create new ones
|
||||
if not use_provided_axes:
|
||||
if self.viz_fig is None:
|
||||
# Create figure once on first denoise step
|
||||
self.viz_fig, self.viz_axs = plt.subplots(6, 1, figsize=(12, 12))
|
||||
self.viz_v_fig, self.viz_v_axs = plt.subplots(6, 1, figsize=(12, 12))
|
||||
xt_axs = self.viz_axs
|
||||
vt_axs = self.viz_v_axs
|
||||
else:
|
||||
xt_axs = viz_xt_axs
|
||||
vt_axs = viz_vt_axs
|
||||
|
||||
# Define colors for different denoise steps (using a colormap)
|
||||
colors = plt.cm.viridis(np.linspace(0, 1, self.config.num_steps))
|
||||
color = colors[self.denoise_step_counter % len(colors)]
|
||||
|
||||
# Plot this denoise step
|
||||
plot_waypoints(xt_axs, x_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}")
|
||||
|
||||
# Plot this denoise step
|
||||
plot_waypoints(vt_axs, v_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}")
|
||||
|
||||
if correction is not None:
|
||||
plot_waypoints(
|
||||
vt_axs,
|
||||
correction,
|
||||
start_from=0,
|
||||
color="red",
|
||||
label=f"Step corr {self.denoise_step_counter}",
|
||||
)
|
||||
|
||||
# Plot x1_t if axes provided and RTC is enabled
|
||||
if viz_x1t_axs is not None and x1_t is not None:
|
||||
plot_waypoints(
|
||||
viz_x1t_axs,
|
||||
x1_t,
|
||||
start_from=0,
|
||||
color=color,
|
||||
label=f"x1_t Step {self.denoise_step_counter}",
|
||||
)
|
||||
|
||||
# Plot error on the same axes with different color
|
||||
if error is not None:
|
||||
# Use orange color for error
|
||||
# Handle batch dimension if present
|
||||
error_chunk = error[0].cpu().numpy() if len(error.shape) == 3 else error.cpu().numpy()
|
||||
|
||||
num_dims = min(error_chunk.shape[-1], 6)
|
||||
for j in range(num_dims):
|
||||
viz_x1t_axs[j].plot(
|
||||
np.arange(0, error_chunk.shape[0]),
|
||||
error_chunk[:, j],
|
||||
color="orange",
|
||||
linestyle="--",
|
||||
alpha=0.7,
|
||||
label=f"error Step {self.denoise_step_counter}",
|
||||
)
|
||||
|
||||
self.denoise_step_counter += 1
|
||||
|
||||
# Save visualization of x_t denoise steps (only if using internal figures)
|
||||
if not use_provided_axes and self.viz_fig is not None:
|
||||
plt.figure(self.viz_fig.number)
|
||||
|
||||
xt_name = "smolvla_x_t_denoise_steps.png"
|
||||
v_name = "smolvla_v_denoise_steps.png"
|
||||
|
||||
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
|
||||
xt_name = "smolvla_x_t_with_rtc_denoise_steps.png"
|
||||
v_name = "smolvla_v_with_rtc_denoise_steps.png"
|
||||
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
|
||||
if prev_chunk_left_over is not None:
|
||||
plot_waypoints(
|
||||
self.viz_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
plt.savefig(xt_name)
|
||||
plt.close(self.viz_fig)
|
||||
|
||||
# Reset for next inference
|
||||
self.viz_fig = None
|
||||
self.viz_axs = None
|
||||
self.denoise_step_counter = 0
|
||||
|
||||
plt.figure(self.viz_v_fig.number)
|
||||
plt.savefig(v_name)
|
||||
plt.close(self.viz_v_fig)
|
||||
|
||||
self.viz_v_fig = None
|
||||
self.viz_v_axs = None
|
||||
|
||||
# Plot ground truth on provided axes if available
|
||||
if use_provided_axes:
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
if prev_chunk_left_over is not None and self._rtc_enabled():
|
||||
plot_waypoints(
|
||||
viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
# Also plot ground truth on x1_t axes if provided
|
||||
if viz_x1t_axs is not None:
|
||||
plot_waypoints(
|
||||
viz_x1t_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Reset counter when using provided axes (for next call)
|
||||
if use_provided_axes:
|
||||
self.denoise_step_counter = 0
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
return x_t
|
||||
|
||||
|
||||
Reference in New Issue
Block a user