mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Silent validation
This commit is contained in:
@@ -18,6 +18,7 @@ Usage:
|
|||||||
--rtc.execution_horizon=8 \
|
--rtc.execution_horizon=8 \
|
||||||
--device=mps \
|
--device=mps \
|
||||||
--rtc.max_guidance_weight=10.0 \
|
--rtc.max_guidance_weight=10.0 \
|
||||||
|
--rtc.prefix_attention_schedule=ONES \
|
||||||
--seed=10
|
--seed=10
|
||||||
|
|
||||||
# Basic usage with pi0.5 policy
|
# Basic usage with pi0.5 policy
|
||||||
@@ -512,9 +513,9 @@ class RTCEvaluator:
|
|||||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
|
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
|
||||||
|
|
||||||
# Validate RTC behavior
|
# Validate RTC behavior
|
||||||
logging.info("=" * 80)
|
# logging.info("=" * 80)
|
||||||
logging.info("Validating RTC behavior...")
|
# logging.info("Validating RTC behavior...")
|
||||||
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
# self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
||||||
|
|
||||||
# Plot final actions comparison
|
# Plot final actions comparison
|
||||||
logging.info("=" * 80)
|
logging.info("=" * 80)
|
||||||
@@ -800,6 +801,9 @@ class RTCEvaluator:
|
|||||||
num_steps,
|
num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
|
||||||
|
self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
|
||||||
|
|
||||||
# Plot ground truth on x_t axes
|
# Plot ground truth on x_t axes
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
@@ -924,6 +928,41 @@ class RTCEvaluator:
|
|||||||
self._rescale_axes(corr_axs)
|
self._rescale_axes(corr_axs)
|
||||||
self._rescale_axes(x1t_axs)
|
self._rescale_axes(x1t_axs)
|
||||||
|
|
||||||
|
def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
|
||||||
|
"""Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
|
||||||
|
xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
|
||||||
|
num_steps: Total number of denoising steps for colormap
|
||||||
|
"""
|
||||||
|
debug_steps = no_rtc_tracked_steps
|
||||||
|
if not debug_steps:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Plot only the final x_t step as orange dashed line
|
||||||
|
final_step = debug_steps[-1]
|
||||||
|
logging.info("Plotting final no-RTC x_t step as orange dashed reference")
|
||||||
|
|
||||||
|
if final_step.x_t is not None:
|
||||||
|
x_t_chunk = (
|
||||||
|
final_step.x_t[0].cpu().numpy()
|
||||||
|
if len(final_step.x_t.shape) == 3
|
||||||
|
else final_step.x_t.cpu().numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
num_dims = min(x_t_chunk.shape[-1], 6)
|
||||||
|
for j in range(num_dims):
|
||||||
|
xt_axs[j].plot(
|
||||||
|
np.arange(0, x_t_chunk.shape[0]),
|
||||||
|
x_t_chunk[:, j],
|
||||||
|
color="orange",
|
||||||
|
linestyle="--",
|
||||||
|
alpha=0.7,
|
||||||
|
linewidth=2,
|
||||||
|
label="No RTC (final)" if j == 0 else "",
|
||||||
|
)
|
||||||
|
|
||||||
def _rescale_axes(self, axes):
|
def _rescale_axes(self, axes):
|
||||||
"""Rescale axes to show all data with proper margins.
|
"""Rescale axes to show all data with proper margins.
|
||||||
|
|
||||||
|
|||||||
@@ -217,6 +217,11 @@ class RTCProcessor:
|
|||||||
grad_outputs = err.clone().detach()
|
grad_outputs = err.clone().detach()
|
||||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||||
|
|
||||||
|
# Explicitly nullify correction after execution horizon to ensure exact match with no-RTC
|
||||||
|
# Create a mask that zeros out correction after execution_horizon
|
||||||
|
correction_mask = weights.clone() # weights already have zeros after execution_horizon
|
||||||
|
correction = correction * correction_mask
|
||||||
|
|
||||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||||
tau_tensor = torch.as_tensor(tau)
|
tau_tensor = torch.as_tensor(tau)
|
||||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||||
|
|||||||
Reference in New Issue
Block a user