Update README

This commit is contained in:
Eugene Mironov
2025-11-10 19:04:12 +07:00
parent 9e92337f24
commit 433ccc9603
8 changed files with 321 additions and 1066 deletions
+153 -183
View File
@@ -16,156 +16,161 @@ Real-Time Chunking addresses the challenge of maintaining consistency and reacti
## Scripts ## Scripts
### 1. `real_time_chunking_evaluate.py` ### 1. `eval_dataset.py`
Real-time evaluation on physical robots or simulation environments. Offline evaluation on dataset samples with detailed visualization and validation.
**Features:** **Features:**
- Run policy with RTC on real robot or simulation - Compare RTC vs non-RTC predictions on two random dataset samples
- Compare RTC vs non-RTC actions in real-time - Validate RTC behavior (delay region, blend region, post-horizon region)
- Multi-threaded action execution and inference - Generate debug visualizations:
- Denoising step comparisons (x_t, v_t, x1_t, corrections)
- Final action predictions comparison
- Support for torch.compile() optimization - Support for torch.compile() optimization
- Memory-efficient sequential policy loading for large models
**Usage:** **Usage:**
```bash ```bash
# With real robot # Basic usage with SmolVLA policy
uv run python examples/rtc/real_time_chunking_evaluate.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/smolvla_base \ --policy.path=helper2424/smolvla_check_rtc_last3 \
--robot.type=so100 \ --dataset.repo_id=helper2424/check_rtc \
--task="pick up the cup" --rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--seed=10
# With simulation environment # With Pi0.5 policy on CUDA
uv run python examples/rtc/real_time_chunking_evaluate.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/smolvla_base \ --policy.path=lerobot/pi05_libero_finetuned \
--env.type=pusht \ --dataset.repo_id=HuggingFaceVLA/libero \
--duration=60.0 --rtc.execution_horizon=8 \
--device=cuda
# Disable verbose comparison (faster) # With Pi0 policy
uv run python examples/rtc/real_time_chunking_evaluate.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/smolvla_base \ --policy.path=lerobot/pi0_libero_finetuned \
--robot.type=so100 \ --dataset.repo_id=HuggingFaceVLA/libero \
--verbose_rtc_comparison=false --rtc.execution_horizon=8 \
--device=cuda
# With policy compilation (CUDA only, not MPS) # With torch.compile for faster inference
uv run python examples/rtc/real_time_chunking_evaluate.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/smolvla_base \ --policy.path=helper2424/smolvla_check_rtc_last3 \
--robot.type=so100 \ --dataset.repo_id=helper2424/check_rtc \
--compile_policy=true \ --rtc.execution_horizon=8 \
--compile_mode=max-autotune --device=cuda \
``` --use_torch_compile=true \
--torch_compile_mode=max-autotune
**Key Parameters:** # Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
- `--policy.path`: Path to pretrained policy --policy.path=helper2424/smolvla_check_rtc_last3 \
- `--robot.type` or `--env.type`: Robot or environment to use --dataset.repo_id=helper2424/check_rtc \
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10) --use_torch_compile=true \
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0) --torch_compile_backend=inductor \
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP) --torch_compile_mode=max-autotune \
- `--verbose_rtc_comparison`: Enable detailed RTC comparison logging (default: true) --torch_compile_disable_cudagraphs=false
- `--duration`: How long to run (seconds, default: 30.0)
- `--fps`: Action execution frequency (Hz, default: 10.0)
### 2. `evaluate_rtc_on_dataset.py`
Offline evaluation on dataset samples to measure RTC effectiveness.
**Features:**
- Evaluate RTC on dataset without running robot
- Compare RTC vs non-RTC predictions
- Measure consistency and ground truth alignment
- Simulate different inference delays
- Save detailed metrics to JSON
**Usage:**
```bash
# Basic evaluation
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=lerobot/pusht \
--num_iterations=100
# Simulate inference delay (every 3rd step)
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=lerobot/pusht \
--num_iterations=200 \
--skip_steps=3
# Custom RTC configuration
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=lerobot/pusht \
--num_iterations=100 \
--rtc.execution_horizon=12 \
--rtc.max_guidance_weight=5.0 \
--rtc.prefix_attention_schedule=LINEAR
# Save results to file
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=lerobot/pusht \
--num_iterations=100 \
--output_path=results/rtc_evaluation.json
# Verbose mode with detailed logging
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
--policy.path=lerobot/smolvla_base \
--dataset.repo_id=lerobot/pusht \
--num_iterations=50 \
--verbose=true
``` ```
**Key Parameters:** **Key Parameters:**
- `--policy.path`: Path to pretrained policy - `--policy.path`: Path to pretrained policy
- `--dataset.repo_id`: Dataset to evaluate on - `--dataset.repo_id`: Dataset to evaluate on
- `--num_iterations`: Number of samples to evaluate (default: 100) - `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 20)
- `--skip_steps`: Steps to skip between inferences, simulates inference delay (default: 1) - `--rtc.max_guidance_weight`: Maximum guidance weight (default: 10.0)
- `--start_episode`: Episode to start from (default: 0) - `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
- `--output_path`: Path to save results JSON - `--inference_delay`: Inference delay for RTC (default: 4)
- `--verbose`: Enable detailed per-sample logging - `--seed`: Random seed for reproducibility (default: 42)
- `--output_dir`: Directory to save visualizations (default: rtc_debug_output)
- `--device`: Device to use (cuda, cpu, mps, auto) - `--device`: Device to use (cuda, cpu, mps, auto)
- `--use_torch_compile`: Enable torch.compile() for faster inference
**Metrics Reported:** **Output:**
- **RTC vs Ground Truth MSE**: How close RTC predictions are to actual actions The script generates several visualization files in `rtc_debug_output/`:
- **No-RTC vs Ground Truth MSE**: Baseline without RTC
- **RTC Improvement**: Absolute and relative improvement over baseline
- **RTC Consistency**: How well RTC maintains consistency in prefix region
- Prefix MSE
- Mean/Max error in overlap region
### 3. `run_dataset_evaluation.sh` - `denoising_xt_comparison.png` - Noisy state evolution during denoising
- `denoising_vt_comparison.png` - Velocity predictions during denoising
- `denoising_x1t_comparison.png` - Predicted final states during denoising
- `denoising_correction_comparison.png` - RTC guidance corrections applied
- `final_actions_comparison.png` - Final action predictions (prev_chunk, no_rtc, rtc)
Convenience script with multiple evaluation scenarios. The script also validates RTC behavior and reports:
- ✅ Delay region [0:inference_delay]: RTC = prev_chunk
- ✅ Blend region [inference_delay:execution_horizon]: prev_chunk ≤ RTC ≤ no_rtc
- ✅ Post-horizon [execution_horizon:]: RTC = no_rtc
### 2. `eval_with_real_robot.py`
Real-time evaluation on physical robots or simulation environments.
**Features:**
- Run policy with RTC on real robot or simulation
- Multi-threaded action execution and inference
- Action queue management with proper timing
- Latency tracking and adaptive inference delay
- Support for both robots and gym environments
- Support for torch.compile() optimization
**Usage:** **Usage:**
```bash ```bash
# Edit the script to set your policy and dataset # With real robot
# Then run all examples: uv run python examples/rtc/eval_with_real_robot.py \
./examples/rtc/run_dataset_evaluation.sh --policy.path=lerobot/smolvla_base \
--robot.type=so100 \
--task="pick up the cup" \
--duration=30.0
# Or run individual examples from the script # With simulation environment
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--env.type=pusht \
--duration=60.0
# With policy compilation (CUDA only, not MPS)
uv run python examples/rtc/eval_with_real_robot.py \
--policy.path=lerobot/smolvla_base \
--robot.type=so100 \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
``` ```
**Key Parameters:**
- `--policy.path`: Path to pretrained policy
- `--robot.type` or `--env.type`: Robot or environment to use
- `--task`: Task description (for VLA models)
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
- `--duration`: How long to run (seconds, default: 30.0)
- `--fps`: Action execution frequency (Hz, default: 10.0)
- `--action_queue_size_to_get_new_actions`: Queue size threshold to request new actions (default: 30)
- `--device`: Device to use (cuda, cpu, mps, auto)
- `--use_torch_compile`: Enable torch.compile() for faster inference
## Understanding RTC Parameters ## Understanding RTC Parameters
### `execution_horizon` ### `execution_horizon`
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity. Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
**Typical values:** 8-12 steps **Typical values:** 8-12 steps for dataset evaluation, 10 steps for real-time execution
### `max_guidance_weight` ### `max_guidance_weight`
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions. Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
**Typical values:** 1.0-10.0 **Typical values:**
- Dataset evaluation: 10.0-100.0 (can be higher for analysis)
- Real-time execution: 1.0-10.0 (more conservative)
### `prefix_attention_schedule` ### `prefix_attention_schedule`
@@ -178,104 +183,69 @@ How to weight consistency across the overlap region:
**Recommended:** `EXP` **Recommended:** `EXP`
### `skip_steps` (evaluation only) ### `inference_delay`
Simulates inference delay by evaluating every N-th step. This helps understand how RTC performs with realistic delays. Number of timesteps from the prefix to use for guidance. Typically calculated dynamically based on inference latency in real-time execution, but fixed for dataset evaluation.
**Example:** `skip_steps=3` means policy infers every 3 steps, simulating 3x action execution frequency vs inference frequency. **Typical values:** 3-5 steps for dataset evaluation
## Output Format (Dataset Evaluation) ### `action_queue_size_to_get_new_actions` (real-time only)
When using `--output_path`, results are saved in JSON format: Threshold for requesting new action chunks. Should be higher than `inference_delay + execution_horizon` to ensure smooth operation.
```json **Typical values:** 20-30 steps
{
"summary": { ## Validation Rules (Dataset Evaluation)
"rtc_vs_ground_truth_mse": {
"mean": 0.00123, The dataset evaluation script validates that RTC behavior matches expectations:
"std": 0.00045,
"min": 0.00012, 1. **Delay Region [0:inference_delay]**: RTC actions should equal previous chunk
"max": 0.00456 - Ensures consistency during the inference delay period
},
"improvement": { 2. **Blend Region [inference_delay:execution_horizon]**: RTC should be between prev_chunk and no_rtc
"absolute": 0.00034, - Smooth transition from previous plan to new predictions
"relative_percent": 12.5
}, 3. **Post-Horizon [execution_horizon:]**: RTC should equal no_rtc
... - Full adoption of new predictions after execution horizon
},
"config": {
"num_iterations": 100,
"skip_steps": 3,
"execution_horizon": 10,
...
},
"detailed_results": [
{
"sample_idx": 0,
"rtc_vs_ground_truth_mse": 0.00112,
"no_rtc_vs_ground_truth_mse": 0.00145,
...
},
...
]
}
```
## Tips ## Tips
1. **Start with dataset evaluation** to understand RTC behavior before running on robot 1. **Start with dataset evaluation** (`eval_dataset.py`) to understand RTC behavior and tune parameters before running on robot
2. **Use verbose mode** for debugging unexpected behavior 2. **Use visualizations** to debug unexpected behavior - check denoising steps and final actions
3. **Tune execution_horizon** based on your inference latency and action frequency 3. **Tune execution_horizon** based on your inference latency and action frequency
4. **Monitor consistency metrics** - very low consistency might indicate execution_horizon is too small 4. **Monitor validation output** - failures indicate potential implementation issues or misconfigured parameters
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable 5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
## Troubleshooting ## Troubleshooting
### High RTC vs No-RTC difference but no improvement ### Validation fails in delay region
- Try reducing `max_guidance_weight` - Check that `prev_chunk_left_over` is properly passed to the policy
- Check if `execution_horizon` is too large - Verify RTC guidance is being applied during denoising
- Look at denoising visualizations to see where guidance diverges
### Poor consistency metrics ### Validation fails in post-horizon region
- Increase `execution_horizon` - RTC and no_rtc use different noise - verify same noise is being used for comparison
- Check that `skip_steps` is not larger than your action chunk size - Check that weights are correctly zeroed out after execution horizon
- Verify episodes are being reset correctly - Review prefix_attention_schedule visualization
### RTC worse than No-RTC ### Poor performance on real robot
- RTC may not help if inference is faster than action execution - Increase `action_queue_size_to_get_new_actions` if you see warnings
- Try different `prefix_attention_schedule` - Reduce `max_guidance_weight` if robot is too conservative
- Ensure `execution_horizon` matches your use case - Try different `prefix_attention_schedule` values
- Enable torch.compile() for faster inference (CUDA only)
## Examples Results ### Memory issues with large models
Example output from dataset evaluation: - The dataset evaluation script loads policies sequentially to minimize memory
- For real-time execution, only one policy is loaded
``` - Use smaller batch sizes if needed
================================================================================
EVALUATION SUMMARY
================================================================================
Ground Truth Alignment:
RTC MSE: 0.001234 ± 0.000456
No-RTC MSE: 0.001567 ± 0.000512
RTC Improvement:
Absolute: 0.000333
Relative: 21.23%
RTC vs No-RTC Difference:
MSE: 0.000112 ± 0.000034
RTC Consistency (Prefix Region):
MSE: 0.000089 ± 0.000023
Mean Error: 0.007654 ± 0.002341
Max Error: 0.023456 ± 0.008765
```
## Related Documentation ## Related Documentation
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py) - [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py) - [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
- [Action Queue](../../src/lerobot/policies/rtc/action_queue.py)
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf) - [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
+158 -38
View File
@@ -16,7 +16,9 @@ Usage:
--policy.path=helper2424/smolvla_check_rtc_last3 \ --policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \ --dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \ --rtc.execution_horizon=8 \
--device=mps --device=mps \
--rtc.max_guidance_weight=10.0 \
--seed=10
# Basic usage with pi0.5 policy # Basic usage with pi0.5 policy
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
@@ -439,6 +441,8 @@ class RTCEvaluator:
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc") logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
logging.info("=" * 80) logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 2 # Initialize policy 2
policy_no_rtc_policy = self._init_policy( policy_no_rtc_policy = self._init_policy(
name="policy_no_rtc", name="policy_no_rtc",
@@ -470,6 +474,8 @@ class RTCEvaluator:
logging.info("Step 3: Generating actions WITH RTC with policy_rtc") logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
logging.info("=" * 80) logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 3 # Initialize policy 3
policy_rtc_policy = self._init_policy( policy_rtc_policy = self._init_policy(
name="policy_rtc", name="policy_rtc",
@@ -510,6 +516,11 @@ class RTCEvaluator:
logging.info("Validating RTC behavior...") logging.info("Validating RTC behavior...")
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over) self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
# Plot final actions comparison
logging.info("=" * 80)
logging.info("Plotting final actions comparison...")
self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
logging.info("=" * 80) logging.info("=" * 80)
logging.info("Evaluation completed successfully") logging.info("Evaluation completed successfully")
@@ -527,29 +538,24 @@ class RTCEvaluator:
no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim) no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim)
prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim) prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim)
""" """
if rtc_actions is None or no_rtc_actions is None: # Remove batch dimension if present and move to CPU
logging.warning(" ⚠ Cannot validate: missing action predictions") rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
return no_rtc_actions_t = (
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
)
prev_chunk = prev_chunk_left_over.cpu()
# Convert to numpy for comparison (remove batch dimension if present) logging.info(f" rtc_actions shape: {rtc_actions_t.shape}")
rtc_actions_np = ( logging.info(f" no_rtc_actions shape: {no_rtc_actions_t.shape}")
rtc_actions.squeeze(0).cpu().numpy() if len(rtc_actions.shape) == 3 else rtc_actions.cpu().numpy() logging.info(f" prev_chunk shape: {prev_chunk.shape}")
)
no_rtc_actions_np = (
no_rtc_actions.squeeze(0).cpu().numpy()
if len(no_rtc_actions.shape) == 3
else no_rtc_actions.cpu().numpy()
)
prev_chunk = prev_chunk_left_over.cpu().numpy()
# Determine chunk length for comparison # Determine chunk length for comparison
chunk_len = min(rtc_actions_np.shape[0], no_rtc_actions_np.shape[0], prev_chunk.shape[0]) chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk.shape[0])
inference_delay = self.cfg.inference_delay inference_delay = self.cfg.inference_delay
execution_horizon = self.cfg.rtc.execution_horizon execution_horizon = self.cfg.rtc.execution_horizon
# Tolerance for floating point comparison # Tolerance for floating point comparison
rtol = 1e-3 # Relative tolerance rtol = 1e-2 # Relative tolerance
atol = 1e-3 # Absolute tolerance
validation_passed = True validation_passed = True
warnings = [] warnings = []
@@ -558,19 +564,26 @@ class RTCEvaluator:
logging.info(f" Chunk length: {chunk_len}") logging.info(f" Chunk length: {chunk_len}")
logging.info(f" Inference delay: {inference_delay}") logging.info(f" Inference delay: {inference_delay}")
logging.info(f" Execution horizon: {execution_horizon}") logging.info(f" Execution horizon: {execution_horizon}")
logging.info(f" Tolerance: rtol={rtol}, atol={atol}") logging.info(f" Tolerance: rtol={rtol}")
# ============================================================================ # ============================================================================
# Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk # Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk
# ============================================================================ # ============================================================================
if inference_delay > 0: if inference_delay > 0:
delay_end = min(inference_delay, chunk_len) delay_end = min(inference_delay, chunk_len)
rtc_delay = rtc_actions_np[:delay_end] rtc_delay = rtc_actions_t[:delay_end]
prev_delay = prev_chunk[:delay_end] prev_delay = prev_chunk[:delay_end]
if not np.allclose(rtc_delay, prev_delay, rtol=rtol, atol=atol): logging.info(f" rtc_delay: {rtc_delay.shape}")
max_diff = np.max(np.abs(rtc_delay - prev_delay)) logging.info(f" prev_delay: {prev_delay.shape}")
mean_diff = np.mean(np.abs(rtc_delay - prev_delay))
if not torch.allclose(rtc_delay, prev_delay, rtol=rtol):
max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item()
mean_diff = torch.mean(torch.abs(rtc_delay - prev_delay)).item()
logging.info(f" rtc_delay: {rtc_delay}")
logging.info(f" prev_delay: {prev_delay}")
logging.info(f" max_diff: {max_diff}")
logging.info(f" mean_diff: {mean_diff}")
warnings.append( warnings.append(
f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], " f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], "
f"RTC does NOT equal prev_chunk!\n" f"RTC does NOT equal prev_chunk!\n"
@@ -589,26 +602,26 @@ class RTCEvaluator:
blend_end = min(execution_horizon, chunk_len) blend_end = min(execution_horizon, chunk_len)
if blend_end > blend_start: if blend_end > blend_start:
rtc_blend = rtc_actions_np[blend_start:blend_end] rtc_blend = rtc_actions_t[blend_start:blend_end]
prev_blend = prev_chunk[blend_start:blend_end] prev_blend = prev_chunk[blend_start:blend_end]
no_rtc_blend = no_rtc_actions_np[blend_start:blend_end] no_rtc_blend = no_rtc_actions_t[blend_start:blend_end]
# Check if RTC is between prev_chunk and no_rtc (element-wise) # Check if RTC is between prev_chunk and no_rtc (element-wise)
# For each element, check if it's between the min and max of prev_chunk and no_rtc # For each element, check if it's between the min and max of prev_chunk and no_rtc
min_bound = np.minimum(prev_blend, no_rtc_blend) - atol min_bound = torch.minimum(prev_blend, no_rtc_blend)
max_bound = np.maximum(prev_blend, no_rtc_blend) + atol max_bound = torch.maximum(prev_blend, no_rtc_blend)
within_bounds = np.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
if not np.all(within_bounds): if not torch.all(within_bounds):
violations = np.sum(~within_bounds) violations = torch.sum(~within_bounds).item()
total_elements = within_bounds.size total_elements = within_bounds.numel()
violation_pct = 100.0 * violations / total_elements violation_pct = 100.0 * violations / total_elements
# Find max violation # Find max violation
lower_violations = np.maximum(0, min_bound - rtc_blend) lower_violations = torch.maximum(torch.tensor(0.0), min_bound - rtc_blend)
upper_violations = np.maximum(0, rtc_blend - max_bound) upper_violations = torch.maximum(torch.tensor(0.0), rtc_blend - max_bound)
max_violation = np.max(np.maximum(lower_violations, upper_violations)) max_violation = torch.max(torch.maximum(lower_violations, upper_violations)).item()
warnings.append( warnings.append(
f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], " f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], "
@@ -626,12 +639,15 @@ class RTCEvaluator:
# Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc # Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc
# ============================================================================ # ============================================================================
if execution_horizon < chunk_len: if execution_horizon < chunk_len:
rtc_after = rtc_actions_np[execution_horizon:chunk_len] rtc_after = rtc_actions_t[execution_horizon:chunk_len]
no_rtc_after = no_rtc_actions_np[execution_horizon:chunk_len] no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len]
if not np.allclose(rtc_after, no_rtc_after, rtol=rtol, atol=atol): logging.info(f" rtc_after: {rtc_after}")
max_diff = np.max(np.abs(rtc_after - no_rtc_after)) logging.info(f" no_rtc_after: {no_rtc_after}")
mean_diff = np.mean(np.abs(rtc_after - no_rtc_after))
if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol):
max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item()
mean_diff = torch.mean(torch.abs(rtc_after - no_rtc_after)).item()
warnings.append( warnings.append(
f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], " f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], "
f"RTC does NOT equal no_rtc!\n" f"RTC does NOT equal no_rtc!\n"
@@ -661,6 +677,103 @@ class RTCEvaluator:
logging.error("") logging.error("")
logging.error(" Please check the implementation of RTC guidance.") logging.error(" Please check the implementation of RTC guidance.")
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
"""Plot final action predictions comparison on a single chart.
Args:
rtc_actions: Final actions from RTC policy
no_rtc_actions: Final actions from non-RTC policy
prev_chunk_left_over: Previous chunk used as ground truth
"""
# Remove batch dimension if present
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
no_rtc_actions_plot = (
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
)
prev_chunk_plot = prev_chunk_left_over.cpu()
# Create figure with 6 subplots (one per action dimension)
fig, axes = plt.subplots(6, 1, figsize=(16, 12))
fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
# Plot each action dimension
for dim_idx, ax in enumerate(axes):
# Plot previous chunk (ground truth) in red
RTCDebugVisualizer.plot_waypoints(
[ax],
prev_chunk_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="red",
label="Previous Chunk (Ground Truth)",
linewidth=2.5,
alpha=0.8,
)
# Plot no-RTC actions in blue
RTCDebugVisualizer.plot_waypoints(
[ax],
no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="blue",
label="No RTC",
linewidth=2,
alpha=0.7,
)
# Plot RTC actions in green
RTCDebugVisualizer.plot_waypoints(
[ax],
rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="green",
label="RTC",
linewidth=2,
alpha=0.7,
)
# Add vertical lines for inference delay and execution horizon
inference_delay = self.cfg.inference_delay
execution_horizon = self.cfg.rtc.execution_horizon
if inference_delay > 0:
ax.axvline(
x=inference_delay - 1,
color="orange",
linestyle="--",
alpha=0.5,
label=f"Inference Delay ({inference_delay})",
)
if execution_horizon > 0:
ax.axvline(
x=execution_horizon,
color="purple",
linestyle="--",
alpha=0.5,
label=f"Execution Horizon ({execution_horizon})",
)
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
# Set x-axis ticks to show all integer values
max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
# Add legend only to first subplot
if dim_idx == 0:
ax.legend(loc="best", fontsize=9)
axes[-1].set_xlabel("Step", fontsize=10)
# Save figure
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
fig.tight_layout()
fig.savefig(output_path, dpi=150)
logging.info(f"Saved final actions comparison to {output_path}")
plt.close(fig)
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps): def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
# Create side-by-side figures for denoising visualization # Create side-by-side figures for denoising visualization
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)") fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
@@ -828,6 +941,13 @@ class RTCEvaluator:
margin = y_range * 0.1 margin = y_range * 0.1
ax.set_ylim(ylim[0] - margin, ylim[1] + margin) ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
# Set x-axis ticks to show all integer values
xlim = ax.get_xlim()
max_len = int(xlim[1]) + 1
if max_len > 0:
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
@parser.wrap() @parser.wrap()
def main(cfg: RTCEvalConfig): def main(cfg: RTCEvalConfig):
+8
View File
@@ -14,6 +14,7 @@ RTC can be integrated with any policy that supports flow mathicng for chunking:
- **SmolVLA**: Vision-language-action model with RTC support - **SmolVLA**: Vision-language-action model with RTC support
- **Pi0**: Action prediction model with adaptive chunking - **Pi0**: Action prediction model with adaptive chunking
- **Pi05**: Action prediction model with adaptive chunking
## Original Implementation ## Original Implementation
@@ -39,3 +40,10 @@ uv run python examples/rtc/eval_dataset.py \
--device=mps \ --device=mps \
--seed=42 --seed=42
``` ```
This script will evaluate RTC on a data from a dataset and save the results to a file, u can check the results in the `rtc_debug_output` directory.
The example output should look like this:
![Flow Matching with RTC](./flow_matching.png)
It shows how flow matching works with RTC and without it. The chart shows values of action predictions for each timestep. The colour shows the the generation progress. The blue ones - earlier timesteps, the yellow ones - later timesteps. The red line is the ground truth (previous action chunk).
Binary file not shown.

