mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
Update README
This commit is contained in:
+158
-38
@@ -16,7 +16,9 @@ Usage:
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
--seed=10
|
||||
|
||||
# Basic usage with pi0.5 policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
@@ -439,6 +441,8 @@ class RTCEvaluator:
|
||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||
logging.info("=" * 80)
|
||||
|
||||
set_seed(self.cfg.seed)
|
||||
|
||||
# Initialize policy 2
|
||||
policy_no_rtc_policy = self._init_policy(
|
||||
name="policy_no_rtc",
|
||||
@@ -470,6 +474,8 @@ class RTCEvaluator:
|
||||
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||
logging.info("=" * 80)
|
||||
|
||||
set_seed(self.cfg.seed)
|
||||
|
||||
# Initialize policy 3
|
||||
policy_rtc_policy = self._init_policy(
|
||||
name="policy_rtc",
|
||||
@@ -510,6 +516,11 @@ class RTCEvaluator:
|
||||
logging.info("Validating RTC behavior...")
|
||||
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
||||
|
||||
# Plot final actions comparison
|
||||
logging.info("=" * 80)
|
||||
logging.info("Plotting final actions comparison...")
|
||||
self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("Evaluation completed successfully")
|
||||
|
||||
@@ -527,29 +538,24 @@ class RTCEvaluator:
|
||||
no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim)
|
||||
prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim)
|
||||
"""
|
||||
if rtc_actions is None or no_rtc_actions is None:
|
||||
logging.warning(" ⚠ Cannot validate: missing action predictions")
|
||||
return
|
||||
# Remove batch dimension if present and move to CPU
|
||||
rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
||||
no_rtc_actions_t = (
|
||||
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
||||
)
|
||||
prev_chunk = prev_chunk_left_over.cpu()
|
||||
|
||||
# Convert to numpy for comparison (remove batch dimension if present)
|
||||
rtc_actions_np = (
|
||||
rtc_actions.squeeze(0).cpu().numpy() if len(rtc_actions.shape) == 3 else rtc_actions.cpu().numpy()
|
||||
)
|
||||
no_rtc_actions_np = (
|
||||
no_rtc_actions.squeeze(0).cpu().numpy()
|
||||
if len(no_rtc_actions.shape) == 3
|
||||
else no_rtc_actions.cpu().numpy()
|
||||
)
|
||||
prev_chunk = prev_chunk_left_over.cpu().numpy()
|
||||
logging.info(f" rtc_actions shape: {rtc_actions_t.shape}")
|
||||
logging.info(f" no_rtc_actions shape: {no_rtc_actions_t.shape}")
|
||||
logging.info(f" prev_chunk shape: {prev_chunk.shape}")
|
||||
|
||||
# Determine chunk length for comparison
|
||||
chunk_len = min(rtc_actions_np.shape[0], no_rtc_actions_np.shape[0], prev_chunk.shape[0])
|
||||
chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk.shape[0])
|
||||
inference_delay = self.cfg.inference_delay
|
||||
execution_horizon = self.cfg.rtc.execution_horizon
|
||||
|
||||
# Tolerance for floating point comparison
|
||||
rtol = 1e-3 # Relative tolerance
|
||||
atol = 1e-3 # Absolute tolerance
|
||||
rtol = 1e-2 # Relative tolerance
|
||||
|
||||
validation_passed = True
|
||||
warnings = []
|
||||
@@ -558,19 +564,26 @@ class RTCEvaluator:
|
||||
logging.info(f" Chunk length: {chunk_len}")
|
||||
logging.info(f" Inference delay: {inference_delay}")
|
||||
logging.info(f" Execution horizon: {execution_horizon}")
|
||||
logging.info(f" Tolerance: rtol={rtol}, atol={atol}")
|
||||
logging.info(f" Tolerance: rtol={rtol}")
|
||||
|
||||
# ============================================================================
|
||||
# Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk
|
||||
# ============================================================================
|
||||
if inference_delay > 0:
|
||||
delay_end = min(inference_delay, chunk_len)
|
||||
rtc_delay = rtc_actions_np[:delay_end]
|
||||
rtc_delay = rtc_actions_t[:delay_end]
|
||||
prev_delay = prev_chunk[:delay_end]
|
||||
|
||||
if not np.allclose(rtc_delay, prev_delay, rtol=rtol, atol=atol):
|
||||
max_diff = np.max(np.abs(rtc_delay - prev_delay))
|
||||
mean_diff = np.mean(np.abs(rtc_delay - prev_delay))
|
||||
logging.info(f" rtc_delay: {rtc_delay.shape}")
|
||||
logging.info(f" prev_delay: {prev_delay.shape}")
|
||||
|
||||
if not torch.allclose(rtc_delay, prev_delay, rtol=rtol):
|
||||
max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item()
|
||||
mean_diff = torch.mean(torch.abs(rtc_delay - prev_delay)).item()
|
||||
logging.info(f" rtc_delay: {rtc_delay}")
|
||||
logging.info(f" prev_delay: {prev_delay}")
|
||||
logging.info(f" max_diff: {max_diff}")
|
||||
logging.info(f" mean_diff: {mean_diff}")
|
||||
warnings.append(
|
||||
f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], "
|
||||
f"RTC does NOT equal prev_chunk!\n"
|
||||
@@ -589,26 +602,26 @@ class RTCEvaluator:
|
||||
blend_end = min(execution_horizon, chunk_len)
|
||||
|
||||
if blend_end > blend_start:
|
||||
rtc_blend = rtc_actions_np[blend_start:blend_end]
|
||||
rtc_blend = rtc_actions_t[blend_start:blend_end]
|
||||
prev_blend = prev_chunk[blend_start:blend_end]
|
||||
no_rtc_blend = no_rtc_actions_np[blend_start:blend_end]
|
||||
no_rtc_blend = no_rtc_actions_t[blend_start:blend_end]
|
||||
|
||||
# Check if RTC is between prev_chunk and no_rtc (element-wise)
|
||||
# For each element, check if it's between the min and max of prev_chunk and no_rtc
|
||||
min_bound = np.minimum(prev_blend, no_rtc_blend) - atol
|
||||
max_bound = np.maximum(prev_blend, no_rtc_blend) + atol
|
||||
min_bound = torch.minimum(prev_blend, no_rtc_blend)
|
||||
max_bound = torch.maximum(prev_blend, no_rtc_blend)
|
||||
|
||||
within_bounds = np.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
||||
within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
||||
|
||||
if not np.all(within_bounds):
|
||||
violations = np.sum(~within_bounds)
|
||||
total_elements = within_bounds.size
|
||||
if not torch.all(within_bounds):
|
||||
violations = torch.sum(~within_bounds).item()
|
||||
total_elements = within_bounds.numel()
|
||||
violation_pct = 100.0 * violations / total_elements
|
||||
|
||||
# Find max violation
|
||||
lower_violations = np.maximum(0, min_bound - rtc_blend)
|
||||
upper_violations = np.maximum(0, rtc_blend - max_bound)
|
||||
max_violation = np.max(np.maximum(lower_violations, upper_violations))
|
||||
lower_violations = torch.maximum(torch.tensor(0.0), min_bound - rtc_blend)
|
||||
upper_violations = torch.maximum(torch.tensor(0.0), rtc_blend - max_bound)
|
||||
max_violation = torch.max(torch.maximum(lower_violations, upper_violations)).item()
|
||||
|
||||
warnings.append(
|
||||
f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], "
|
||||
@@ -626,12 +639,15 @@ class RTCEvaluator:
|
||||
# Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc
|
||||
# ============================================================================
|
||||
if execution_horizon < chunk_len:
|
||||
rtc_after = rtc_actions_np[execution_horizon:chunk_len]
|
||||
no_rtc_after = no_rtc_actions_np[execution_horizon:chunk_len]
|
||||
rtc_after = rtc_actions_t[execution_horizon:chunk_len]
|
||||
no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len]
|
||||
|
||||
if not np.allclose(rtc_after, no_rtc_after, rtol=rtol, atol=atol):
|
||||
max_diff = np.max(np.abs(rtc_after - no_rtc_after))
|
||||
mean_diff = np.mean(np.abs(rtc_after - no_rtc_after))
|
||||
logging.info(f" rtc_after: {rtc_after}")
|
||||
logging.info(f" no_rtc_after: {no_rtc_after}")
|
||||
|
||||
if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol):
|
||||
max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item()
|
||||
mean_diff = torch.mean(torch.abs(rtc_after - no_rtc_after)).item()
|
||||
warnings.append(
|
||||
f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], "
|
||||
f"RTC does NOT equal no_rtc!\n"
|
||||
@@ -661,6 +677,103 @@ class RTCEvaluator:
|
||||
logging.error("")
|
||||
logging.error(" Please check the implementation of RTC guidance.")
|
||||
|
||||
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
|
||||
"""Plot final action predictions comparison on a single chart.
|
||||
|
||||
Args:
|
||||
rtc_actions: Final actions from RTC policy
|
||||
no_rtc_actions: Final actions from non-RTC policy
|
||||
prev_chunk_left_over: Previous chunk used as ground truth
|
||||
"""
|
||||
# Remove batch dimension if present
|
||||
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
||||
no_rtc_actions_plot = (
|
||||
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
||||
)
|
||||
prev_chunk_plot = prev_chunk_left_over.cpu()
|
||||
|
||||
# Create figure with 6 subplots (one per action dimension)
|
||||
fig, axes = plt.subplots(6, 1, figsize=(16, 12))
|
||||
fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
|
||||
|
||||
# Plot each action dimension
|
||||
for dim_idx, ax in enumerate(axes):
|
||||
# Plot previous chunk (ground truth) in red
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
[ax],
|
||||
prev_chunk_plot[:, dim_idx : dim_idx + 1],
|
||||
start_from=0,
|
||||
color="red",
|
||||
label="Previous Chunk (Ground Truth)",
|
||||
linewidth=2.5,
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
# Plot no-RTC actions in blue
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
[ax],
|
||||
no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
|
||||
start_from=0,
|
||||
color="blue",
|
||||
label="No RTC",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Plot RTC actions in green
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
[ax],
|
||||
rtc_actions_plot[:, dim_idx : dim_idx + 1],
|
||||
start_from=0,
|
||||
color="green",
|
||||
label="RTC",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Add vertical lines for inference delay and execution horizon
|
||||
inference_delay = self.cfg.inference_delay
|
||||
execution_horizon = self.cfg.rtc.execution_horizon
|
||||
|
||||
if inference_delay > 0:
|
||||
ax.axvline(
|
||||
x=inference_delay - 1,
|
||||
color="orange",
|
||||
linestyle="--",
|
||||
alpha=0.5,
|
||||
label=f"Inference Delay ({inference_delay})",
|
||||
)
|
||||
|
||||
if execution_horizon > 0:
|
||||
ax.axvline(
|
||||
x=execution_horizon,
|
||||
color="purple",
|
||||
linestyle="--",
|
||||
alpha=0.5,
|
||||
label=f"Execution Horizon ({execution_horizon})",
|
||||
)
|
||||
|
||||
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Set x-axis ticks to show all integer values
|
||||
max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
|
||||
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
||||
ax.set_xlim(-0.5, max_len - 0.5)
|
||||
|
||||
# Add legend only to first subplot
|
||||
if dim_idx == 0:
|
||||
ax.legend(loc="best", fontsize=9)
|
||||
|
||||
axes[-1].set_xlabel("Step", fontsize=10)
|
||||
|
||||
# Save figure
|
||||
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=150)
|
||||
logging.info(f"Saved final actions comparison to {output_path}")
|
||||
plt.close(fig)
|
||||
|
||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||
@@ -828,6 +941,13 @@ class RTCEvaluator:
|
||||
margin = y_range * 0.1
|
||||
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
|
||||
|
||||
# Set x-axis ticks to show all integer values
|
||||
xlim = ax.get_xlim()
|
||||
max_len = int(xlim[1]) + 1
|
||||
if max_len > 0:
|
||||
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
||||
ax.set_xlim(-0.5, max_len - 0.5)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: RTCEvalConfig):
|
||||
|
||||
Reference in New Issue
Block a user