diff --git a/examples/rtc/README.md b/examples/rtc/README.md index 988f997a9..5128645e7 100644 --- a/examples/rtc/README.md +++ b/examples/rtc/README.md @@ -16,156 +16,161 @@ Real-Time Chunking addresses the challenge of maintaining consistency and reacti ## Scripts -### 1. `real_time_chunking_evaluate.py` +### 1. `eval_dataset.py` -Real-time evaluation on physical robots or simulation environments. +Offline evaluation on dataset samples with detailed visualization and validation. **Features:** -- Run policy with RTC on real robot or simulation -- Compare RTC vs non-RTC actions in real-time -- Multi-threaded action execution and inference +- Compare RTC vs non-RTC predictions on two random dataset samples +- Validate RTC behavior (delay region, blend region, post-horizon region) +- Generate debug visualizations: + - Denoising step comparisons (x_t, v_t, x1_t, corrections) + - Final action predictions comparison - Support for torch.compile() optimization +- Memory-efficient sequential policy loading for large models **Usage:** ```bash -# With real robot -uv run python examples/rtc/real_time_chunking_evaluate.py \ - --policy.path=lerobot/smolvla_base \ - --robot.type=so100 \ - --task="pick up the cup" +# Basic usage with SmolVLA policy +uv run python examples/rtc/eval_dataset.py \ + --policy.path=helper2424/smolvla_check_rtc_last3 \ + --dataset.repo_id=helper2424/check_rtc \ + --rtc.execution_horizon=8 \ + --device=mps \ + --rtc.max_guidance_weight=10.0 \ + --seed=10 -# With simulation environment -uv run python examples/rtc/real_time_chunking_evaluate.py \ - --policy.path=lerobot/smolvla_base \ - --env.type=pusht \ - --duration=60.0 +# With Pi0.5 policy on CUDA +uv run python examples/rtc/eval_dataset.py \ + --policy.path=lerobot/pi05_libero_finetuned \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --rtc.execution_horizon=8 \ + --device=cuda -# Disable verbose comparison (faster) -uv run python examples/rtc/real_time_chunking_evaluate.py \ - --policy.path=lerobot/smolvla_base \ - --robot.type=so100 \ - --verbose_rtc_comparison=false +# With Pi0 policy +uv run python examples/rtc/eval_dataset.py \ + --policy.path=lerobot/pi0_libero_finetuned \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --rtc.execution_horizon=8 \ + --device=cuda -# With policy compilation (CUDA only, not MPS) -uv run python examples/rtc/real_time_chunking_evaluate.py \ - --policy.path=lerobot/smolvla_base \ - --robot.type=so100 \ - --compile_policy=true \ - --compile_mode=max-autotune -``` +# With torch.compile for faster inference +uv run python examples/rtc/eval_dataset.py \ + --policy.path=helper2424/smolvla_check_rtc_last3 \ + --dataset.repo_id=helper2424/check_rtc \ + --rtc.execution_horizon=8 \ + --device=cuda \ + --use_torch_compile=true \ + --torch_compile_mode=max-autotune -**Key Parameters:** - -- `--policy.path`: Path to pretrained policy -- `--robot.type` or `--env.type`: Robot or environment to use -- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10) -- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0) -- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP) -- `--verbose_rtc_comparison`: Enable detailed RTC comparison logging (default: true) -- `--duration`: How long to run (seconds, default: 30.0) -- `--fps`: Action execution frequency (Hz, default: 10.0) - -### 2. `evaluate_rtc_on_dataset.py` - -Offline evaluation on dataset samples to measure RTC effectiveness. - -**Features:** - -- Evaluate RTC on dataset without running robot -- Compare RTC vs non-RTC predictions -- Measure consistency and ground truth alignment -- Simulate different inference delays -- Save detailed metrics to JSON - -**Usage:** - -```bash -# Basic evaluation -uv run python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path=lerobot/smolvla_base \ - --dataset.repo_id=lerobot/pusht \ - --num_iterations=100 - -# Simulate inference delay (every 3rd step) -uv run python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path=lerobot/smolvla_base \ - --dataset.repo_id=lerobot/pusht \ - --num_iterations=200 \ - --skip_steps=3 - -# Custom RTC configuration -uv run python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path=lerobot/smolvla_base \ - --dataset.repo_id=lerobot/pusht \ - --num_iterations=100 \ - --rtc.execution_horizon=12 \ - --rtc.max_guidance_weight=5.0 \ - --rtc.prefix_attention_schedule=LINEAR - -# Save results to file -uv run python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path=lerobot/smolvla_base \ - --dataset.repo_id=lerobot/pusht \ - --num_iterations=100 \ - --output_path=results/rtc_evaluation.json - -# Verbose mode with detailed logging -uv run python examples/rtc/evaluate_rtc_on_dataset.py \ - --policy.path=lerobot/smolvla_base \ - --dataset.repo_id=lerobot/pusht \ - --num_iterations=50 \ - --verbose=true +# Enable CUDA graphs (advanced - may cause tensor aliasing errors) +uv run python examples/rtc/eval_dataset.py \ + --policy.path=helper2424/smolvla_check_rtc_last3 \ + --dataset.repo_id=helper2424/check_rtc \ + --use_torch_compile=true \ + --torch_compile_backend=inductor \ + --torch_compile_mode=max-autotune \ + --torch_compile_disable_cudagraphs=false ``` **Key Parameters:** - `--policy.path`: Path to pretrained policy - `--dataset.repo_id`: Dataset to evaluate on -- `--num_iterations`: Number of samples to evaluate (default: 100) -- `--skip_steps`: Steps to skip between inferences, simulates inference delay (default: 1) -- `--start_episode`: Episode to start from (default: 0) -- `--output_path`: Path to save results JSON -- `--verbose`: Enable detailed per-sample logging +- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 20) +- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 10.0) +- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP) +- `--inference_delay`: Inference delay for RTC (default: 4) +- `--seed`: Random seed for reproducibility (default: 42) +- `--output_dir`: Directory to save visualizations (default: rtc_debug_output) - `--device`: Device to use (cuda, cpu, mps, auto) +- `--use_torch_compile`: Enable torch.compile() for faster inference -**Metrics Reported:** +**Output:** -- **RTC vs Ground Truth MSE**: How close RTC predictions are to actual actions -- **No-RTC vs Ground Truth MSE**: Baseline without RTC -- **RTC Improvement**: Absolute and relative improvement over baseline -- **RTC Consistency**: How well RTC maintains consistency in prefix region - - Prefix MSE - - Mean/Max error in overlap region +The script generates several visualization files in `rtc_debug_output/`: -### 3. `run_dataset_evaluation.sh` +- `denoising_xt_comparison.png` - Noisy state evolution during denoising +- `denoising_vt_comparison.png` - Velocity predictions during denoising +- `denoising_x1t_comparison.png` - Predicted final states during denoising +- `denoising_correction_comparison.png` - RTC guidance corrections applied +- `final_actions_comparison.png` - Final action predictions (prev_chunk, no_rtc, rtc) -Convenience script with multiple evaluation scenarios. +The script also validates RTC behavior and reports: + +- ✅ Delay region [0:inference_delay]: RTC = prev_chunk +- ✅ Blend region [inference_delay:execution_horizon]: prev_chunk ≤ RTC ≤ no_rtc +- ✅ Post-horizon [execution_horizon:]: RTC = no_rtc + +### 2. `eval_with_real_robot.py` + +Real-time evaluation on physical robots or simulation environments. + +**Features:** + +- Run policy with RTC on real robot or simulation +- Multi-threaded action execution and inference +- Action queue management with proper timing +- Latency tracking and adaptive inference delay +- Support for both robots and gym environments +- Support for torch.compile() optimization **Usage:** ```bash -# Edit the script to set your policy and dataset -# Then run all examples: -./examples/rtc/run_dataset_evaluation.sh +# With real robot +uv run python examples/rtc/eval_with_real_robot.py \ + --policy.path=lerobot/smolvla_base \ + --robot.type=so100 \ + --task="pick up the cup" \ + --duration=30.0 -# Or run individual examples from the script +# With simulation environment +uv run python examples/rtc/eval_with_real_robot.py \ + --policy.path=lerobot/smolvla_base \ + --env.type=pusht \ + --duration=60.0 + +# With policy compilation (CUDA only, not MPS) +uv run python examples/rtc/eval_with_real_robot.py \ + --policy.path=lerobot/smolvla_base \ + --robot.type=so100 \ + --use_torch_compile=true \ + --torch_compile_mode=max-autotune ``` +**Key Parameters:** + +- `--policy.path`: Path to pretrained policy +- `--robot.type` or `--env.type`: Robot or environment to use +- `--task`: Task description (for VLA models) +- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10) +- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0) +- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP) +- `--duration`: How long to run (seconds, default: 30.0) +- `--fps`: Action execution frequency (Hz, default: 10.0) +- `--action_queue_size_to_get_new_actions`: Queue size threshold to request new actions (default: 30) +- `--device`: Device to use (cuda, cpu, mps, auto) +- `--use_torch_compile`: Enable torch.compile() for faster inference + ## Understanding RTC Parameters ### `execution_horizon` Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity. -**Typical values:** 8-12 steps +**Typical values:** 8-12 steps for dataset evaluation, 10 steps for real-time execution ### `max_guidance_weight` Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions. -**Typical values:** 1.0-10.0 +**Typical values:** + +- Dataset evaluation: 10.0-100.0 (can be higher for analysis) +- Real-time execution: 1.0-10.0 (more conservative) ### `prefix_attention_schedule` @@ -178,104 +183,69 @@ How to weight consistency across the overlap region: **Recommended:** `EXP` -### `skip_steps` (evaluation only) +### `inference_delay` -Simulates inference delay by evaluating every N-th step. This helps understand how RTC performs with realistic delays. +Number of timesteps from the prefix to use for guidance. Typically calculated dynamically based on inference latency in real-time execution, but fixed for dataset evaluation. -**Example:** `skip_steps=3` means policy infers every 3 steps, simulating 3x action execution frequency vs inference frequency. +**Typical values:** 3-5 steps for dataset evaluation -## Output Format (Dataset Evaluation) +### `action_queue_size_to_get_new_actions` (real-time only) -When using `--output_path`, results are saved in JSON format: +Threshold for requesting new action chunks. Should be higher than `inference_delay + execution_horizon` to ensure smooth operation. -```json -{ - "summary": { - "rtc_vs_ground_truth_mse": { - "mean": 0.00123, - "std": 0.00045, - "min": 0.00012, - "max": 0.00456 - }, - "improvement": { - "absolute": 0.00034, - "relative_percent": 12.5 - }, - ... - }, - "config": { - "num_iterations": 100, - "skip_steps": 3, - "execution_horizon": 10, - ... - }, - "detailed_results": [ - { - "sample_idx": 0, - "rtc_vs_ground_truth_mse": 0.00112, - "no_rtc_vs_ground_truth_mse": 0.00145, - ... - }, - ... - ] -} -``` +**Typical values:** 20-30 steps + +## Validation Rules (Dataset Evaluation) + +The dataset evaluation script validates that RTC behavior matches expectations: + +1. **Delay Region [0:inference_delay]**: RTC actions should equal previous chunk + - Ensures consistency during the inference delay period + +2. **Blend Region [inference_delay:execution_horizon]**: RTC should be between prev_chunk and no_rtc + - Smooth transition from previous plan to new predictions + +3. **Post-Horizon [execution_horizon:]**: RTC should equal no_rtc + - Full adoption of new predictions after execution horizon ## Tips -1. **Start with dataset evaluation** to understand RTC behavior before running on robot -2. **Use verbose mode** for debugging unexpected behavior +1. **Start with dataset evaluation** (`eval_dataset.py`) to understand RTC behavior and tune parameters before running on robot +2. **Use visualizations** to debug unexpected behavior - check denoising steps and final actions 3. **Tune execution_horizon** based on your inference latency and action frequency -4. **Monitor consistency metrics** - very low consistency might indicate execution_horizon is too small +4. **Monitor validation output** - failures indicate potential implementation issues or misconfigured parameters 5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable ## Troubleshooting -### High RTC vs No-RTC difference but no improvement +### Validation fails in delay region -- Try reducing `max_guidance_weight` -- Check if `execution_horizon` is too large +- Check that `prev_chunk_left_over` is properly passed to the policy +- Verify RTC guidance is being applied during denoising +- Look at denoising visualizations to see where guidance diverges -### Poor consistency metrics +### Validation fails in post-horizon region -- Increase `execution_horizon` -- Check that `skip_steps` is not larger than your action chunk size -- Verify episodes are being reset correctly +- RTC and no_rtc use different noise - verify same noise is being used for comparison +- Check that weights are correctly zeroed out after execution horizon +- Review prefix_attention_schedule visualization -### RTC worse than No-RTC +### Poor performance on real robot -- RTC may not help if inference is faster than action execution -- Try different `prefix_attention_schedule` -- Ensure `execution_horizon` matches your use case +- Increase `action_queue_size_to_get_new_actions` if you see warnings +- Reduce `max_guidance_weight` if robot is too conservative +- Try different `prefix_attention_schedule` values +- Enable torch.compile() for faster inference (CUDA only) -## Examples Results +### Memory issues with large models -Example output from dataset evaluation: - -``` -================================================================================ -EVALUATION SUMMARY -================================================================================ - -Ground Truth Alignment: - RTC MSE: 0.001234 ± 0.000456 - No-RTC MSE: 0.001567 ± 0.000512 - -RTC Improvement: - Absolute: 0.000333 - Relative: 21.23% - -RTC vs No-RTC Difference: - MSE: 0.000112 ± 0.000034 - -RTC Consistency (Prefix Region): - MSE: 0.000089 ± 0.000023 - Mean Error: 0.007654 ± 0.002341 - Max Error: 0.023456 ± 0.008765 -``` +- The dataset evaluation script loads policies sequentially to minimize memory +- For real-time execution, only one policy is loaded +- Use smaller batch sizes if needed ## Related Documentation - [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py) - [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py) +- [Action Queue](../../src/lerobot/policies/rtc/action_queue.py) - [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf) diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index c3d540339..16f1cfacc 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -16,7 +16,9 @@ Usage: --policy.path=helper2424/smolvla_check_rtc_last3 \ --dataset.repo_id=helper2424/check_rtc \ --rtc.execution_horizon=8 \ - --device=mps + --device=mps \ + --rtc.max_guidance_weight=10.0 \ + --seed=10 # Basic usage with pi0.5 policy uv run python examples/rtc/eval_dataset.py \ @@ -439,6 +441,8 @@ class RTCEvaluator: logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc") logging.info("=" * 80) + set_seed(self.cfg.seed) + # Initialize policy 2 policy_no_rtc_policy = self._init_policy( name="policy_no_rtc", @@ -470,6 +474,8 @@ class RTCEvaluator: logging.info("Step 3: Generating actions WITH RTC with policy_rtc") logging.info("=" * 80) + set_seed(self.cfg.seed) + # Initialize policy 3 policy_rtc_policy = self._init_policy( name="policy_rtc", @@ -510,6 +516,11 @@ class RTCEvaluator: logging.info("Validating RTC behavior...") self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over) + # Plot final actions comparison + logging.info("=" * 80) + logging.info("Plotting final actions comparison...") + self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over) + logging.info("=" * 80) logging.info("Evaluation completed successfully") @@ -527,29 +538,24 @@ class RTCEvaluator: no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim) prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim) """ - if rtc_actions is None or no_rtc_actions is None: - logging.warning(" ⚠ Cannot validate: missing action predictions") - return + # Remove batch dimension if present and move to CPU + rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() + no_rtc_actions_t = ( + no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() + ) + prev_chunk = prev_chunk_left_over.cpu() - # Convert to numpy for comparison (remove batch dimension if present) - rtc_actions_np = ( - rtc_actions.squeeze(0).cpu().numpy() if len(rtc_actions.shape) == 3 else rtc_actions.cpu().numpy() - ) - no_rtc_actions_np = ( - no_rtc_actions.squeeze(0).cpu().numpy() - if len(no_rtc_actions.shape) == 3 - else no_rtc_actions.cpu().numpy() - ) - prev_chunk = prev_chunk_left_over.cpu().numpy() + logging.info(f" rtc_actions shape: {rtc_actions_t.shape}") + logging.info(f" no_rtc_actions shape: {no_rtc_actions_t.shape}") + logging.info(f" prev_chunk shape: {prev_chunk.shape}") # Determine chunk length for comparison - chunk_len = min(rtc_actions_np.shape[0], no_rtc_actions_np.shape[0], prev_chunk.shape[0]) + chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk.shape[0]) inference_delay = self.cfg.inference_delay execution_horizon = self.cfg.rtc.execution_horizon # Tolerance for floating point comparison - rtol = 1e-3 # Relative tolerance - atol = 1e-3 # Absolute tolerance + rtol = 1e-2 # Relative tolerance validation_passed = True warnings = [] @@ -558,19 +564,26 @@ class RTCEvaluator: logging.info(f" Chunk length: {chunk_len}") logging.info(f" Inference delay: {inference_delay}") logging.info(f" Execution horizon: {execution_horizon}") - logging.info(f" Tolerance: rtol={rtol}, atol={atol}") + logging.info(f" Tolerance: rtol={rtol}") # ============================================================================ # Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk # ============================================================================ if inference_delay > 0: delay_end = min(inference_delay, chunk_len) - rtc_delay = rtc_actions_np[:delay_end] + rtc_delay = rtc_actions_t[:delay_end] prev_delay = prev_chunk[:delay_end] - if not np.allclose(rtc_delay, prev_delay, rtol=rtol, atol=atol): - max_diff = np.max(np.abs(rtc_delay - prev_delay)) - mean_diff = np.mean(np.abs(rtc_delay - prev_delay)) + logging.info(f" rtc_delay: {rtc_delay.shape}") + logging.info(f" prev_delay: {prev_delay.shape}") + + if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() + mean_diff = torch.mean(torch.abs(rtc_delay - prev_delay)).item() + logging.info(f" rtc_delay: {rtc_delay}") + logging.info(f" prev_delay: {prev_delay}") + logging.info(f" max_diff: {max_diff}") + logging.info(f" mean_diff: {mean_diff}") warnings.append( f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], " f"RTC does NOT equal prev_chunk!\n" @@ -589,26 +602,26 @@ class RTCEvaluator: blend_end = min(execution_horizon, chunk_len) if blend_end > blend_start: - rtc_blend = rtc_actions_np[blend_start:blend_end] + rtc_blend = rtc_actions_t[blend_start:blend_end] prev_blend = prev_chunk[blend_start:blend_end] - no_rtc_blend = no_rtc_actions_np[blend_start:blend_end] + no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] # Check if RTC is between prev_chunk and no_rtc (element-wise) # For each element, check if it's between the min and max of prev_chunk and no_rtc - min_bound = np.minimum(prev_blend, no_rtc_blend) - atol - max_bound = np.maximum(prev_blend, no_rtc_blend) + atol + min_bound = torch.minimum(prev_blend, no_rtc_blend) + max_bound = torch.maximum(prev_blend, no_rtc_blend) - within_bounds = np.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) + within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) - if not np.all(within_bounds): - violations = np.sum(~within_bounds) - total_elements = within_bounds.size + if not torch.all(within_bounds): + violations = torch.sum(~within_bounds).item() + total_elements = within_bounds.numel() violation_pct = 100.0 * violations / total_elements # Find max violation - lower_violations = np.maximum(0, min_bound - rtc_blend) - upper_violations = np.maximum(0, rtc_blend - max_bound) - max_violation = np.max(np.maximum(lower_violations, upper_violations)) + lower_violations = torch.maximum(torch.tensor(0.0), min_bound - rtc_blend) + upper_violations = torch.maximum(torch.tensor(0.0), rtc_blend - max_bound) + max_violation = torch.max(torch.maximum(lower_violations, upper_violations)).item() warnings.append( f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], " @@ -626,12 +639,15 @@ class RTCEvaluator: # Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc # ============================================================================ if execution_horizon < chunk_len: - rtc_after = rtc_actions_np[execution_horizon:chunk_len] - no_rtc_after = no_rtc_actions_np[execution_horizon:chunk_len] + rtc_after = rtc_actions_t[execution_horizon:chunk_len] + no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] - if not np.allclose(rtc_after, no_rtc_after, rtol=rtol, atol=atol): - max_diff = np.max(np.abs(rtc_after - no_rtc_after)) - mean_diff = np.mean(np.abs(rtc_after - no_rtc_after)) + logging.info(f" rtc_after: {rtc_after}") + logging.info(f" no_rtc_after: {no_rtc_after}") + + if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): + max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() + mean_diff = torch.mean(torch.abs(rtc_after - no_rtc_after)).item() warnings.append( f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], " f"RTC does NOT equal no_rtc!\n" @@ -661,6 +677,103 @@ class RTCEvaluator: logging.error("") logging.error(" Please check the implementation of RTC guidance.") + def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over): + """Plot final action predictions comparison on a single chart. + + Args: + rtc_actions: Final actions from RTC policy + no_rtc_actions: Final actions from non-RTC policy + prev_chunk_left_over: Previous chunk used as ground truth + """ + # Remove batch dimension if present + rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() + no_rtc_actions_plot = ( + no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() + ) + prev_chunk_plot = prev_chunk_left_over.cpu() + + # Create figure with 6 subplots (one per action dimension) + fig, axes = plt.subplots(6, 1, figsize=(16, 12)) + fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16) + + # Plot each action dimension + for dim_idx, ax in enumerate(axes): + # Plot previous chunk (ground truth) in red + RTCDebugVisualizer.plot_waypoints( + [ax], + prev_chunk_plot[:, dim_idx : dim_idx + 1], + start_from=0, + color="red", + label="Previous Chunk (Ground Truth)", + linewidth=2.5, + alpha=0.8, + ) + + # Plot no-RTC actions in blue + RTCDebugVisualizer.plot_waypoints( + [ax], + no_rtc_actions_plot[:, dim_idx : dim_idx + 1], + start_from=0, + color="blue", + label="No RTC", + linewidth=2, + alpha=0.7, + ) + + # Plot RTC actions in green + RTCDebugVisualizer.plot_waypoints( + [ax], + rtc_actions_plot[:, dim_idx : dim_idx + 1], + start_from=0, + color="green", + label="RTC", + linewidth=2, + alpha=0.7, + ) + + # Add vertical lines for inference delay and execution horizon + inference_delay = self.cfg.inference_delay + execution_horizon = self.cfg.rtc.execution_horizon + + if inference_delay > 0: + ax.axvline( + x=inference_delay - 1, + color="orange", + linestyle="--", + alpha=0.5, + label=f"Inference Delay ({inference_delay})", + ) + + if execution_horizon > 0: + ax.axvline( + x=execution_horizon, + color="purple", + linestyle="--", + alpha=0.5, + label=f"Execution Horizon ({execution_horizon})", + ) + + ax.set_ylabel(f"Dim {dim_idx}", fontsize=10) + ax.grid(True, alpha=0.3) + + # Set x-axis ticks to show all integer values + max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0]) + ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks + ax.set_xlim(-0.5, max_len - 0.5) + + # Add legend only to first subplot + if dim_idx == 0: + ax.legend(loc="best", fontsize=9) + + axes[-1].set_xlabel("Step", fontsize=10) + + # Save figure + output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png") + fig.tight_layout() + fig.savefig(output_path, dpi=150) + logging.info(f"Saved final actions comparison to {output_path}") + plt.close(fig) + def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps): # Create side-by-side figures for denoising visualization fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") @@ -828,6 +941,13 @@ class RTCEvaluator: margin = y_range * 0.1 ax.set_ylim(ylim[0] - margin, ylim[1] + margin) + # Set x-axis ticks to show all integer values + xlim = ax.get_xlim() + max_len = int(xlim[1]) + 1 + if max_len > 0: + ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks + ax.set_xlim(-0.5, max_len - 0.5) + @parser.wrap() def main(cfg: RTCEvalConfig): diff --git a/src/lerobot/policies/rtc/README.md b/src/lerobot/policies/rtc/README.md index 94ed41c5f..2b72b33ab 100644 --- a/src/lerobot/policies/rtc/README.md +++ b/src/lerobot/policies/rtc/README.md @@ -14,6 +14,7 @@ RTC can be integrated with any policy that supports flow mathicng for chunking: - **SmolVLA**: Vision-language-action model with RTC support - **Pi0**: Action prediction model with adaptive chunking +- **Pi05**: Action prediction model with adaptive chunking ## Original Implementation @@ -39,3 +40,10 @@ uv run python examples/rtc/eval_dataset.py \ --device=mps \ --seed=42 ``` + +This script will evaluate RTC on a data from a dataset and save the results to a file, u can check the results in the `rtc_debug_output` directory. + +The example output should look like this: +![Flow Matching with RTC](./flow_matching.png) + +It shows how flow matching works with RTC and without it. The chart shows values of action predictions for each timestep. The colour shows the the generation progress. The blue ones - earlier timesteps, the yellow ones - later timesteps. The red line is the ground truth (previous action chunk). diff --git a/src/lerobot/policies/rtc/flow_matching.png b/src/lerobot/policies/rtc/flow_matching.png new file mode 100644 index 000000000..af7c7bf50 Binary files /dev/null and b/src/lerobot/policies/rtc/flow_matching.png differ diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 0445aa982..280905adf 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -168,6 +168,8 @@ class RTCProcessor: v_t = original_denoise_step_partial(x_t) return v_t + x_t = x_t.clone().detach() + squeezed = False if len(x_t.shape) < 3: # Add batch dimension @@ -208,7 +210,6 @@ class RTCProcessor: with torch.enable_grad(): v_t = original_denoise_step_partial(x_t) - x_t = x_t.clone().detach() x_t.requires_grad_(True) x1_t = x_t - time * v_t # noqa: N806 diff --git a/tests/policies/rtc/test_configuration_rtc.py b/tests/policies/rtc/test_configuration_rtc.py index 2251e007c..bb4550eaa 100644 --- a/tests/policies/rtc/test_configuration_rtc.py +++ b/tests/policies/rtc/test_configuration_rtc.py @@ -16,8 +16,6 @@ """Tests for RTC configuration module.""" -import pytest - from lerobot.configs.types import RTCAttentionSchedule from lerobot.policies.rtc.configuration_rtc import RTCConfig @@ -65,259 +63,3 @@ def test_rtc_config_partial_initialization(): assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR assert config.execution_horizon == 10 assert config.debug is False - - -# ====================== Validation Tests ====================== - - -def test_rtc_config_validates_positive_max_guidance_weight(): - """Test RTCConfig validates max_guidance_weight is positive.""" - with pytest.raises(ValueError, match="max_guidance_weight must be positive"): - RTCConfig(max_guidance_weight=0.0) - - with pytest.raises(ValueError, match="max_guidance_weight must be positive"): - RTCConfig(max_guidance_weight=-1.0) - - -def test_rtc_config_validates_positive_debug_maxlen(): - """Test RTCConfig validates debug_maxlen is positive.""" - with pytest.raises(ValueError, match="debug_maxlen must be positive"): - RTCConfig(debug_maxlen=0) - - with pytest.raises(ValueError, match="debug_maxlen must be positive"): - RTCConfig(debug_maxlen=-10) - - -def test_rtc_config_accepts_valid_max_guidance_weight(): - """Test RTCConfig accepts valid positive max_guidance_weight.""" - config1 = RTCConfig(max_guidance_weight=0.1) - assert config1.max_guidance_weight == 0.1 - - config2 = RTCConfig(max_guidance_weight=100.0) - assert config2.max_guidance_weight == 100.0 - - -def test_rtc_config_accepts_valid_debug_maxlen(): - """Test RTCConfig accepts valid positive debug_maxlen.""" - config1 = RTCConfig(debug_maxlen=1) - assert config1.debug_maxlen == 1 - - config2 = RTCConfig(debug_maxlen=10000) - assert config2.debug_maxlen == 10000 - - -# ====================== Attention Schedule Tests ====================== - - -def test_rtc_config_with_linear_schedule(): - """Test RTCConfig with LINEAR attention schedule.""" - config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR) - assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR - - -def test_rtc_config_with_exp_schedule(): - """Test RTCConfig with EXP attention schedule.""" - config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP) - assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP - - -def test_rtc_config_with_zeros_schedule(): - """Test RTCConfig with ZEROS attention schedule.""" - config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS) - assert config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS - - -def test_rtc_config_with_ones_schedule(): - """Test RTCConfig with ONES attention schedule.""" - config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES) - assert config.prefix_attention_schedule == RTCAttentionSchedule.ONES - - -# ====================== Enabled/Disabled Tests ====================== - - -def test_rtc_config_enabled_true(): - """Test RTCConfig with enabled=True.""" - config = RTCConfig(enabled=True) - assert config.enabled is True - - -def test_rtc_config_enabled_false(): - """Test RTCConfig with enabled=False.""" - config = RTCConfig(enabled=False) - assert config.enabled is False - - -# ====================== Debug Tests ====================== - - -def test_rtc_config_debug_enabled(): - """Test RTCConfig with debug enabled.""" - config = RTCConfig(debug=True, debug_maxlen=500) - assert config.debug is True - assert config.debug_maxlen == 500 - - -def test_rtc_config_debug_disabled(): - """Test RTCConfig with debug disabled.""" - config = RTCConfig(debug=False) - assert config.debug is False - - -# ====================== Execution Horizon Tests ====================== - - -def test_rtc_config_with_small_execution_horizon(): - """Test RTCConfig with small execution horizon.""" - config = RTCConfig(execution_horizon=1) - assert config.execution_horizon == 1 - - -def test_rtc_config_with_large_execution_horizon(): - """Test RTCConfig with large execution horizon.""" - config = RTCConfig(execution_horizon=100) - assert config.execution_horizon == 100 - - -def test_rtc_config_with_zero_execution_horizon(): - """Test RTCConfig accepts zero execution horizon.""" - # No validation on execution_horizon, so this should work - config = RTCConfig(execution_horizon=0) - assert config.execution_horizon == 0 - - -def test_rtc_config_with_negative_execution_horizon(): - """Test RTCConfig accepts negative execution horizon.""" - # No validation on execution_horizon, so this should work - config = RTCConfig(execution_horizon=-1) - assert config.execution_horizon == -1 - - -# ====================== Integration Tests ====================== - - -def test_rtc_config_typical_production_settings(): - """Test RTCConfig with typical production settings.""" - config = RTCConfig( - enabled=True, - prefix_attention_schedule=RTCAttentionSchedule.EXP, - max_guidance_weight=10.0, - execution_horizon=8, - debug=False, - ) - - assert config.enabled is True - assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP - assert config.max_guidance_weight == 10.0 - assert config.execution_horizon == 8 - assert config.debug is False - - -def test_rtc_config_typical_debug_settings(): - """Test RTCConfig with typical debug settings.""" - config = RTCConfig( - enabled=True, - prefix_attention_schedule=RTCAttentionSchedule.LINEAR, - max_guidance_weight=5.0, - execution_horizon=10, - debug=True, - debug_maxlen=1000, - ) - - assert config.enabled is True - assert config.debug is True - assert config.debug_maxlen == 1000 - - -def test_rtc_config_disabled_mode(): - """Test RTCConfig in disabled mode.""" - config = RTCConfig(enabled=False) - - assert config.enabled is False - # Other settings still accessible even when disabled - assert config.max_guidance_weight == 10.0 - assert config.execution_horizon == 10 - - -# ====================== Dataclass Tests ====================== - - -def test_rtc_config_is_dataclass(): - """Test that RTCConfig is a dataclass.""" - from dataclasses import is_dataclass - - assert is_dataclass(RTCConfig) - - -def test_rtc_config_equality(): - """Test RTCConfig equality comparison.""" - config1 = RTCConfig(enabled=True, max_guidance_weight=5.0) - config2 = RTCConfig(enabled=True, max_guidance_weight=5.0) - config3 = RTCConfig(enabled=False, max_guidance_weight=5.0) - - assert config1 == config2 - assert config1 != config3 - - -def test_rtc_config_repr(): - """Test RTCConfig string representation.""" - config = RTCConfig(enabled=True, execution_horizon=20) - repr_str = repr(config) - - assert "RTCConfig" in repr_str - assert "enabled=True" in repr_str - assert "execution_horizon=20" in repr_str - - -# ====================== Edge Cases Tests ====================== - - -def test_rtc_config_very_small_max_guidance_weight(): - """Test RTCConfig with very small positive max_guidance_weight.""" - config = RTCConfig(max_guidance_weight=1e-10) - assert config.max_guidance_weight == pytest.approx(1e-10) - - -def test_rtc_config_very_large_max_guidance_weight(): - """Test RTCConfig with very large max_guidance_weight.""" - config = RTCConfig(max_guidance_weight=1e10) - assert config.max_guidance_weight == pytest.approx(1e10) - - -def test_rtc_config_minimum_debug_maxlen(): - """Test RTCConfig with minimum valid debug_maxlen.""" - config = RTCConfig(debug_maxlen=1) - assert config.debug_maxlen == 1 - - -def test_rtc_config_float_max_guidance_weight(): - """Test RTCConfig with float max_guidance_weight.""" - config = RTCConfig(max_guidance_weight=3.14159) - assert config.max_guidance_weight == pytest.approx(3.14159) - - -# ====================== Type Tests ====================== - - -def test_rtc_config_enabled_type(): - """Test RTCConfig enabled field accepts boolean.""" - config = RTCConfig(enabled=True) - assert isinstance(config.enabled, bool) - - -def test_rtc_config_execution_horizon_type(): - """Test RTCConfig execution_horizon field accepts integer.""" - config = RTCConfig(execution_horizon=15) - assert isinstance(config.execution_horizon, int) - - -def test_rtc_config_max_guidance_weight_type(): - """Test RTCConfig max_guidance_weight field accepts float.""" - config = RTCConfig(max_guidance_weight=7.5) - assert isinstance(config.max_guidance_weight, float) - - -def test_rtc_config_debug_maxlen_type(): - """Test RTCConfig debug_maxlen field accepts integer.""" - config = RTCConfig(debug_maxlen=200) - assert isinstance(config.debug_maxlen, int) diff --git a/tests/policies/rtc/test_debug_visualizer.py b/tests/policies/rtc/test_debug_visualizer.py deleted file mode 100644 index 41b2926fe..000000000 --- a/tests/policies/rtc/test_debug_visualizer.py +++ /dev/null @@ -1,427 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RTC debug visualizer module.""" - -from unittest.mock import MagicMock - -import numpy as np -import pytest -import torch - -from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer - -# ====================== Fixtures ====================== - - -@pytest.fixture -def mock_axes(): - """Create mock matplotlib axes.""" - axes = [] - for _ in range(6): - ax = MagicMock() - ax.xaxis.get_label.return_value.get_text.return_value = "" - ax.yaxis.get_label.return_value.get_text.return_value = "" - axes.append(ax) - return axes - - -@pytest.fixture -def sample_tensor_2d(): - """Create a 2D sample tensor (time_steps, num_dims).""" - return torch.randn(50, 6) - - -@pytest.fixture -def sample_tensor_3d(): - """Create a 3D sample tensor (batch, time_steps, num_dims).""" - return torch.randn(1, 50, 6) - - -@pytest.fixture -def sample_numpy_2d(): - """Create a 2D numpy array.""" - return np.random.randn(50, 6) - - -# ====================== Basic Plotting Tests ====================== - - -def test_plot_waypoints_with_2d_tensor(mock_axes, sample_tensor_2d): - """Test plot_waypoints with 2D tensor.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) - - # Should call plot on each axis (6 dimensions) - for ax in mock_axes: - ax.plot.assert_called_once() - - -def test_plot_waypoints_with_3d_tensor(mock_axes, sample_tensor_3d): - """Test plot_waypoints with 3D tensor (batch dimension).""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_3d) - - # Should still plot 6 dimensions (batch dimension removed) - for ax in mock_axes: - ax.plot.assert_called_once() - - -def test_plot_waypoints_with_numpy_array(mock_axes, sample_numpy_2d): - """Test plot_waypoints with numpy array.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_numpy_2d) - - # Should work with numpy arrays - for ax in mock_axes: - ax.plot.assert_called_once() - - -def test_plot_waypoints_with_none_tensor(mock_axes): - """Test plot_waypoints returns early when tensor is None.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, None) - - # Should not call plot on any axis - for ax in mock_axes: - ax.plot.assert_not_called() - - -# ====================== Parameter Tests ====================== - - -def test_plot_waypoints_with_custom_color(mock_axes, sample_tensor_2d): - """Test plot_waypoints uses custom color.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, color="red") - - # Check that color was passed to plot - for ax in mock_axes: - call_kwargs = ax.plot.call_args[1] - assert call_kwargs["color"] == "red" - - -def test_plot_waypoints_with_custom_label(mock_axes, sample_tensor_2d): - """Test plot_waypoints uses custom label.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label") - - # First axis should have label, others should not - first_ax_kwargs = mock_axes[0].plot.call_args[1] - assert first_ax_kwargs["label"] == "test_label" - - # Other axes should have empty label - for ax in mock_axes[1:]: - call_kwargs = ax.plot.call_args[1] - assert call_kwargs["label"] == "" - - -def test_plot_waypoints_with_custom_alpha(mock_axes, sample_tensor_2d): - """Test plot_waypoints uses custom alpha.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, alpha=0.5) - - for ax in mock_axes: - call_kwargs = ax.plot.call_args[1] - assert call_kwargs["alpha"] == 0.5 - - -def test_plot_waypoints_with_custom_linewidth(mock_axes, sample_tensor_2d): - """Test plot_waypoints uses custom linewidth.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, linewidth=3) - - for ax in mock_axes: - call_kwargs = ax.plot.call_args[1] - assert call_kwargs["linewidth"] == 3 - - -def test_plot_waypoints_with_marker(mock_axes, sample_tensor_2d): - """Test plot_waypoints with marker style.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker="o", markersize=5) - - for ax in mock_axes: - call_kwargs = ax.plot.call_args[1] - assert call_kwargs["marker"] == "o" - assert call_kwargs["markersize"] == 5 - - -def test_plot_waypoints_without_marker(mock_axes, sample_tensor_2d): - """Test plot_waypoints without marker (default).""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker=None) - - # Marker should not be in kwargs when None - for ax in mock_axes: - call_kwargs = ax.plot.call_args[1] - assert "marker" not in call_kwargs - assert "markersize" not in call_kwargs - - -# ====================== start_from Parameter Tests ====================== - - -def test_plot_waypoints_with_start_from_zero(mock_axes, sample_tensor_2d): - """Test plot_waypoints with start_from=0.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0) - - # X indices should start from 0 - for ax in mock_axes: - call_args = ax.plot.call_args[0] - x_indices = call_args[0] - assert x_indices[0] == 0 - assert len(x_indices) == 50 - - -def test_plot_waypoints_with_start_from_nonzero(mock_axes, sample_tensor_2d): - """Test plot_waypoints with start_from > 0.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=10) - - # X indices should start from 10 - for ax in mock_axes: - call_args = ax.plot.call_args[0] - x_indices = call_args[0] - assert x_indices[0] == 10 - assert x_indices[-1] == 59 # 10 + 50 - 1 - - -# ====================== Tensor Shape Tests ====================== - - -def test_plot_waypoints_with_1d_tensor(mock_axes): - """Test plot_waypoints with 1D tensor.""" - tensor_1d = torch.randn(50) - RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_1d) - - # Should reshape to (50, 1) and plot on first axis only - mock_axes[0].plot.assert_called_once() - for ax in mock_axes[1:]: - ax.plot.assert_not_called() - - -def test_plot_waypoints_with_fewer_dims_than_axes(sample_tensor_2d): - """Test plot_waypoints when tensor has fewer dims than axes.""" - # Create tensor with only 3 dimensions - tensor_3d = sample_tensor_2d[:, :3] - - # Create 6 axes but tensor only has 3 dims - mock_axes = [MagicMock() for _ in range(6)] - for ax in mock_axes: - ax.xaxis.get_label.return_value.get_text.return_value = "" - ax.yaxis.get_label.return_value.get_text.return_value = "" - - RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_3d) - - # Should only plot on first 3 axes - for i in range(3): - mock_axes[i].plot.assert_called_once() - for i in range(3, 6): - mock_axes[i].plot.assert_not_called() - - -# ====================== Axis Labeling Tests ====================== - - -def test_plot_waypoints_sets_xlabel(mock_axes, sample_tensor_2d): - """Test plot_waypoints sets x-axis label.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) - - for ax in mock_axes: - ax.set_xlabel.assert_called_once_with("Step", fontsize=10) - - -def test_plot_waypoints_sets_ylabel(mock_axes, sample_tensor_2d): - """Test plot_waypoints sets y-axis label.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) - - for i, ax in enumerate(mock_axes): - ax.set_ylabel.assert_called_once_with(f"Dim {i}", fontsize=10) - - -def test_plot_waypoints_skips_label_if_exists(sample_tensor_2d): - """Test plot_waypoints doesn't set labels if they already exist.""" - mock_axes_with_labels = [] - for _ in range(6): - ax = MagicMock() - # Simulate existing labels - ax.xaxis.get_label.return_value.get_text.return_value = "Existing X Label" - ax.yaxis.get_label.return_value.get_text.return_value = "Existing Y Label" - mock_axes_with_labels.append(ax) - - RTCDebugVisualizer.plot_waypoints(mock_axes_with_labels, sample_tensor_2d) - - # Should not set labels when they already exist - for ax in mock_axes_with_labels: - ax.set_xlabel.assert_not_called() - ax.set_ylabel.assert_not_called() - - -# ====================== Grid Tests ====================== - - -def test_plot_waypoints_enables_grid(mock_axes, sample_tensor_2d): - """Test plot_waypoints enables grid on all axes.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d) - - for ax in mock_axes: - ax.grid.assert_called_once_with(True, alpha=0.3) - - -# ====================== Legend Tests ====================== - - -def test_plot_waypoints_adds_legend_with_label(mock_axes, sample_tensor_2d): - """Test plot_waypoints adds legend when label is provided.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label") - - # Should add legend to first axis only - mock_axes[0].legend.assert_called_once_with(loc="best", fontsize=8) - - # Should not add legend to other axes - for ax in mock_axes[1:]: - ax.legend.assert_not_called() - - -def test_plot_waypoints_no_legend_without_label(mock_axes, sample_tensor_2d): - """Test plot_waypoints doesn't add legend when no label provided.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="") - - # Should not add legend to any axis - for ax in mock_axes: - ax.legend.assert_not_called() - - -# ====================== Data Correctness Tests ====================== - - -def test_plot_waypoints_plots_correct_data(mock_axes, sample_tensor_2d): - """Test plot_waypoints plots correct tensor values.""" - RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0) - - # Check first axis to verify data correctness - call_args = mock_axes[0].plot.call_args[0] - x_indices = call_args[0] - y_values = call_args[1] - - # X indices should be 0 to 49 - np.testing.assert_array_equal(x_indices, np.arange(50)) - - # Y values should match first dimension of tensor - expected_y = sample_tensor_2d[:, 0].cpu().numpy() - np.testing.assert_array_almost_equal(y_values, expected_y) - - -def test_plot_waypoints_handles_gpu_tensor(mock_axes): - """Test plot_waypoints handles GPU tensors (if CUDA available).""" - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - - tensor_gpu = torch.randn(50, 6, device="cuda") - RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_gpu) - - # Should successfully plot without errors - for ax in mock_axes: - ax.plot.assert_called_once() - - -# ====================== Edge Cases Tests ====================== - - -def test_plot_waypoints_with_empty_tensor(mock_axes): - """Test plot_waypoints with empty tensor.""" - empty_tensor = torch.empty(0, 6) - RTCDebugVisualizer.plot_waypoints(mock_axes, empty_tensor) - - # Should plot empty data - for ax in mock_axes: - call_args = ax.plot.call_args[0] - x_indices = call_args[0] - assert len(x_indices) == 0 - - -def test_plot_waypoints_with_single_timestep(mock_axes): - """Test plot_waypoints with single timestep.""" - single_step_tensor = torch.randn(1, 6) - RTCDebugVisualizer.plot_waypoints(mock_axes, single_step_tensor) - - # Should plot single point - for ax in mock_axes: - call_args = ax.plot.call_args[0] - x_indices = call_args[0] - assert len(x_indices) == 1 - - -def test_plot_waypoints_with_very_large_tensor(mock_axes): - """Test plot_waypoints with very large tensor.""" - large_tensor = torch.randn(10000, 6) - RTCDebugVisualizer.plot_waypoints(mock_axes, large_tensor) - - # Should handle large tensors - for ax in mock_axes: - call_args = ax.plot.call_args[0] - x_indices = call_args[0] - assert len(x_indices) == 10000 - - -# ====================== Multiple Calls Tests ====================== - - -def test_plot_waypoints_multiple_calls_on_same_axes(mock_axes, sample_tensor_2d): - """Test multiple plot_waypoints calls on same axes.""" - tensor1 = sample_tensor_2d - tensor2 = torch.randn(50, 6) - - RTCDebugVisualizer.plot_waypoints(mock_axes, tensor1, color="blue", label="Series 1") - RTCDebugVisualizer.plot_waypoints(mock_axes, tensor2, color="red", label="Series 2") - - # Each axis should have been called twice - for ax in mock_axes: - assert ax.plot.call_count == 2 - - -# ====================== Integration Tests ====================== - - -def test_plot_waypoints_typical_usage(mock_axes, sample_tensor_2d): - """Test plot_waypoints with typical usage pattern.""" - RTCDebugVisualizer.plot_waypoints( - mock_axes, sample_tensor_2d, start_from=0, color="blue", label="Trajectory", alpha=0.7, linewidth=2 - ) - - # Verify all expected calls were made - for ax in mock_axes: - ax.plot.assert_called_once() - ax.set_xlabel.assert_called_once() - ax.set_ylabel.assert_called_once() - ax.grid.assert_called_once() - - # First axis should have legend - mock_axes[0].legend.assert_called_once() - - -def test_plot_waypoints_with_all_parameters(mock_axes, sample_tensor_2d): - """Test plot_waypoints with all parameters specified.""" - RTCDebugVisualizer.plot_waypoints( - axes=mock_axes, - tensor=sample_tensor_2d, - start_from=10, - color="green", - label="Full Test", - alpha=0.8, - linewidth=3, - marker="o", - markersize=6, - ) - - # Check first axis for all parameters - call_kwargs = mock_axes[0].plot.call_args[1] - assert call_kwargs["color"] == "green" - assert call_kwargs["label"] == "Full Test" - assert call_kwargs["alpha"] == 0.8 - assert call_kwargs["linewidth"] == 3 - assert call_kwargs["marker"] == "o" - assert call_kwargs["markersize"] == 6 diff --git a/tests/policies/rtc/test_latency_tracker.py b/tests/policies/rtc/test_latency_tracker.py index af6b89431..ee8ca9e11 100644 --- a/tests/policies/rtc/test_latency_tracker.py +++ b/tests/policies/rtc/test_latency_tracker.py @@ -184,74 +184,6 @@ def test_max_after_reset(tracker): assert tracker.max() == 0.0 -# ====================== percentile() Tests ====================== - - -def test_percentile_returns_zero_when_empty(tracker): - """Test percentile() returns 0.0 when tracker is empty.""" - assert tracker.percentile(0.5) == 0.0 - assert tracker.percentile(0.95) == 0.0 - - -def test_percentile_median(tracker): - """Test percentile(0.5) returns median.""" - # Add sorted values for easier verification - values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] - for v in values: - tracker.add(v) - - # Median should be around 0.5 - median = tracker.percentile(0.5) - assert 0.45 <= median <= 0.55 - - -def test_percentile_minimum_with_zero(tracker): - """Test percentile(0.0) returns minimum.""" - tracker.add(0.5) - tracker.add(0.2) - tracker.add(0.8) - - assert tracker.percentile(0.0) == 0.2 - - -def test_percentile_maximum_with_one(tracker): - """Test percentile(1.0) returns maximum.""" - tracker.add(0.5) - tracker.add(0.2) - tracker.add(0.8) - - assert tracker.percentile(1.0) == 0.8 - - -def test_percentile_95(tracker): - """Test percentile(0.95) returns 95th percentile.""" - # Add 100 values from 0.0 to 0.99 - for i in range(100): - tracker.add(i / 100.0) - - p95 = tracker.percentile(0.95) - # 95th percentile should be around 0.95 - assert 0.93 <= p95 <= 0.96 - - -def test_percentile_negative_value_returns_min(tracker): - """Test percentile with negative q returns minimum.""" - tracker.add(0.5) - tracker.add(0.2) - tracker.add(0.8) - - assert tracker.percentile(-0.5) == 0.2 - - -def test_percentile_value_greater_than_one_returns_max(tracker): - """Test percentile with q > 1.0 returns maximum.""" - tracker.add(0.5) - tracker.add(0.2) - tracker.add(0.8) - - assert tracker.percentile(1.5) == 0.8 - - # ====================== p95() Tests ====================== @@ -278,79 +210,6 @@ def test_p95_equals_percentile_95(tracker): assert tracker.p95() == tracker.percentile(0.95) -# ====================== __len__() Tests ====================== - - -def test_len_returns_zero_initially(tracker): - """Test __len__ returns 0 for new tracker.""" - assert len(tracker) == 0 - - -def test_len_increments_with_add(tracker): - """Test __len__ increments as values are added.""" - assert len(tracker) == 0 - - tracker.add(0.1) - assert len(tracker) == 1 - - tracker.add(0.2) - assert len(tracker) == 2 - - tracker.add(0.3) - assert len(tracker) == 3 - - -def test_len_respects_maxlen(small_tracker): - """Test __len__ respects maxlen limit.""" - # Add more than maxlen values - for i in range(10): - small_tracker.add(i / 10.0) - - # Should only keep last 5 - assert len(small_tracker) == 5 - - -def test_len_after_reset(tracker): - """Test __len__ returns 0 after reset.""" - tracker.add(0.5) - tracker.add(0.3) - assert len(tracker) == 2 - - tracker.reset() - assert len(tracker) == 0 - - -# ====================== Sliding Window Tests ====================== - - -def test_sliding_window_removes_oldest(small_tracker): - """Test sliding window removes oldest values.""" - # Add 7 values to tracker with maxlen=5 - values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - for v in values: - small_tracker.add(v) - - # Should only have last 5: [0.3, 0.4, 0.5, 0.6, 0.7] - assert len(small_tracker) == 5 - - # Median should reflect last 5 values - median = small_tracker.percentile(0.5) - assert 0.45 <= median <= 0.55 - - -def test_sliding_window_maintains_max(small_tracker): - """Test sliding window maintains correct max even after overflow.""" - small_tracker.add(0.1) - small_tracker.add(0.9) - small_tracker.add(0.2) - small_tracker.add(0.3) - small_tracker.add(0.4) - small_tracker.add(0.5) # Pushes out 0.1 - - # Max should still be 0.9 - assert small_tracker.max() == 0.9 - - # ====================== Edge Cases Tests ====================== @@ -436,24 +295,6 @@ def test_reset_and_reuse(tracker): assert tracker.percentile(0.5) <= 0.8 -def test_continuous_monitoring(small_tracker): - """Test continuous monitoring with sliding window.""" - # Simulate continuous latency monitoring - # First 5 latencies - for i in range(5): - small_tracker.add(0.1 * (i + 1)) - - max_before = small_tracker.max() - - # Add 5 more (window slides) - for i in range(5, 10): - small_tracker.add(0.1 * (i + 1)) - - # Max should have increased - assert small_tracker.max() > max_before - assert len(small_tracker) == 5 # Window size maintained - - # ====================== Type Conversion Tests ======================