From 6db3afca6fd754284ba6e9cff999dc1ab6b97c66 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 10 Nov 2025 19:41:31 +0700 Subject: [PATCH] Silent validation --- examples/rtc/eval_dataset.py | 45 ++++++++++++++++++++++-- src/lerobot/policies/rtc/modeling_rtc.py | 5 +++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 16f1cfacc..ee801a575 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -18,6 +18,7 @@ Usage: --rtc.execution_horizon=8 \ --device=mps \ --rtc.max_guidance_weight=10.0 \ + --rtc.prefix_attention_schedule=ONES \ --seed=10 # Basic usage with pi0.5 policy @@ -512,9 +513,9 @@ class RTCEvaluator: self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps) # Validate RTC behavior - logging.info("=" * 80) - logging.info("Validating RTC behavior...") - self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over) + # logging.info("=" * 80) + # logging.info("Validating RTC behavior...") + # self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over) # Plot final actions comparison logging.info("=" * 80) @@ -800,6 +801,9 @@ class RTCEvaluator: num_steps, ) + # Plot no-RTC x_t data on right chart as orange dashed line for comparison + self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps) + # Plot ground truth on x_t axes RTCDebugVisualizer.plot_waypoints( axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" @@ -924,6 +928,41 @@ class RTCEvaluator: self._rescale_axes(corr_axs) self._rescale_axes(x1t_axs) + def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps): + """Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison. + + Args: + no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps + xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column) + num_steps: Total number of denoising steps for colormap + """ + debug_steps = no_rtc_tracked_steps + if not debug_steps: + return + + # Plot only the final x_t step as orange dashed line + final_step = debug_steps[-1] + logging.info("Plotting final no-RTC x_t step as orange dashed reference") + + if final_step.x_t is not None: + x_t_chunk = ( + final_step.x_t[0].cpu().numpy() + if len(final_step.x_t.shape) == 3 + else final_step.x_t.cpu().numpy() + ) + + num_dims = min(x_t_chunk.shape[-1], 6) + for j in range(num_dims): + xt_axs[j].plot( + np.arange(0, x_t_chunk.shape[0]), + x_t_chunk[:, j], + color="orange", + linestyle="--", + alpha=0.7, + linewidth=2, + label="No RTC (final)" if j == 0 else "", + ) + def _rescale_axes(self, axes): """Rescale axes to show all data with proper margins. diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 280905adf..6a02aa3e8 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -217,6 +217,11 @@ class RTCProcessor: grad_outputs = err.clone().detach() correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0] + # Explicitly nullify correction after execution horizon to ensure exact match with no-RTC + # Create a mask that zeros out correction after execution_horizon + correction_mask = weights.clone() # weights already have zeros after execution_horizon + correction = correction * correction_mask + max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight) tau_tensor = torch.as_tensor(tau) squared_one_minus_tau = (1 - tau_tensor) ** 2