mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
Move plotting logic from modeling_smolvla to eval_dataset script
Refactor to improve separation of concerns: modeling_smolvla.py changes: - Remove all plotting logic from sample_actions method - Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters - Remove matplotlib and RTCDebugVisualizer imports - Remove viz_fig, viz_axs, denoise_step_counter instance variables - Simplify denoising loop to only track data in rtc_processor eval_dataset.py changes: - Add _plot_denoising_steps_from_tracker helper method - Retrieve debug steps from tracker after inference - Plot x_t, v_t, x1_t, correction, and error from tracker data - Enable debug tracking (cfg.rtc.debug = True) for visualization - Remove viz axes parameters from predict_action_chunk calls modeling_rtc.py changes: - Remove v_t from track() call (handled by user change) Benefits: - Cleaner modeling code focused on inference - Evaluation script owns all visualization logic - Better separation of concerns - Tracker is single source of truth for debug data 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -152,6 +152,7 @@ class RTCEvaluator:
|
||||
|
||||
# Configure RTC
|
||||
cfg.rtc.enabled = True
|
||||
cfg.rtc.debug = True # Enable debug tracking for visualization
|
||||
self.policy.config.rtc_config = cfg.rtc
|
||||
self.policy.init_rtc_processor()
|
||||
|
||||
@@ -210,18 +211,19 @@ class RTCEvaluator:
|
||||
fig_x1t, axs_x1t = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig_x1t.suptitle("x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)", fontsize=16)
|
||||
|
||||
# Generate actions WITHOUT RTC (plot on left column)
|
||||
# Generate actions WITHOUT RTC
|
||||
logger.info("Generating actions WITHOUT RTC")
|
||||
self.policy.config.rtc_config.enabled = False
|
||||
with torch.no_grad():
|
||||
no_rtc_actions = self.policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
|
||||
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC (plot on right column)
|
||||
# Plot denoising steps from tracker (no RTC - left column)
|
||||
# Note: No tracker data for non-RTC case since tracking is only done when RTC processor exists
|
||||
|
||||
# Generate actions WITH RTC
|
||||
logger.info("Generating actions WITH RTC")
|
||||
self.policy.config.rtc_config.enabled = True
|
||||
with torch.no_grad():
|
||||
@@ -231,9 +233,27 @@ class RTCEvaluator:
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
viz_xt_axs=axs_xt[:, 1], # Right column for x_t
|
||||
viz_vt_axs=axs_vt[:, 1], # Right column for v_t
|
||||
viz_x1t_axs=axs_x1t[:, 1], # Right column for x1_t
|
||||
)
|
||||
|
||||
# Plot denoising steps from tracker (RTC - right column)
|
||||
if self.policy.rtc_processor is not None:
|
||||
num_steps = self.policy.config.num_steps
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
self.policy.rtc_processor.tracker,
|
||||
axs_xt[:, 1], # Right column for x_t
|
||||
axs_vt[:, 1], # Right column for v_t
|
||||
axs_x1t[:, 1], # Right column for x1_t
|
||||
num_steps,
|
||||
)
|
||||
|
||||
# Plot ground truth on x_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Plot ground truth on x1_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Set titles for denoising plots
|
||||
@@ -390,6 +410,80 @@ class RTCEvaluator:
|
||||
|
||||
logger.info(f"Debug visualizations saved to {self.cfg.output_dir}")
|
||||
|
||||
def _plot_denoising_steps_from_tracker(self, tracker, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||
"""Plot denoising steps from tracker data.
|
||||
|
||||
Args:
|
||||
tracker: Tracker object containing debug steps
|
||||
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
|
||||
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
|
||||
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
||||
num_steps: Total number of denoising steps for colormap
|
||||
"""
|
||||
if tracker is None:
|
||||
return
|
||||
|
||||
debug_steps = tracker.get_all_steps()
|
||||
if not debug_steps:
|
||||
return
|
||||
|
||||
# Define colors for different denoise steps (using a colormap)
|
||||
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
|
||||
|
||||
for step_idx, debug_step in enumerate(debug_steps):
|
||||
color = colors[step_idx % len(colors)]
|
||||
|
||||
# Plot x_t
|
||||
if debug_step.x_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
|
||||
)
|
||||
|
||||
# Plot v_t
|
||||
if debug_step.v_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
|
||||
)
|
||||
|
||||
# Plot correction in red
|
||||
if debug_step.correction is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
vt_axs,
|
||||
debug_step.correction,
|
||||
start_from=0,
|
||||
color="red",
|
||||
label=f"Step corr {step_idx}",
|
||||
)
|
||||
|
||||
# Plot x1_t (predicted state)
|
||||
if x1t_axs is not None and debug_step.x1_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
x1t_axs,
|
||||
debug_step.x1_t,
|
||||
start_from=0,
|
||||
color=color,
|
||||
label=f"x1_t Step {step_idx}",
|
||||
)
|
||||
|
||||
# Plot error in orange dashed
|
||||
if x1t_axs is not None and debug_step.err is not None:
|
||||
error_chunk = (
|
||||
debug_step.err[0].cpu().numpy()
|
||||
if len(debug_step.err.shape) == 3
|
||||
else debug_step.err.cpu().numpy()
|
||||
)
|
||||
|
||||
num_dims = min(error_chunk.shape[-1], 6)
|
||||
for j in range(num_dims):
|
||||
x1t_axs[j].plot(
|
||||
np.arange(0, error_chunk.shape[0]),
|
||||
error_chunk[:, j],
|
||||
color="orange",
|
||||
linestyle="--",
|
||||
alpha=0.7,
|
||||
label=f"error Step {step_idx}",
|
||||
)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: RTCEvalConfig):
|
||||
|
||||
Reference in New Issue
Block a user