mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Add validatio at the end
This commit is contained in:
@@ -452,12 +452,13 @@ class RTCEvaluator:
|
||||
noise_clone = noise.clone()
|
||||
policy_no_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = policy_no_rtc_policy.predict_action_chunk(
|
||||
no_rtc_actions = policy_no_rtc_policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
)
|
||||
no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
|
||||
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
|
||||
logging.info(f" Generated no_rtc_actions shape: {no_rtc_actions.shape}")
|
||||
|
||||
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
||||
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
|
||||
@@ -477,7 +478,7 @@ class RTCEvaluator:
|
||||
)
|
||||
policy_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = policy_rtc_policy.predict_action_chunk(
|
||||
rtc_actions = policy_rtc_policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise_clone,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
@@ -486,6 +487,7 @@ class RTCEvaluator:
|
||||
)
|
||||
rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
|
||||
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
|
||||
logging.info(f" Generated rtc_actions shape: {rtc_actions.shape}")
|
||||
|
||||
# Save num_steps before destroying policy (needed for plotting)
|
||||
try:
|
||||
@@ -502,9 +504,163 @@ class RTCEvaluator:
|
||||
logging.info("=" * 80)
|
||||
logging.info("Plotting results...")
|
||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
|
||||
|
||||
# Validate RTC behavior
|
||||
logging.info("=" * 80)
|
||||
logging.info("Validating RTC behavior...")
|
||||
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("Evaluation completed successfully")
|
||||
|
||||
def validate_rtc_behavior(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
|
||||
"""Validate RTC behavior by comparing final action predictions with expected values.
|
||||
|
||||
Validation rules:
|
||||
1. During delay [0:inference_delay]: RTC should equal prev_chunk
|
||||
2. After delay, within execution horizon [inference_delay:execution_horizon]:
|
||||
RTC should be between prev_chunk and no_rtc
|
||||
3. After execution horizon [execution_horizon:]: RTC should equal no_rtc
|
||||
|
||||
Args:
|
||||
rtc_actions: Final actions from RTC policy (batch, time, action_dim)
|
||||
no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim)
|
||||
prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim)
|
||||
"""
|
||||
if rtc_actions is None or no_rtc_actions is None:
|
||||
logging.warning(" ⚠ Cannot validate: missing action predictions")
|
||||
return
|
||||
|
||||
# Convert to numpy for comparison (remove batch dimension if present)
|
||||
rtc_actions_np = (
|
||||
rtc_actions.squeeze(0).cpu().numpy() if len(rtc_actions.shape) == 3 else rtc_actions.cpu().numpy()
|
||||
)
|
||||
no_rtc_actions_np = (
|
||||
no_rtc_actions.squeeze(0).cpu().numpy()
|
||||
if len(no_rtc_actions.shape) == 3
|
||||
else no_rtc_actions.cpu().numpy()
|
||||
)
|
||||
prev_chunk = prev_chunk_left_over.cpu().numpy()
|
||||
|
||||
# Determine chunk length for comparison
|
||||
chunk_len = min(rtc_actions_np.shape[0], no_rtc_actions_np.shape[0], prev_chunk.shape[0])
|
||||
inference_delay = self.cfg.inference_delay
|
||||
execution_horizon = self.cfg.rtc.execution_horizon
|
||||
|
||||
# Tolerance for floating point comparison
|
||||
rtol = 1e-3 # Relative tolerance
|
||||
atol = 1e-3 # Absolute tolerance
|
||||
|
||||
validation_passed = True
|
||||
warnings = []
|
||||
|
||||
logging.info(" Validating RTC behavior:")
|
||||
logging.info(f" Chunk length: {chunk_len}")
|
||||
logging.info(f" Inference delay: {inference_delay}")
|
||||
logging.info(f" Execution horizon: {execution_horizon}")
|
||||
logging.info(f" Tolerance: rtol={rtol}, atol={atol}")
|
||||
|
||||
# ============================================================================
|
||||
# Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk
|
||||
# ============================================================================
|
||||
if inference_delay > 0:
|
||||
delay_end = min(inference_delay, chunk_len)
|
||||
rtc_delay = rtc_actions_np[:delay_end]
|
||||
prev_delay = prev_chunk[:delay_end]
|
||||
|
||||
if not np.allclose(rtc_delay, prev_delay, rtol=rtol, atol=atol):
|
||||
max_diff = np.max(np.abs(rtc_delay - prev_delay))
|
||||
mean_diff = np.mean(np.abs(rtc_delay - prev_delay))
|
||||
warnings.append(
|
||||
f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], "
|
||||
f"RTC does NOT equal prev_chunk!\n"
|
||||
f" Max difference: {max_diff:.6f}\n"
|
||||
f" Mean difference: {mean_diff:.6f}"
|
||||
)
|
||||
validation_passed = False
|
||||
else:
|
||||
logging.info(f" ✓ During delay [0:{delay_end}]: RTC equals prev_chunk")
|
||||
|
||||
# ============================================================================
|
||||
# Rule 2: After delay, within execution horizon [inference_delay:execution_horizon]
|
||||
# RTC should be between prev_chunk and no_rtc
|
||||
# ============================================================================
|
||||
blend_start = inference_delay
|
||||
blend_end = min(execution_horizon, chunk_len)
|
||||
|
||||
if blend_end > blend_start:
|
||||
rtc_blend = rtc_actions_np[blend_start:blend_end]
|
||||
prev_blend = prev_chunk[blend_start:blend_end]
|
||||
no_rtc_blend = no_rtc_actions_np[blend_start:blend_end]
|
||||
|
||||
# Check if RTC is between prev_chunk and no_rtc (element-wise)
|
||||
# For each element, check if it's between the min and max of prev_chunk and no_rtc
|
||||
min_bound = np.minimum(prev_blend, no_rtc_blend) - atol
|
||||
max_bound = np.maximum(prev_blend, no_rtc_blend) + atol
|
||||
|
||||
within_bounds = np.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
||||
|
||||
if not np.all(within_bounds):
|
||||
violations = np.sum(~within_bounds)
|
||||
total_elements = within_bounds.size
|
||||
violation_pct = 100.0 * violations / total_elements
|
||||
|
||||
# Find max violation
|
||||
lower_violations = np.maximum(0, min_bound - rtc_blend)
|
||||
upper_violations = np.maximum(0, rtc_blend - max_bound)
|
||||
max_violation = np.max(np.maximum(lower_violations, upper_violations))
|
||||
|
||||
warnings.append(
|
||||
f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], "
|
||||
f"RTC is NOT always between prev_chunk and no_rtc!\n"
|
||||
f" Violations: {violations}/{total_elements} elements ({violation_pct:.1f}%)\n"
|
||||
f" Max violation distance: {max_violation:.6f}"
|
||||
)
|
||||
validation_passed = False
|
||||
else:
|
||||
logging.info(
|
||||
f" ✓ Blend region [{blend_start}:{blend_end}]: RTC is between prev_chunk and no_rtc"
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc
|
||||
# ============================================================================
|
||||
if execution_horizon < chunk_len:
|
||||
rtc_after = rtc_actions_np[execution_horizon:chunk_len]
|
||||
no_rtc_after = no_rtc_actions_np[execution_horizon:chunk_len]
|
||||
|
||||
if not np.allclose(rtc_after, no_rtc_after, rtol=rtol, atol=atol):
|
||||
max_diff = np.max(np.abs(rtc_after - no_rtc_after))
|
||||
mean_diff = np.mean(np.abs(rtc_after - no_rtc_after))
|
||||
warnings.append(
|
||||
f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], "
|
||||
f"RTC does NOT equal no_rtc!\n"
|
||||
f" Max difference: {max_diff:.6f}\n"
|
||||
f" Mean difference: {mean_diff:.6f}"
|
||||
)
|
||||
validation_passed = False
|
||||
else:
|
||||
logging.info(
|
||||
f" ✓ After execution horizon [{execution_horizon}:{chunk_len}]: RTC equals no_rtc"
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Report results
|
||||
# ============================================================================
|
||||
logging.info("=" * 80)
|
||||
if validation_passed:
|
||||
logging.info(" ✅ VALIDATION PASSED: All RTC behavior checks passed!")
|
||||
logging.info(" • During delay: RTC = prev_chunk ✓")
|
||||
logging.info(" • Blend region: prev_chunk ≤ RTC ≤ no_rtc ✓")
|
||||
logging.info(" • After execution horizon: RTC = no_rtc ✓")
|
||||
else:
|
||||
logging.error(" ❌ VALIDATION FAILED: RTC behavior does not match expected!")
|
||||
logging.error("")
|
||||
for warning in warnings:
|
||||
logging.error(warning)
|
||||
logging.error("")
|
||||
logging.error(" Please check the implementation of RTC guidance.")
|
||||
|
||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run RTC evaluation on dataset
|
||||
# This shows different usage scenarios
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
POLICY_PATH="lerobot/smolvla_base"
|
||||
DATASET="lerobot/pusht"
|
||||
DEVICE="cuda" # Change to "cpu" or "mps" if needed
|
||||
|
||||
echo "========================================"
|
||||
echo "RTC Dataset Evaluation Examples"
|
||||
echo "========================================"
|
||||
|
||||
# Example 1: Quick evaluation (100 samples, every step)
|
||||
echo -e "\n[Example 1] Quick evaluation - 100 samples, every step"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=100 \
|
||||
--skip_steps=1 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_quick.json"
|
||||
|
||||
# Example 2: Simulating realistic inference delay (every 3rd step)
|
||||
echo -e "\n[Example 2] Realistic inference delay - 200 samples, every 3rd step"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=200 \
|
||||
--skip_steps=3 \
|
||||
--rtc.execution_horizon=10 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_delay3.json"
|
||||
|
||||
# Example 3: Higher inference delay (every 5th step)
|
||||
echo -e "\n[Example 3] High inference delay - 200 samples, every 5th step"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=200 \
|
||||
--skip_steps=5 \
|
||||
--rtc.execution_horizon=12 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_delay5.json"
|
||||
|
||||
# Example 4: Testing different RTC configurations
|
||||
echo -e "\n[Example 4] Different RTC config - LINEAR schedule"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=100 \
|
||||
--skip_steps=3 \
|
||||
--rtc.execution_horizon=8 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--device="${DEVICE}" \
|
||||
--output_path="results/rtc_eval_linear.json"
|
||||
|
||||
# Example 5: Verbose mode for debugging
|
||||
echo -e "\n[Example 5] Verbose mode - 20 samples with detailed output"
|
||||
python examples/rtc/evaluate_rtc_on_dataset.py \
|
||||
--policy.path="${POLICY_PATH}" \
|
||||
--dataset.repo_id="${DATASET}" \
|
||||
--num_iterations=20 \
|
||||
--skip_steps=3 \
|
||||
--device="${DEVICE}" \
|
||||
--verbose=true \
|
||||
--output_path="results/rtc_eval_verbose.json"
|
||||
|
||||
echo -e "\n========================================"
|
||||
echo "All evaluations completed!"
|
||||
echo "Results saved in results/ directory"
|
||||
echo "========================================"
|
||||
Reference in New Issue
Block a user