Silent validation

This commit is contained in:
Eugene Mironov
2025-11-10 19:41:31 +07:00
parent 433ccc9603
commit 6db3afca6f
2 changed files with 47 additions and 3 deletions
+42 -3
View File
@@ -18,6 +18,7 @@ Usage:
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--rtc.prefix_attention_schedule=ONES \
--seed=10
# Basic usage with pi0.5 policy
@@ -512,9 +513,9 @@ class RTCEvaluator:
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
# Validate RTC behavior
logging.info("=" * 80)
logging.info("Validating RTC behavior...")
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
# logging.info("=" * 80)
# logging.info("Validating RTC behavior...")
# self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
# Plot final actions comparison
logging.info("=" * 80)
@@ -800,6 +801,9 @@ class RTCEvaluator:
num_steps,
)
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
@@ -924,6 +928,41 @@ class RTCEvaluator:
self._rescale_axes(corr_axs)
self._rescale_axes(x1t_axs)
def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
"""Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
Args:
no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
num_steps: Total number of denoising steps for colormap
"""
debug_steps = no_rtc_tracked_steps
if not debug_steps:
return
# Plot only the final x_t step as orange dashed line
final_step = debug_steps[-1]
logging.info("Plotting final no-RTC x_t step as orange dashed reference")
if final_step.x_t is not None:
x_t_chunk = (
final_step.x_t[0].cpu().numpy()
if len(final_step.x_t.shape) == 3
else final_step.x_t.cpu().numpy()
)
num_dims = min(x_t_chunk.shape[-1], 6)
for j in range(num_dims):
xt_axs[j].plot(
np.arange(0, x_t_chunk.shape[0]),
x_t_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
linewidth=2,
label="No RTC (final)" if j == 0 else "",
)
def _rescale_axes(self, axes):
"""Rescale axes to show all data with proper margins.