mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Move plotting logic from modeling_smolvla to eval_dataset script
Refactor to improve separation of concerns: modeling_smolvla.py changes: - Remove all plotting logic from sample_actions method - Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters - Remove matplotlib and RTCDebugVisualizer imports - Remove viz_fig, viz_axs, denoise_step_counter instance variables - Simplify denoising loop to only track data in rtc_processor eval_dataset.py changes: - Add _plot_denoising_steps_from_tracker helper method - Retrieve debug steps from tracker after inference - Plot x_t, v_t, x1_t, correction, and error from tracker data - Enable debug tracking (cfg.rtc.debug = True) for visualization - Remove viz axes parameters from predict_action_chunk calls modeling_rtc.py changes: - Remove v_t from track() call (handled by user change) Benefits: - Cleaner modeling code focused on inference - Evaluation script owns all visualization logic - Better separation of concerns - Tracker is single source of truth for debug data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -152,6 +152,7 @@ class RTCEvaluator:
|
|||||||
|
|
||||||
# Configure RTC
|
# Configure RTC
|
||||||
cfg.rtc.enabled = True
|
cfg.rtc.enabled = True
|
||||||
|
cfg.rtc.debug = True # Enable debug tracking for visualization
|
||||||
self.policy.config.rtc_config = cfg.rtc
|
self.policy.config.rtc_config = cfg.rtc
|
||||||
self.policy.init_rtc_processor()
|
self.policy.init_rtc_processor()
|
||||||
|
|
||||||
@@ -210,18 +211,19 @@ class RTCEvaluator:
|
|||||||
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
|
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
|
||||||
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
|
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
|
||||||
|
|
||||||
# Generate actions WITHOUT RTC (plot on left column)
|
# Generate actions WITHOUT RTC
|
||||||
logger.info("Generating actions WITHOUT RTC")
|
logger.info("Generating actions WITHOUT RTC")
|
||||||
self.policy.config.rtc_config.enabled = False
|
self.policy.config.rtc_config.enabled = False
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
no_rtc_actions = self.policy.predict_action_chunk(
|
no_rtc_actions = self.policy.predict_action_chunk(
|
||||||
preprocessed_second_sample,
|
preprocessed_second_sample,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
|
|
||||||
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate actions WITH RTC (plot on right column)
|
# Plot denoising steps from tracker (no RTC - left column)
|
||||||
|
# Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists
|
||||||
|
|
||||||
|
# Generate actions WITH RTC
|
||||||
logger.info("Generating actions WITH RTC")
|
logger.info("Generating actions WITH RTC")
|
||||||
self.policy.config.rtc_config.enabled = True
|
self.policy.config.rtc_config.enabled = True
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -231,9 +233,27 @@ class RTCEvaluator:
|
|||||||
inference_delay=self.cfg.inference_delay,
|
inference_delay=self.cfg.inference_delay,
|
||||||
prev_chunk_left_over=prev_chunk_left_over,
|
prev_chunk_left_over=prev_chunk_left_over,
|
||||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||||
viz_xt_axs=axs_xt[:, 1], # Right column for x_t
|
)
|
||||||
viz_vt_axs=axs_vt[:, 1], # Right column for v_t
|
|
||||||
viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t
|
# Plot denoising steps from tracker (RTC - right column)
|
||||||
|
if self.policy.rtc_processor is not None:
|
||||||
|
num_steps = self.policy.config.num_steps
|
||||||
|
self._plot_denoising_steps_from_tracker(
|
||||||
|
self.policy.rtc_processor.tracker,
|
||||||
|
axs_xt[:, 1], # Right column for x_t
|
||||||
|
axs_vt[:, 1], # Right column for v_t
|
||||||
|
axs_x1t[:, 1], # Right column for x1_t
|
||||||
|
num_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot ground truth on x_t axes
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot ground truth on x1_t axes
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set titles for denoising plots
|
# Set titles for denoising plots
|
||||||
@@ -390,6 +410,80 @@ class RTCEvaluator:
|
|||||||
|
|
||||||
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
|
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
|
||||||
|
|
||||||
|
def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||||
|
"""Plot denoising steps from tracker data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracker: Tracker object containing debug steps
|
||||||
|
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
|
||||||
|
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
|
||||||
|
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
||||||
|
num_steps: Total number of denoising steps for colormap
|
||||||
|
"""
|
||||||
|
if tracker is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
debug_steps = tracker.get_all_steps()
|
||||||
|
if not debug_steps:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Define colors for different denoise steps (using a colormap)
|
||||||
|
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
|
||||||
|
|
||||||
|
for step_idx, debug_step in enumerate(debug_steps):
|
||||||
|
color = colors[step_idx % len(colors)]
|
||||||
|
|
||||||
|
# Plot x_t
|
||||||
|
if debug_step.x_t is not None:
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot v_t
|
||||||
|
if debug_step.v_t is not None:
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot correction in red
|
||||||
|
if debug_step.correction is not None:
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
vt_axs,
|
||||||
|
debug_step.correction,
|
||||||
|
start_from=0,
|
||||||
|
color="red",
|
||||||
|
label=f"Step corr {step_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot x1_t (predicted state)
|
||||||
|
if x1t_axs is not None and debug_step.x1_t is not None:
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
x1t_axs,
|
||||||
|
debug_step.x1_t,
|
||||||
|
start_from=0,
|
||||||
|
color=color,
|
||||||
|
label=f"x1_t Step {step_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot error in orange dashed
|
||||||
|
if x1t_axs is not None and debug_step.err is not None:
|
||||||
|
error_chunk = (
|
||||||
|
debug_step.err[0].cpu().numpy()
|
||||||
|
if len(debug_step.err.shape) == 3
|
||||||
|
else debug_step.err.cpu().numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
num_dims = min(error_chunk.shape[-1], 6)
|
||||||
|
for j in range(num_dims):
|
||||||
|
x1t_axs[j].plot(
|
||||||
|
np.arange(0, error_chunk.shape[0]),
|
||||||
|
error_chunk[:, j],
|
||||||
|
color="orange",
|
||||||
|
linestyle="--",
|
||||||
|
alpha=0.7,
|
||||||
|
label=f"error Step {step_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def main(cfg: RTCEvalConfig):
|
def main(cfg: RTCEvalConfig):
|
||||||
|
|||||||
@@ -263,7 +263,6 @@ class RTCProcessor:
|
|||||||
# Record debug information (all params except x_t which is recorded externally)
|
# Record debug information (all params except x_t which is recorded externally)
|
||||||
self.track(
|
self.track(
|
||||||
time=time,
|
time=time,
|
||||||
v_t=v_t,
|
|
||||||
x1_t=x1_t,
|
x1_t=x1_t,
|
||||||
correction=correction,
|
correction=correction,
|
||||||
err=err,
|
err=err,
|
||||||
|
|||||||
@@ -55,14 +55,11 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||||
@@ -72,9 +69,6 @@ from lerobot.policies.utils import (
|
|||||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||||
from lerobot.utils.utils import get_safe_dtype
|
from lerobot.utils.utils import get_safe_dtype
|
||||||
|
|
||||||
# Make plot_waypoints easily accessible
|
|
||||||
plot_waypoints = RTCDebugVisualizer.plot_waypoints
|
|
||||||
|
|
||||||
|
|
||||||
def create_sinusoidal_pos_embedding(
|
def create_sinusoidal_pos_embedding(
|
||||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||||
@@ -544,11 +538,6 @@ class VLAFlowMatching(nn.Module):
|
|||||||
self.prefix_length = self.config.prefix_length
|
self.prefix_length = self.config.prefix_length
|
||||||
self.rtc_processor = rtc_processor
|
self.rtc_processor = rtc_processor
|
||||||
|
|
||||||
# For visualization of x_t during denoising
|
|
||||||
self.denoise_step_counter = 0
|
|
||||||
self.viz_fig = None
|
|
||||||
self.viz_axs = None
|
|
||||||
|
|
||||||
def _rtc_enabled(self):
|
def _rtc_enabled(self):
|
||||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
@@ -750,22 +739,10 @@ class VLAFlowMatching(nn.Module):
|
|||||||
def sample_actions(
|
def sample_actions(
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs
|
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)
|
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||||
|
|
||||||
Args:
|
|
||||||
viz_xt_axs: Optional matplotlib axes for plotting x_t trajectories (array of 6 axes)
|
|
||||||
viz_vt_axs: Optional matplotlib axes for plotting v_t trajectories (array of 6 axes)
|
|
||||||
viz_x1t_axs: Optional matplotlib axes for plotting x1_t predicted state and error (array of 6 axes)
|
|
||||||
When RTC is enabled, plots both x1_t (solid line) and error (orange dashed line)
|
|
||||||
"""
|
|
||||||
bsize = state.shape[0]
|
bsize = state.shape[0]
|
||||||
device = state.device
|
device = state.device
|
||||||
|
|
||||||
# Extract visualization axes from kwargs
|
|
||||||
viz_xt_axs = kwargs.pop("viz_xt_axs", None)
|
|
||||||
viz_vt_axs = kwargs.pop("viz_vt_axs", None)
|
|
||||||
viz_x1t_axs = kwargs.pop("viz_x1t_axs", None)
|
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device)
|
||||||
@@ -789,7 +766,6 @@ class VLAFlowMatching(nn.Module):
|
|||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||||
use_provided_axes = viz_xt_axs is not None and viz_vt_axs is not None
|
|
||||||
|
|
||||||
while time >= -dt / 2:
|
while time >= -dt / 2:
|
||||||
expanded_time = time.expand(bsize)
|
expanded_time = time.expand(bsize)
|
||||||
@@ -824,132 +800,9 @@ class VLAFlowMatching(nn.Module):
|
|||||||
x_t += dt * v_t
|
x_t += dt * v_t
|
||||||
time += dt
|
time += dt
|
||||||
|
|
||||||
# Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||||
if self._rtc_enabled() and self.rtc_processor.is_debug_enabled():
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
self.rtc_processor.track(time=time, x_t=x_t)
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
# Retrieve data from tracker for plotting
|
|
||||||
correction = None
|
|
||||||
x1_t = None
|
|
||||||
error = None
|
|
||||||
if self._rtc_enabled() and self.rtc_processor.is_debug_enabled():
|
|
||||||
recent_steps = self.rtc_processor.get_recent_debug_steps(n=1)
|
|
||||||
if recent_steps:
|
|
||||||
debug_step = recent_steps[0]
|
|
||||||
correction = debug_step.correction
|
|
||||||
x1_t = debug_step.x1_t
|
|
||||||
error = debug_step.err
|
|
||||||
|
|
||||||
# Visualize x_t using plot_waypoints - accumulate all denoise steps
|
|
||||||
# Use provided axes or create new ones
|
|
||||||
if not use_provided_axes:
|
|
||||||
if self.viz_fig is None:
|
|
||||||
# Create figure once on first denoise step
|
|
||||||
self.viz_fig, self.viz_axs = plt.subplots(6, 1, figsize=(12, 12))
|
|
||||||
self.viz_v_fig, self.viz_v_axs = plt.subplots(6, 1, figsize=(12, 12))
|
|
||||||
xt_axs = self.viz_axs
|
|
||||||
vt_axs = self.viz_v_axs
|
|
||||||
else:
|
|
||||||
xt_axs = viz_xt_axs
|
|
||||||
vt_axs = viz_vt_axs
|
|
||||||
|
|
||||||
# Define colors for different denoise steps (using a colormap)
|
|
||||||
colors = plt.cm.viridis(np.linspace(0, 1, self.config.num_steps))
|
|
||||||
color = colors[self.denoise_step_counter % len(colors)]
|
|
||||||
|
|
||||||
# Plot this denoise step
|
|
||||||
plot_waypoints(xt_axs, x_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}")
|
|
||||||
|
|
||||||
# Plot this denoise step
|
|
||||||
plot_waypoints(vt_axs, v_t, start_from=0, color=color, label=f"Step {self.denoise_step_counter}")
|
|
||||||
|
|
||||||
if correction is not None:
|
|
||||||
plot_waypoints(
|
|
||||||
vt_axs,
|
|
||||||
correction,
|
|
||||||
start_from=0,
|
|
||||||
color="red",
|
|
||||||
label=f"Step corr {self.denoise_step_counter}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plot x1_t if axes provided and RTC is enabled
|
|
||||||
if viz_x1t_axs is not None and x1_t is not None:
|
|
||||||
plot_waypoints(
|
|
||||||
viz_x1t_axs,
|
|
||||||
x1_t,
|
|
||||||
start_from=0,
|
|
||||||
color=color,
|
|
||||||
label=f"x1_t Step {self.denoise_step_counter}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plot error on the same axes with different color
|
|
||||||
if error is not None:
|
|
||||||
# Use orange color for error
|
|
||||||
# Handle batch dimension if present
|
|
||||||
error_chunk = error[0].cpu().numpy() if len(error.shape) == 3 else error.cpu().numpy()
|
|
||||||
|
|
||||||
num_dims = min(error_chunk.shape[-1], 6)
|
|
||||||
for j in range(num_dims):
|
|
||||||
viz_x1t_axs[j].plot(
|
|
||||||
np.arange(0, error_chunk.shape[0]),
|
|
||||||
error_chunk[:, j],
|
|
||||||
color="orange",
|
|
||||||
linestyle="--",
|
|
||||||
alpha=0.7,
|
|
||||||
label=f"error Step {self.denoise_step_counter}",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.denoise_step_counter += 1
|
|
||||||
|
|
||||||
# Save visualization of x_t denoise steps (only if using internal figures)
|
|
||||||
if not use_provided_axes and self.viz_fig is not None:
|
|
||||||
plt.figure(self.viz_fig.number)
|
|
||||||
|
|
||||||
xt_name = "smolvla_x_t_denoise_steps.png"
|
|
||||||
v_name = "smolvla_v_denoise_steps.png"
|
|
||||||
|
|
||||||
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
|
|
||||||
xt_name = "smolvla_x_t_with_rtc_denoise_steps.png"
|
|
||||||
v_name = "smolvla_v_with_rtc_denoise_steps.png"
|
|
||||||
|
|
||||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
|
||||||
|
|
||||||
if prev_chunk_left_over is not None:
|
|
||||||
plot_waypoints(
|
|
||||||
self.viz_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.savefig(xt_name)
|
|
||||||
plt.close(self.viz_fig)
|
|
||||||
|
|
||||||
# Reset for next inference
|
|
||||||
self.viz_fig = None
|
|
||||||
self.viz_axs = None
|
|
||||||
self.denoise_step_counter = 0
|
|
||||||
|
|
||||||
plt.figure(self.viz_v_fig.number)
|
|
||||||
plt.savefig(v_name)
|
|
||||||
plt.close(self.viz_v_fig)
|
|
||||||
|
|
||||||
self.viz_v_fig = None
|
|
||||||
self.viz_v_axs = None
|
|
||||||
|
|
||||||
# Plot ground truth on provided axes if available
|
|
||||||
if use_provided_axes:
|
|
||||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
|
||||||
if prev_chunk_left_over is not None and self._rtc_enabled():
|
|
||||||
plot_waypoints(
|
|
||||||
viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
||||||
)
|
|
||||||
# Also plot ground truth on x1_t axes if provided
|
|
||||||
if viz_x1t_axs is not None:
|
|
||||||
plot_waypoints(
|
|
||||||
viz_x1t_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reset counter when using provided axes (for next call)
|
|
||||||
if use_provided_axes:
|
|
||||||
self.denoise_step_counter = 0
|
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user