mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 23:19:48 +00:00
Update README
This commit is contained in:
+153
-183
@@ -16,156 +16,161 @@ Real-Time Chunking addresses the challenge of maintaining consistency and reacti
|
|||||||
|
|
||||||
## Scripts
|
## Scripts
|
||||||
|
|
||||||
### 1. `real_time_chunking_evaluate.py`
|
### 1. `eval_dataset.py`
|
||||||
|
|
||||||
Real-time evaluation on physical robots or simulation environments.
|
Offline evaluation on dataset samples with detailed visualization and validation.
|
||||||
|
|
||||||
**Features:**
|
**Features:**
|
||||||
|
|
||||||
- Run policy with RTC on real robot or simulation
|
- Compare RTC vs non-RTC predictions on two random dataset samples
|
||||||
- Compare RTC vs non-RTC actions in real-time
|
- Validate RTC behavior (delay region, blend region, post-horizon region)
|
||||||
- Multi-threaded action execution and inference
|
- Generate debug visualizations:
|
||||||
|
- Denoising step comparisons (x_t, v_t, x1_t, corrections)
|
||||||
|
- Final action predictions comparison
|
||||||
- Support for torch.compile() optimization
|
- Support for torch.compile() optimization
|
||||||
|
- Memory-efficient sequential policy loading for large models
|
||||||
|
|
||||||
**Usage:**
|
**Usage:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# With real robot
|
# Basic usage with SmolVLA policy
|
||||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=lerobot/smolvla_base \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--robot.type=so100 \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
--task="pick up the cup"
|
--rtc.execution_horizon=8 \
|
||||||
|
--device=mps \
|
||||||
|
--rtc.max_guidance_weight=10.0 \
|
||||||
|
--seed=10
|
||||||
|
|
||||||
# With simulation environment
|
# With Pi0.5 policy on CUDA
|
||||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=lerobot/smolvla_base \
|
--policy.path=lerobot/pi05_libero_finetuned \
|
||||||
--env.type=pusht \
|
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||||
--duration=60.0
|
--rtc.execution_horizon=8 \
|
||||||
|
--device=cuda
|
||||||
|
|
||||||
# Disable verbose comparison (faster)
|
# With Pi0 policy
|
||||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=lerobot/smolvla_base \
|
--policy.path=lerobot/pi0_libero_finetuned \
|
||||||
--robot.type=so100 \
|
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||||
--verbose_rtc_comparison=false
|
--rtc.execution_horizon=8 \
|
||||||
|
--device=cuda
|
||||||
|
|
||||||
# With policy compilation (CUDA only, not MPS)
|
# With torch.compile for faster inference
|
||||||
uv run python examples/rtc/real_time_chunking_evaluate.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=lerobot/smolvla_base \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--robot.type=so100 \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
--compile_policy=true \
|
--rtc.execution_horizon=8 \
|
||||||
--compile_mode=max-autotune
|
--device=cuda \
|
||||||
```
|
--use_torch_compile=true \
|
||||||
|
--torch_compile_mode=max-autotune
|
||||||
|
|
||||||
**Key Parameters:**
|
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||||
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
- `--policy.path`: Path to pretrained policy
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
- `--robot.type` or `--env.type`: Robot or environment to use
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
|
--use_torch_compile=true \
|
||||||
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
|
--torch_compile_backend=inductor \
|
||||||
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
--torch_compile_mode=max-autotune \
|
||||||
- `--verbose_rtc_comparison`: Enable detailed RTC comparison logging (default: true)
|
--torch_compile_disable_cudagraphs=false
|
||||||
- `--duration`: How long to run (seconds, default: 30.0)
|
|
||||||
- `--fps`: Action execution frequency (Hz, default: 10.0)
|
|
||||||
|
|
||||||
### 2. `evaluate_rtc_on_dataset.py`
|
|
||||||
|
|
||||||
Offline evaluation on dataset samples to measure RTC effectiveness.
|
|
||||||
|
|
||||||
**Features:**
|
|
||||||
|
|
||||||
- Evaluate RTC on dataset without running robot
|
|
||||||
- Compare RTC vs non-RTC predictions
|
|
||||||
- Measure consistency and ground truth alignment
|
|
||||||
- Simulate different inference delays
|
|
||||||
- Save detailed metrics to JSON
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Basic evaluation
|
|
||||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
|
||||||
--policy.path=lerobot/smolvla_base \
|
|
||||||
--dataset.repo_id=lerobot/pusht \
|
|
||||||
--num_iterations=100
|
|
||||||
|
|
||||||
# Simulate inference delay (every 3rd step)
|
|
||||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
|
||||||
--policy.path=lerobot/smolvla_base \
|
|
||||||
--dataset.repo_id=lerobot/pusht \
|
|
||||||
--num_iterations=200 \
|
|
||||||
--skip_steps=3
|
|
||||||
|
|
||||||
# Custom RTC configuration
|
|
||||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
|
||||||
--policy.path=lerobot/smolvla_base \
|
|
||||||
--dataset.repo_id=lerobot/pusht \
|
|
||||||
--num_iterations=100 \
|
|
||||||
--rtc.execution_horizon=12 \
|
|
||||||
--rtc.max_guidance_weight=5.0 \
|
|
||||||
--rtc.prefix_attention_schedule=LINEAR
|
|
||||||
|
|
||||||
# Save results to file
|
|
||||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
|
||||||
--policy.path=lerobot/smolvla_base \
|
|
||||||
--dataset.repo_id=lerobot/pusht \
|
|
||||||
--num_iterations=100 \
|
|
||||||
--output_path=results/rtc_evaluation.json
|
|
||||||
|
|
||||||
# Verbose mode with detailed logging
|
|
||||||
uv run python examples/rtc/evaluate_rtc_on_dataset.py \
|
|
||||||
--policy.path=lerobot/smolvla_base \
|
|
||||||
--dataset.repo_id=lerobot/pusht \
|
|
||||||
--num_iterations=50 \
|
|
||||||
--verbose=true
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Parameters:**
|
**Key Parameters:**
|
||||||
|
|
||||||
- `--policy.path`: Path to pretrained policy
|
- `--policy.path`: Path to pretrained policy
|
||||||
- `--dataset.repo_id`: Dataset to evaluate on
|
- `--dataset.repo_id`: Dataset to evaluate on
|
||||||
- `--num_iterations`: Number of samples to evaluate (default: 100)
|
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 20)
|
||||||
- `--skip_steps`: Steps to skip between inferences, simulates inference delay (default: 1)
|
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 10.0)
|
||||||
- `--start_episode`: Episode to start from (default: 0)
|
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
||||||
- `--output_path`: Path to save results JSON
|
- `--inference_delay`: Inference delay for RTC (default: 4)
|
||||||
- `--verbose`: Enable detailed per-sample logging
|
- `--seed`: Random seed for reproducibility (default: 42)
|
||||||
|
- `--output_dir`: Directory to save visualizations (default: rtc_debug_output)
|
||||||
- `--device`: Device to use (cuda, cpu, mps, auto)
|
- `--device`: Device to use (cuda, cpu, mps, auto)
|
||||||
|
- `--use_torch_compile`: Enable torch.compile() for faster inference
|
||||||
|
|
||||||
**Metrics Reported:**
|
**Output:**
|
||||||
|
|
||||||
- **RTC vs Ground Truth MSE**: How close RTC predictions are to actual actions
|
The script generates several visualization files in `rtc_debug_output/`:
|
||||||
- **No-RTC vs Ground Truth MSE**: Baseline without RTC
|
|
||||||
- **RTC Improvement**: Absolute and relative improvement over baseline
|
|
||||||
- **RTC Consistency**: How well RTC maintains consistency in prefix region
|
|
||||||
- Prefix MSE
|
|
||||||
- Mean/Max error in overlap region
|
|
||||||
|
|
||||||
### 3. `run_dataset_evaluation.sh`
|
- `denoising_xt_comparison.png` - Noisy state evolution during denoising
|
||||||
|
- `denoising_vt_comparison.png` - Velocity predictions during denoising
|
||||||
|
- `denoising_x1t_comparison.png` - Predicted final states during denoising
|
||||||
|
- `denoising_correction_comparison.png` - RTC guidance corrections applied
|
||||||
|
- `final_actions_comparison.png` - Final action predictions (prev_chunk, no_rtc, rtc)
|
||||||
|
|
||||||
Convenience script with multiple evaluation scenarios.
|
The script also validates RTC behavior and reports:
|
||||||
|
|
||||||
|
- ✅ Delay region [0:inference_delay]: RTC = prev_chunk
|
||||||
|
- ✅ Blend region [inference_delay:execution_horizon]: prev_chunk ≤ RTC ≤ no_rtc
|
||||||
|
- ✅ Post-horizon [execution_horizon:]: RTC = no_rtc
|
||||||
|
|
||||||
|
### 2. `eval_with_real_robot.py`
|
||||||
|
|
||||||
|
Real-time evaluation on physical robots or simulation environments.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
|
||||||
|
- Run policy with RTC on real robot or simulation
|
||||||
|
- Multi-threaded action execution and inference
|
||||||
|
- Action queue management with proper timing
|
||||||
|
- Latency tracking and adaptive inference delay
|
||||||
|
- Support for both robots and gym environments
|
||||||
|
- Support for torch.compile() optimization
|
||||||
|
|
||||||
**Usage:**
|
**Usage:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Edit the script to set your policy and dataset
|
# With real robot
|
||||||
# Then run all examples:
|
uv run python examples/rtc/eval_with_real_robot.py \
|
||||||
./examples/rtc/run_dataset_evaluation.sh
|
--policy.path=lerobot/smolvla_base \
|
||||||
|
--robot.type=so100 \
|
||||||
|
--task="pick up the cup" \
|
||||||
|
--duration=30.0
|
||||||
|
|
||||||
# Or run individual examples from the script
|
# With simulation environment
|
||||||
|
uv run python examples/rtc/eval_with_real_robot.py \
|
||||||
|
--policy.path=lerobot/smolvla_base \
|
||||||
|
--env.type=pusht \
|
||||||
|
--duration=60.0
|
||||||
|
|
||||||
|
# With policy compilation (CUDA only, not MPS)
|
||||||
|
uv run python examples/rtc/eval_with_real_robot.py \
|
||||||
|
--policy.path=lerobot/smolvla_base \
|
||||||
|
--robot.type=so100 \
|
||||||
|
--use_torch_compile=true \
|
||||||
|
--torch_compile_mode=max-autotune
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Key Parameters:**
|
||||||
|
|
||||||
|
- `--policy.path`: Path to pretrained policy
|
||||||
|
- `--robot.type` or `--env.type`: Robot or environment to use
|
||||||
|
- `--task`: Task description (for VLA models)
|
||||||
|
- `--rtc.execution_horizon`: Number of steps to maintain consistency (default: 10)
|
||||||
|
- `--rtc.max_guidance_weight`: Maximum guidance weight (default: 1.0)
|
||||||
|
- `--rtc.prefix_attention_schedule`: Schedule type (ZEROS, ONES, LINEAR, EXP)
|
||||||
|
- `--duration`: How long to run (seconds, default: 30.0)
|
||||||
|
- `--fps`: Action execution frequency (Hz, default: 10.0)
|
||||||
|
- `--action_queue_size_to_get_new_actions`: Queue size threshold to request new actions (default: 30)
|
||||||
|
- `--device`: Device to use (cuda, cpu, mps, auto)
|
||||||
|
- `--use_torch_compile`: Enable torch.compile() for faster inference
|
||||||
|
|
||||||
## Understanding RTC Parameters
|
## Understanding RTC Parameters
|
||||||
|
|
||||||
### `execution_horizon`
|
### `execution_horizon`
|
||||||
|
|
||||||
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
|
Number of timesteps from previous chunk to maintain consistency with. Higher values mean more consistency but potentially less reactivity.
|
||||||
|
|
||||||
**Typical values:** 8-12 steps
|
**Typical values:** 8-12 steps for dataset evaluation, 10 steps for real-time execution
|
||||||
|
|
||||||
### `max_guidance_weight`
|
### `max_guidance_weight`
|
||||||
|
|
||||||
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
|
Upper bound on guidance strength. Higher values give stronger consistency but may over-constrain new predictions.
|
||||||
|
|
||||||
**Typical values:** 1.0-10.0
|
**Typical values:**
|
||||||
|
|
||||||
|
- Dataset evaluation: 10.0-100.0 (can be higher for analysis)
|
||||||
|
- Real-time execution: 1.0-10.0 (more conservative)
|
||||||
|
|
||||||
### `prefix_attention_schedule`
|
### `prefix_attention_schedule`
|
||||||
|
|
||||||
@@ -178,104 +183,69 @@ How to weight consistency across the overlap region:
|
|||||||
|
|
||||||
**Recommended:** `EXP`
|
**Recommended:** `EXP`
|
||||||
|
|
||||||
### `skip_steps` (evaluation only)
|
### `inference_delay`
|
||||||
|
|
||||||
Simulates inference delay by evaluating every N-th step. This helps understand how RTC performs with realistic delays.
|
Number of timesteps from the prefix to use for guidance. Typically calculated dynamically based on inference latency in real-time execution, but fixed for dataset evaluation.
|
||||||
|
|
||||||
**Example:** `skip_steps=3` means policy infers every 3 steps, simulating 3x action execution frequency vs inference frequency.
|
**Typical values:** 3-5 steps for dataset evaluation
|
||||||
|
|
||||||
## Output Format (Dataset Evaluation)
|
### `action_queue_size_to_get_new_actions` (real-time only)
|
||||||
|
|
||||||
When using `--output_path`, results are saved in JSON format:
|
Threshold for requesting new action chunks. Should be higher than `inference_delay + execution_horizon` to ensure smooth operation.
|
||||||
|
|
||||||
```json
|
**Typical values:** 20-30 steps
|
||||||
{
|
|
||||||
"summary": {
|
## Validation Rules (Dataset Evaluation)
|
||||||
"rtc_vs_ground_truth_mse": {
|
|
||||||
"mean": 0.00123,
|
The dataset evaluation script validates that RTC behavior matches expectations:
|
||||||
"std": 0.00045,
|
|
||||||
"min": 0.00012,
|
1. **Delay Region [0:inference_delay]**: RTC actions should equal previous chunk
|
||||||
"max": 0.00456
|
- Ensures consistency during the inference delay period
|
||||||
},
|
|
||||||
"improvement": {
|
2. **Blend Region [inference_delay:execution_horizon]**: RTC should be between prev_chunk and no_rtc
|
||||||
"absolute": 0.00034,
|
- Smooth transition from previous plan to new predictions
|
||||||
"relative_percent": 12.5
|
|
||||||
},
|
3. **Post-Horizon [execution_horizon:]**: RTC should equal no_rtc
|
||||||
...
|
- Full adoption of new predictions after execution horizon
|
||||||
},
|
|
||||||
"config": {
|
|
||||||
"num_iterations": 100,
|
|
||||||
"skip_steps": 3,
|
|
||||||
"execution_horizon": 10,
|
|
||||||
...
|
|
||||||
},
|
|
||||||
"detailed_results": [
|
|
||||||
{
|
|
||||||
"sample_idx": 0,
|
|
||||||
"rtc_vs_ground_truth_mse": 0.00112,
|
|
||||||
"no_rtc_vs_ground_truth_mse": 0.00145,
|
|
||||||
...
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Tips
|
## Tips
|
||||||
|
|
||||||
1. **Start with dataset evaluation** to understand RTC behavior before running on robot
|
1. **Start with dataset evaluation** (`eval_dataset.py`) to understand RTC behavior and tune parameters before running on robot
|
||||||
2. **Use verbose mode** for debugging unexpected behavior
|
2. **Use visualizations** to debug unexpected behavior - check denoising steps and final actions
|
||||||
3. **Tune execution_horizon** based on your inference latency and action frequency
|
3. **Tune execution_horizon** based on your inference latency and action frequency
|
||||||
4. **Monitor consistency metrics** - very low consistency might indicate execution_horizon is too small
|
4. **Monitor validation output** - failures indicate potential implementation issues or misconfigured parameters
|
||||||
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
|
5. **Compare different schedules** - EXP usually works best but LINEAR can be more interpretable
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
### High RTC vs No-RTC difference but no improvement
|
### Validation fails in delay region
|
||||||
|
|
||||||
- Try reducing `max_guidance_weight`
|
- Check that `prev_chunk_left_over` is properly passed to the policy
|
||||||
- Check if `execution_horizon` is too large
|
- Verify RTC guidance is being applied during denoising
|
||||||
|
- Look at denoising visualizations to see where guidance diverges
|
||||||
|
|
||||||
### Poor consistency metrics
|
### Validation fails in post-horizon region
|
||||||
|
|
||||||
- Increase `execution_horizon`
|
- RTC and no_rtc use different noise - verify same noise is being used for comparison
|
||||||
- Check that `skip_steps` is not larger than your action chunk size
|
- Check that weights are correctly zeroed out after execution horizon
|
||||||
- Verify episodes are being reset correctly
|
- Review prefix_attention_schedule visualization
|
||||||
|
|
||||||
### RTC worse than No-RTC
|
### Poor performance on real robot
|
||||||
|
|
||||||
- RTC may not help if inference is faster than action execution
|
- Increase `action_queue_size_to_get_new_actions` if you see warnings
|
||||||
- Try different `prefix_attention_schedule`
|
- Reduce `max_guidance_weight` if robot is too conservative
|
||||||
- Ensure `execution_horizon` matches your use case
|
- Try different `prefix_attention_schedule` values
|
||||||
|
- Enable torch.compile() for faster inference (CUDA only)
|
||||||
|
|
||||||
## Examples Results
|
### Memory issues with large models
|
||||||
|
|
||||||
Example output from dataset evaluation:
|
- The dataset evaluation script loads policies sequentially to minimize memory
|
||||||
|
- For real-time execution, only one policy is loaded
|
||||||
```
|
- Use smaller batch sizes if needed
|
||||||
================================================================================
|
|
||||||
EVALUATION SUMMARY
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
Ground Truth Alignment:
|
|
||||||
RTC MSE: 0.001234 ± 0.000456
|
|
||||||
No-RTC MSE: 0.001567 ± 0.000512
|
|
||||||
|
|
||||||
RTC Improvement:
|
|
||||||
Absolute: 0.000333
|
|
||||||
Relative: 21.23%
|
|
||||||
|
|
||||||
RTC vs No-RTC Difference:
|
|
||||||
MSE: 0.000112 ± 0.000034
|
|
||||||
|
|
||||||
RTC Consistency (Prefix Region):
|
|
||||||
MSE: 0.000089 ± 0.000023
|
|
||||||
Mean Error: 0.007654 ± 0.002341
|
|
||||||
Max Error: 0.023456 ± 0.008765
|
|
||||||
```
|
|
||||||
|
|
||||||
## Related Documentation
|
## Related Documentation
|
||||||
|
|
||||||
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
|
- [RTC Implementation](../../src/lerobot/policies/rtc/modeling_rtc.py)
|
||||||
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
|
- [RTC Configuration](../../src/lerobot/policies/rtc/configuration_rtc.py)
|
||||||
|
- [Action Queue](../../src/lerobot/policies/rtc/action_queue.py)
|
||||||
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
- [Physical Intelligence Paper](https://www.physicalintelligence.company/download/real_time_chunking.pdf)
|
||||||
|
|||||||
+158
-38
@@ -16,7 +16,9 @@ Usage:
|
|||||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--dataset.repo_id=helper2424/check_rtc \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
--rtc.execution_horizon=8 \
|
--rtc.execution_horizon=8 \
|
||||||
--device=mps
|
--device=mps \
|
||||||
|
--rtc.max_guidance_weight=10.0 \
|
||||||
|
--seed=10
|
||||||
|
|
||||||
# Basic usage with pi0.5 policy
|
# Basic usage with pi0.5 policy
|
||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
@@ -439,6 +441,8 @@ class RTCEvaluator:
|
|||||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||||
logging.info("=" * 80)
|
logging.info("=" * 80)
|
||||||
|
|
||||||
|
set_seed(self.cfg.seed)
|
||||||
|
|
||||||
# Initialize policy 2
|
# Initialize policy 2
|
||||||
policy_no_rtc_policy = self._init_policy(
|
policy_no_rtc_policy = self._init_policy(
|
||||||
name="policy_no_rtc",
|
name="policy_no_rtc",
|
||||||
@@ -470,6 +474,8 @@ class RTCEvaluator:
|
|||||||
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||||
logging.info("=" * 80)
|
logging.info("=" * 80)
|
||||||
|
|
||||||
|
set_seed(self.cfg.seed)
|
||||||
|
|
||||||
# Initialize policy 3
|
# Initialize policy 3
|
||||||
policy_rtc_policy = self._init_policy(
|
policy_rtc_policy = self._init_policy(
|
||||||
name="policy_rtc",
|
name="policy_rtc",
|
||||||
@@ -510,6 +516,11 @@ class RTCEvaluator:
|
|||||||
logging.info("Validating RTC behavior...")
|
logging.info("Validating RTC behavior...")
|
||||||
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
self.validate_rtc_behavior(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
||||||
|
|
||||||
|
# Plot final actions comparison
|
||||||
|
logging.info("=" * 80)
|
||||||
|
logging.info("Plotting final actions comparison...")
|
||||||
|
self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
||||||
|
|
||||||
logging.info("=" * 80)
|
logging.info("=" * 80)
|
||||||
logging.info("Evaluation completed successfully")
|
logging.info("Evaluation completed successfully")
|
||||||
|
|
||||||
@@ -527,29 +538,24 @@ class RTCEvaluator:
|
|||||||
no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim)
|
no_rtc_actions: Final actions from non-RTC policy (batch, time, action_dim)
|
||||||
prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim)
|
prev_chunk_left_over: Previous chunk used as ground truth (time, action_dim)
|
||||||
"""
|
"""
|
||||||
if rtc_actions is None or no_rtc_actions is None:
|
# Remove batch dimension if present and move to CPU
|
||||||
logging.warning(" ⚠ Cannot validate: missing action predictions")
|
rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
||||||
return
|
no_rtc_actions_t = (
|
||||||
|
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
||||||
|
)
|
||||||
|
prev_chunk = prev_chunk_left_over.cpu()
|
||||||
|
|
||||||
# Convert to numpy for comparison (remove batch dimension if present)
|
logging.info(f" rtc_actions shape: {rtc_actions_t.shape}")
|
||||||
rtc_actions_np = (
|
logging.info(f" no_rtc_actions shape: {no_rtc_actions_t.shape}")
|
||||||
rtc_actions.squeeze(0).cpu().numpy() if len(rtc_actions.shape) == 3 else rtc_actions.cpu().numpy()
|
logging.info(f" prev_chunk shape: {prev_chunk.shape}")
|
||||||
)
|
|
||||||
no_rtc_actions_np = (
|
|
||||||
no_rtc_actions.squeeze(0).cpu().numpy()
|
|
||||||
if len(no_rtc_actions.shape) == 3
|
|
||||||
else no_rtc_actions.cpu().numpy()
|
|
||||||
)
|
|
||||||
prev_chunk = prev_chunk_left_over.cpu().numpy()
|
|
||||||
|
|
||||||
# Determine chunk length for comparison
|
# Determine chunk length for comparison
|
||||||
chunk_len = min(rtc_actions_np.shape[0], no_rtc_actions_np.shape[0], prev_chunk.shape[0])
|
chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk.shape[0])
|
||||||
inference_delay = self.cfg.inference_delay
|
inference_delay = self.cfg.inference_delay
|
||||||
execution_horizon = self.cfg.rtc.execution_horizon
|
execution_horizon = self.cfg.rtc.execution_horizon
|
||||||
|
|
||||||
# Tolerance for floating point comparison
|
# Tolerance for floating point comparison
|
||||||
rtol = 1e-3 # Relative tolerance
|
rtol = 1e-2 # Relative tolerance
|
||||||
atol = 1e-3 # Absolute tolerance
|
|
||||||
|
|
||||||
validation_passed = True
|
validation_passed = True
|
||||||
warnings = []
|
warnings = []
|
||||||
@@ -558,19 +564,26 @@ class RTCEvaluator:
|
|||||||
logging.info(f" Chunk length: {chunk_len}")
|
logging.info(f" Chunk length: {chunk_len}")
|
||||||
logging.info(f" Inference delay: {inference_delay}")
|
logging.info(f" Inference delay: {inference_delay}")
|
||||||
logging.info(f" Execution horizon: {execution_horizon}")
|
logging.info(f" Execution horizon: {execution_horizon}")
|
||||||
logging.info(f" Tolerance: rtol={rtol}, atol={atol}")
|
logging.info(f" Tolerance: rtol={rtol}")
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk
|
# Rule 1: During delay [0:inference_delay], RTC should equal prev_chunk
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
if inference_delay > 0:
|
if inference_delay > 0:
|
||||||
delay_end = min(inference_delay, chunk_len)
|
delay_end = min(inference_delay, chunk_len)
|
||||||
rtc_delay = rtc_actions_np[:delay_end]
|
rtc_delay = rtc_actions_t[:delay_end]
|
||||||
prev_delay = prev_chunk[:delay_end]
|
prev_delay = prev_chunk[:delay_end]
|
||||||
|
|
||||||
if not np.allclose(rtc_delay, prev_delay, rtol=rtol, atol=atol):
|
logging.info(f" rtc_delay: {rtc_delay.shape}")
|
||||||
max_diff = np.max(np.abs(rtc_delay - prev_delay))
|
logging.info(f" prev_delay: {prev_delay.shape}")
|
||||||
mean_diff = np.mean(np.abs(rtc_delay - prev_delay))
|
|
||||||
|
if not torch.allclose(rtc_delay, prev_delay, rtol=rtol):
|
||||||
|
max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item()
|
||||||
|
mean_diff = torch.mean(torch.abs(rtc_delay - prev_delay)).item()
|
||||||
|
logging.info(f" rtc_delay: {rtc_delay}")
|
||||||
|
logging.info(f" prev_delay: {prev_delay}")
|
||||||
|
logging.info(f" max_diff: {max_diff}")
|
||||||
|
logging.info(f" mean_diff: {mean_diff}")
|
||||||
warnings.append(
|
warnings.append(
|
||||||
f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], "
|
f" ⚠ VALIDATION FAILED: During delay [0:{delay_end}], "
|
||||||
f"RTC does NOT equal prev_chunk!\n"
|
f"RTC does NOT equal prev_chunk!\n"
|
||||||
@@ -589,26 +602,26 @@ class RTCEvaluator:
|
|||||||
blend_end = min(execution_horizon, chunk_len)
|
blend_end = min(execution_horizon, chunk_len)
|
||||||
|
|
||||||
if blend_end > blend_start:
|
if blend_end > blend_start:
|
||||||
rtc_blend = rtc_actions_np[blend_start:blend_end]
|
rtc_blend = rtc_actions_t[blend_start:blend_end]
|
||||||
prev_blend = prev_chunk[blend_start:blend_end]
|
prev_blend = prev_chunk[blend_start:blend_end]
|
||||||
no_rtc_blend = no_rtc_actions_np[blend_start:blend_end]
|
no_rtc_blend = no_rtc_actions_t[blend_start:blend_end]
|
||||||
|
|
||||||
# Check if RTC is between prev_chunk and no_rtc (element-wise)
|
# Check if RTC is between prev_chunk and no_rtc (element-wise)
|
||||||
# For each element, check if it's between the min and max of prev_chunk and no_rtc
|
# For each element, check if it's between the min and max of prev_chunk and no_rtc
|
||||||
min_bound = np.minimum(prev_blend, no_rtc_blend) - atol
|
min_bound = torch.minimum(prev_blend, no_rtc_blend)
|
||||||
max_bound = np.maximum(prev_blend, no_rtc_blend) + atol
|
max_bound = torch.maximum(prev_blend, no_rtc_blend)
|
||||||
|
|
||||||
within_bounds = np.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
||||||
|
|
||||||
if not np.all(within_bounds):
|
if not torch.all(within_bounds):
|
||||||
violations = np.sum(~within_bounds)
|
violations = torch.sum(~within_bounds).item()
|
||||||
total_elements = within_bounds.size
|
total_elements = within_bounds.numel()
|
||||||
violation_pct = 100.0 * violations / total_elements
|
violation_pct = 100.0 * violations / total_elements
|
||||||
|
|
||||||
# Find max violation
|
# Find max violation
|
||||||
lower_violations = np.maximum(0, min_bound - rtc_blend)
|
lower_violations = torch.maximum(torch.tensor(0.0), min_bound - rtc_blend)
|
||||||
upper_violations = np.maximum(0, rtc_blend - max_bound)
|
upper_violations = torch.maximum(torch.tensor(0.0), rtc_blend - max_bound)
|
||||||
max_violation = np.max(np.maximum(lower_violations, upper_violations))
|
max_violation = torch.max(torch.maximum(lower_violations, upper_violations)).item()
|
||||||
|
|
||||||
warnings.append(
|
warnings.append(
|
||||||
f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], "
|
f" ⚠ VALIDATION FAILED: In blend region [{blend_start}:{blend_end}], "
|
||||||
@@ -626,12 +639,15 @@ class RTCEvaluator:
|
|||||||
# Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc
|
# Rule 3: After execution horizon [execution_horizon:], RTC should equal no_rtc
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
if execution_horizon < chunk_len:
|
if execution_horizon < chunk_len:
|
||||||
rtc_after = rtc_actions_np[execution_horizon:chunk_len]
|
rtc_after = rtc_actions_t[execution_horizon:chunk_len]
|
||||||
no_rtc_after = no_rtc_actions_np[execution_horizon:chunk_len]
|
no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len]
|
||||||
|
|
||||||
if not np.allclose(rtc_after, no_rtc_after, rtol=rtol, atol=atol):
|
logging.info(f" rtc_after: {rtc_after}")
|
||||||
max_diff = np.max(np.abs(rtc_after - no_rtc_after))
|
logging.info(f" no_rtc_after: {no_rtc_after}")
|
||||||
mean_diff = np.mean(np.abs(rtc_after - no_rtc_after))
|
|
||||||
|
if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol):
|
||||||
|
max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item()
|
||||||
|
mean_diff = torch.mean(torch.abs(rtc_after - no_rtc_after)).item()
|
||||||
warnings.append(
|
warnings.append(
|
||||||
f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], "
|
f" ⚠ VALIDATION FAILED: After execution horizon [{execution_horizon}:{chunk_len}], "
|
||||||
f"RTC does NOT equal no_rtc!\n"
|
f"RTC does NOT equal no_rtc!\n"
|
||||||
@@ -661,6 +677,103 @@ class RTCEvaluator:
|
|||||||
logging.error("")
|
logging.error("")
|
||||||
logging.error(" Please check the implementation of RTC guidance.")
|
logging.error(" Please check the implementation of RTC guidance.")
|
||||||
|
|
||||||
|
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
|
||||||
|
"""Plot final action predictions comparison on a single chart.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rtc_actions: Final actions from RTC policy
|
||||||
|
no_rtc_actions: Final actions from non-RTC policy
|
||||||
|
prev_chunk_left_over: Previous chunk used as ground truth
|
||||||
|
"""
|
||||||
|
# Remove batch dimension if present
|
||||||
|
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
||||||
|
no_rtc_actions_plot = (
|
||||||
|
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
||||||
|
)
|
||||||
|
prev_chunk_plot = prev_chunk_left_over.cpu()
|
||||||
|
|
||||||
|
# Create figure with 6 subplots (one per action dimension)
|
||||||
|
fig, axes = plt.subplots(6, 1, figsize=(16, 12))
|
||||||
|
fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
|
||||||
|
|
||||||
|
# Plot each action dimension
|
||||||
|
for dim_idx, ax in enumerate(axes):
|
||||||
|
# Plot previous chunk (ground truth) in red
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
[ax],
|
||||||
|
prev_chunk_plot[:, dim_idx : dim_idx + 1],
|
||||||
|
start_from=0,
|
||||||
|
color="red",
|
||||||
|
label="Previous Chunk (Ground Truth)",
|
||||||
|
linewidth=2.5,
|
||||||
|
alpha=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot no-RTC actions in blue
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
[ax],
|
||||||
|
no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
|
||||||
|
start_from=0,
|
||||||
|
color="blue",
|
||||||
|
label="No RTC",
|
||||||
|
linewidth=2,
|
||||||
|
alpha=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot RTC actions in green
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
[ax],
|
||||||
|
rtc_actions_plot[:, dim_idx : dim_idx + 1],
|
||||||
|
start_from=0,
|
||||||
|
color="green",
|
||||||
|
label="RTC",
|
||||||
|
linewidth=2,
|
||||||
|
alpha=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add vertical lines for inference delay and execution horizon
|
||||||
|
inference_delay = self.cfg.inference_delay
|
||||||
|
execution_horizon = self.cfg.rtc.execution_horizon
|
||||||
|
|
||||||
|
if inference_delay > 0:
|
||||||
|
ax.axvline(
|
||||||
|
x=inference_delay - 1,
|
||||||
|
color="orange",
|
||||||
|
linestyle="--",
|
||||||
|
alpha=0.5,
|
||||||
|
label=f"Inference Delay ({inference_delay})",
|
||||||
|
)
|
||||||
|
|
||||||
|
if execution_horizon > 0:
|
||||||
|
ax.axvline(
|
||||||
|
x=execution_horizon,
|
||||||
|
color="purple",
|
||||||
|
linestyle="--",
|
||||||
|
alpha=0.5,
|
||||||
|
label=f"Execution Horizon ({execution_horizon})",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Set x-axis ticks to show all integer values
|
||||||
|
max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
|
||||||
|
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
||||||
|
ax.set_xlim(-0.5, max_len - 0.5)
|
||||||
|
|
||||||
|
# Add legend only to first subplot
|
||||||
|
if dim_idx == 0:
|
||||||
|
ax.legend(loc="best", fontsize=9)
|
||||||
|
|
||||||
|
axes[-1].set_xlabel("Step", fontsize=10)
|
||||||
|
|
||||||
|
# Save figure
|
||||||
|
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(output_path, dpi=150)
|
||||||
|
logging.info(f"Saved final actions comparison to {output_path}")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
||||||
# Create side-by-side figures for denoising visualization
|
# Create side-by-side figures for denoising visualization
|
||||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||||
@@ -828,6 +941,13 @@ class RTCEvaluator:
|
|||||||
margin = y_range * 0.1
|
margin = y_range * 0.1
|
||||||
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
|
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
|
||||||
|
|
||||||
|
# Set x-axis ticks to show all integer values
|
||||||
|
xlim = ax.get_xlim()
|
||||||
|
max_len = int(xlim[1]) + 1
|
||||||
|
if max_len > 0:
|
||||||
|
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
||||||
|
ax.set_xlim(-0.5, max_len - 0.5)
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def main(cfg: RTCEvalConfig):
|
def main(cfg: RTCEvalConfig):
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ RTC can be integrated with any policy that supports flow mathicng for chunking:
|
|||||||
|
|
||||||
- **SmolVLA**: Vision-language-action model with RTC support
|
- **SmolVLA**: Vision-language-action model with RTC support
|
||||||
- **Pi0**: Action prediction model with adaptive chunking
|
- **Pi0**: Action prediction model with adaptive chunking
|
||||||
|
- **Pi05**: Action prediction model with adaptive chunking
|
||||||
|
|
||||||
## Original Implementation
|
## Original Implementation
|
||||||
|
|
||||||
@@ -39,3 +40,10 @@ uv run python examples/rtc/eval_dataset.py \
|
|||||||
--device=mps \
|
--device=mps \
|
||||||
--seed=42
|
--seed=42
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This script will evaluate RTC on a data from a dataset and save the results to a file, u can check the results in the `rtc_debug_output` directory.
|
||||||
|
|
||||||
|
The example output should look like this:
|
||||||
|

|
||||||
|
|
||||||
|
It shows how flow matching works with RTC and without it. The chart shows values of action predictions for each timestep. The colour shows the the generation progress. The blue ones - earlier timesteps, the yellow ones - later timesteps. The red line is the ground truth (previous action chunk).
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 538 KiB |
@@ -168,6 +168,8 @@ class RTCProcessor:
|
|||||||
v_t = original_denoise_step_partial(x_t)
|
v_t = original_denoise_step_partial(x_t)
|
||||||
return v_t
|
return v_t
|
||||||
|
|
||||||
|
x_t = x_t.clone().detach()
|
||||||
|
|
||||||
squeezed = False
|
squeezed = False
|
||||||
if len(x_t.shape) < 3:
|
if len(x_t.shape) < 3:
|
||||||
# Add batch dimension
|
# Add batch dimension
|
||||||
@@ -208,7 +210,6 @@ class RTCProcessor:
|
|||||||
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
v_t = original_denoise_step_partial(x_t)
|
v_t = original_denoise_step_partial(x_t)
|
||||||
x_t = x_t.clone().detach()
|
|
||||||
x_t.requires_grad_(True)
|
x_t.requires_grad_(True)
|
||||||
|
|
||||||
x1_t = x_t - time * v_t # noqa: N806
|
x1_t = x_t - time * v_t # noqa: N806
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
|
|
||||||
"""Tests for RTC configuration module."""
|
"""Tests for RTC configuration module."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.configs.types import RTCAttentionSchedule
|
from lerobot.configs.types import RTCAttentionSchedule
|
||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
|
|
||||||
@@ -65,259 +63,3 @@ def test_rtc_config_partial_initialization():
|
|||||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||||
assert config.execution_horizon == 10
|
assert config.execution_horizon == 10
|
||||||
assert config.debug is False
|
assert config.debug is False
|
||||||
|
|
||||||
|
|
||||||
# ====================== Validation Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_validates_positive_max_guidance_weight():
|
|
||||||
"""Test RTCConfig validates max_guidance_weight is positive."""
|
|
||||||
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
|
|
||||||
RTCConfig(max_guidance_weight=0.0)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="max_guidance_weight must be positive"):
|
|
||||||
RTCConfig(max_guidance_weight=-1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_validates_positive_debug_maxlen():
|
|
||||||
"""Test RTCConfig validates debug_maxlen is positive."""
|
|
||||||
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
|
|
||||||
RTCConfig(debug_maxlen=0)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="debug_maxlen must be positive"):
|
|
||||||
RTCConfig(debug_maxlen=-10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_accepts_valid_max_guidance_weight():
|
|
||||||
"""Test RTCConfig accepts valid positive max_guidance_weight."""
|
|
||||||
config1 = RTCConfig(max_guidance_weight=0.1)
|
|
||||||
assert config1.max_guidance_weight == 0.1
|
|
||||||
|
|
||||||
config2 = RTCConfig(max_guidance_weight=100.0)
|
|
||||||
assert config2.max_guidance_weight == 100.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_accepts_valid_debug_maxlen():
|
|
||||||
"""Test RTCConfig accepts valid positive debug_maxlen."""
|
|
||||||
config1 = RTCConfig(debug_maxlen=1)
|
|
||||||
assert config1.debug_maxlen == 1
|
|
||||||
|
|
||||||
config2 = RTCConfig(debug_maxlen=10000)
|
|
||||||
assert config2.debug_maxlen == 10000
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Attention Schedule Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_linear_schedule():
|
|
||||||
"""Test RTCConfig with LINEAR attention schedule."""
|
|
||||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
|
||||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_exp_schedule():
|
|
||||||
"""Test RTCConfig with EXP attention schedule."""
|
|
||||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
|
|
||||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_zeros_schedule():
|
|
||||||
"""Test RTCConfig with ZEROS attention schedule."""
|
|
||||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
|
|
||||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_ones_schedule():
|
|
||||||
"""Test RTCConfig with ONES attention schedule."""
|
|
||||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
|
|
||||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.ONES
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Enabled/Disabled Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_enabled_true():
|
|
||||||
"""Test RTCConfig with enabled=True."""
|
|
||||||
config = RTCConfig(enabled=True)
|
|
||||||
assert config.enabled is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_enabled_false():
|
|
||||||
"""Test RTCConfig with enabled=False."""
|
|
||||||
config = RTCConfig(enabled=False)
|
|
||||||
assert config.enabled is False
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Debug Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_debug_enabled():
|
|
||||||
"""Test RTCConfig with debug enabled."""
|
|
||||||
config = RTCConfig(debug=True, debug_maxlen=500)
|
|
||||||
assert config.debug is True
|
|
||||||
assert config.debug_maxlen == 500
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_debug_disabled():
|
|
||||||
"""Test RTCConfig with debug disabled."""
|
|
||||||
config = RTCConfig(debug=False)
|
|
||||||
assert config.debug is False
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Execution Horizon Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_small_execution_horizon():
|
|
||||||
"""Test RTCConfig with small execution horizon."""
|
|
||||||
config = RTCConfig(execution_horizon=1)
|
|
||||||
assert config.execution_horizon == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_large_execution_horizon():
|
|
||||||
"""Test RTCConfig with large execution horizon."""
|
|
||||||
config = RTCConfig(execution_horizon=100)
|
|
||||||
assert config.execution_horizon == 100
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_zero_execution_horizon():
|
|
||||||
"""Test RTCConfig accepts zero execution horizon."""
|
|
||||||
# No validation on execution_horizon, so this should work
|
|
||||||
config = RTCConfig(execution_horizon=0)
|
|
||||||
assert config.execution_horizon == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_with_negative_execution_horizon():
|
|
||||||
"""Test RTCConfig accepts negative execution horizon."""
|
|
||||||
# No validation on execution_horizon, so this should work
|
|
||||||
config = RTCConfig(execution_horizon=-1)
|
|
||||||
assert config.execution_horizon == -1
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Integration Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_typical_production_settings():
|
|
||||||
"""Test RTCConfig with typical production settings."""
|
|
||||||
config = RTCConfig(
|
|
||||||
enabled=True,
|
|
||||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
||||||
max_guidance_weight=10.0,
|
|
||||||
execution_horizon=8,
|
|
||||||
debug=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert config.enabled is True
|
|
||||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
|
||||||
assert config.max_guidance_weight == 10.0
|
|
||||||
assert config.execution_horizon == 8
|
|
||||||
assert config.debug is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_typical_debug_settings():
|
|
||||||
"""Test RTCConfig with typical debug settings."""
|
|
||||||
config = RTCConfig(
|
|
||||||
enabled=True,
|
|
||||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
|
||||||
max_guidance_weight=5.0,
|
|
||||||
execution_horizon=10,
|
|
||||||
debug=True,
|
|
||||||
debug_maxlen=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert config.enabled is True
|
|
||||||
assert config.debug is True
|
|
||||||
assert config.debug_maxlen == 1000
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_disabled_mode():
|
|
||||||
"""Test RTCConfig in disabled mode."""
|
|
||||||
config = RTCConfig(enabled=False)
|
|
||||||
|
|
||||||
assert config.enabled is False
|
|
||||||
# Other settings still accessible even when disabled
|
|
||||||
assert config.max_guidance_weight == 10.0
|
|
||||||
assert config.execution_horizon == 10
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Dataclass Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_is_dataclass():
|
|
||||||
"""Test that RTCConfig is a dataclass."""
|
|
||||||
from dataclasses import is_dataclass
|
|
||||||
|
|
||||||
assert is_dataclass(RTCConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_equality():
|
|
||||||
"""Test RTCConfig equality comparison."""
|
|
||||||
config1 = RTCConfig(enabled=True, max_guidance_weight=5.0)
|
|
||||||
config2 = RTCConfig(enabled=True, max_guidance_weight=5.0)
|
|
||||||
config3 = RTCConfig(enabled=False, max_guidance_weight=5.0)
|
|
||||||
|
|
||||||
assert config1 == config2
|
|
||||||
assert config1 != config3
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_repr():
|
|
||||||
"""Test RTCConfig string representation."""
|
|
||||||
config = RTCConfig(enabled=True, execution_horizon=20)
|
|
||||||
repr_str = repr(config)
|
|
||||||
|
|
||||||
assert "RTCConfig" in repr_str
|
|
||||||
assert "enabled=True" in repr_str
|
|
||||||
assert "execution_horizon=20" in repr_str
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Edge Cases Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_very_small_max_guidance_weight():
|
|
||||||
"""Test RTCConfig with very small positive max_guidance_weight."""
|
|
||||||
config = RTCConfig(max_guidance_weight=1e-10)
|
|
||||||
assert config.max_guidance_weight == pytest.approx(1e-10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_very_large_max_guidance_weight():
|
|
||||||
"""Test RTCConfig with very large max_guidance_weight."""
|
|
||||||
config = RTCConfig(max_guidance_weight=1e10)
|
|
||||||
assert config.max_guidance_weight == pytest.approx(1e10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_minimum_debug_maxlen():
|
|
||||||
"""Test RTCConfig with minimum valid debug_maxlen."""
|
|
||||||
config = RTCConfig(debug_maxlen=1)
|
|
||||||
assert config.debug_maxlen == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_float_max_guidance_weight():
|
|
||||||
"""Test RTCConfig with float max_guidance_weight."""
|
|
||||||
config = RTCConfig(max_guidance_weight=3.14159)
|
|
||||||
assert config.max_guidance_weight == pytest.approx(3.14159)
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Type Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_enabled_type():
|
|
||||||
"""Test RTCConfig enabled field accepts boolean."""
|
|
||||||
config = RTCConfig(enabled=True)
|
|
||||||
assert isinstance(config.enabled, bool)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_execution_horizon_type():
|
|
||||||
"""Test RTCConfig execution_horizon field accepts integer."""
|
|
||||||
config = RTCConfig(execution_horizon=15)
|
|
||||||
assert isinstance(config.execution_horizon, int)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_max_guidance_weight_type():
|
|
||||||
"""Test RTCConfig max_guidance_weight field accepts float."""
|
|
||||||
config = RTCConfig(max_guidance_weight=7.5)
|
|
||||||
assert isinstance(config.max_guidance_weight, float)
|
|
||||||
|
|
||||||
|
|
||||||
def test_rtc_config_debug_maxlen_type():
|
|
||||||
"""Test RTCConfig debug_maxlen field accepts integer."""
|
|
||||||
config = RTCConfig(debug_maxlen=200)
|
|
||||||
assert isinstance(config.debug_maxlen, int)
|
|
||||||
|
|||||||
@@ -1,427 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Tests for RTC debug visualizer module."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
|
||||||
|
|
||||||
# ====================== Fixtures ======================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_axes():
|
|
||||||
"""Create mock matplotlib axes."""
|
|
||||||
axes = []
|
|
||||||
for _ in range(6):
|
|
||||||
ax = MagicMock()
|
|
||||||
ax.xaxis.get_label.return_value.get_text.return_value = ""
|
|
||||||
ax.yaxis.get_label.return_value.get_text.return_value = ""
|
|
||||||
axes.append(ax)
|
|
||||||
return axes
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_tensor_2d():
|
|
||||||
"""Create a 2D sample tensor (time_steps, num_dims)."""
|
|
||||||
return torch.randn(50, 6)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_tensor_3d():
|
|
||||||
"""Create a 3D sample tensor (batch, time_steps, num_dims)."""
|
|
||||||
return torch.randn(1, 50, 6)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_numpy_2d():
|
|
||||||
"""Create a 2D numpy array."""
|
|
||||||
return np.random.randn(50, 6)
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Basic Plotting Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_2d_tensor(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints with 2D tensor."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
|
||||||
|
|
||||||
# Should call plot on each axis (6 dimensions)
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.plot.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_3d_tensor(mock_axes, sample_tensor_3d):
|
|
||||||
"""Test plot_waypoints with 3D tensor (batch dimension)."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_3d)
|
|
||||||
|
|
||||||
# Should still plot 6 dimensions (batch dimension removed)
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.plot.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_numpy_array(mock_axes, sample_numpy_2d):
|
|
||||||
"""Test plot_waypoints with numpy array."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_numpy_2d)
|
|
||||||
|
|
||||||
# Should work with numpy arrays
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.plot.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_none_tensor(mock_axes):
|
|
||||||
"""Test plot_waypoints returns early when tensor is None."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, None)
|
|
||||||
|
|
||||||
# Should not call plot on any axis
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.plot.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Parameter Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_custom_color(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints uses custom color."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, color="red")
|
|
||||||
|
|
||||||
# Check that color was passed to plot
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_kwargs = ax.plot.call_args[1]
|
|
||||||
assert call_kwargs["color"] == "red"
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_custom_label(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints uses custom label."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
|
|
||||||
|
|
||||||
# First axis should have label, others should not
|
|
||||||
first_ax_kwargs = mock_axes[0].plot.call_args[1]
|
|
||||||
assert first_ax_kwargs["label"] == "test_label"
|
|
||||||
|
|
||||||
# Other axes should have empty label
|
|
||||||
for ax in mock_axes[1:]:
|
|
||||||
call_kwargs = ax.plot.call_args[1]
|
|
||||||
assert call_kwargs["label"] == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_custom_alpha(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints uses custom alpha."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, alpha=0.5)
|
|
||||||
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_kwargs = ax.plot.call_args[1]
|
|
||||||
assert call_kwargs["alpha"] == 0.5
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_custom_linewidth(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints uses custom linewidth."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, linewidth=3)
|
|
||||||
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_kwargs = ax.plot.call_args[1]
|
|
||||||
assert call_kwargs["linewidth"] == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_marker(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints with marker style."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker="o", markersize=5)
|
|
||||||
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_kwargs = ax.plot.call_args[1]
|
|
||||||
assert call_kwargs["marker"] == "o"
|
|
||||||
assert call_kwargs["markersize"] == 5
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_without_marker(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints without marker (default)."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, marker=None)
|
|
||||||
|
|
||||||
# Marker should not be in kwargs when None
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_kwargs = ax.plot.call_args[1]
|
|
||||||
assert "marker" not in call_kwargs
|
|
||||||
assert "markersize" not in call_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== start_from Parameter Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_start_from_zero(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints with start_from=0."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
|
|
||||||
|
|
||||||
# X indices should start from 0
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_args = ax.plot.call_args[0]
|
|
||||||
x_indices = call_args[0]
|
|
||||||
assert x_indices[0] == 0
|
|
||||||
assert len(x_indices) == 50
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_start_from_nonzero(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints with start_from > 0."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=10)
|
|
||||||
|
|
||||||
# X indices should start from 10
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_args = ax.plot.call_args[0]
|
|
||||||
x_indices = call_args[0]
|
|
||||||
assert x_indices[0] == 10
|
|
||||||
assert x_indices[-1] == 59 # 10 + 50 - 1
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Tensor Shape Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_1d_tensor(mock_axes):
|
|
||||||
"""Test plot_waypoints with 1D tensor."""
|
|
||||||
tensor_1d = torch.randn(50)
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_1d)
|
|
||||||
|
|
||||||
# Should reshape to (50, 1) and plot on first axis only
|
|
||||||
mock_axes[0].plot.assert_called_once()
|
|
||||||
for ax in mock_axes[1:]:
|
|
||||||
ax.plot.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_fewer_dims_than_axes(sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints when tensor has fewer dims than axes."""
|
|
||||||
# Create tensor with only 3 dimensions
|
|
||||||
tensor_3d = sample_tensor_2d[:, :3]
|
|
||||||
|
|
||||||
# Create 6 axes but tensor only has 3 dims
|
|
||||||
mock_axes = [MagicMock() for _ in range(6)]
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.xaxis.get_label.return_value.get_text.return_value = ""
|
|
||||||
ax.yaxis.get_label.return_value.get_text.return_value = ""
|
|
||||||
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_3d)
|
|
||||||
|
|
||||||
# Should only plot on first 3 axes
|
|
||||||
for i in range(3):
|
|
||||||
mock_axes[i].plot.assert_called_once()
|
|
||||||
for i in range(3, 6):
|
|
||||||
mock_axes[i].plot.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Axis Labeling Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_sets_xlabel(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints sets x-axis label."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
|
||||||
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.set_xlabel.assert_called_once_with("Step", fontsize=10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_sets_ylabel(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints sets y-axis label."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
|
||||||
|
|
||||||
for i, ax in enumerate(mock_axes):
|
|
||||||
ax.set_ylabel.assert_called_once_with(f"Dim {i}", fontsize=10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_skips_label_if_exists(sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints doesn't set labels if they already exist."""
|
|
||||||
mock_axes_with_labels = []
|
|
||||||
for _ in range(6):
|
|
||||||
ax = MagicMock()
|
|
||||||
# Simulate existing labels
|
|
||||||
ax.xaxis.get_label.return_value.get_text.return_value = "Existing X Label"
|
|
||||||
ax.yaxis.get_label.return_value.get_text.return_value = "Existing Y Label"
|
|
||||||
mock_axes_with_labels.append(ax)
|
|
||||||
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes_with_labels, sample_tensor_2d)
|
|
||||||
|
|
||||||
# Should not set labels when they already exist
|
|
||||||
for ax in mock_axes_with_labels:
|
|
||||||
ax.set_xlabel.assert_not_called()
|
|
||||||
ax.set_ylabel.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Grid Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_enables_grid(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints enables grid on all axes."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d)
|
|
||||||
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.grid.assert_called_once_with(True, alpha=0.3)
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Legend Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_adds_legend_with_label(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints adds legend when label is provided."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="test_label")
|
|
||||||
|
|
||||||
# Should add legend to first axis only
|
|
||||||
mock_axes[0].legend.assert_called_once_with(loc="best", fontsize=8)
|
|
||||||
|
|
||||||
# Should not add legend to other axes
|
|
||||||
for ax in mock_axes[1:]:
|
|
||||||
ax.legend.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_no_legend_without_label(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints doesn't add legend when no label provided."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, label="")
|
|
||||||
|
|
||||||
# Should not add legend to any axis
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.legend.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Data Correctness Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_plots_correct_data(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints plots correct tensor values."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, sample_tensor_2d, start_from=0)
|
|
||||||
|
|
||||||
# Check first axis to verify data correctness
|
|
||||||
call_args = mock_axes[0].plot.call_args[0]
|
|
||||||
x_indices = call_args[0]
|
|
||||||
y_values = call_args[1]
|
|
||||||
|
|
||||||
# X indices should be 0 to 49
|
|
||||||
np.testing.assert_array_equal(x_indices, np.arange(50))
|
|
||||||
|
|
||||||
# Y values should match first dimension of tensor
|
|
||||||
expected_y = sample_tensor_2d[:, 0].cpu().numpy()
|
|
||||||
np.testing.assert_array_almost_equal(y_values, expected_y)
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_handles_gpu_tensor(mock_axes):
|
|
||||||
"""Test plot_waypoints handles GPU tensors (if CUDA available)."""
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
pytest.skip("CUDA not available")
|
|
||||||
|
|
||||||
tensor_gpu = torch.randn(50, 6, device="cuda")
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor_gpu)
|
|
||||||
|
|
||||||
# Should successfully plot without errors
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.plot.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Edge Cases Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_empty_tensor(mock_axes):
|
|
||||||
"""Test plot_waypoints with empty tensor."""
|
|
||||||
empty_tensor = torch.empty(0, 6)
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, empty_tensor)
|
|
||||||
|
|
||||||
# Should plot empty data
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_args = ax.plot.call_args[0]
|
|
||||||
x_indices = call_args[0]
|
|
||||||
assert len(x_indices) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_single_timestep(mock_axes):
|
|
||||||
"""Test plot_waypoints with single timestep."""
|
|
||||||
single_step_tensor = torch.randn(1, 6)
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, single_step_tensor)
|
|
||||||
|
|
||||||
# Should plot single point
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_args = ax.plot.call_args[0]
|
|
||||||
x_indices = call_args[0]
|
|
||||||
assert len(x_indices) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_very_large_tensor(mock_axes):
|
|
||||||
"""Test plot_waypoints with very large tensor."""
|
|
||||||
large_tensor = torch.randn(10000, 6)
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, large_tensor)
|
|
||||||
|
|
||||||
# Should handle large tensors
|
|
||||||
for ax in mock_axes:
|
|
||||||
call_args = ax.plot.call_args[0]
|
|
||||||
x_indices = call_args[0]
|
|
||||||
assert len(x_indices) == 10000
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Multiple Calls Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_multiple_calls_on_same_axes(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test multiple plot_waypoints calls on same axes."""
|
|
||||||
tensor1 = sample_tensor_2d
|
|
||||||
tensor2 = torch.randn(50, 6)
|
|
||||||
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor1, color="blue", label="Series 1")
|
|
||||||
RTCDebugVisualizer.plot_waypoints(mock_axes, tensor2, color="red", label="Series 2")
|
|
||||||
|
|
||||||
# Each axis should have been called twice
|
|
||||||
for ax in mock_axes:
|
|
||||||
assert ax.plot.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Integration Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_typical_usage(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints with typical usage pattern."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
|
||||||
mock_axes, sample_tensor_2d, start_from=0, color="blue", label="Trajectory", alpha=0.7, linewidth=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify all expected calls were made
|
|
||||||
for ax in mock_axes:
|
|
||||||
ax.plot.assert_called_once()
|
|
||||||
ax.set_xlabel.assert_called_once()
|
|
||||||
ax.set_ylabel.assert_called_once()
|
|
||||||
ax.grid.assert_called_once()
|
|
||||||
|
|
||||||
# First axis should have legend
|
|
||||||
mock_axes[0].legend.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
def test_plot_waypoints_with_all_parameters(mock_axes, sample_tensor_2d):
|
|
||||||
"""Test plot_waypoints with all parameters specified."""
|
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
|
||||||
axes=mock_axes,
|
|
||||||
tensor=sample_tensor_2d,
|
|
||||||
start_from=10,
|
|
||||||
color="green",
|
|
||||||
label="Full Test",
|
|
||||||
alpha=0.8,
|
|
||||||
linewidth=3,
|
|
||||||
marker="o",
|
|
||||||
markersize=6,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check first axis for all parameters
|
|
||||||
call_kwargs = mock_axes[0].plot.call_args[1]
|
|
||||||
assert call_kwargs["color"] == "green"
|
|
||||||
assert call_kwargs["label"] == "Full Test"
|
|
||||||
assert call_kwargs["alpha"] == 0.8
|
|
||||||
assert call_kwargs["linewidth"] == 3
|
|
||||||
assert call_kwargs["marker"] == "o"
|
|
||||||
assert call_kwargs["markersize"] == 6
|
|
||||||
@@ -184,74 +184,6 @@ def test_max_after_reset(tracker):
|
|||||||
assert tracker.max() == 0.0
|
assert tracker.max() == 0.0
|
||||||
|
|
||||||
|
|
||||||
# ====================== percentile() Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_percentile_returns_zero_when_empty(tracker):
|
|
||||||
"""Test percentile() returns 0.0 when tracker is empty."""
|
|
||||||
assert tracker.percentile(0.5) == 0.0
|
|
||||||
assert tracker.percentile(0.95) == 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_percentile_median(tracker):
|
|
||||||
"""Test percentile(0.5) returns median."""
|
|
||||||
# Add sorted values for easier verification
|
|
||||||
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
|
||||||
for v in values:
|
|
||||||
tracker.add(v)
|
|
||||||
|
|
||||||
# Median should be around 0.5
|
|
||||||
median = tracker.percentile(0.5)
|
|
||||||
assert 0.45 <= median <= 0.55
|
|
||||||
|
|
||||||
|
|
||||||
def test_percentile_minimum_with_zero(tracker):
|
|
||||||
"""Test percentile(0.0) returns minimum."""
|
|
||||||
tracker.add(0.5)
|
|
||||||
tracker.add(0.2)
|
|
||||||
tracker.add(0.8)
|
|
||||||
|
|
||||||
assert tracker.percentile(0.0) == 0.2
|
|
||||||
|
|
||||||
|
|
||||||
def test_percentile_maximum_with_one(tracker):
|
|
||||||
"""Test percentile(1.0) returns maximum."""
|
|
||||||
tracker.add(0.5)
|
|
||||||
tracker.add(0.2)
|
|
||||||
tracker.add(0.8)
|
|
||||||
|
|
||||||
assert tracker.percentile(1.0) == 0.8
|
|
||||||
|
|
||||||
|
|
||||||
def test_percentile_95(tracker):
|
|
||||||
"""Test percentile(0.95) returns 95th percentile."""
|
|
||||||
# Add 100 values from 0.0 to 0.99
|
|
||||||
for i in range(100):
|
|
||||||
tracker.add(i / 100.0)
|
|
||||||
|
|
||||||
p95 = tracker.percentile(0.95)
|
|
||||||
# 95th percentile should be around 0.95
|
|
||||||
assert 0.93 <= p95 <= 0.96
|
|
||||||
|
|
||||||
|
|
||||||
def test_percentile_negative_value_returns_min(tracker):
|
|
||||||
"""Test percentile with negative q returns minimum."""
|
|
||||||
tracker.add(0.5)
|
|
||||||
tracker.add(0.2)
|
|
||||||
tracker.add(0.8)
|
|
||||||
|
|
||||||
assert tracker.percentile(-0.5) == 0.2
|
|
||||||
|
|
||||||
|
|
||||||
def test_percentile_value_greater_than_one_returns_max(tracker):
|
|
||||||
"""Test percentile with q > 1.0 returns maximum."""
|
|
||||||
tracker.add(0.5)
|
|
||||||
tracker.add(0.2)
|
|
||||||
tracker.add(0.8)
|
|
||||||
|
|
||||||
assert tracker.percentile(1.5) == 0.8
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== p95() Tests ======================
|
# ====================== p95() Tests ======================
|
||||||
|
|
||||||
|
|
||||||
@@ -278,79 +210,6 @@ def test_p95_equals_percentile_95(tracker):
|
|||||||
assert tracker.p95() == tracker.percentile(0.95)
|
assert tracker.p95() == tracker.percentile(0.95)
|
||||||
|
|
||||||
|
|
||||||
# ====================== __len__() Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_len_returns_zero_initially(tracker):
|
|
||||||
"""Test __len__ returns 0 for new tracker."""
|
|
||||||
assert len(tracker) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_len_increments_with_add(tracker):
|
|
||||||
"""Test __len__ increments as values are added."""
|
|
||||||
assert len(tracker) == 0
|
|
||||||
|
|
||||||
tracker.add(0.1)
|
|
||||||
assert len(tracker) == 1
|
|
||||||
|
|
||||||
tracker.add(0.2)
|
|
||||||
assert len(tracker) == 2
|
|
||||||
|
|
||||||
tracker.add(0.3)
|
|
||||||
assert len(tracker) == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_len_respects_maxlen(small_tracker):
|
|
||||||
"""Test __len__ respects maxlen limit."""
|
|
||||||
# Add more than maxlen values
|
|
||||||
for i in range(10):
|
|
||||||
small_tracker.add(i / 10.0)
|
|
||||||
|
|
||||||
# Should only keep last 5
|
|
||||||
assert len(small_tracker) == 5
|
|
||||||
|
|
||||||
|
|
||||||
def test_len_after_reset(tracker):
|
|
||||||
"""Test __len__ returns 0 after reset."""
|
|
||||||
tracker.add(0.5)
|
|
||||||
tracker.add(0.3)
|
|
||||||
assert len(tracker) == 2
|
|
||||||
|
|
||||||
tracker.reset()
|
|
||||||
assert len(tracker) == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Sliding Window Tests ======================
|
|
||||||
|
|
||||||
|
|
||||||
def test_sliding_window_removes_oldest(small_tracker):
|
|
||||||
"""Test sliding window removes oldest values."""
|
|
||||||
# Add 7 values to tracker with maxlen=5
|
|
||||||
values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
|
||||||
for v in values:
|
|
||||||
small_tracker.add(v)
|
|
||||||
|
|
||||||
# Should only have last 5: [0.3, 0.4, 0.5, 0.6, 0.7]
|
|
||||||
assert len(small_tracker) == 5
|
|
||||||
|
|
||||||
# Median should reflect last 5 values
|
|
||||||
median = small_tracker.percentile(0.5)
|
|
||||||
assert 0.45 <= median <= 0.55
|
|
||||||
|
|
||||||
|
|
||||||
def test_sliding_window_maintains_max(small_tracker):
|
|
||||||
"""Test sliding window maintains correct max even after overflow."""
|
|
||||||
small_tracker.add(0.1)
|
|
||||||
small_tracker.add(0.9)
|
|
||||||
small_tracker.add(0.2)
|
|
||||||
small_tracker.add(0.3)
|
|
||||||
small_tracker.add(0.4)
|
|
||||||
small_tracker.add(0.5) # Pushes out 0.1
|
|
||||||
|
|
||||||
# Max should still be 0.9
|
|
||||||
assert small_tracker.max() == 0.9
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Edge Cases Tests ======================
|
# ====================== Edge Cases Tests ======================
|
||||||
|
|
||||||
|
|
||||||
@@ -436,24 +295,6 @@ def test_reset_and_reuse(tracker):
|
|||||||
assert tracker.percentile(0.5) <= 0.8
|
assert tracker.percentile(0.5) <= 0.8
|
||||||
|
|
||||||
|
|
||||||
def test_continuous_monitoring(small_tracker):
|
|
||||||
"""Test continuous monitoring with sliding window."""
|
|
||||||
# Simulate continuous latency monitoring
|
|
||||||
# First 5 latencies
|
|
||||||
for i in range(5):
|
|
||||||
small_tracker.add(0.1 * (i + 1))
|
|
||||||
|
|
||||||
max_before = small_tracker.max()
|
|
||||||
|
|
||||||
# Add 5 more (window slides)
|
|
||||||
for i in range(5, 10):
|
|
||||||
small_tracker.add(0.1 * (i + 1))
|
|
||||||
|
|
||||||
# Max should have increased
|
|
||||||
assert small_tracker.max() > max_before
|
|
||||||
assert len(small_tracker) == 5 # Window size maintained
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Type Conversion Tests ======================
|
# ====================== Type Conversion Tests ======================
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user