After

Width:  |  Height:  |  Size: 538 KiB

+2 -1
View File
@@ -168,6 +168,8 @@ class RTCProcessor:
v_t = original_denoise_step_partial(x_t) v_t = original_denoise_step_partial(x_t)
return v_t return v_t
x_t = x_t.clone().detach()
squeezed = False squeezed = False
if len(x_t.shape) < 3: if len(x_t.shape) < 3:
# Add batch dimension # Add batch dimension
@@ -208,7 +210,6 @@ class RTCProcessor:
with torch.enable_grad(): with torch.enable_grad():
v_t = original_denoise_step_partial(x_t) v_t = original_denoise_step_partial(x_t)
x_t = x_t.clone().detach()
x_t.requires_grad_(True) x_t.requires_grad_(True)
x1_t = x_t - time * v_t # noqa: N806 x1_t = x_t - time * v_t # noqa: N806
@@ -16,8 +16,6 @@
"""Tests for RTC configuration module.""" """Tests for RTC configuration module."""
import pytest
from lerobot.configs.types import RTCAttentionSchedule from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig
@@ -65,259 +63,3 @@ def test_rtc_config_partial_initialization():
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
assert config.execution_horizon == 10 assert config.execution_horizon == 10
assert config.debug is False assert config.debug is False
# ====================== Validation Tests ======================
def test_rtc_config_validates_positive_max_guidance_weight():
"""Test RTCConfig validates max_guidance_weight is positive."""
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
RTCConfig(max_guidance_weight=0.0)
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
RTCConfig(max_guidance_weight=-1.0)
def test_rtc_config_validates_positive_debug_maxlen():
"""Test RTCConfig validates debug_maxlen is positive."""
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
RTCConfig(debug_maxlen=0)
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
RTCConfig(debug_maxlen=-10)
def test_rtc_config_accepts_valid_max_guidance_weight():
"""Test RTCConfig accepts valid positive max_guidance_weight."""
config1 = RTCConfig(max_guidance_weight=0.1)
assert config1.max_guidance_weight == 0.1
config2 = RTCConfig(max_guidance_weight=100.0)
assert config2.max_guidance_weight == 100.0
def test_rtc_config_accepts_valid_debug_maxlen():
"""Test RTCConfig accepts valid positive debug_maxlen."""
config1 = RTCConfig(debug_maxlen=1)
assert config1.debug_maxlen == 1
config2 = RTCConfig(debug_maxlen=10000)
assert config2.debug_maxlen == 10000
# ====================== Attention Schedule Tests ======================
def test_rtc_config_with_linear_schedule():
"""Test RTCConfig with LINEAR attention schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
def test_rtc_config_with_exp_schedule():
"""Test RTCConfig with EXP attention schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
def test_rtc_config_with_zeros_schedule():
"""Test RTCConfig with ZEROS attention schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
assert config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS
def test_rtc_config_with_ones_schedule():
"""Test RTCConfig with ONES attention schedule."""
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
assert config.prefix_attention_schedule == RTCAttentionSchedule.ONES
# ====================== Enabled/Disabled Tests ======================
def test_rtc_config_enabled_true():
"""Test RTCConfig with enabled=True."""
config = RTCConfig(enabled=True)
assert config.enabled is True
def test_rtc_config_enabled_false():
"""Test RTCConfig with enabled=False."""
config = RTCConfig(enabled=False)
assert config.enabled is False
# ====================== Debug Tests ======================
def test_rtc_config_debug_enabled():
"""Test RTCConfig with debug enabled."""
config = RTCConfig(debug=True, debug_maxlen=500)
assert config.debug is True
assert config.debug_maxlen == 500
def test_rtc_config_debug_disabled():
"""Test RTCConfig with debug disabled."""
config = RTCConfig(debug=False)
assert config.debug is False
# ====================== Execution Horizon Tests ======================
def test_rtc_config_with_small_execution_horizon():
"""Test RTCConfig with small execution horizon."""
config = RTCConfig(execution_horizon=1)
assert config.execution_horizon == 1
def test_rtc_config_with_large_execution_horizon():
"""Test RTCConfig with large execution horizon."""
config = RTCConfig(execution_horizon=100)
assert config.execution_horizon == 100
def test_rtc_config_with_zero_execution_horizon():
"""Test RTCConfig accepts zero execution horizon."""
# No validation on execution_horizon, so this should work
config = RTCConfig(execution_horizon=0)
assert config.execution_horizon == 0
def test_rtc_config_with_negative_execution_horizon():
"""Test RTCConfig accepts negative execution horizon."""
# No validation on execution_horizon, so this should work
config = RTCConfig(execution_horizon=-1)
assert config.execution_horizon == -1
# ====================== Integration Tests ======================
def test_rtc_config_typical_production_settings():
"""Test RTCConfig with typical production settings."""
config = RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
max_guidance_weight=10.0,
execution_horizon=8,
debug=False,
)
assert config.enabled is True
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
assert config.max_guidance_weight == 10.0
assert config.execution_horizon == 8
assert config.debug is False
def test_rtc_config_typical_debug_settings():
"""Test RTCConfig with typical debug settings."""
config = RTCConfig(
enabled=True,
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
max_guidance_weight=5.0,
execution_horizon=10,
debug=True,
debug_maxlen=1000,
)
assert config.enabled is True
assert config.debug is True
assert config.debug_maxlen == 1000
def test_rtc_config_disabled_mode():
"""Test RTCConfig in disabled mode."""
config = RTCConfig(enabled=False)
assert config.enabled is False
# Other settings still accessible even when disabled
assert config.max_guidance_weight == 10.0
assert config.execution_horizon == 10
# ====================== Dataclass Tests ======================
def test_rtc_config_is_dataclass():
"""Test that RTCConfig is a dataclass."""
from dataclasses import is_dataclass
assert is_dataclass(RTCConfig)
def test_rtc_config_equality():
"""Test RTCConfig equality comparison."""
config1 = RTCConfig(enabled=True, max_guidance_weight=5.0)
config2 = RTCConfig(enabled=True, max_guidance_weight=5.0)
config3 = RTCConfig(enabled=False, max_guidance_weight=5.0)
assert config1 == config2
assert config1 != config3
def test_rtc_config_repr():
"""Test RTCConfig string representation."""
config = RTCConfig(enabled=True, execution_horizon=20)
repr_str = repr(config)
assert "RTCConfig" in repr_str
assert "enabled=True" in repr_str
assert "execution_horizon=20" in repr_str
# ====================== Edge Cases Tests ======================
def test_rtc_config_very_small_max_guidance_weight():
"""Test RTCConfig with very small positive max_guidance_weight."""
config = RTCConfig(max_guidance_weight=1e-10)
assert config.max_guidance_weight == pytest.approx(1e-10)
def test_rtc_config_very_large_max_guidance_weight():
"""Test RTCConfig with very large max_guidance_weight."""
config = RTCConfig(max_guidance_weight=1e10)
assert config.max_guidance_weight == pytest.approx(1e10)
def test_rtc_config_minimum_debug_maxlen():
"""Test RTCConfig with minimum valid debug_maxlen."""
config = RTCConfig(debug_maxlen=1)
assert config.debug_maxlen == 1
def test_rtc_config_float_max_guidance_weight():
"""Test RTCConfig with float max_guidance_weight."""
config = RTCConfig(max_guidance_weight=3.14159)
assert config.max_guidance_weight == pytest.approx(3.14159)
# ====================== Type Tests ======================
def test_rtc_config_enabled_type():
"""Test RTCConfig enabled field accepts boolean."""
config = RTCConfig(enabled=True)
assert isinstance(config.enabled, bool)
def test_rtc_config_execution_horizon_type():
"""Test RTCConfig execution_horizon field accepts integer."""
config = RTCConfig(execution_horizon=15)
assert isinstance(config.execution_horizon, int)
def test_rtc_config_max_guidance_weight_type():
"""Test RTCConfig max_guidance_weight field accepts float."""
config = RTCConfig(max_guidance_weight=7.5)
assert isinstance(config.max_guidance_weight, float)
def test_rtc_config_debug_maxlen_type():
"""Test RTCConfig debug_maxlen field accepts integer."""
config = RTCConfig(debug_maxlen=200)
assert isinstance(config.debug_maxlen, int)
-427
View File
@@ -1,427 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC debug visualizer module."""
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
# ====================== Fixtures ======================
@pytest.fixture
def mock_axes():
"""Create mock matplotlib axes."""
axes = []
for _ in range(6):
ax = MagicMock()
ax.xaxis.get_label.return_value.get_text.return_value = ""
ax.yaxis.get_label.return_value.get_text.return_value = ""
axes.append(ax)
return axes
@pytest.fixture
def sample_tensor_2d():
"""Create a 2D sample tensor (time_steps, num_dims)."""
return torch.randn(50, 6)
@pytest.fixture
def sample_tensor_3d():
"""Create a 3D sample tensor (batch, time_steps, num_dims)."""
return torch.randn(1, 50, 6)
@pytest.fixture
def sample_numpy_2d():
"""Create a 2D numpy array."""
return np.random.randn(50, 6)
# ====================== Basic Plotting Tests ======================
def test_plot_waypoints_with_2d_tensor(mock_axes, sample_tensor_2d):
"""Test plot_waypoints with 2D tensor."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
# Should call plot on each axis (6 dimensions)
for ax in mock_axes:
ax.plot.assert_called_once()
def test_plot_waypoints_with_3d_tensor(mock_axes, sample_tensor_3d):
"""Test plot_waypoints with 3D tensor (batch dimension)."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_3d)
# Should still plot 6 dimensions (batch dimension removed)
for ax in mock_axes:
ax.plot.assert_called_once()
def test_plot_waypoints_with_numpy_array(mock_axes, sample_numpy_2d):
"""Test plot_waypoints with numpy array."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_numpy_2d)
# Should work with numpy arrays
for ax in mock_axes:
ax.plot.assert_called_once()
def test_plot_waypoints_with_none_tensor(mock_axes):
"""Test plot_waypoints returns early when tensor is None."""
RTCDebugVisualizer.plot_waypoints(mock_axes, None)
# Should not call plot on any axis
for ax in mock_axes:
ax.plot.assert_not_called()
# ====================== Parameter Tests ======================
def test_plot_waypoints_with_custom_color(mock_axes, sample_tensor_2d):
"""Test plot_waypoints uses custom color."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, color="red")
# Check that color was passed to plot
for ax in mock_axes:
call_kwargs = ax.plot.call_args[1]
assert call_kwargs["color"] == "red"
def test_plot_waypoints_with_custom_label(mock_axes, sample_tensor_2d):
"""Test plot_waypoints uses custom label."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
# First axis should have label, others should not
first_ax_kwargs = mock_axes[0].plot.call_args[1]
assert first_ax_kwargs["label"] == "test_label"
# Other axes should have empty label
for ax in mock_axes[1:]:
call_kwargs = ax.plot.call_args[1]
assert call_kwargs["label"] == ""
def test_plot_waypoints_with_custom_alpha(mock_axes, sample_tensor_2d):
"""Test plot_waypoints uses custom alpha."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, alpha=0.5)
for ax in mock_axes:
call_kwargs = ax.plot.call_args[1]
assert call_kwargs["alpha"] == 0.5
def test_plot_waypoints_with_custom_linewidth(mock_axes, sample_tensor_2d):
"""Test plot_waypoints uses custom linewidth."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, linewidth=3)
for ax in mock_axes:
call_kwargs = ax.plot.call_args[1]
assert call_kwargs["linewidth"] == 3
def test_plot_waypoints_with_marker(mock_axes, sample_tensor_2d):
"""Test plot_waypoints with marker style."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker="o", markersize=5)
for ax in mock_axes:
call_kwargs = ax.plot.call_args[1]
assert call_kwargs["marker"] == "o"
assert call_kwargs["markersize"] == 5
def test_plot_waypoints_without_marker(mock_axes, sample_tensor_2d):
"""Test plot_waypoints without marker (default)."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker=None)
# Marker should not be in kwargs when None
for ax in mock_axes:
call_kwargs = ax.plot.call_args[1]
assert "marker" not in call_kwargs
assert "markersize" not in call_kwargs
# ====================== start_from Parameter Tests ======================
def test_plot_waypoints_with_start_from_zero(mock_axes, sample_tensor_2d):
"""Test plot_waypoints with start_from=0."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
# X indices should start from 0
for ax in mock_axes:
call_args = ax.plot.call_args[0]
x_indices = call_args[0]
assert x_indices[0] == 0
assert len(x_indices) == 50
def test_plot_waypoints_with_start_from_nonzero(mock_axes, sample_tensor_2d):
"""Test plot_waypoints with start_from > 0."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=10)
# X indices should start from 10
for ax in mock_axes:
call_args = ax.plot.call_args[0]
x_indices = call_args[0]
assert x_indices[0] == 10
assert x_indices[-1] == 59 # 10 + 50 - 1
# ====================== Tensor Shape Tests ======================
def test_plot_waypoints_with_1d_tensor(mock_axes):
"""Test plot_waypoints with 1D tensor."""
tensor_1d = torch.randn(50)
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_1d)
# Should reshape to (50, 1) and plot on first axis only
mock_axes[0].plot.assert_called_once()
for ax in mock_axes[1:]:
ax.plot.assert_not_called()
def test_plot_waypoints_with_fewer_dims_than_axes(sample_tensor_2d):
"""Test plot_waypoints when tensor has fewer dims than axes."""
# Create tensor with only 3 dimensions
tensor_3d = sample_tensor_2d[:, :3]
# Create 6 axes but tensor only has 3 dims
mock_axes = [MagicMock() for _ in range(6)]
for ax in mock_axes:
ax.xaxis.get_label.return_value.get_text.return_value = ""
ax.yaxis.get_label.return_value.get_text.return_value = ""
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_3d)
# Should only plot on first 3 axes
for i in range(3):
mock_axes[i].plot.assert_called_once()
for i in range(3, 6):
mock_axes[i].plot.assert_not_called()
# ====================== Axis Labeling Tests ======================
def test_plot_waypoints_sets_xlabel(mock_axes, sample_tensor_2d):
"""Test plot_waypoints sets x-axis label."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
for ax in mock_axes:
ax.set_xlabel.assert_called_once_with("Step", fontsize=10)
def test_plot_waypoints_sets_ylabel(mock_axes, sample_tensor_2d):
"""Test plot_waypoints sets y-axis label."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
for i, ax in enumerate(mock_axes):
ax.set_ylabel.assert_called_once_with(f"Dim {i}", fontsize=10)
def test_plot_waypoints_skips_label_if_exists(sample_tensor_2d):
"""Test plot_waypoints doesn't set labels if they already exist."""
mock_axes_with_labels = []
for _ in range(6):
ax = MagicMock()
# Simulate existing labels
ax.xaxis.get_label.return_value.get_text.return_value = "Existing X Label"
ax.yaxis.get_label.return_value.get_text.return_value = "Existing Y Label"
mock_axes_with_labels.append(ax)
RTCDebugVisualizer.plot_waypoints(mock_axes_with_labels, sample_tensor_2d)
# Should not set labels when they already exist
for ax in mock_axes_with_labels:
ax.set_xlabel.assert_not_called()
ax.set_ylabel.assert_not_called()
# ====================== Grid Tests ======================
def test_plot_waypoints_enables_grid(mock_axes, sample_tensor_2d):
"""Test plot_waypoints enables grid on all axes."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
for ax in mock_axes:
ax.grid.assert_called_once_with(True, alpha=0.3)
# ====================== Legend Tests ======================
def test_plot_waypoints_adds_legend_with_label(mock_axes, sample_tensor_2d):
"""Test plot_waypoints adds legend when label is provided."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
# Should add legend to first axis only
mock_axes[0].legend.assert_called_once_with(loc="best", fontsize=8)
# Should not add legend to other axes
for ax in mock_axes[1:]:
ax.legend.assert_not_called()
def test_plot_waypoints_no_legend_without_label(mock_axes, sample_tensor_2d):
"""Test plot_waypoints doesn't add legend when no label provided."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="")
# Should not add legend to any axis
for ax in mock_axes:
ax.legend.assert_not_called()
# ====================== Data Correctness Tests ======================
def test_plot_waypoints_plots_correct_data(mock_axes, sample_tensor_2d):
"""Test plot_waypoints plots correct tensor values."""
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
# Check first axis to verify data correctness
call_args = mock_axes[0].plot.call_args[0]
x_indices = call_args[0]
y_values = call_args[1]
# X indices should be 0 to 49
np.testing.assert_array_equal(x_indices, np.arange(50))
# Y values should match first dimension of tensor
expected_y = sample_tensor_2d[:, 0].cpu().numpy()
np.testing.assert_array_almost_equal(y_values, expected_y)
def test_plot_waypoints_handles_gpu_tensor(mock_axes):
"""Test plot_waypoints handles GPU tensors (if CUDA available)."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
tensor_gpu = torch.randn(50, 6, device="cuda")
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_gpu)
# Should successfully plot without errors
for ax in mock_axes:
ax.plot.assert_called_once()
# ====================== Edge Cases Tests ======================
def test_plot_waypoints_with_empty_tensor(mock_axes):
"""Test plot_waypoints with empty tensor."""
empty_tensor = torch.empty(0, 6)
RTCDebugVisualizer.plot_waypoints(mock_axes, empty_tensor)
# Should plot empty data
for ax in mock_axes:
call_args = ax.plot.call_args[0]
x_indices = call_args[0]
assert len(x_indices) == 0
def test_plot_waypoints_with_single_timestep(mock_axes):
"""Test plot_waypoints with single timestep."""
single_step_tensor = torch.randn(1, 6)
RTCDebugVisualizer.plot_waypoints(mock_axes, single_step_tensor)
# Should plot single point
for ax in mock_axes:
call_args = ax.plot.call_args[0]
x_indices = call_args[0]
assert len(x_indices) == 1
def test_plot_waypoints_with_very_large_tensor(mock_axes):
"""Test plot_waypoints with very large tensor."""
large_tensor = torch.randn(10000, 6)
RTCDebugVisualizer.plot_waypoints(mock_axes, large_tensor)
# Should handle large tensors
for ax in mock_axes:
call_args = ax.plot.call_args[0]
x_indices = call_args[0]
assert len(x_indices) == 10000
# ====================== Multiple Calls Tests ======================
def test_plot_waypoints_multiple_calls_on_same_axes(mock_axes, sample_tensor_2d):
"""Test multiple plot_waypoints calls on same axes."""
tensor1 = sample_tensor_2d
tensor2 = torch.randn(50, 6)
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor1, color="blue", label="Series 1")
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor2, color="red", label="Series 2")
# Each axis should have been called twice
for ax in mock_axes:
assert ax.plot.call_count == 2
# ====================== Integration Tests ======================
def test_plot_waypoints_typical_usage(mock_axes, sample_tensor_2d):
"""Test plot_waypoints with typical usage pattern."""
RTCDebugVisualizer.plot_waypoints(
mock_axes, sample_tensor_2d, start_from=0, color="blue", label="Trajectory", alpha=0.7, linewidth=2
)
# Verify all expected calls were made
for ax in mock_axes:
ax.plot.assert_called_once()
ax.set_xlabel.assert_called_once()
ax.set_ylabel.assert_called_once()
ax.grid.assert_called_once()
# First axis should have legend
mock_axes[0].legend.assert_called_once()
def test_plot_waypoints_with_all_parameters(mock_axes, sample_tensor_2d):
"""Test plot_waypoints with all parameters specified."""
RTCDebugVisualizer.plot_waypoints(
axes=mock_axes,
tensor=sample_tensor_2d,
start_from=10,
color="green",
label="Full Test",
alpha=0.8,
linewidth=3,
marker="o",
markersize=6,
)
# Check first axis for all parameters
call_kwargs = mock_axes[0].plot.call_args[1]
assert call_kwargs["color"] == "green"
assert call_kwargs["label"] == "Full Test"
assert call_kwargs["alpha"] == 0.8
assert call_kwargs["linewidth"] == 3
assert call_kwargs["marker"] == "o"
assert call_kwargs["markersize"] == 6
-159
View File
@@ -184,74 +184,6 @@ def test_max_after_reset(tracker):
assert tracker.max() == 0.0 assert tracker.max() == 0.0
# ====================== percentile() Tests ======================
def test_percentile_returns_zero_when_empty(tracker):
"""Test percentile() returns 0.0 when tracker is empty."""
assert tracker.percentile(0.5) == 0.0
assert tracker.percentile(0.95) == 0.0
def test_percentile_median(tracker):
"""Test percentile(0.5) returns median."""
# Add sorted values for easier verification
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
for v in values:
tracker.add(v)
# Median should be around 0.5
median = tracker.percentile(0.5)
assert 0.45 <= median <= 0.55
def test_percentile_minimum_with_zero(tracker):
"""Test percentile(0.0) returns minimum."""
tracker.add(0.5)
tracker.add(0.2)
tracker.add(0.8)
assert tracker.percentile(0.0) == 0.2
def test_percentile_maximum_with_one(tracker):
"""Test percentile(1.0) returns maximum."""
tracker.add(0.5)
tracker.add(0.2)
tracker.add(0.8)
assert tracker.percentile(1.0) == 0.8
def test_percentile_95(tracker):
"""Test percentile(0.95) returns 95th percentile."""
# Add 100 values from 0.0 to 0.99
for i in range(100):
tracker.add(i / 100.0)
p95 = tracker.percentile(0.95)
# 95th percentile should be around 0.95
assert 0.93 <= p95 <= 0.96
def test_percentile_negative_value_returns_min(tracker):
"""Test percentile with negative q returns minimum."""
tracker.add(0.5)
tracker.add(0.2)
tracker.add(0.8)
assert tracker.percentile(-0.5) == 0.2
def test_percentile_value_greater_than_one_returns_max(tracker):
"""Test percentile with q > 1.0 returns maximum."""
tracker.add(0.5)
tracker.add(0.2)
tracker.add(0.8)
assert tracker.percentile(1.5) == 0.8
# ====================== p95() Tests ====================== # ====================== p95() Tests ======================
@@ -278,79 +210,6 @@ def test_p95_equals_percentile_95(tracker):
assert tracker.p95() == tracker.percentile(0.95) assert tracker.p95() == tracker.percentile(0.95)
# ====================== __len__() Tests ======================
def test_len_returns_zero_initially(tracker):
"""Test __len__ returns 0 for new tracker."""
assert len(tracker) == 0
def test_len_increments_with_add(tracker):
"""Test __len__ increments as values are added."""
assert len(tracker) == 0
tracker.add(0.1)
assert len(tracker) == 1
tracker.add(0.2)
assert len(tracker) == 2
tracker.add(0.3)
assert len(tracker) == 3
def test_len_respects_maxlen(small_tracker):
"""Test __len__ respects maxlen limit."""
# Add more than maxlen values
for i in range(10):
small_tracker.add(i / 10.0)
# Should only keep last 5
assert len(small_tracker) == 5
def test_len_after_reset(tracker):
"""Test __len__ returns 0 after reset."""
tracker.add(0.5)
tracker.add(0.3)
assert len(tracker) == 2
tracker.reset()
assert len(tracker) == 0
# ====================== Sliding Window Tests ======================
def test_sliding_window_removes_oldest(small_tracker):
"""Test sliding window removes oldest values."""
# Add 7 values to tracker with maxlen=5
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
for v in values:
small_tracker.add(v)
# Should only have last 5: [0.3, 0.4, 0.5, 0.6, 0.7]
assert len(small_tracker) == 5
# Median should reflect last 5 values
median = small_tracker.percentile(0.5)
assert 0.45 <= median <= 0.55
def test_sliding_window_maintains_max(small_tracker):
"""Test sliding window maintains correct max even after overflow."""
small_tracker.add(0.1)
small_tracker.add(0.9)
small_tracker.add(0.2)
small_tracker.add(0.3)
small_tracker.add(0.4)
small_tracker.add(0.5) # Pushes out 0.1
# Max should still be 0.9
assert small_tracker.max() == 0.9
# ====================== Edge Cases Tests ====================== # ====================== Edge Cases Tests ======================
@@ -436,24 +295,6 @@ def test_reset_and_reuse(tracker):
assert tracker.percentile(0.5) <= 0.8 assert tracker.percentile(0.5) <= 0.8
def test_continuous_monitoring(small_tracker):
"""Test continuous monitoring with sliding window."""
# Simulate continuous latency monitoring
# First 5 latencies
for i in range(5):
small_tracker.add(0.1 * (i + 1))
max_before = small_tracker.max()
# Add 5 more (window slides)
for i in range(5, 10):
small_tracker.add(0.1 * (i + 1))
# Max should have increased
assert small_tracker.max() > max_before
assert len(small_tracker) == 5 # Window size maintained
# ====================== Type Conversion Tests ====================== # ====================== Type Conversion Tests ======================