Silent validation

This commit is contained in:
Eugene Mironov
2025-11-10 19:41:31 +07:00
parent dd39d7a037
commit 36dc58d05e
2 changed files with 47 additions and 3 deletions
+42 -3
View File
@@ -18,6 +18,7 @@ Usage:
--rtc.execution_horizon=8 \ --rtc.execution_horizon=8 \
--device=mps \ --device=mps \
--rtc.max_guidance_weight=10.0 \ --rtc.max_guidance_weight=10.0 \
--rtc.prefix_attention_schedule=ONES \
--seed=10 --seed=10
# Basic usage with pi0.5 policy # Basic usage with pi0.5 policy
@@ -512,9 +513,9 @@ class RTCEvaluator:
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps) self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
# Validate RTC behavior # Validate RTC behavior
logging.info("=" * 80) # logging.info("=" * 80)
logging.info("Validating RTC behavior...") # logging.info("Validating RTC behavior...")
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over) # self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
# Plot final actions comparison # Plot final actions comparison
logging.info("=" * 80) logging.info("=" * 80)
@@ -800,6 +801,9 @@ class RTCEvaluator:
num_steps, num_steps,
) )
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
# Plot ground truth on x_t axes # Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints( RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
@@ -924,6 +928,41 @@ class RTCEvaluator:
self._rescale_axes(corr_axs) self._rescale_axes(corr_axs)
self._rescale_axes(x1t_axs) self._rescale_axes(x1t_axs)
def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
"""Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
Args:
no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
num_steps: Total number of denoising steps for colormap
"""
debug_steps = no_rtc_tracked_steps
if not debug_steps:
return
# Plot only the final x_t step as orange dashed line
final_step = debug_steps[-1]
logging.info("Plotting final no-RTC x_t step as orange dashed reference")
if final_step.x_t is not None:
x_t_chunk = (
final_step.x_t[0].cpu().numpy()
if len(final_step.x_t.shape) == 3
else final_step.x_t.cpu().numpy()
)
num_dims = min(x_t_chunk.shape[-1], 6)
for j in range(num_dims):
xt_axs[j].plot(
np.arange(0, x_t_chunk.shape[0]),
x_t_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
linewidth=2,
label="No RTC (final)" if j == 0 else "",
)
def _rescale_axes(self, axes): def _rescale_axes(self, axes):
"""Rescale axes to show all data with proper margins. """Rescale axes to show all data with proper margins.
+5
View File
@@ -217,6 +217,11 @@ class RTCProcessor:
grad_outputs = err.clone().detach() grad_outputs = err.clone().detach()
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
# Explicitly nullify correction after execution horizon to ensure exact match with no-RTC
# Create a mask that zeros out correction after execution_horizon
correction_mask = weights.clone() # weights already have zeros after execution_horizon
correction = correction * correction_mask
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
tau_tensor = torch.as_tensor(tau) tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2 squared_one_minus_tau = (1 - tau_tensor) ** 2